from os import XATTR_REPLACE
import time
import subprocess
import shlex
import json
import requests
from requests.auth import HTTPBasicAuth
import time
import vast
import datacrunchwrap
import scaler
import sys
import math
import logging
import pika
import threading
from datetime import datetime

global buffer_usage
buffer_usage = 0
global updating_now
updating_now = False
msg_ready_t1 = 0
launched_instances = []
get_first_message = []
get_died = []
get_started = []
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.DEBUG)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter('%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s')
fh.setFormatter(fh_formatter)
logger.addHandler(fh)

gpu_tflops = {
    "RTX 2080 Ti": 26,
    "Tesla V100": 28,
    "A100 PCIE": 78,
    "Tesla P100": 20,
    "RTX A6000": 40,
    "RTX 2080": 20,
    "RTX 3090": 35,
    "fuckitab": 1,
}


def main():

    """
    dph_total: $ hour rental cost. min_bid: if interruptable.
    num_gpus: int, gpu_name: string,
    """
    # TODO: How to shut down the gpus? -> Just shut them down if huggingface is not growing larger every 30 seconds or so? we should be smart with not shuttin down rare gpus with good pricing though.
    # This needs good code. A lot of optimization possibilities.
    # TODO: Prioritize smaller num_gpu instances while renting?:
    # might make sense to have some smaller num_gpu instances for easier up and down scaling, probably shouldn'r prioritize them too strongly tho
    # TODO: Error handling!
    # TODO: Keep some gpus as the buffer.
    # TODO: Watch the buffer gpus to scale, don't do momentum for now, immediately scale as there is any users on the buffer gpu. This needs distinction between normal nodes and buffers.
    # TODO: check node health regularly, take measures if buffer node or any other node fails, don't close nodes because you're not getting any data from
    # buffer nodes and that makes it 0, scale according to the queue too.
    
    global buffer_usage, launched_instances
    x = 0
    hfloadlist = []
    qloadlist = []
    jointlist = []
    bufferlist = []
    tflops_will_launch = 0
    inst_to_launch = []
    
    gpus_to_launch = []
    buffer_gpus_to_launch = []
    buffer_gpus_launched = []
    buffer_instances = []
    debug = True
    num_buffer = 0
    buffer_gpus = 1
    zero_times = 0

    if not debug:
        create_price_table()
        sys.exit(0)

    
    launched_instances, gpus_to_launch = restore_instances_data() #this doesn't work.
    #TODO: Write a hack here to not get buffer gpus in. Launch them yourself for the alpha.

    '''
    buffer_instances = None #TODO: Implement reading buffer instances.

    if len(buffer_instances) < num_buffer:
        scaler.launch_accordinggops(num_buffer-len(buffer_instances), buffer=True)
    '''

    #TODO: should check if gpus are living or not every 10 seconds or so.

    while 1:
        for instance in gpus_to_launch:
            logger.debug(f"Checking if {instance.dcname} is launched")
            if instance.dcname in get_first_message:#get_started or get_first_message?
                logger.debug(f"{instance.dcname} has launched, adding it to launched instances.")
                tflops_will_launch -= instance.tflops
                gpus_to_launch.remove(instance)
                launched_instances.append(instance)
                #TODO: also track running TFLOPs.
            
            else:
                if instance.is_launching(): # TODO: doesn't work, keeps saying its launching even though it launched.
                    logger.debug(f"{instance.dcname} is still provisioning.")
                else:
                    logger.debug(f"{instance.dcname} couldn't launch, removing from launch list.")
                    gpus_to_launch.remove(instance)
                    tflops_will_launch -= instance.tflops

        #TODO: Launch buffer gpus properly with their own list and shit, name it as buffer_tflops instead?
        if buffer_gpus_launched + buffer_gpus_to_launch < buffer_gpus:
            #logger.info(f"There is only {")
            logger.info("We don't have enough buffer GPUs, launching more.")
            tflops = 28 * (buffer_gpus_launched + buffer_gpus_to_launch)  # hack
            buffer_gpus_to_launch_local, tflopsx = launch_given_tflops(tflops,)
            buffer_gpus_to_launch += buffer_gpus_to_launch_local
            tflops_will_launch += tflopsx
            logger.info(f"Will launch {tflopsx} tflops and {str(len(buffer_gpus_to_launch_local))} GPUs.")
        hfload = int(listen_hf_load())
        qload = int(listen_q_load())
        hfloadlist.append(hfload)
        qloadlist.append(qload)
        
        buffer_usage_sum = buffer_usage
        
        bufferlist.append(int(buffer_usage_sum))
        jointlist.append(int(buffer_usage_sum) + int(qload))

        # TODO: destroy vastai for now, do stop/start handling in the future.
        # need to scale down slowly, start cutting the most expensive tflops and smaller instances first.

        if qload > 10 and hfload < 10:
            logger.error("HFAPI is probably not working.")

        if x == 10:
            buffer_usage = 0
            momentum = calculate_momentum(bufferlist)
            bufferavg = (sum(bufferlist) / 10)
            queueavg = (sum(qloadlist) / 10)
            loadavg = (sum(jointlist) / 10)
            user_needed = bufferavg + momentum
            
            #launched_instances, gpus_to_launch = restore_instances_data()
            #TODO: might check if we lost any instances here.
            logger.info("Launched instances: " + str(len(launched_instances)) + ", Instance about to launch: " + str(len(gpus_to_launch)))
            #TODO: handle gpus_to_launch better, make proper checks and make sure they launch in 3 minutes or so, if they dont remove them from the list.

            logger.info(f'bufferavg + momentum: {str(user_needed)}, bufferavg: {str(bufferavg)}, momentum: {str(momentum)}')

            if user_needed == 0:
                logger.info("Zero Times: " + str(zero_times))
                zero_times += 1

            if zero_times == 30 and len(launched_instances) > 1:
                logger.info("Killing the most expensive $/TFLOP and smallest instance because no usage.")
                mostexpensive = sorted(launched_instances, key=lambda x: (-(float(x.tflops) / float(x.price)), -float(x.instance["num_gpus"])))
                if mostexpensive[0].provider == "datacrunch":
                    mostexpensive[0].destroy_instance()
                    launched_instances.remove(mostexpensive[0])

                elif mostexpensive[0].provider == "vastai":
                    mostexpensive[0].destroy_instance()
                    launched_instances.remove(mostexpensive[0])

                zero_times = 0
            
            elif zero_times == 30:
                zero_times = 0
                
            if user_needed >= 1 and len(gpus_to_launch) == 0: 
                #TODO: need better logic here. Definitely look at momentum and do automagic gpu buffer scaling too. 
                logger.info("We don't have enough GPUs, launching more.")
                tflops = 28
                
                #tflops = (user_needed * 8) - tflops_will_launch
                # 8 tflops per user. #TODO: TUNE THIS!
                '''
                gpus_to_launch_local, tflopsx = launch_given_tflops(tflops)
                gpus_to_launch = gpus_to_launch + gpus_to_launch_local
                logger.info(f"Will launch {tflopsx} tflops and {str(len(gpus_to_launch_local))} GPUs.") #TODO: instances or GPUs?
                tflops_will_launch += tflopsx
                '''
            jointlist = []
            hfloadlist = []
            qloadlist = []
            bufferlist = []
            x = 0

        x = x + 1
        time.sleep(1)


def restore_instances_data(): #need to do this by pinging all of them or something.
    dc_maclist = datacrunchwrap.list_my_instances()
    #vast_maclist = vast.list_my_instances()
    maclist = dc_maclist# + vast_maclist
    instances = []
    launched_instances = []
    launching_instances = []
    for x in range(0, len(maclist)):
        
        instances.append(scaler.Instance(maclist[x]['provider']))
        instances[x].restore_instance(maclist[x])
        if instances[x].is_running() and not "buffer" in instances[x].dcinstance.hostname:
            launched_instances.append(instances[x])
        elif instances[x].is_launching() and not "buffer" in instances[x].dcinstance.hostname:
            launching_instances.append(instances[x])
    return launched_instances, launching_instances

def launch_given_tflops(tflops, nameprefix):
    non_buffer_count = len(launched_instances) + len(gpus_to_launch)
    #TODO: have a global count variable holding how many buffer gpus there are and add upon that to nane the instances. Load this on restore_instances too.
    # 19 is mininimum tflops we can get(rtx 2080), so we should atleast look for that amount of gpus.
    run_find_times = int(tflops / 19) + 5
    maclist = (datacrunchwrap.list_available_instances())
    maclist = find_cheapest(maclist, run_find_times)
    # print(maclist)
    tflops_found = 0
    launched_instances = []
    x = 0
    y = 0
    while x < len(maclist):

        if tflops <= tflops_found:
            break

        curr_instance_tflop = (
            gpu_tflops[maclist[x]["gpu_name"]] * maclist[x]["num_gpus"]
        )
        if tflops - tflops_found + (tflops - tflops_found) * 0.5 <= curr_instance_tflop:
            x += 0

        else:
            if maclist[x]["provider"] == "datacrunch":
                tflops_needed = tflops - tflops_found
                # TODO: need better logic here.
                how_many_gpus = math.ceil(tflops_needed / curr_instance_tflop)
                for z in range(0, how_many_gpus):
                    launched_instances.append(scaler.Instance("datacrunch"))
                    tflops_found += (
                        gpu_tflops[maclist[x]["gpu_name"]] * maclist[x]["num_gpus"]
                    )
                    result = launched_instances[y].launch_instance(maclist[x])
                    # handle errors here!
                    if result:
                        y += 1
                    else:
                        launched_instances.pop()
                        break

            elif maclist[x]["provider"] == "vastai":
                tflops_needed = tflops - tflops_found
                print(tflops_needed)
                launched_instances.append(scaler.Instance(maclist[x]["provider"]))
                tflops_found += (
                    gpu_tflops[maclist[x]["gpu_name"]] * maclist[x]["num_gpus"]
                )
                # print(maclist[x]["gpu_name"] + " " + str(maclist[x]["id"]))
                launched_instances[y].launch_instance(maclist[x])
                # do handling of datacrunch launches here. get the cheapest for many times you can from datacrunch.
                y += 1
        x += 1

    # print("THERE")
    return launched_instances, tflops_found


def calculate_momentum(users):
    """
    takes in a x length list from listen functions and calculates the momentum by assigning weights.
    """

    momentum = 0
    coeff = 0.1
    for x in range(0, len(users) - 1):
        momentum = momentum + (users[x + 1] - users[x]) * coeff
        coeff = coeff + 1 / len(users)

    momentum = momentum / len(users)
    return momentum


def listen_hf_load():
    """
    get people going through hf every second
    """

    with open("hfreqs", "r") as f:
        usage = f.read().strip("\n").strip(" ")
        #print(usage)
        f.close()

    return usage



def listen_q_load():
    global msg_ready_t1
    """
    get people waiting in the queue because hf is doomed or something.
    """
    test = True
    r = requests.get("http://104.248.82.249:15672/api/queues", auth=HTTPBasicAuth("kurumuz", "IX0zuEY6mLqsqDN0xS90nI8cFDCrr47o"))
    jsonx = r.json()
    for node in jsonx:
        if node["name"] == "generation_jobs":
            jsonx = node
            test = False

    if test:
        return 0
    
    #should look at messages_ready.
    return jsonx["messages_ready"]


# TODO: Have one cheapest function for all providers or seperate them? DOOMP EET
def find_cheapest(maclist, count):

    allowed_ones = []
    cheapest_ones = []
    allowed_size = range(1, 20)

    for mac in maclist:
        if mac["gpu_name"] in gpu_tflops and mac["num_gpus"] in allowed_size:
            allowed_ones.append(mac)

    for x in range(0, count):
        cheapest_mac = {"dph_total": 2, "gpu_name": "fuckitab", "num_gpus": 1}
        y = 0
        for mac in allowed_ones:
            tflops = (gpu_tflops[mac["gpu_name"]] * mac["num_gpus"]) / mac["dph_total"]
            if (
                tflops
                > (gpu_tflops[cheapest_mac["gpu_name"]] * cheapest_mac["num_gpus"])
                / cheapest_mac["dph_total"]
            ):
                cheapest_mac = mac

            y = y + 1

        if allowed_ones:
            allowed_ones.remove(cheapest_mac)

        cheapest_ones.append(cheapest_mac)

        if len(allowed_ones) == 1:
            break

    return cheapest_ones


def create_price_table():
    test = launch_given_tflops(2000)
    total_tflops = 0
    for instance in test:
        total_tflops += int(instance.tflops)
        print(
            f'{instance.instance["gpu_name"]}x{instance.instance["num_gpus"]} TFLOPS: {int(instance.tflops)} | TFLOP/$: {int(instance.tflops/instance.instance["dph_total"])} | {instance.provider}'
        )
    print("TOTAL TFLOPS: " + str(total_tflops))


def listen_to_nodes():
    global buffer_usage
    credentials = pika.credentials.PlainCredentials("kurumuz", "IX0zuEY6mLqsqDN0xS90nI8cFDCrr47o")
    connection = pika.BlockingConnection(pika.ConnectionParameters('104.248.82.249', credentials=credentials))
    channel = connection.channel()

    channel.queue_declare(queue='buffergpu')

    def callback(ch, method, properties, body):
        global buffer_usage
        req_dict = json.loads(body.decode("utf-8"))
        idx = req_dict["id"]
        action = req_dict["action"]

        if action == "need_update" and updating_now is False:
            logger.info(f"{idx} said there is a new update for the nodes, updating all the nodes slowly.")
            thread3 = threading.Thread(target=update_all)
            
            thread3.start()
            # TODO: needs to happen on another thread or async, implement stats for now.
            #update_all(channel)
            # Get into the updating node, updating nodes every minute by sending messages until every one of them is done.
        
        elif action == "info":
            f = open("usage_log", "a")
            logger.info(f"{idx} is alive and answered {str(req_dict['load_rate_minute'])} requests, total processing time: {str(req_dict['total_processing_time'])}")
            date = datetime.now().strftime("%m-%d-%Y %H:%M:%S|")
            stringx = f"\n{date}id:{idx},load_rate_minute:{str(req_dict['load_rate_minute'])},total_processing_time:{str(req_dict['total_processing_time'])}"
            f.write(stringx)
            f.close()
        elif action == "started":
            logger.info(f"{idx} started running")
            get_started.append(idx)
        elif action == "died":
            logger.info(f"{idx} died")
            get_died.append(idx)
        elif action == "sent_first_message":
            logger.info(f"{idx} sent it's first message.")
            get_first_message.append(idx)
        elif action == "buffer_used":
            buffer_usage += 1

    channel.basic_consume(queue='buffergpu', on_message_callback=callback, auto_ack=True)
    channel.start_consuming()


def update_all():
    global updating_now
    updating_now = True
    credentials = pika.credentials.PlainCredentials("kurumuz", "IX0zuEY6mLqsqDN0xS90nI8cFDCrr47o")
    connection = pika.BlockingConnection(pika.ConnectionParameters('104.248.82.249', credentials=credentials))
    channel = connection.channel()
    #send to the queue with the id
    channel.queue_declare(queue="scaler_to_nodes")
    num_instances = len(launched_instances) + 2
    #send update messages every 90 seconds.
    for x in range(0, num_instances):
        logger.info("sent update to 1 node.")
        channel.basic_publish(exchange="", routing_key="scaler_to_nodes", body=json.dumps({"id": "any", "command": "update"}).encode("utf-8"))
        time.sleep(10)

    updating_now = False


#def run_on_node(command):


if __name__ == "__main__":
    th = threading.Thread(target=listen_to_nodes)
    th.start()
    main()