from typing import Union, Sequence, Generator, List
import gc
import sys
import pathlib

import lm_node
from lm_node.config import init_config
from lm_node.models.GPT import SoftEmbedding
from lm_node.utils.lock import AtomicLockFileWithTimeout

thisPath = str(pathlib.Path(__file__).parent.resolve())
sys.path.append(thisPath + "/transformerfork/src")

from transformers import (
    AutoTokenizer,
    GPTNeoForCausalLM,
)

import signal

import os
import time
import logging
import json
import torch
import socket
from multiprocessing import Process, Queue
import multiprocessing as mp
import numpy as np
import boto3
import multitokenizer
import base64
import prefix
import asyncio
import websockets
import zlib
from nacl import pwhash, secret
import random
from transformers.optimization import Adafactor
import torch.nn as nn
from dotmap import DotMap

import sentry_sdk
from sentry_sdk import capture_exception, capture_message

from sentry_sdk.integrations.threading import ThreadingIntegration
from unitrim import trim, send_ready

global running
running = DotMap({'value': True})



def addStopRunningHook(sig):
    def handleSignal(sig_param, frame):
        logger.warning(f"Received signal {sig_param}!")
        running.value = False
    signal.signal(sig, handleSignal)

addStopRunningHook(signal.SIGTERM)
addStopRunningHook(signal.SIGINT)

hidden_dim = 4096
password = b'novelai_16YQDi0u8DDQLDCvTZJPYuVTcJNLP7MG'
salt = b'__novelai_salt__'
kdf = pwhash.argon2i.kdf

# no need for super secure hashing in this case
ops = pwhash.argon2i.OPSLIMIT_MIN
mem = pwhash.argon2i.MEMLIMIT_MIN

key = kdf(secret.SecretBox.KEY_SIZE, password, salt, opslimit=ops, memlimit=mem)
box = secret.SecretBox(key)

token_queue = Queue()

model_path = os.environ['MODEL_PATH']
batch_size = 1

tokens_per_sample = int(os.getenv('CHUNK_SIZE', '256'))
overlap_samples = os.getenv('OVERLAP_CHUNKS', 'True').lower() == 'true'

is_dev = ""
environment = "production"
if os.environ['DEV'] == "True":
    environment = "staging"
    is_dev = "_dev"

model_name = os.environ['MODEL']
# 2.7B, 6B, 6B-v3

modules_path = os.environ['MODULES']

version = "0.0.0.2smhs"
queue_name = "training_jobs_" + model_name + is_dev

# init sentry
sentry_sdk.init(
    "https://58fd055d1f1449328b5131a7dfddf4e4@o846434.ingest.sentry.io/5987512",
    server_name=socket.gethostname(),
    traces_sample_rate=0.1,
    environment=environment,
    integrations=[ThreadingIntegration(propagate_hub=True)],
)

# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
    "%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)

logger.info("version: " + version)
logger.info("Node started")
logger.info("queue: " + queue_name)

q = Queue()

aws_access_key_id = os.environ['S3_ACCESS_KEY']
aws_secret_access_key = os.environ['S3_SECRET_KEY']
s3 = boto3.client('s3',
                  endpoint_url='https://s3.us-east-2.wasabisys.com',
                  aws_access_key_id=aws_access_key_id,
                  aws_secret_access_key=aws_secret_access_key,
                  region_name="us-east-2")

total_sent = 0
processing_time = 0

def node_ok():
    return


async def streaming_ws():
    global running
    # uri = "ws://localhost:8765"
    if os.environ['DEV'] == "True":
        uri = "wss://staging.novelai.net/ai/internal/node-pipe"
    else:
        uri = "wss://api.novelai.net/ai/internal/node-pipe"

    global token_queue
    print("Started streaming thread")
    while running.value:
        # print("test")
        try:
            async with websockets.connect(uri) as websocket:
                message = json.dumps(token_queue.get())
                await websocket.send(message)

        except Exception as e:
            logger.error(e)
            capture_exception(e)
            pass
    logger.warning("Websocket streamer exiting!")


import subprocess as sp
import os


def get_gpu_memory():
    _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]

    ACCEPTABLE_AVAILABLE_MEMORY = 1024
    COMMAND = "nvidia-smi --query-gpu=memory.total --format=csv"
    memory_free_info = _output_to_list(sp.check_output(COMMAND.split()))[1:]
    memory_free_values = [int(x.split()[0]) for i, x in
                          enumerate(memory_free_info)]
    print(memory_free_values)
    return memory_free_values


def getTimestamp(d):
    if environment == "staging":
        timestamp_loc = 3
    else:
        timestamp_loc = 2
    return d['Key'].split('_')[timestamp_loc],


def perform_work(model: lm_node.models.GPT,
                 data_key, bucket_name,
                 overlap_chunks=True,
                 tokens_per_chunk=256):
    logger.info("Got 1 request")
    global q
    global token_queue
    global sent_first_message
    global total_sent
    global processing_time
    global args

    logger.info(f"GOT KEY: {data_key}")
    split_list = data_key.split('_')

    if split_list[0] == "dev":
        session_id = split_list[2]
    else:
        session_id = split_list[1]

    try:
        data = s3.get_object(Bucket=bucket_name, Key=data_key)[
            "Body"].read().decode("utf-8")

    except Exception as e:
        # put the lock file and update this consistenly to see when the node
        # was last doing processing to check if node has crashed while
        # processing a request.
        token_queue.put({"event": "training_update",
                         "data": {"id": session_id, "status": "error",
                                  "data": json.dumps({
                                      "message": "unable to fetch training " +
                                                 "data from shared storage"})}})
        logger.error(e)
        capture_exception(e)
        return

    # add _inprogress to the filename to signal it's being trained by a node.
    # this file will be updated every 30 seconds or every step to update the
    # LastModified key as well.

    data = json.loads(data)
    steps = data["steps"]
    # train text is zlib + base64
    train_text = base64.b64decode(data["data"])
    # do deflate here
    try:
        train_text = zlib.decompress(train_text, -15).decode().replace(
            "\r\n", "\n").replace("\r", "\n")
    except zlib.error as e:
        token_queue.put({"event": "training_update",
                         "data": {"id": session_id, "status": "error",
                                  "data": json.dumps(
                                      {"message": "invalid training data " +
                                                  "archive"})}})
        logger.error(e)
        capture_exception(e)
        s3.delete_object(Bucket=bucket_name,
                         Key=data_key)
        return

    s = time.perf_counter()
    n_chunk = int(os.cpu_count())
    default_config = {
        "chunk_size": tokens_per_chunk,
        "overlap_chunks": overlap_chunks,
        "overlap_boundary": 198,
    }

    if train_text.startswith("CONFIG:"):
        config_line = train_text.splitlines()[0]

    if overlap_chunks:
        delimiter = model.tokenizer.tokenize("\n")[0]
    else:
        delimiter = -1

    tokens_array = list(model.split_chunks([train_text],
                                           size=tokens_per_chunk,
                                           delimiter=delimiter))
    logger.info(f"Got {len(tokens_array)} chunks")
    tokens_out = np.hstack(tokens_array).astype(np.uint16)

    tokenize_time = time.perf_counter() - s
    logger.info(f"Tokenized in {tokenize_time:0.2f}s, starting training...")
    logger.info(f"Training request with id {str(id)}, steps:{str(steps)}, " +
                f"{tokens_out.shape[0]} tokens," +
                f" tokens_per_sample={tokens_per_chunk}," +
                f" overlap={overlap_chunks}")

    for msg in model.train(tokens=tokens_out,
                           model_version=5,
                           session_id=session_id,
                           steps=int(steps // batch_size),
                           bs=batch_size,
                           tokens_per_chunk=tokens_per_chunk):
        if msg["event"] == "training_update":
            token_queue.put(msg)
        elif msg["ok"]:
            encrypted, prefix_id = prefix.self_encrypt_prefix(
                msg["encoded_embedding"])
            f = open(modules_path + "/" + prefix_id, "wb")
            f.write(encrypted)
            f.close()
            token_queue.put({"event": "training_update",
                             "data": {"id": session_id, "status": "ready",
                                      "data": json.dumps(
                                          {"encoded_emb":
                                               msg["encoded_embedding"],
                                           "loss": msg["loss"]})}})
            s3.delete_object(Bucket=bucket_name,
                             Key=data_key)
            logger.info("Training done!")
        elif msg["reason"] == "shutdown":
            logger.warning("Shutting down mid-job!")
        else:
            reason = msg["reason"]
            fail_str = f"failed by result ok false, sending fail event {reason}"
            capture_message(fail_str)
            logger.error(fail_str)
            token_queue.put({"event": "training_update",
                             "data":
                                 {"id": session_id, "status": "error",
                                  "data": json.dumps({
                                      "message": "unknown error while training"
                                  })}})
            s3.delete_object(Bucket=bucket_name,
                             Key=data_key)


def workSource(lockstore="/locks", bucket="novelailm-modtraindata"):
    global running
    while running.value:
        try:
            object_list = s3.list_objects(Bucket=bucket)["Contents"]
        except Exception as e:
            continue

        sanitized = []
        for x in range(len(object_list)):
            key = object_list[x]["Key"]
            is_dev = environment == "staging" and key.startswith('dev_')
            is_production = environment == "production" and \
                            not key.startswith('dev_')
            is_model = model_name in key
            is_not_lock = ".lock" not in key
            is_not_training = "training_data" not in key

            if (is_dev or is_production) \
                    and is_model \
                    and is_not_lock \
                    and is_not_training:
                sanitized.append(object_list[x])

        training_list = sorted(sanitized,
                               key=getTimestamp,
                               reverse=False)

        for x in training_list:
            key = x["Key"]
            lockfile_path = os.path.join(lockstore, key + ".lock")

            with AtomicLockFileWithTimeout(lockfile_path, 3, 'w') as lockfile:
                if lockfile is None:
                    time.sleep(1)
                    continue
                else:
                    yield key

    logger.warning("workSource exiting!")


def main():
    model, config = init_config()

    global batch_size
    total_gpu_mem = get_gpu_memory()
    logger.info(total_gpu_mem)
    if len(total_gpu_mem) == 1:
        if total_gpu_mem[0] > 38000 and "6B" in model_name:
            batch_size = 10
        else:
            batch_size = 1
    else:
        batch_size = 1  # error fallback

    logger.info(f"batch_size set to {str(batch_size)}")
    logger.info(f"tokens_per_chunk set to {str(tokens_per_sample)}")
    logger.info(f"overlap_chunks set to {str(overlap_samples)}")

    for work in workSource(bucket="novelailm-modtraindata"):
        perform_work(model,
                     work,
                     bucket_name="novelailm-modtraindata",
                     overlap_chunks=overlap_samples,
                     tokens_per_chunk=tokens_per_sample)


if __name__ == "__main__":
    mp.set_start_method('fork', force=True)
    thread2 = Process(
        target=lambda: asyncio.get_event_loop().run_until_complete(
            streaming_ws()))
    thread2.start()
    main()
