import glob
import logging
import os
import re
import shutil

import imgbased

from subprocess import check_call, check_output
from rpmUtils.miscutils import splitFilename
from imgbased.bootloader import BootConfiguration

log = logging.getLogger()

b = BootConfiguration()
imgbased = imgbased.Application()


def fix_new_kernel_boot(new_kernel_version):
    # new-kernel-pkg erases our kernels from /boot
    # put them back for now so virt-v2v and friends still work
    current_layer = imgbased.imgbase.current_layer()

    old_kernels = glob.glob("/boot/{}/vmlinuz*".format(current_layer))
    old_initrds = glob.glob("/boot/{}/init*".format(current_layer))

    for kernel in old_kernels:
        log.info("Copying %s to %s" % (kernel, "/boot/"))
        shutil.copy2(kernel, "/boot")

    for initrd in old_initrds:
        log.info("Copying %s to %s" % (initrd, "/boot/"))
        shutil.copy2(initrd, "/boot")

    _, v, r, _, _ = splitFilename(new_kernel_version)

    verrel = "{}-{}".format(v, r)

    new_kernel_files = glob.glob("/boot/*{}*".format(verrel))

    for f in new_kernel_files:
        log.info("Copying %s to %s" % (f, "/boot/{}".format(current_layer)))
        shutil.copy2(f, "/boot/{}/".format(current_layer))

    if os.path.exists("/etc/grub2-efi.cfg"):
        check_call(["grub2-mkconfig", "-o", "/etc/grub2-efi.cfg"])
    else:
        check_call(["grub2-mkconfig", "-o", "/etc/grub2.cfg"])


def check_new_kernel():
    # Compare the current kernel to the one from the factory
    current_kernel = check_output(["rpm", "-q", "kernel"]).strip()
    stock_kernel = check_output(["rpm", "-q", "--dbpath",
                                 "/usr/share/factory/var/lib/rpm",
                                 "kernel"]).strip()

    current_kernel = re.sub(r'^kernel-', '', current_kernel)
    stock_kernel = re.sub(r'^kernel-', '', stock_kernel)

    # Also make sure that the users have not reverted the
    # changes from new-kernel-pkg
    if current_kernel != stock_kernel and current_kernel in b.get_default():
        return (True, current_kernel)

    return (False, current_kernel)

if __name__ == "__main__":
    lvl = logging.INFO
    fmt = "[%(levelname)s] %(message)s"

    h = logging.StreamHandler()
    h.setLevel(lvl)
    h.setFormatter(logging.Formatter(fmt))

    log.addHandler(h)
    log.setLevel(lvl)

    new_kernel_installed, new_version = check_new_kernel()
    if new_kernel_installed:
        fix_new_kernel_boot(new_version)
