import yaml
from termcolor import colored
import subprocess
from pyfra import * 
import time
import argparse
from enum import Enum

class GPU(Enum):
    A100_NVLINK = 2.06
    A100_PCIE_40GB = 2.06000001
    A100_PCIE_80GB = 2.06000002
    Tesla_V100_NVLINK = 0.80
    A40 = 1.2801
    Quadro_RTX_5000 = 0.57
    Quadro_RTX_4000 = 0.24
    GeForce_RTX_2080_Ti = 0.40
    Tesla_V100 = 0.47
    RTX_A6000 = 1.28
    RTX_A5000 = 0.77
    RTX_A4000 = 0.61

class KubeConfig():
    def __init__(self, load_default_presets=True, cw=True) -> None:
        self.config = None
        self.serviceconfig = None
        self.dry = False
        self.cw = cw
        if load_default_presets:
            self.init_default_deploy_config()
            self.init_cluster_ip_config()

    def init_default_deploy_config(self, name="default-name", rep_amount=1, max_surge="100%", max_unavailable="100%", default_settings=True):
        #not working yet
        config = {}
        config["metadata"] = {}

        spec = config["spec"] = {}
        spec["selector"] = {}
        spec["strategy"] = {}

        template = config["spec"]["template"] = {}
        temp_spec = template["spec"] = {}
        containers = temp_spec["containers"] = [{},]

        first_container = containers[0]
        self.ports = first_container["ports"] = []
        resources = first_container["resources"] = {}
        volume_mounts = first_container["volumeMounts"] = []
        limits = resources["limits"] = {}
        template["metadata"] = {"labels": {"app.kubernetes.io/name": name}}

        config["apiVersion"] = "apps/v1"
        config["kind"] = "Deployment"
        config["metadata"]["name"] = name

        args = first_container["args"] = ['-D']
        commands = first_container["command"] = ['/usr/sbin/sshd']
        first_container["image"] = "novelai/kube-ssh:13"
        first_container["imagePullPolicy"] = "Always"
        first_container["tty"] = True
        first_container["name"] = name
        
        self.ports.append({'containerPort': 22, 'name': 'sshd', 'protocol': 'TCP'})
        limits["ephemeral-storage"] = "100Gi"
        limits["cpu"] = "4"
        limits["memory"] = "16Gi"
        limits["nvidia.com/gpu"] = "1"
        
        if self.cw:
            volume_mounts.append({"mountPath": "/home/xuser/pssd", "name": "jlab-ssd"})
            volume_mounts.append({"mountPath": "/home/xuser/bigssd", "name": "jlab-bigssd"})
            volume_mounts.append({"mountPath": "/home/xuser/hugessd", "name": "jlab-hugessd"})
            volume_mounts.append({"mountPath": "/home/xuser/models", "name": "model-storage2"})
            volume_mounts.append({"mountPath": "/home/xuser/prodbigmodels", "name": "prod-bigmodels"})
            volume_mounts.append({"mountPath": "/home/xuser/diffusionstorage", "name": "permstorage"})
        volume_mounts.append({"mountPath": "/dev/shm", "name": "dshm"})

        temp_spec["dnsPolicy"] = "ClusterFirst"
        temp_spec["imagePullSecrets"] = [{"name": "regcred"}]
        temp_spec["restartPolicy"] = "Always"
        temp_spec["schedulerName"] = "default-scheduler"
        temp_spec["terminationGracePeriodSeconds"] = 10

        volumes = temp_spec["volumes"] = []
        if self.cw:
            volumes.append({"name": "jlab-ssd", "persistentVolumeClaim": {"claimName": "jlab-ssd"}})
            volumes.append({"name": "jlab-bigssd", "persistentVolumeClaim": {"claimName": "jlab-bigssd"}})
            volumes.append({"name": "jlab-hugessd", "persistentVolumeClaim": {"claimName": "jlab-hugessd"}})
            volumes.append({"name": "model-storage2", "persistentVolumeClaim": {"claimName": "model-storage2"}})
            volumes.append({"name": "prod-bigmodels", "persistentVolumeClaim": {"claimName": "prod-bigmodels"}})
            volumes.append({"name": "permstorage", "persistentVolumeClaim": {"claimName": "permstorage"}})
        volumes.append({"name": "dshm", "emptyDir": {"medium": "Memory"}})

        if self.cw:
            temp_spec["affinity"] = {"nodeAffinity": {"requiredDuringSchedulingIgnoredDuringExecution": {"nodeSelectorTerms": [{'matchExpressions': []}, ]}}}
            spec_match_expressions = temp_spec["affinity"]["nodeAffinity"]["requiredDuringSchedulingIgnoredDuringExecution"]["nodeSelectorTerms"][0]["matchExpressions"]
            spec_match_expressions.append({'key': 'gpu.nvidia.com/class', 'operator': 'In', 'values': ['RTX_A6000']})
            spec_match_expressions.append({'key': 'topology.kubernetes.io/region', 'operator': 'In', 'values': ['ORD1']})
            #spec_match_expressions.append({'key': 'kubernetes.io/hostname', 'operator': 'NotIn', 'values': ['g0c529a']})

        spec["replicas"] = rep_amount
        spec["selector"]["matchLabels"] = {"app.kubernetes.io/name": name}
        spec["strategy"]["rollingUpdate"] = {"maxSurge": max_surge, "maxUnavailable": max_unavailable}
        spec["strategy"]["type"] = "RollingUpdate"
        
        self.limits = limits
        self.spec_match_expressions = spec_match_expressions
        self.config = config
        self.volume_mounts = volume_mounts
        self.volumes = volumes
        self.resources = resources
        self.spec = spec
        self.first_container = first_container
        self.containers = containers
        self.temp_spec = temp_spec
        self.template = template
        self.name = name

    def init_default_service_config(self):
        config = {}
        config["apiVersion"] = "v1"
        config["kind"] = "Service"

        metadata = config["metadata"] = {}
        metadata["name"] = "deploy-name-service"
        metadata["annotations"] = {}
        metadata["annotations"]["metallb.universe.tf/address-pool"] = "public-ord1"
        metadata["annotations"]["metallb.universe.tf/allow-shared-ip"] = "deploy-name"

        spec = config["spec"] = {}
        spec["type"] = "LoadBalancer"
        spec["externalTrafficPolicy"] = "Local"
        spec["selector"] = {"app.kubernetes.io/name": "deploy-name"}
        ports = spec["ports"] = []
        ports.append({"name": "sshd", "port": 22, "protocol": "TCP", "targetPort": "sshd"})
        self.serviceconfig = config

    def init_cluster_ip_config(self):
        config = {}
        config["apiVersion"] = "v1"
        config["kind"] = "Service"

        metadata = config["metadata"] = {}
        metadata["name"] = "deploy-name-service"

        spec = config["spec"] = {}
        spec["type"] = "ClusterIP"
        spec["selector"] = {"app.kubernetes.io/name": "deploy-name"}
        ports = spec["ports"] = []
        ports.append({"name": "sshd", "port": 22, "protocol": "TCP", "targetPort": "sshd"})
        self.serviceconfig = config

    def load_deploy_config(self, path, name=None):
        f = open(path, "r")
        config = f.read()
        config = yaml.safe_load(config)
        self.config = config
        if name:
            self.set_name(name)

        f.close()

    def load_service_config(self, path):
        f = open(path, "r")
        config = f.read()
        config = yaml.safe_load(config)
        self.serviceconfig = config
        f.close()

    def set_name(self, name):
        if self.config:
            config = self.config
        
        else:
            raise(RuntimeError("No config, either load with load_config() one or init with init_config()"))

        config["metadata"]["name"] = name
        config["spec"]["selector"]["matchLabels"]["app.kubernetes.io/name"] = name
        config["spec"]["template"]["metadata"]["labels"]["app.kubernetes.io/name"] = name
        config["spec"]["template"]["spec"]["containers"][0]["name"] = name
        self.name = name

    def set_gpu(self, gpu_name=None, amount=1):

        if gpu_name:
            if isinstance(gpu_name, GPU):
                gpu_name = gpu_name.name

            self.config["spec"]["template"]["spec"]["affinity"]["nodeAffinity"]["requiredDuringSchedulingIgnoredDuringExecution"]["nodeSelectorTerms"][0]["matchExpressions"][0]["values"][0] = gpu_name

        self.config["spec"]["template"]["spec"]["containers"][0]["resources"]["limits"]["nvidia.com/gpu"] = str(amount)
        return

    def set_cpu(self, cores: int):
        self.limits["cpu"] = str(cores)
        return
    
    def set_cpu_only(self):
        del self.limits["nvidia.com/gpu"]
        iterx = self.spec_match_expressions
        for exp in iterx:
            if exp["key"] == "gpu.nvidia.com/class":
                iterx.remove(exp)

    def set_ram(self, ram: int):
        self.limits["memory"] = str(ram) + "Gi"
        return

    def set_local_storage(self, amount: int):
        self.limits["ephemeral-storage"] = str(amount) + "Gi"

    def set_image(self, image_str: str):
        self.first_container["image"] = image_str

    def reset_mounts(self):
        self.volume_mounts = []
        self.volumes = []

    def add_mount(self, mount_path, claim_name):
        self.add_volume(claim_name, claim_name)
        self.add_volume_mount(mount_path, claim_name)
    
    def reset_volumes(self):
        self.volumes = []

    def add_volume(self, volume_name, claim_name):
        self.volumes.append({"name": volume_name, "persistentVolumeClaim": {"claimName": claim_name}})
    
    def reset_volume_mounts(self):
        self.volume_mounts = []

    def add_volume_mount(self, mount_path, name):
        self.volume_mounts.append({"mountPath": mount_path, "name": name})

    def add_port(self, container_port: int, name: str, protocol: str = "TCP"):
        self.ports.append({'containerPort': container_port, 'name': name, 'protocol': protocol})
        self.serviceconfig["spec"]["ports"].append({'name': name, "port": container_port, "protocol": protocol, "targetPort": name})

    def add_env(self, name, value):
        self.env.append({'name': name, 'value': value})

    def print_information(self):
        config = self.config
        name = config["metadata"]["name"]
        try:
            gpu_str = config["spec"]["template"]["spec"]["affinity"]["nodeAffinity"]["requiredDuringSchedulingIgnoredDuringExecution"]["nodeSelectorTerms"][0]["matchExpressions"][0]["values"][0]
        except:
            pass
        resources = config["spec"]["template"]["spec"]["containers"][0]["resources"]["limits"]
        image = self.config["spec"]["template"]["spec"]["containers"][0]["image"]
        gpu_amount = resources["nvidia.com/gpu"] if "nvidia.com/gpu" in resources else None
        cpu_cores = resources["cpu"]
        ram = resources["memory"]
        local_storage = resources["ephemeral-storage"]
        if self.dry:
            print(colored("Dry runing the deployment...", "green"))

        print("Deployment name:", colored(name, "green"))
        if gpu_amount:
            print("GPU:", colored(gpu_amount + " x", "red"), colored(gpu_str, "green"))
        else:
            print("GPU:", colored("CPU ONLY", "red"))
        print("CPU:", colored(cpu_cores + " cores", "green"))
        print("RAM:", colored(ram, "green"))
        print("Local Storage:", colored(local_storage, "green"))
        print("Docker image:", colored(image, "green"))

        if self.serviceconfig:
            type = self.serviceconfig["spec"]["type"]
            print("Service type:", colored(type, "green"))

        else:
            print(colored("Service file not initialized.", "red"))

        return

    def dry_run(self, enable=True):
        self.dry = enable

    def create_service(self, overwrite=False, override=True):
        if override:
            print("Force using ClusterIP, this makes so you can't access the pod outside of the K8s cluster.")
            print("Set 'override'=False to disable this behavior.")
            self.create_cluster_ip()
            return

        dry = self.dry
        name = self.config["metadata"]["name"]
        self.serviceconfig["metadata"]["name"] = name + "-service"
        self.serviceconfig["metadata"]["annotations"]["metallb.universe.tf/allow-shared-ip"] = name
        self.serviceconfig["spec"]["selector"]["app.kubernetes.io/name"] = name
        serviceconfig_yaml = yaml.dump(self.serviceconfig)
        try:
            service = sh("kubectl get service| grep " + self.name + '-service', quiet=True).split(' ')[0]
        except:
            service = None

        if service and not overwrite:
            print(colored("Service already exists, use overwrite=True to overwrite.", "red"))
            return

        print(colored("Applying the service...", "blue"))
        if not dry:
            result = subprocess.run(['kubectl', 'apply', '-f', "-"], input=serviceconfig_yaml, encoding='ascii', stdout=subprocess.PIPE)

        if dry:
            result = subprocess.run(['kubectl', 'apply', '--dry-run=server', '-f', "-"], input=serviceconfig_yaml, encoding='ascii', stdout=subprocess.PIPE)
        result = result.stdout
        print(result)

    def create_cluster_ip(self, overwrite=False):
        dry = self.dry
        name = self.config["metadata"]["name"]
        self.serviceconfig["metadata"]["name"] = name + "-service"
        self.serviceconfig["spec"]["selector"]["app.kubernetes.io/name"] = name
        serviceconfig_yaml = yaml.dump(self.serviceconfig)
        try:
            service = sh("kubectl get service| grep " + self.name + '-service', quiet=True).split(' ')[0]
        except:
            service = None

        if service and not overwrite:
            print(colored("Service already exists, use overwrite=True to overwrite.", "red"))
            return

        print(colored("Applying the service...", "blue"))
        if not dry:
            result = subprocess.run(['kubectl', 'apply', '-f', "-"], input=serviceconfig_yaml, encoding='ascii', stdout=subprocess.PIPE)

        if dry:
            result = subprocess.run(['kubectl', 'apply', '--dry-run=server', '-f', "-"], input=serviceconfig_yaml, encoding='ascii', stdout=subprocess.PIPE)
        result = result.stdout
        print(result)

    def create_deployment(self, overwrite=False):
        dry = self.dry
        config_yaml = yaml.dump(self.config)
        try:
            deployment = sh("kubectl get deployment| grep " + self.name, quiet=True).split(' ')[0]
        except:
            deployment = None
        
        if deployment and not overwrite:
            print(colored("Deployment already exists, use overwrite=True to overwrite.", "red"))
            return

        print(colored("Applying the deployment...", "green"))
        
        if not dry:
            result = subprocess.run(['kubectl', 'apply', '-f', "-"], input=config_yaml, encoding='ascii', stdout=subprocess.PIPE)

        if dry:
            result = subprocess.run(['kubectl', 'apply', '--dry-run=server', '-f', "-"], input=config_yaml, encoding='ascii', stdout=subprocess.PIPE)
        result = result.stdout
        print(result)

    def get_pyfra_remote(self, user='root', check_every=5, quiet=True):
        '''
        Checks until the pod becomes running and after that returns a pyfra remote instance
        '''
        try:
            pod = sh("kubectl get pods| grep " + self.name, quiet=quiet).split(' ')[0] #works only with one replica, be aware. TODO: fix this to handle more than one replicas
        except:
            print(colored("Pod not found.", "red"))
            return None
            
        while 1:
            try:
                print(colored(".", "green"), end='', flush=True)
                remote = contrib.kube_remote(pod=pod, service_name=self.name+"-service", user=user, quiet=quiet) #copies ssh keys and gets a Remote instance
            except:
                time.sleep(check_every)
                continue
            else:
                print(colored("\nConnected to pod:", "green"), colored(pod, "blue"))
                return remote


    def kill_deployment(self, name=None):
        if name:
            sh("kubectl delete deploy " + name)
        else:
            name = self.config["metadata"]["name"]
            sh("kubectl delete deploy " + name)
    
    def kill_service(self, name=None):
        if name:
            sh("kubectl delete service " + name + "-service")
        else:
            name = self.config["metadata"]["name"]
            sh("kubectl delete service " + name + "-service")

