import sys
import pathlib

thisPath = str(pathlib.Path(__file__).parent.resolve())
sys.path.append(thisPath + "/transformerfork/src")

from transformers import (
    AutoTokenizer,
    GPTNeoForCausalLM,
)

import errno
import signal
from contextlib import contextmanager

import time
import logging
import json
import torch
import socket
from multiprocessing import Process, Queue
import multiprocessing as mp
import numpy as np
import boto3
import multitokenizer
import base64
import prefix
import asyncio
import websockets
import zlib
from nacl import pwhash, secret
import random
from transformers.optimization import Adafactor
import torch.nn as nn
from dotmap import DotMap

import sentry_sdk
from sentry_sdk import capture_exception, capture_message

from sentry_sdk.integrations.threading import ThreadingIntegration

global running
running = DotMap({'value': True})


def addStopRunningHook(sig):
    def handleSignal(sig_param, frame):
        logger.warning(f"Received signal {sig_param}!")
        running.value = False
    signal.signal(sig, handleSignal)

addStopRunningHook(signal.SIGTERM)
addStopRunningHook(signal.SIGINT)

# Timeout on system calls like `fnctl.lock` by using alarm signals.
try:
    # Posix based file locking (Linux, Ubuntu, MacOS, etc.)
    #   Only allows locking on writable files, might cause
    #   strange results for reading.
    import fcntl, os

    def lock_file(f):
        if f.writable(): fcntl.lockf(f, fcntl.LOCK_EX)

    def unlock_file(f):
        if f.writable(): fcntl.lockf(f, fcntl.LOCK_UN)
except ModuleNotFoundError:
    # Windows file locking
    import msvcrt, os

    def file_size(f):
        return os.path.getsize(os.path.realpath(f.name))

    def lock_file(f):
        msvcrt.locking(f.fileno(), msvcrt.LK_RLCK, file_size(f))

    def unlock_file(f):
        msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, file_size(f))


@contextmanager
def timeout(seconds):
    def timeout_handler(signum, frame):
        raise InterruptedError

    original_handler = signal.signal(signal.SIGALRM, timeout_handler)

    try:
        signal.alarm(seconds)
        yield
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, original_handler)


# Class for ensuring that all file operations are atomic, treat
# initialization like a standard call to 'open' that happens to be atomic.
# This file opener *must* be used in a "with" block.
class AtomicLockFileWithTimeout:
    # Open the file with arguments provided by user. Then acquire
    # a lock on that file object (WARNING: Advisory locking).
    def __init__(self, path, wait=3, *args, **kwargs):
        with timeout(wait):
            # Open the file and acquire a lock on the file before operating
            self.path = path
            self.file = open(path, *args, **kwargs)
            print(f"Trying to lock {self.path}")
            # Lock the opened file
            try:
                lock_file(self.file)
                print(f"Got lock on {self.path}")
            except InterruptedError:
                print(f"{path} lock attempt timed out")
                self.file = None

    # Return the opened file object (knowing a lock has been obtained).
    # This should be checked if `None`, as `None` means that a lock has NOT
    # been obtained.
    def __enter__(self, *args, **kwargs):
        return self.file

    # Unlock the file and close the file object.
    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
        if self.file is not None:
            # Flush to make sure all buffered contents are written to file.
            self.file.flush()
            os.fsync(self.file.fileno())
            # Release the lock on the file.
            try:
                os.remove(self.path)
            except FileNotFoundError:
                pass
            unlock_file(self.file)
            self.file.close()
        # Handle exceptions that may have come up during execution, by
        # default any exceptions are raised to the user.
        return exc_type is None


hidden_dim = 4096
password = b'novelai_16YQDi0u8DDQLDCvTZJPYuVTcJNLP7MG'
salt = b'__novelai_salt__'
kdf = pwhash.argon2i.kdf
print("test")

# no need for super secure hashing in this case
ops = pwhash.argon2i.OPSLIMIT_MIN
mem = pwhash.argon2i.MEMLIMIT_MIN

key = kdf(secret.SecretBox.KEY_SIZE, password, salt, opslimit=ops, memlimit=mem)
box = secret.SecretBox(key)

token_queue = Queue()

model_path = os.environ['MODEL_PATH']
batch_size = 1

is_dev = ""
environment = "production"
if os.environ['DEV'] == "True":
    environment = "staging"
    is_dev = "_dev"

model_name = os.environ['MODEL']
# 2.7B, 6B, 6B-v3

modules_path = os.environ['MODULES']

version = "0.0.0.2smhs"
queue_name = "training_jobs_" + model_name + is_dev

# init sentry
sentry_sdk.init(
    "https://58fd055d1f1449328b5131a7dfddf4e4@o846434.ingest.sentry.io/5987512",
    server_name=socket.gethostname(),
    traces_sample_rate=0.1,
    environment=environment,
    integrations=[ThreadingIntegration(propagate_hub=True)],
)

# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
    "%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)

logger.info("version: " + version)
logger.info("Node started")
logger.info("queue: " + queue_name)

q = Queue()
tokenizer = AutoTokenizer.from_pretrained("tokenizer/gpt2",
                                          local_files_only=True)
# MODELS
aws_access_key_id = os.environ['S3_ACCESS_KEY']
aws_secret_access_key = os.environ['S3_SECRET_KEY']
s3 = boto3.client('s3',
                  endpoint_url='https://s3.us-east-2.wasabisys.com',
                  aws_access_key_id=aws_access_key_id,
                  aws_secret_access_key=aws_secret_access_key,
                  region_name="us-east-2")

load_time = time.time()


def no_init(loading_code):
    def dummy(self):
        return

    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy

    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]

    return result


model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(model_path))

logger.info("Models loaded in " + str(time.time() - load_time) + "seconds")
# initialize wasabi s3 client

total_sent = 0

processing_time = 0


def get_init_tokens(tokens, n_tokens):
    hist = np.bincount(tokens)
    hist_s = np.argsort(hist)[::-1]

    import json
    with open("wordtokens.json", "r") as fh:
        word_tokens = json.load(fh)

    got = 0
    i = 0
    relevant = []
    while got < n_tokens and i < 50256:
        t = hist_s[i]
        if word_tokens[t] > 0:
            # print(t, tokenizer.decode([t])) # output these to the user somehow?
            got += 1
            relevant.append(t)
        i += 1

    if got < n_tokens:
        return hist_s[0:n_tokens][::-1]

    return relevant[::-1]


def node_ok():
    return


def train(tokens, model_version, session_id, steps=3000, bs=10,
          tokens_per_sample=256, prefix_len=20):
    global running
    if not (isinstance(tokens, np.ndarray) and tokens.dtype in [
        np.uint16] and len(tokens.shape) == 1 and tokens.shape[
                0] < 2048 * 1024 * 50 and tokens.shape[0] > 1):
        return {"ok": False, "reason": "invalid input"}
    global token_queue
    init_tokens = get_init_tokens(tokens, prefix_len)
    if os.environ['INIT_NEWLINE'] == "1":
        init_tokens[0] = 198

    # take care of small token chunks
    tiny_size = tokens.shape[0] % tokens_per_sample
    if tiny_size > 0:
        batches = [
            torch.tensor(tokens[-tiny_size:].astype(np.int32)).long().unsqueeze(
                0)]
        tokens = tokens[:-tiny_size]
    else:
        batches = []

    def appendBatch(tokens, k):
        batch.append(
            torch.tensor(tokens[k * tokens_per_sample:
                                (k + 1) * tokens_per_sample].astype(
                np.int32)).long())

    # shuffle samples
    n_samples = tokens.shape[0] // tokens_per_sample
    if n_samples > 0:
        indexes = list(range(n_samples))
        random.shuffle(indexes)

        n_big_batches = n_samples // bs
        n_partial_batch = n_samples % bs

        for i in range(n_big_batches):
            batch = []
            for j in range(bs):
                k = indexes.pop()
                appendBatch(tokens, k)

            batches.append(torch.stack(batch))

        batch = []
        for i in range(n_partial_batch):
            k = indexes.pop()
            appendBatch(tokens, k)
        if n_partial_batch > 0:
            batches.append(torch.stack(batch))
    del tokens

    # prep model
    model.train()
    for param in model.parameters():
        param.requires_grad = False

    old_wte = model.transformer.wte
    s_wte = SoftEmbedding(old_wte, n_tokens=prefix_len,
                          initialize_from_vocab=True,
                          init_tokens=init_tokens).to("cuda")
    model.transformer.wte = s_wte
    params = [model.transformer.wte.learned_embedding]
    optimizer = Adafactor(params=params)

    indexes = []
    n_batches = len(batches)
    avg_loss = 0.0
    total_iter_time = 0.0
    for i in range(steps):
        if not running.value:
            return {"ok": False,
                    "reason": "shutdown"}

        curr_time = time.perf_counter()
        # shuffle batches
        if len(indexes) < 1:
            indexes = list(range(n_batches))
            random.shuffle(indexes)
        j = indexes.pop()
        batch = batches[j]
        # train
        optimizer.zero_grad()

        inputs = {}
        inputs['input_ids'] = torch.cat(
            [torch.full((batch.shape[0], prefix_len), 50256), batch], 1).cuda()
        inputs['attention_mask'] = torch.full(
            (batch.shape[0], prefix_len + batch.shape[1]), 1).cuda()
        labels = torch.cat(
            [torch.full((batch.shape[0], prefix_len), -100), batch], 1).cuda()

        output = model(**inputs, labels=labels)
        del labels
        del inputs['input_ids']
        del inputs['attention_mask']

        loss = output.loss
        loss.backward()
        optimizer.step()
        avg_loss += loss.detach().cpu().item()
        total_iter_time += time.perf_counter() - curr_time
        if i % 10 == 0 or i == steps - 1:
            # print(f"{i}: Loss: {loss} Avg: {avg_loss / float(i+1)}")
            token_queue.put({"event": "training_update",
                             "data": {"id": session_id, "status": "training",
                                      "data": json.dumps(
                                          {"step": i,
                                           "loss": avg_loss / float(i + 1),
                                           "percentage": (i / steps) * 100})}})

        if i % 50 == 0 and i != 0:
            logger.info("step: " + str(i) + ", step_iter: " + str(
                1 / (float(total_iter_time) / i)) + ", contexts/s: " + str(
                (1 / (float(total_iter_time) / i)) * bs))

    encoded_embedding = prefix.encode_prefix(
        {"embs": model.transformer.wte.learned_embedding.data,
         "model_version": model_version})

    for i in range(n_batches):
        del batches[n_batches - i - 1]
    del batches

    # unprep model
    model.transformer.wte = old_wte
    s_wte.wte = None
    del s_wte.learned_embedding
    del s_wte
    del optimizer
    model.eval()
    return {"ok": True, "encoded_embedding": encoded_embedding,
            "step": i,
            "loss": avg_loss / float(steps)}


class SoftEmbedding(nn.Module):
    def __init__(self,
                 wte: nn.Embedding,
                 n_tokens: int = 10,
                 random_range: float = 0.5,
                 initialize_from_vocab: bool = True, init_tokens=None):
        """appends learned embedding to 
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding
            (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default
            vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(
            self.initialize_embedding(wte,
                                      n_tokens,
                                      random_range,
                                      initialize_from_vocab, init_tokens))

    def initialize_embedding(self,
                             wte: nn.Embedding,
                             n_tokens: int = 10,
                             random_range: float = 0.5,
                             initialize_from_vocab: bool = True,
                             init_tokens=None):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab and init_tokens is not None:
            return self.wte.weight[init_tokens].clone().detach()
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(wte.weight.size(1), n_tokens).uniform_(
            -random_range, random_range)

    def forward(self, tokens):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(
            input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)


async def streaming_ws():
    global running
    # uri = "ws://localhost:8765"
    if os.environ['DEV'] == "True":
        uri = "wss://staging.novelai.net/ai/internal/node-pipe"
    else:
        uri = "wss://api.novelai.net/ai/internal/node-pipe"

    global token_queue
    print("Started streaming thread")
    while running.value:
        # print("test")
        try:
            async with websockets.connect(uri) as websocket:
                message = json.dumps(token_queue.get())
                await websocket.send(message)

        except Exception as e:
            logger.error(e)
            capture_exception(e)
            pass
    logger.warning("Websocket streamer exiting!")


import subprocess as sp
import os


def get_gpu_memory():
    _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]

    ACCEPTABLE_AVAILABLE_MEMORY = 1024
    COMMAND = "nvidia-smi --query-gpu=memory.total --format=csv"
    memory_free_info = _output_to_list(sp.check_output(COMMAND.split()))[1:]
    memory_free_values = [int(x.split()[0]) for i, x in
                          enumerate(memory_free_info)]
    print(memory_free_values)
    return memory_free_values


def getTimestamp(d):
    if environment == "staging":
        timestamp_loc = 3
    else:
        timestamp_loc = 2
    return d['Key'].split('_')[timestamp_loc],


def perform_work(data_key, bucket_name):
    logger.info("Got 1 request")
    global q
    global token_queue
    global sent_first_message
    global total_sent
    global processing_time
    global args

    logger.info(f"GOT KEY: {data_key}")
    split_list = data_key.split('_')

    if split_list[0] == "dev":
        session_id = split_list[2]
    else:
        session_id = split_list[1]

    try:
        data = s3.get_object(Bucket=bucket_name, Key=data_key)[
            "Body"].read().decode("utf-8")

    except Exception as e:
        # put the lock file and update this consistenly to see when the node
        # was last doing processing to check if node has crashed while
        # processing a request.
        token_queue.put({"event": "training_update",
                         "data": {"id": session_id, "status": "error",
                                  "data": json.dumps({
                                      "message": "unable to fetch training data from shared storage"})}})
        logger.error(e)
        capture_exception(e)
        return

    # add _inprogress to the filename to signal it's being trained by a node.
    # this file will be updated every 30 seconds or every step to update the
    # LastModified key as well.

    data = json.loads(data)
    steps = data["steps"]
    # train text is zlib + base64
    train_text = base64.b64decode(data["data"])
    # do deflate here
    try:
        train_text = zlib.decompress(train_text, -15).decode().replace(
            "\r\n", "\n").replace("\r", "\n")
    except zlib.error as e:
        token_queue.put({"event": "training_update",
                         "data": {"id": session_id, "status": "error",
                                  "data": json.dumps(
                                      {"message": "invalid training data archive"})}})
        logger.error(e)
        capture_exception(e)
        s3.delete_object(Bucket=bucket_name,
                         Key=data_key)
        return

    n_chunk = int(os.cpu_count())
    chunks = multitokenizer.chunkit(train_text, n_chunk)
    iterx = range(n_chunk)
    s = time.perf_counter()
    tokens_array = multitokenizer.run(multitokenizer.tokenizePart, iterx,
                                      chunks)
    tokens_out = np.hstack(tokens_array).astype(np.uint16)[0]
    logger.info("Tokenized, starting training...")
    logger.info(f"Training request with id {id}, steps:{str(steps)}")

    result = train(tokens=tokens_out, model_version=5, session_id=session_id,
                   steps=int(steps // batch_size), bs=batch_size)

    if result["ok"]:
        encrypted, prefix_id = prefix.self_encrypt_prefix(
            result["encoded_embedding"])
        f = open(modules_path + "/" + prefix_id, "wb")
        f.write(encrypted)
        f.close()
        token_queue.put({"event": "training_update",
                         "data": {"id": session_id, "status": "ready",
                                  "data": json.dumps({"encoded_emb": result[
                                      "encoded_embedding"], "loss": result[
                                      "loss"]})}})
        s3.delete_object(Bucket=bucket_name,
                         Key=data_key)
        logger.info("Training done!")
    elif result["reason"] == "shutdown":
        logger.warning("Shutting down mid-job!")
    else:
        reason = result["reason"]
        fail_str = f"failed by result ok false, sending fail event {reason}"
        capture_message(fail_str)
        logger.error(fail_str)
        token_queue.put({"event": "training_update",
                         "data": {"id": session_id, "status": "error",
                                  "data": json.dumps({
                                      "message": "unknown error while training"
                                  })}})
        s3.delete_object(Bucket=bucket_name,
                         Key=data_key)


def workSource(lockstore="/locks", bucket="novelailm-modtraindata"):
    global running
    while running.value:
        try:
            object_list = s3.list_objects(Bucket=bucket)["Contents"]
        except Exception as e:
            continue

        sanitized = []
        for x in range(len(object_list)):
            key = object_list[x]["Key"]
            is_dev = environment == "staging" and key.startswith('dev_')
            is_production = environment == "production" and \
                            not key.startswith('dev_')
            is_model = model_name in key
            is_not_lock = ".lock" not in key
            is_not_training = "training_data" not in key

            if (is_dev or is_production) \
                    and is_model \
                    and is_not_lock \
                    and is_not_training:
                sanitized.append(object_list[x])

        training_list = sorted(sanitized,
                               key=getTimestamp,
                               reverse=False)

        for x in training_list:
            key = x["Key"]
            lockfile_path = os.path.join(lockstore, key + ".lock")

            with AtomicLockFileWithTimeout(lockfile_path, 3, 'w') as lockfile:
                if lockfile is None:
                    time.sleep(1)
                    continue
                else:
                    yield key

    logger.warning("workSource exiting!")


def main():
    global batch_size
    total_gpu_mem = get_gpu_memory()
    logger.info(total_gpu_mem)
    if len(total_gpu_mem) == 1:
        if total_gpu_mem[0] > 38000 and "6B" in model_name:
            batch_size = 10
        else:
            batch_size = 1
    else:
        batch_size = 1  # error fallback

    logger.info(f"batch_size set to {str(batch_size)}")

    for work in workSource(bucket="novelailm-modtraindata"):
        perform_work(work, bucket_name="novelailm-modtraindata")


if __name__ == "__main__":
    mp.set_start_method('fork', force=True)
    thread2 = Process(
        target=lambda: asyncio.get_event_loop().run_until_complete(
            streaming_ws()))
    thread2.start()
    main()
