#!/usr/bin/env python

import os
import subprocess
import sys
import re
import argparse
from enum import Enum
from pathlib import Path, PurePosixPath
from natsort import natsorted


class Profile(Enum):
    intel_xe = "Intel Xe"
    intel_i915 = "Intel i915"
    vm_vmware = "VMware"
    vm_virtualbox = "Virtual Box"
    vm_qemu = "QEMU"
    nvidia_open = "NVIDIA nvidia-open"
    nvidia_580xx = "NVIDIA nvidia 580.xx"
    nvidia_legacy = "NVIDIA nouveau"
    amd_amdgpu = "AMD amdgpu"
    amd_legacy = "AMD legacy"
    generic = "Generic"
    x = "x"


class GPU:
    def __init__(self):
        self.devices = list()
        try:
            result = subprocess.run(["bash", "-c", "lspci -kmmnnv -d '*:*:03xx'"], capture_output=True, text=True)

            if result.returncode == 0:
                in_slot = False
                info = dict()
                for line in result.stdout.split('\n'):
                    str_parts = line.split(':', 1)
                    if len(str_parts) > 1:
                        match str_parts[0].strip():
                            case "Slot":
                                # When we hit slot we are starting a new device
                                if in_slot:
                                    # If we were already reading a device, add it to the device list and clear the dict
                                    self.devices.append(info.copy())
                                    info.clear()
                                else:
                                    in_slot = True
                            case "Device":
                                parts = split_pci_string(str_parts[1].strip())
                                info["device_name"] = parts[0]
                                info["device_id"] = parts[1]
                            case "Vendor":
                                parts = split_pci_string(str_parts[1].strip())
                                info["vendor_name"] = parts[0]
                                info["vendor_id"] = parts[1]
                            case "Driver":
                                info["driver"] = str_parts[1].strip()
                            case "Module":
                                info["module"] = str_parts[1].strip()

                if in_slot:
                    # Unless there were no devices, there should be one device that needs to be added to the list
                    self.devices.append(info.copy())

                return
        except:
            # We can ignore any exceptions because after this point we are going to report failure
            pass

        print('Failed to get data PCI data', file=sys.stderr)
        sys.exit(1)


def split_pci_string(input_string):
    pos = input_string.rfind("[")

    if pos == -1:
        return input_string, ''

    return input_string[0:pos], input_string[pos+1:len(input_string)-1]


class ProfileData:
    data = {
        Profile.intel_xe: {
            'packages': {"mesa", "vulkan-intel", "intel-media-driver", "gst-plugin-va", "vpl-gpu-rt"},
            'packages32': {"lib32-mesa", "lib32-vulkan-intel"},
            'services': {},
            'dracut_config': 'force_drivers+=" xe "'},
        Profile.intel_i915: {
            'packages': {"mesa", "vulkan-intel", "intel-media-driver", "gst-plugin-va", "libva-intel-driver"},
            'packages32': {"lib32-mesa", "lib32-vulkan-intel", "lib32-libva-intel-driver"},
            'services': {},
            'dracut_config': 'force_drivers+=" i915 "'},
        Profile.vm_vmware: {
            'packages': {"mesa", "open-vm-tools", "xf86-input-vmmouse", "gtkmm3", "vulkan-swrast"},
            'packages32': {"lib32-mesa"},
            'services': {'vmtoolsd.service', 'vmware-vmblock-fuse.service'},
            'dracut_config': ''},
        Profile.vm_virtualbox: {
            'packages': {"mesa", "virtualbox-guest-utils", "vulkan-swrast"},
            'packages32': {"lib32-mesa"},
            'services': {'vboxservice.service'},
            'dracut_config': ''},
        Profile.vm_qemu: {
            'packages': {"mesa", "spice-vdagent", "qemu-guest-agent", "vulkan-virtio"},
            'packages32': {"lib32-mesa"},
            'services': {},
            'dracut_config': ''},
        Profile.nvidia_open: {
            'packages': {"nvidia-open-dkms", "nvidia-utils", "egl-wayland", "nvidia-settings", "opencl-nvidia",
                         "libva-nvidia-driver", "libxnvctrl"},
            'packages32': {"lib32-nvidia-utils", "lib32-opencl-nvidia"},
            'services': {},
            'dracut_config': 'force_drivers+=" nvidia nvidia_modeset nvidia_uvm nvidia_drm "'},
        Profile.nvidia_580xx: {
            'packages': {"mesa", "vulkan-nouveau"},
            'packages32': {"lib32-mesa", "lib32-vulkan-nouveau"},
            'services': {},
            'dracut_config': 'force_drivers+=" nouveau "'},
        Profile.nvidia_legacy: {
            'packages': {"mesa", "vulkan-nouveau"},
            'packages32': {"lib32-mesa", "lib32-vulkan-nouveau"},
            'services': {},
            'dracut_config': 'force_drivers+=" nouveau "'},
        Profile.amd_amdgpu: {
            'packages': {"mesa", "vulkan-radeon", "gst-plugin-va"},
            'packages32': {"lib32-mesa", "lib32-vulkan-radeon"},
            'services': {},
            'dracut_config': 'force_drivers+=" amdgpu "'},
        Profile.amd_legacy: {
            'packages': {"mesa", "vulkan-swrast"},
            'packages32': {"lib32-mesa"},
            'services': {},
            'dracut_config': ''},
        Profile.generic: {
            'packages': {"mesa", "vulkan-swrast"},
            'packages32': {"lib32-mesa"},
            'services': {},
            'dracut_config': ''}
    }

    @staticmethod
    def get_packages(profile, no32, is_iso):
        if profile in ProfileData.data:
            packages = ProfileData.data[profile]["packages"]
            if is_iso and "nvidia-open-dkms" in packages:
                packages.remove("nvidia-open-dkms")
                packages.add("nvidia-open")
                if is_installed('linux-lts'):
                    packages.add('nvidia-open-lts')
            if not no32:
                packages = packages.union(ProfileData.data[profile]["packages32"])
            return packages
        else:
            return None

    @staticmethod
    def get_dracut_config(profile):
       if profile in ProfileData.data:
            return ProfileData.data[profile]["dracut_config"]
       else:
            return None

    @staticmethod
    def get_services(profile):
       if profile in ProfileData.data:
            return ProfileData.data[profile]["services"]
       else:
            return None


def get_vm():
    """
    Detects if we are running in a virtual machine

    :return: A list of device dicts
    """
    try:
        device = dict()
        match subprocess.run(["systemd-detect-virt"], text=True, capture_output=True).stdout.strip():
            case "vmware":
                device["name"] = "Broadcom VMware"
                device["profiles"] = [Profile.vm_vmware]
            case "qemu" | "kvm":
                device["name"] = "Qemu"
                device["profiles"] = [Profile.vm_qemu]
            case "oracle":
                device["name"] = "Oracle VirtualBox"
                device["profiles"] = [Profile.vm_virtualbox]
            case "none":
                return list()
            case _:
                device["name"] = "Other VM guest"
                device["profiles"] = [Profile.generic]
    except subprocess.SubprocessError:
        return list()

    return [device]


def manage_services(profile, action):
    """
    Enables systemd services for a profile

    :param profile: The name of the profile that services should be enabled for
    :param action: The action to take, enable or disable
    :return: N/A
    """
    if action != 'enable' and action != 'disable':
        print(f'ERROR: action {action} is not supported', file=sys.stderr)

    services = ProfileData.get_services(profile)

    if services:
        for service in services:
            base_command = ['systemctl', action, '--now']
            print(f'INFO: Enabling {service}')
            try:
                result = subprocess.run(base_command + [service])
                if result.returncode != 0:
                    print(f'WARNING: Failed to {action} {service}, please {action} it manually', file=sys.stderr)
            except subprocess.CalledProcessError:
                print(f'WARNING: Failed to {action} {service}, please {action} it manually', file=sys.stderr)


def create_dracut_conf(profile):
    """
    If the input profile contains any dracut configuration options, create a conf file using those options. If
    the conf file already exists, overwrite it.

    :param profile: The profile to create the conf file for
    :return: Returns True if a profile was created
    """
    options = ProfileData.get_dracut_config(profile)

    if not options:
        return False

    # Make sure that the dracut conf directory exists
    conf_directory = '/etc/dracut.conf.d'
    filename = PurePosixPath(conf_directory).joinpath('eos_' + profile.name + ".conf")
    try:
        if Path(conf_directory).is_dir():
            print(f'Installing dracut conf to {filename}')
            with open(filename, "w", encoding="utf-8") as dracut_conf_file:
                dracut_conf_file.write(options + "\n")
            return True
        else:
            print(f"Dracut configuration directory doesn't exist, skipping dracut config")
            return False
    except OSError:
        print(f'Warning: Failed to write dracut config to {filename}', file=sys.stderr)
        return False


def remove_dracut_conf(profile):
    """
    Remove the dracut config file for the profile if it exists

    :param profile: The profile to remove the conf file for
    :return: Returns True on success
    """
    conf_directory = '/etc/dracut.conf.d'
    filename = PurePosixPath(conf_directory).joinpath('eos_' + profile.name + ".conf")

    if Path(filename).exists():
        try:
            print(f'Removing dracut conf at {filename}')
            Path(filename).unlink()
            return True
        except OSError:
            print(f'Failed to delete dracut config at {filename}. Please remove it manually', file=sys.stderr)
            return False
    else:
        if ProfileData.get_dracut_config(profile):
            # There should have been a profile here but there wasn't one so we inform the user
            print(f'INFO: No dracut conf file found to remove for profile: {profile.name}')
        return False


def get_nvidia_profile(device_info):
    """
    Finds the optimal profile for the specified device

    :return: A profile ENUM
    """
    family = re.match("[A-Z]*", device_info["device_name"].strip(), flags=re.IGNORECASE).group()

    # Anything Turing though Broadwell is modern
    if family in {"TU", "GA", "AD", "GB"}:
        return Profile.nvidia_open
    elif family in {"GM", "GP"}:
        return Profile.nvidia_580xx
    else:
        return Profile.nvidia_legacy


def get_file_list(package_list, package_directory):
    """
    Get the files names for the packages specified by package_list at the location specified by package_directory

    :param package_list: A list of strings containing package names
    :param package_directory: A filesystem path which holds pacman package files
    :return: A list of package files or None
    """
    package_re = re.compile(
        r"""^
        (?P<name>.+)          # package name (greedy)
        -
        (?P<version>[^-]+)    # version (no dashes)
        -
        (?P<release>[^-]+)    # release (no dashes)
        -
        (?P<arch>[^.]+)       # arch up to the first dot
        \.pkg\.tar\.[^.]+$    # .pkg.tar.<ext>
        """,
        re.X,
    )

    path = Path(package_directory)
    if not dir:
        print(f"ERROR: Invalid directory {package_directory}", file=sys.stderr)
        return None

    file_list = list()
    for package in package_list:
        dir_list = natsorted(path.glob(package+"-*"), reverse=True)
        for filename in dir_list:
            match = package_re.match(filename.name)
            if match and match.group("name") == package:
                file_list.append(filename.name)
                break

    return file_list


def install_packages(package_list, is_iso, package_directory):
    """
    Uses pacman to install the listed packages. If package_directory is not empty, packages will be installed from
    this directory instead of the repos.

    :param package_list: A list of strings containing package names
    :param is_iso: Boolean, when true, use --noconfirm since we are running from the iso
    :param package_directory: A filesystem path which holds pacman package files
    :return: True on success
    """
    if package_directory:
        # We are installing from a directory
        filename_list = get_file_list(package_list, package_directory)
        if filename_list:
            #command = ["pacman", "--noconfirm", "-U"] + filename_list
            command = ["pacman", "-U"] + filename_list
        else:
            return False
    else:
        # We are installing from the repos
        if is_iso:
            command = ["pacman", "--needed", "--noconfirm", "-S"] + package_list
        else:
            command = ["pacman", "--needed", "-Syu"] + package_list

    try:
        result = subprocess.run(command)
        if result.returncode != 0:
            return False
    except subprocess.CalledProcessError as cpe:
        print(f'Pacman failed with error {format(cpe)}', file=sys.stderr)
        return False

    return True


def is_installed(package):
    """
    Uses pacman to determine if the given package is installed

    :param package: The name of a package to check
    :return: True if the package is installed, False if not
    """
    try:
        if package:
            result = subprocess.run(["pacman", "-Qq", package.strip()], capture_output=True, text=True)
            if result.returncode == 0:
                return True
    except subprocess.CalledProcessError:
        pass

    return False


def remove_packages(package_list, is_iso):
    """
    Uses pacman to remove the listed packages

    :param is_iso: Boolean, True when the `--iso` argument is passed
    :param package_list: A list of strings containing package names
    :return: True on success
    """

    # Some packages are needed for other purposes and should not be removed along with the drivers
    removal_blacklist = {"mesa", "lib32-mesa", "gtkmm3"}

    # First remove any packages from the list that are not installed
    filtered_package_list = list()
    for check_package in package_list:
        if is_installed(check_package) and check_package not in removal_blacklist:
            filtered_package_list.append(check_package)

    if not filtered_package_list:
        print('INFO: No packages need to be removed')
        return True

    if is_iso:
        command = ["pacman", "--noconfirm", "-Rc"] + filtered_package_list
    else:
        command = ["pacman", "-Rc"] + filtered_package_list

    try:
        result = subprocess.run(command)
        if result.returncode != 0:
            return False
    except subprocess.CalledProcessError as cpe:
        print(f'Pacman failed with error {format(cpe)}', file=sys.stderr)
        return False

    return True


def get_device_info(input_gpu_info):
    """
    Populates the device structure

    :param input_gpu_info: An object containing all the PCI data for GPUs
    :return: A list of dicts containing the profile information for each device
    """
    device_list = list()
    # Iterate over the devices
    for gpu_device in input_gpu_info.devices:
        device = dict()
        device["name"] = gpu_device["vendor_name"] + " - " + gpu_device["device_name"]
        device["profiles"] = set()

        # The best match is the driver but if that is not available, try to match using the module name
        if "driver" in gpu_device:
            match_string = gpu_device["driver"]
        elif "module" in gpu_device:
            match_string = gpu_device["module"]
        else:
            match_string = "none"


        match match_string:
            case "amdgpu":
                device["profiles"].add(Profile.amd_amdgpu)
            case "radeon":
                device["profiles"].add(Profile.amd_legacy)
            case "nouveau" | "nvidia":

                nvidia_profile = get_nvidia_profile(gpu_device)
                if nvidia_profile in (Profile.nvidia_open, Profile.nvidia_580xx):
                    if not args.free:
                        device["profiles"].add(nvidia_profile)

                    if args.free or not only_recommended:
                        device["profiles"].add(Profile.nvidia_legacy)
                else:
                    device["profiles"].add(Profile.nvidia_legacy)
            case "xe":
                device["profiles"].add(Profile.intel_xe)
            case "i915":
                device["profiles"].add(Profile.intel_i915)
            case _:
                device["profiles"].add(Profile.generic)

        # Every device can use the generic profile
        if not only_recommended:
            device["profiles"].add(Profile.generic)

        device_list.append(device)

    return device_list

def do_install(arg_list, device_list):
    """
    Installs the requested profiles.  If install-recommended was selected, install the recommended profiles

    :param arg_list: A Namespace object from parse_args
    :param device_list: A list of the profiles to be considered for each GPU device
    :return: N/A
    """
    profile_set = set()
    package_set = set()

    if arg_list.install:
        # When using append, we end up with an array of arrays in the arg list
        for inner_array in arg_list.install:
            for profile_name in inner_array:
                # Validate the input profiles
                try:
                    profile = getattr(Profile, profile_name)
                except AttributeError:
                    print(f'{profile_name} is not a valid profile name', file=sys.stderr)
                    sys.exit(1)

                profile_set.add(profile)

    if arg_list.install_recommended:
        for device in device_list:
            for profile in device["profiles"]:
                profile_set.add(profile)

    for profile in profile_set:
        # If this is an nvidia only install, skip everything except the proprietary nvidia drivers
        if arg_list.nvidia_only and profile not in (Profile.nvidia_open, Profile.nvidia_580xx):
            continue

        packages = ProfileData.get_packages(profile, arg_list.no32, arg_list.iso)

        if packages:
            for package in packages:
                package_set.add(package)

    is_packages_installed = False
    if package_set:
        is_packages_installed = install_packages(list(package_set), arg_list.iso, arg_list.packagedir)

    # If packages were installed, we can enable services, create the dracut profiles and then regenerate the initramfs
    if is_packages_installed:
        is_dracut_conf_created = False
        for profile in profile_set:
            manage_services(profile, 'enable')

            if create_dracut_conf(profile):
                is_dracut_conf_created = True

        try:
            if is_dracut_conf_created:
                if Path('/usr/bin/dracut-rebuild').is_file():
                    subprocess.run(["dracut-rebuild"], capture_output=False, text=True)
                else:
                    print('Warning: Cannot rebuild initramfs, please rebuild manually', file=sys.stderr)
        except subprocess.CalledProcessError:
            print('Warning: Rebuilding initramfs failed, please rebuild manually', file=sys.stderr)



def do_remove(arg_list):
    """
    Removes the requested profiles.

    :param arg_list: A Namespace object from parse_args
    :return: N/A
    """

    profile_list = list()
    package_set = set()

    # When using append, we end up with an array of arrays in the arg list
    for inner_array in arg_list.remove:
        for profile_name in inner_array:
            # Validate the input profiles
            try:
                profile = getattr(Profile, profile_name)
            except AttributeError:
                print(f'{profile_name} is not a valid profile name', file=sys.stderr)
                sys.exit(1)

            profile_list.append(profile)

    for profile in profile_list:
        packages = ProfileData.get_packages(profile, arg_list.no32, arg_list.iso)

        if packages:
            for package in packages:
                package_set.add(package)


    is_packages_removed = remove_packages(list(package_set), arg_list.iso)

    # If package removal didn't fail remove the dracut configs and recreate the initramfs
    if is_packages_removed:
        is_dracut_conf_removed = False
        for profile in profile_list:
            manage_services(profile, 'disable')

            if remove_dracut_conf(profile):
                is_dracut_conf_removed = True

        try:
            if is_dracut_conf_removed:
                if Path('/usr/bin/dracut-rebuild').is_file():
                    subprocess.run(["dracut-rebuild"], capture_output=False, text=True)
                else:
                    print('Warning: Cannot rebuild initramfs, please rebuild manually', file=sys.stderr)
        except subprocess.CalledProcessError:
            print('Warning: Rebuilding initramfs failed, please rebuild manually', file=sys.stderr)


def do_purge(arg_list, device_list):
    """
    Removes all the packages that are *not* part of the recommended profiles but are part of other profiles. This is
    intended for use only in the installer and is hidden from --help

    :param arg_list: A Namespace object from parse_args
    :param device_list: A list of the profiles to be considered for each GPU device
    :return: N/A
    """
    package_set = set()
    package_set_keep = set()


    # First get the list of packages we want to keep
    for device in device_list:
        for profile in device["profiles"]:
            packages = ProfileData.get_packages(profile, arg_list.no32, arg_list.iso)

            if packages:
                for package in packages:
                    package_set_keep.add(package)

    for profile in Profile:
        packages = ProfileData.get_packages(profile, arg_list.no32, arg_list.iso)
        if packages:
            for package in packages:
                if package not in package_set_keep:
                    package_set.add(package)

    remove_packages(list(package_set), arg_list.iso)


def do_enable_services(device_list):
    """
    Enabled services for the recommended profiles.  Used by the ISO

    :param device_list: A list of the profiles to be considered for each GPU device
    :return: N/A
    """

    # First get the list of packages we want to keep
    for device in device_list:
        if "profiles" in device and device["profiles"]:
            manage_services(device["profiles"].pop(), "enable")


def do_check_nvidia(device_list):
    profile_set = set()
    for device in device_list:
        for profile in device["profiles"]:
            profile_set.add(profile)

    # If we are installing on the ISO
    if Profile.nvidia_legacy in profile_set:
        print("nouveau")
    elif Profile.nvidia_580xx in profile_set:
        print("nvidia-580xx")
    elif Profile.nvidia_open in profile_set:
        print("nvidia-open")
    else:
        print("none")


if __name__ == '__main__':
    # Set up the available command line arguments
    parser = argparse.ArgumentParser(description='The EndeavourOS hardware management tool is used to manage VM and GPU '
                                                 'drivers.',
                                     epilog='example: eos-hwtool --install nvidia_open'
                                     )
    parser.add_argument('--list', action="store_true", help='Show all the profiles valid for your devices')
    parser.add_argument('--free', action="store_true", help='Exclude proprietary drivers(Currently only nvidia-open')
    parser.add_argument('--no32', action="store_true", help='Exclude 32-bit libraries')
    parser.add_argument('--recommended', action="store_true", help='Show the recommended profiles for each device')
    parser.add_argument('--check-nvidia', action="store_true",
                        help='Determine if the nvidia-open drivers should be loaded on the ISO')
    parser.add_argument('--install-recommended', action="store_true",
                        help='Install or repair the selected profile')
    parser.add_argument('--install', '--repair', type=str, nargs='*', action="append", metavar='Profile',
                        help='Install or repair the selected profile')
    parser.add_argument('--remove', type=str, nargs='*', action="append", metavar='Profile',
                        help='Remove the selected profile')
    parser.add_argument('--purge', action="store_true", help=argparse.SUPPRESS)
    parser.add_argument('--iso', action="store_true", help=argparse.SUPPRESS)
    parser.add_argument('--nvidia-only', action="store_true", help=argparse.SUPPRESS)
    parser.add_argument('--enable-services', action="store_true", help=argparse.SUPPRESS)
    parser.add_argument('--packagedir', type=str, metavar='Path', help=argparse.SUPPRESS)
    args = parser.parse_args(args=None if sys.argv[1:] else ['--recommended'])

    # We use only_recommended to limit the profile selection in various places to only the best profile for the card
    only_recommended = args.recommended or args.install_recommended or args.purge or args.check_nvidia or args.enable_services

    # First we check if we are in a VM
    devices = get_vm()

    # If we found a VM, we can stop looking for other devices
    if not devices:
        # Scan the PCI bus for GPUs
        gpu_info = GPU()

        # There really should be at least one GPU
        if not gpu_info.devices:
            print("No devices found", file=sys.stderr)
            sys.exit(1)

        devices = get_device_info(gpu_info)

    if args.check_nvidia:
        do_check_nvidia(devices)
        sys.exit(0)

    if args.list or args.recommended:
        for device_item in devices:
            print(f'Device: {device_item["name"]}')
            print('Profiles:')
            for profile_item in device_item["profiles"]:
                print(f'\t{profile_item.name}')

        sys.exit(0)

    # Operations below this point require root privileges
    if os.getuid() != 0:
        print('This operation requires root privileges, please re-run with sudo or as root', file=sys.stderr)
        sys.exit(1)

    if args.install or args.install_recommended:
        do_install(args, devices)
        sys.exit(0)

    if args.remove:
        do_remove(args)
        sys.exit(0)

    if args.purge:
        do_purge(args, devices)
        sys.exit(0)

    if args.enable_services:
        do_enable_services(devices)
        sys.exit(0)

    sys.exit(0)

