#!/usr/bin/python
# Copyright 2011-2014 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA
#
# Refer to the README and COPYING files for full details of the license
#
from pwd import getpwnam
import platform
import sys
import os
import stat
import errno
from functools import wraps
import threading
import re
import getopt
import resource
import signal
import logging
import logging.config

from contextlib import closing
from vdsm.infra import sigutils
from vdsm.infra import zombiereaper

from caps import Architecture
import numaUtils

LOG_CONF_PATH = "/etc/vdsm/svdsm.logger.conf"

try:
    logging.config.fileConfig(LOG_CONF_PATH, disable_existing_loggers=False)
except:
    logging.basicConfig(filename='/dev/stdout', filemode='w+',
                        level=logging.DEBUG)
    log = logging.getLogger("SuperVdsm.Server")
    log.warn("Could not init proper logging", exc_info=True)

from storage import fuser
from multiprocessing import Pipe, Process
try:
    from gluster import listPublicFunctions
    _glusterEnabled = True
except ImportError:
    _glusterEnabled = False

from vdsm import udevadm
from vdsm import utils
from vdsm import sysctl
from vdsm.tool import restore_nets
from parted_utils import getDevicePartedInfo as _getDevicePartedInfo

from network import sourceroutethread
from network.api import editNetwork, setupNetworks, setSafeNetworkConfig,\
    change_numvfs

from network.tc import setPortMirroring, unsetPortMirroring
from storage.multipath import getScsiSerial as _getScsiSerial
from storage.iscsi import getDevIscsiInfo as _getdeviSCSIinfo
from storage.iscsi import readSessionInfo as _readSessionInfo
from supervdsm import _SuperVdsmManager
from storage import hba
from storage import multipath
from storage.fileUtils import chown, resolveGid, resolveUid
from storage.fileUtils import validateAccess as _validateAccess
from vdsm.constants import METADATA_GROUP, EXT_CHOWN, \
    DISKIMAGE_USER, DISKIMAGE_GROUP, \
    P_LIBVIRT_VMCHANNELS, P_OVIRT_VMCONSOLES, \
    VDSM_USER, QEMU_PROCESS_USER, QEMU_PROCESS_GROUP, RHEV_ENABLED
from storage.devicemapper import _removeMapping, _getPathsStatus
from vdsm.config import config
import mkimage

_UDEV_RULE_FILE_DIR = "/etc/udev/rules.d/"
_UDEV_RULE_FILE_PREFIX = "99-vdsm-"
_UDEV_RULE_FILE_EXT = ".rules"
_UDEV_RULE_FILE_NAME = os.path.join(
    _UDEV_RULE_FILE_DIR, _UDEV_RULE_FILE_PREFIX + '%s-%s' +
    _UDEV_RULE_FILE_EXT)
_UDEV_RULE_FILE_NAME_VFIO = os.path.join(
    _UDEV_RULE_FILE_DIR, _UDEV_RULE_FILE_PREFIX + "iommu_group_%s" +
    _UDEV_RULE_FILE_EXT)
_UDEV_RULE_FILE_NAME_USB = os.path.join(
    _UDEV_RULE_FILE_DIR, _UDEV_RULE_FILE_PREFIX + "usb_%s_%s" +
    _UDEV_RULE_FILE_EXT)
_USB_DEVICE_PATH = '/dev/bus/usb/%03d/%03d'

RUN_AS_TIMEOUT = config.getint("irs", "process_pool_timeout")

_running = True


class Timeout(RuntimeError):
    pass


def logDecorator(func):
    callbackLogger = logging.getLogger("SuperVdsm.ServerCallback")

    @wraps(func)
    def wrapper(*args, **kwargs):
        callbackLogger.debug('call %s with %s %s',
                             func.__name__, args[1:], kwargs)
        try:
            res = func(*args, **kwargs)
        except:
            callbackLogger.error("Error in %s", func.__name__, exc_info=True)
            raise
        callbackLogger.debug('return %s with %s',
                             func.__name__, res)
        return res
    return wrapper


def safe_poll(mp_connection, timeout):
    """
    This is a workaround until we get the PEP-475 fix for EINTR.  It
    ensures that a multiprocessing.connection.poll() will not return
    before the timeout due to an interruption.

    Returns True if there is any data to read from the pipe or if the
    pipe was closed.  Returns False if the timeout expired.
    """
    deadline = utils.monotonic_time() + timeout
    remaining = timeout

    while not mp_connection.poll(remaining):
        remaining = deadline - utils.monotonic_time()
        if remaining <= 0:
            return False

    return True


class _SuperVdsm(object):

    log = logging.getLogger("SuperVdsm.ServerCallback")

    @logDecorator
    def ping(self, *args, **kwargs):
        # This method exists for testing purposes
        return True

    @logDecorator
    def getHardwareInfo(self, *args, **kwargs):
        if platform.machine() in ('x86_64', 'i686'):
            from dmidecodeUtil import getHardwareInfoStructure
            return getHardwareInfoStructure()
        elif platform.machine() in Architecture.POWER:
            from ppc64HardwareInfo import getHardwareInfoStructure
            return getHardwareInfoStructure()
        else:
            #  not implemented over other architecture
            return {}

    @logDecorator
    def getDevicePartedInfo(self, *args, **kwargs):
        return _getDevicePartedInfo(*args, **kwargs)

    @logDecorator
    def getScsiSerial(self, *args, **kwargs):
        return _getScsiSerial(*args, **kwargs)

    @logDecorator
    def resizeMap(self, devName):
        return multipath._resize_map(devName)

    @logDecorator
    def removeDeviceMapping(self, devName):
        return _removeMapping(devName)

    @logDecorator
    def getdeviSCSIinfo(self, *args, **kwargs):
        return _getdeviSCSIinfo(*args, **kwargs)

    @logDecorator
    def readSessionInfo(self, sessionID):
        return _readSessionInfo(sessionID)

    @logDecorator
    def getPathsStatus(self):
        return _getPathsStatus()

    @logDecorator
    def getVmPid(self, vmName):
        pidFile = "/var/run/libvirt/qemu/%s.pid" % vmName
        with open(pidFile) as pid:
            return pid.read()

    @logDecorator
    def getVcpuNumaMemoryMapping(self, vmName):
        vmPid = self.getVmPid(vmName)
        vCpuPids = numaUtils.getVcpuPid(vmName)
        vCpuIdxToNode = {}
        for vCpuIndex, vCpuPid in vCpuPids.iteritems():
            numaMapsFile = "/proc/%s/task/%s/numa_maps" % (vmPid, vCpuPid)
            try:
                with open(numaMapsFile, 'r') as f:
                    mappingNodes = map(
                        int, re.findall('N(\d+)=\d+', f.read()))
                    vCpuIdxToNode[vCpuIndex] = list(set(mappingNodes))
            except IOError:
                continue
        return vCpuIdxToNode

    @logDecorator
    def prepareVmChannel(self, socketFile, group=None):
        if (socketFile.startswith(P_LIBVIRT_VMCHANNELS) or
           socketFile.startswith(P_OVIRT_VMCONSOLES)):
            fsinfo = os.stat(socketFile)
            mode = fsinfo.st_mode | stat.S_IWGRP
            os.chmod(socketFile, mode)
            if group is not None:
                os.chown(socketFile,
                         fsinfo.st_uid,
                         resolveGid(group))
        else:
            raise Exception("Incorporate socketFile")

    @logDecorator
    def restoreNetworks(self):
        return restore_nets.restore()

    @logDecorator
    def editNetwork(self, oldBridge, newBridge, options):
        return editNetwork(oldBridge, newBridge, **options)

    @logDecorator
    def setupNetworks(self, networks, bondings, options):
        return setupNetworks(networks, bondings, **options)

    @logDecorator
    def changeNumvfs(self, pci_path, numvfs, net_name):
        return change_numvfs(pci_path, numvfs, net_name)

    def _runAs(self, user, groups, func, args=(), kwargs={}):
        def child(pipe):
            res = ex = None
            try:
                uid = resolveUid(user)
                if groups:
                    gids = map(resolveGid, groups)

                    os.setgid(gids[0])
                    os.setgroups(gids)
                os.setuid(uid)

                res = func(*args, **kwargs)
            except BaseException as e:
                ex = e

            pipe.send((res, ex))
            pipe.recv()

        pipe, hisPipe = Pipe()
        with closing(pipe), closing(hisPipe):
            proc = Process(target=child, args=(hisPipe,))
            proc.start()

            needReaping = True
            try:
                if not safe_poll(pipe, RUN_AS_TIMEOUT):
                    try:

                        os.kill(proc.pid, signal.SIGKILL)
                    except OSError as e:
                        # Don't add to zombiereaper of PID no longer exists
                        if e.errno == errno.ESRCH:
                            needReaping = False
                        else:
                            raise

                    raise Timeout()

                res, err = pipe.recv()
                pipe.send("Bye")
                proc.terminate()

                if err is not None:
                    raise err

                return res

            finally:
                # Add to zombiereaper if process has not been waited upon
                if proc.exitcode is None and needReaping:
                    zombiereaper.autoReapPID(proc.pid)

    @logDecorator
    def validateAccess(self, user, groups, *args, **kwargs):
        return self._runAs(user, groups, _validateAccess, args=args,
                           kwargs=kwargs)

    @logDecorator
    def setSafeNetworkConfig(self):
        return setSafeNetworkConfig()

    @logDecorator
    def udevTriggerMultipath(self, guid):
        self._udevTrigger(property_matches=(('DM_NAME', guid),))

    @logDecorator
    def appropriateMultipathDevice(self, guid, thiefId):
        ruleFile = _UDEV_RULE_FILE_NAME % (guid, thiefId)
        # WARNING: we cannot use USER, GROUP and MODE since using any of them
        # will change the selinux label to the default, causing vms to pause.
        # See https://bugzilla.redhat.com/1147910
        rule = 'SYMLINK=="mapper/%s", RUN+="%s %s:%s $env{DEVNAME}"\n' % (
            guid, EXT_CHOWN, DISKIMAGE_USER, DISKIMAGE_GROUP)
        with open(ruleFile, "w") as rf:
            self.log.debug("Creating rule %s: %r", ruleFile, rule)
            rf.write(rule)

    @logDecorator
    def rmAppropriateMultipathRules(self, thiefId):
        re_apprDevRule = "^" + _UDEV_RULE_FILE_PREFIX + ".*?-" + thiefId + \
                         _UDEV_RULE_FILE_EXT + "$"
        rules = [os.path.join(_UDEV_RULE_FILE_DIR, r) for r in
                 os.listdir(_UDEV_RULE_FILE_DIR)
                 if re.match(re_apprDevRule, r)]
        fails = []
        for r in rules:
            try:
                self.log.debug("Removing rule %s", r)
                os.remove(r)
            except OSError:
                fails.append(r)
        return fails

    @logDecorator
    def appropriateIommuGroup(self, iommu_group):
        """
        Create udev rule in /etc/udev/rules.d/ to change ownership
        of /dev/vfio/$iommu_group to qemu:qemu. This method should be called
        when detaching a device from the host.
        """
        rule_file = _UDEV_RULE_FILE_NAME_VFIO % iommu_group

        if not os.path.isfile(rule_file):
            # If the file exists, different device from the same group has
            # already been detached and we therefore can skip overwriting the
            # file. Also, this file should only be created/removed via the
            # means of supervdsm.

            rule = ('KERNEL=="{}", SUBSYSTEM=="vfio" RUN+="{} {}:{} '
                    '/dev/vfio/{}"').format(iommu_group, EXT_CHOWN,
                                            QEMU_PROCESS_USER,
                                            QEMU_PROCESS_GROUP,
                                            iommu_group)

            with open(rule_file, "w") as rf:
                self.log.debug("Creating rule %s: %r", rule_file, rule)
                rf.write(rule)

            self._udevTrigger(subsystem_matches=('vfio',))

    @logDecorator
    def rmAppropriateIommuGroup(self, iommu_group):
        """
        Remove udev rule in /etc/udev/rules.d/ created by
        vfioAppropriateDevice.
        """
        rule_file = os.path.join(_UDEV_RULE_FILE_DIR, _UDEV_RULE_FILE_PREFIX +
                                 "iommu_group_" + iommu_group +
                                 _UDEV_RULE_FILE_EXT)
        error = False

        try:
            os.remove(rule_file)
        except OSError as e:
            if e.errno == errno.ENOENT:
                # OSError with ENOENT errno here means that the rule file does
                # not exist - this is expected when multiple devices in one
                # iommu group were passed through.
                error = True
            else:
                raise
        else:
            self.log.debug("Removing rule %s", rule_file)

        if not error:
            self._udevTrigger(subsystem_matches=('vfio',))

    @logDecorator
    def appropriateUSBDevice(self, bus, device):
        rule_file = _UDEV_RULE_FILE_NAME_USB % (bus, device)
        rule = ('SUBSYSTEM=="usb", ATTRS{{busnum}}=="{}", '
                'ATTRS{{devnum}}=="{}", OWNER:="{}", GROUP:="{}"\n').format(
            bus, device, QEMU_PROCESS_USER, QEMU_PROCESS_GROUP)

        self.log.debug("Creating rule %s: %r", rule_file, rule)

        with open(rule_file, "w") as rf:
            rf.write(rule)

        self._udevTrigger(attr_matches=(('busnum', int(bus)),
                                        ('devnum', int(device))))

    @logDecorator
    def rmAppropriateUSBDevice(self, bus, device):
        rule_file = _UDEV_RULE_FILE_NAME_USB % (bus, device)
        self.log.debug("Removing rule %s", rule_file)
        try:
            os.remove(rule_file)
        except OSError as e:
            if e.errno != errno.ENOENT:
                raise

            self.log.warning('Rule %s missing', rule_file)

        self.log.debug('Changing ownership (to root:root) of device '
                       'bus: %s, device:: %s', bus, device)
        device_path = _USB_DEVICE_PATH % (int(bus), int(device))
        cmd = [EXT_CHOWN, 'root:root', device_path]
        rc, out, err = utils.execCmd(cmd)
        if err:
            raise OSError(errno.EINVAL, 'Could not change ownership'
                          'out %s\nerr %s' % (out, err))

        # It's possible that the device was in input class or had rule
        # matched against it, trigger to make sure everything is fine.
        self._udevTrigger(attr_matches=(('busnum', int(bus)),
                                        ('devnum', int(device))))

    @logDecorator
    def ksmTune(self, tuningParams):
        '''
        Set KSM tuning parameters for MOM, which runs without root privilege
        when it's lauched by vdsm. So it needs supervdsm's assistance to tune
        KSM's parameters.
        '''
        KSM_PARAMS = {'run': 3, 'merge_across_nodes': 3,
                      'sleep_millisecs': 0x100000000,
                      'pages_to_scan': 0x100000000}
        for (k, v) in tuningParams.iteritems():
            if k not in KSM_PARAMS.iterkeys():
                raise Exception('Invalid key in KSM parameter: %s=%s' % (k, v))
            if int(v) < 0 or int(v) >= KSM_PARAMS[k]:
                raise Exception('Invalid value in KSM parameter: %s=%s' %
                                (k, v))
            with open('/sys/kernel/mm/ksm/%s' % k, 'w') as f:
                f.write(str(v))

    @logDecorator
    def setPortMirroring(self, networkName, ifaceName):
        '''
        Copy networkName traffic of a bridge to an interface

        :param networkName: networkName bridge name to capture the traffic from
        :type networkName: string

        :param ifaceName: ifaceName to copy (mirror) the traffic to
        :type ifaceName: string

        this commands mirror all 'networkName' traffic to 'ifaceName'
        '''
        setPortMirroring(networkName, ifaceName)

    @logDecorator
    def unsetPortMirroring(self, networkName, target):
        '''
        Release captured mirror networkName traffic from networkName bridge

        :param networkName: networkName to release the traffic capture
        :type networkName: string
        :param target: target device to release
        :type target: string
        '''
        unsetPortMirroring(networkName, target)

    @logDecorator
    def mkFloppyFs(self, vmId, files, volId):
        return mkimage.mkFloppyFs(vmId, files, volId)

    @logDecorator
    def mkIsoFs(self, vmId, files, volId):
        return mkimage.mkIsoFs(vmId, files, volId)

    @logDecorator
    def removeFs(self, path):
        return mkimage.removeFs(path)

    @logDecorator
    def fuser(self, *args, **kwargs):
        return fuser.fuser(*args, **kwargs)

    @logDecorator
    def hbaRescan(self):
        return hba._rescan()

    @logDecorator
    def set_rp_filter_loose(self, dev):
        sysctl.set_rp_filter_loose(dev)

    @logDecorator
    def set_rp_filter_strict(self, dev):
        sysctl.set_rp_filter_strict(dev)

    def _udevTrigger(self, *args, **kwargs):
        try:
            udevadm.trigger(*args, **kwargs)
        except udevadm.Error as e:
            raise OSError(errno.EINVAL, 'Could not trigger change '
                          'out %s\nerr %s' % (e.out, e.err))


def terminate(signo, frame):
    global _running
    _running = False


def main(sockfile, pidfile=None):
    log = logging.getLogger("SuperVdsm.Server")
    if not config.getboolean('vars', 'core_dump_enable'):
        resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
    sigutils.register()
    zombiereaper.registerSignalHandler()

    def bind(func):
        def wrapper(_SuperVdsm, *args, **kwargs):
            return func(*args, **kwargs)
        return wrapper

    if _glusterEnabled:
        for name, func in listPublicFunctions(RHEV_ENABLED):
            setattr(_SuperVdsm, name, logDecorator(bind(func)))

    try:
        log.debug("Making sure I'm root - SuperVdsm")
        if os.geteuid() != 0:
            sys.exit(errno.EPERM)

        if pidfile:
            pid = str(os.getpid())
            with open(pidfile, 'w') as f:
                f.write(pid + "\n")

        log.debug("Parsing cmd args")
        address = sockfile

        log.debug("Cleaning old socket %s", address)
        if os.path.exists(address):
            os.unlink(address)

        log.debug("Setting up keep alive thread")

        try:
            signal.signal(signal.SIGTERM, terminate)
            signal.signal(signal.SIGINT, terminate)

            log.debug("Creating remote object manager")
            manager = _SuperVdsmManager(address=address, authkey='')
            manager.register('instance', callable=_SuperVdsm)

            server = manager.get_server()
            servThread = threading.Thread(target=server.serve_forever)
            servThread.setDaemon(True)
            servThread.start()

            chown(address, getpwnam(VDSM_USER).pw_uid, METADATA_GROUP)

            log.debug("Started serving super vdsm object")

            sourceroutethread.start()

            while _running:
                sigutils.wait_for_signal()

            log.debug("Terminated normally")
        finally:
            if os.path.exists(address):
                utils.rmFile(address)

    except Exception:
        log.error("Could not start Super Vdsm", exc_info=True)
        sys.exit(1)


def _usage():
    print "Usage:  supervdsmServer --sockfile=fullPath [--pidfile=fullPath]"


def _parse_args():
    argDict = {}
    opts, args = getopt.getopt(sys.argv[1:], "h", ["sockfile=", "pidfile="])
    for o, v in opts:
        o = o.lower()
        if o == "--sockfile":
            argDict['sockfile'] = v
        elif o == "--pidfile":
            argDict['pidfile'] = v
        else:
            _usage()
            sys.exit(1)
    if 'sockfile' not in argDict:
        _usage()
        sys.exit(1)

    return argDict


if __name__ == '__main__':
    argDict = _parse_args()
    main(**argDict)
