import base64
import gc
import json
import os
import random
import struct
import time
from threading import Lock
from typing import Union, Sequence, Generator, List

import numpy as np
import pynvml
import functools
import logging
import torch
from torch import autocast
from torch.cuda.amp import GradScaler
from torch.cuda.amp.grad_scaler import OptState

import transformers

from lm_node import models, unitrim, prefix
from lm_node.models.GPT import SoftEmbedding
from lm_node.utils.chunk import split_chunks
from transformers import AutoTokenizer, Adafactor
from tokenizers import Tokenizer
import asyncio

curr_path = os.path.dirname(__file__)


class Model:
    def __init__(self, config):
        self.logger = config.get('logger', logging.getLogger(__name__))
        self.config = config
        self.model: Union[
            transformers.modeling_utils.PreTrainedModel, None] = \
            self.config.get("model", None)
        self.clip = None
        self.declipper = None
        self.tokenizer: Union[Tokenizer, None] = \
            self.config.get("tokenizer", None)
        self.unitrim: Union[unitrim.Unitrimmer, None] = \
            self.config.get("unitrim", None)
        self.word_tokens: Union[List[int], None] = \
            self.config.get("word_tokens")
        self.modules = None
        self.mutex = Lock()
        self.running = True

    def generate(self, ids, req_params, use_string, generation_id=None):
        raise NotImplementedError

    def generate_image(self):
        raise NotImplementedError

    def get_hidden_states(self, ids):
        raise NotImplementedError

    def next_token_probabilities(self, ids):
        raise NotImplementedError


class GPTModel(Model):
    def __init__(self, config):
        Model.__init__(self, config)
        pynvml.nvmlInit()
        cudadev = torch.cuda.current_device()
        self.nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev)
        gb_gpu = int(torch.cuda.get_device_properties(0).total_memory /
                     (1000 * 1000))
        self.logger.info("GPU: " + torch.cuda.get_device_name(cudadev))
        self.logger.info("GPU RAM: " + '{:,}'.format(gb_gpu) + "mb")
        if self.model is None:
            self.kernel_memory = pynvml.nvmlDeviceGetMemoryInfo(
                self.nvml_device).used
            self.getGPUram()
            if not hasattr(models, "GPT"):
                raise RuntimeError("ModelNotAvailable")
            self.logger.info(self.getGPUram())
            self.logger.info("Loading model '" + self.config.model_path + "'")
            self.model: transformers.modeling_utils.PreTrainedModel = \
                models.GPT._init_model(self.config)
            self.logger.info("Model done loading: {0}".format(self.getGPUram()))
        else:
            self.logger.info("Using provided model.")

        model_config = self.model.config.to_dict()
        self.config.n_layer = model_config.get('n_layer',
                                               model_config.get('num_layers'))
        self.config.hidden_dim = model_config['hidden_size']

        if self.tokenizer is None:
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    config.model_path, cache_dir="./cache")
                self.logger.info(
                    f"Using tokenizer data from {config.model_path}.")
            except:
                self.logger.info("Falling back to default `gpt2` tokenizer.")
                self.tokenizer = AutoTokenizer.from_pretrained(
                    "gpt2", cache_dir="./cache")
        else:
            self.logger.info("Using provided tokenizer.")

        if self.unitrim is None:
            unitrimModelPath = os.path.join(self.config.model_path,
                                            "unitrim.json")
            if os.path.exists(unitrimModelPath):
                self.logger.info(
                    f"Loading unitrim data from {unitrimModelPath}")
                self.unitrim = unitrim.Unitrimmer(unitrimModelPath)
            else:
                self.logger.info("Loading unitrim data from default (gpt2)")
                self.unitrim = unitrim.Unitrimmer()
        else:
            self.logger.info("Using provided unitrim.")

        if self.word_tokens is None:
            word_tokens_path = os.path.join(self.config.model_path,
                                            "wordtokens.json")
            if os.path.exists(word_tokens_path):
                self.logger.info(
                    f"Loading unitrim data from {word_tokens_path}")
                self.word_tokens = json.load(open(word_tokens_path, 'r'))
            else:
                default_word_path = os.path.join(curr_path, "wordtokens.json")
                self.logger.info("Loading word_tokens data from default (gpt2)")
                self.word_tokens = json.load(open(default_word_path))
        else:
            self.logger.info("Using provided wordtokens.")

        self.eot_token = self.tokenizer.encode("<|endoftext|>")[0]

        self.modules = models.GPT._init_modules(self.model.device,
                                                self.model.dtype,
                                                self.config)
        gc.collect()
        torch.cuda.synchronize()
        self.logger.info("Modules done loading: {0}".format(self.getGPUram()))
        self.hypernets = models.GPT._init_hypernet(self.model.device,
                                                   self.model.dtype,
                                                   self.config)
        gc.collect()
        torch.cuda.synchronize()
        self.logger.info("Hypernets done loading: {0}".format(self.getGPUram()))

        if isinstance(self.model, transformers.GPTNeoXForCausalLM):
            self.tokenizer_id = "pile"
        else:
            self.tokenizer_id = "gpt2"


    def getGPUram(self):
        gpu_info = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_device)
        gpu_total = int(gpu_info.total / 1E6)
        gpu_free = int(gpu_info.free / 1E6)
        gpu_used = int(gpu_info.used / 1E6)

        reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1E6)
        reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1E6)
        used_gpu = int(torch.cuda.memory_allocated() / 1E6)
        max_used_gpu = int(torch.cuda.max_memory_allocated() / 1E6)
        return "gpu: (U: {:,}mb F: {:,}mb T: {:,}mb) ".format(
            gpu_used, gpu_free, gpu_total) + \
               "torch: (R: {:,}mb/{:,}mb, A: {:,}mb/{:,}mb)".format(
                   reserved_gpu, reserved_max, used_gpu, max_used_gpu)

    @functools.lru_cache(maxsize=1024)
    def packB64(self, token):
        packed_token = struct.pack("<H", token)
        return base64.b64encode(packed_token).decode("utf-8")

    def generate(self, ids, req_params, use_string, generation_id=None):
        if self.model is None:
            raise RuntimeError("ModelNotLoaded")

        if self.tokenizer is None:
            raise RuntimeError("TokenizerNotLoaded")

        gc.collect()
        torch.cuda.synchronize()
        #torch.cuda.empty_cache()

        if generation_id:
            with self.mutex:
                loop = asyncio.new_event_loop()
                asyncio.set_event_loop(loop)
                loop.run_until_complete(
                    models.GPT._generate_stream(self,
                                                self.model,
                                                self.tokenizer,
                                                ids,
                                                req_params,
                                                use_string,
                                                generation_id))
                return None
        else:
            with self.mutex:
                output = models.GPT._generate(self,
                                              self.model,
                                              self.tokenizer,
                                              ids,
                                              req_params,
                                              use_string)
                return output

    def get_hidden_states(self, ids):
        with self.mutex:
            return models.GPT._get_hidden_states(self.model, ids)

    def get_token_probabilities(self, ids, n=0, idx=0):
        with self.mutex:
            return models.GPT._token_probabilities(self.model, ids, n=n,
                                                   idx=idx)

    def next_token_probabilities(self, ids):
        with self.mutex:
            return models.GPT._next_token_probabilities(self.model,
                                                        self.tokenizer,
                                                        ids)

    def split_chunks(self, text_seq: Union[Sequence[str],
                                           Generator[str, any, None]],
                     size=2048, yield_tokens=True, boundary: int = 198):
        return split_chunks(text_seq, size, yield_tokens, boundary,
                            tokenizer=self.tokenizer,
                            unitrim=self.unitrim)

    def get_init_tokens(self, tokens, n_tokens):
        hist = np.bincount(tokens)
        hist_s = np.argsort(hist)[::-1]

        got = 0
        i = 0
        relevant = []
        while got < n_tokens and i < len(self.word_tokens):
            t = hist_s[i]
            if self.word_tokens[t] > 0:
                got += 1
                relevant.append(t)
            i += 1

        if got < n_tokens:
            return hist_s[0:n_tokens][::-1]

        init_tokens = relevant[::-1]

        if os.getenv('INIT_NEWLINE', None) == "1":
            init_tokens[0] = self.tokenizer.encode("\n")[0]

        return init_tokens

    def train(self, tokens, model_version, session_id, steps=3000, bs=10,
              tokens_per_chunk=256, prefix_len=20, report_steps=10,
              steps_arr=None, seed=None):
        if not (isinstance(tokens, np.ndarray) and tokens.dtype in [
            np.uint16] and len(tokens.shape) == 1 and tokens.shape[
                    0] < 2048 * 1024 * 50 and tokens.shape[0] > 1):
            return {"ok": False, "reason": "invalid input"}
        init_tokens = self.get_init_tokens(tokens, prefix_len)

        # take care of small token chunks
        tiny_size = tokens.shape[0] % tokens_per_chunk
        if tiny_size > 0:
            batches = [
                torch.tensor(
                    tokens[-tiny_size:].astype(np.int32)).long().unsqueeze(
                    0)]
            tokens = tokens[:-tiny_size]
        else:
            batches = []

        def appendBatch(tokens, k):
            batch.append(
                torch.tensor(tokens[k * tokens_per_chunk:
                                    (k + 1) * tokens_per_chunk].astype(
                    np.int32)).long())

        # shuffle samples
        if seed is not None:
            random.seed(seed)
        n_samples = tokens.shape[0] // tokens_per_chunk
        if n_samples > 0:
            indexes = list(range(n_samples))
            random.shuffle(indexes)

            n_big_batches = n_samples // bs
            n_partial_batch = n_samples % bs

            for i in range(n_big_batches):
                batch = []
                for j in range(bs):
                    k = indexes.pop()
                    appendBatch(tokens, k)

                batches.append(torch.stack(batch))

            batch = []
            for i in range(n_partial_batch):
                k = indexes.pop()
                appendBatch(tokens, k)
            if n_partial_batch > 0:
                batches.append(torch.stack(batch))
        del tokens
        gc.collect()

        # prep model
        self.model.train()
        for param in self.model.parameters():
            param.requires_grad = False

        old_wte = self.model.transformer.wte
        s_wte = SoftEmbedding(old_wte, n_tokens=prefix_len,
                              initialize_from_vocab=True,
                              init_tokens=init_tokens).to("cuda")
        self.model.transformer.wte = s_wte
        params = [self.model.transformer.wte.learned_embedding]
        optimizer = Adafactor(params=params)

        indexes = []
        finished = []
        n_batches = len(batches)
        avg_loss = 0.0
        total_iter_time = 0.0
        losses = []
        # Scale loss
        scaler = GradScaler()

        #
        # Monkey patched version of scaler/unscale_ to allow fp16. `optimizer`
        # and `scaler` are closured from the context.
        #
        def step(*args, **kwargs):
            if (not scaler._enabled):
                return optimizer.step(*args, **kwargs)

            if "closure" in kwargs:
                raise RuntimeError(
                    "Closure use is not currently supported if GradScaler is"
                    "enabled.")

            scaler._check_scale_growth_tracker("step")

            optimizer_state = scaler._per_optimizer_states[id(optimizer)]

            if optimizer_state["stage"] is OptState.STEPPED:
                raise RuntimeError(
                    "step() has already been called since the last update().")

            if (hasattr(optimizer,
                        "_step_supports_amp_scaling") and
                    optimizer._step_supports_amp_scaling):
                retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self))
                optimizer_state["stage"] = OptState.STEPPED
                return retval

            if optimizer_state["stage"] is OptState.READY:
                unscale_()

            assert len(optimizer_state[
                           "found_inf_per_device"]) > 0, \
                "No inf checks were recorded for this optimizer."

            retval = scaler._maybe_opt_step(optimizer, optimizer_state, *args,
                                          **kwargs)

            optimizer_state["stage"] = OptState.STEPPED

            return retval

        def unscale_():
            scaler._check_scale_growth_tracker("unscale_")

            optimizer_state = scaler._per_optimizer_states[id(optimizer)]

            if optimizer_state["stage"] is OptState.UNSCALED:
                raise RuntimeError(
                    "unscale_() has already been called on this optimizer"
                    "since the last update().")
            elif optimizer_state["stage"] is OptState.STEPPED:
                raise RuntimeError("unscale_() is being called after step().")

            # FP32 division can be imprecise for certain compile options, so we
            # carry out the reciprocal in FP64.
            assert scaler._scale is not None
            inv_scale = scaler._scale.double().reciprocal().float()
            found_inf = torch.full((1,), 0.0, dtype=torch.float32,
                                   device=scaler._scale.device)

            optimizer_state["found_inf_per_device"] = scaler._unscale_grads_(
                optimizer, inv_scale, found_inf, True)
            optimizer_state["stage"] = OptState.UNSCALED

        #
        # End monkey patch
        #
        last_reported_embedding = None
        last_reported_step = 0

        for i in range(steps):
            if not self.running:
                return {"ok": False,
                        "reason": "shutdown"}

            curr_time = time.perf_counter()
            # shuffle batches
            if len(indexes) < 1:
                indexes = list(range(n_batches))
                random.shuffle(indexes)
            j = indexes.pop()
            batch = batches[j]
            # train
            optimizer.zero_grad()

            inputs = {}
            inputs['input_ids'] = torch.cat(
                [torch.full((batch.shape[0], prefix_len), self.eot_token),
                 batch],
                1).cuda()
            inputs['attention_mask'] = torch.full(
                (batch.shape[0], prefix_len + batch.shape[1]), 1).cuda()
            labels = torch.cat(
                [torch.full((batch.shape[0], prefix_len), -100), batch],
                1).cuda()

            output = self.model(**inputs, labels=labels)

            loss = output.loss
            loss.backward()
            optimizer.step()
            #scaler.scale(loss).backward()

            #with autocast("cuda"):
            #    step()
            #    scaler.update()

            del labels
            del inputs['input_ids']
            del inputs['attention_mask']

            curr_loss = loss.detach().cpu().item()

            def data_report(status="training"):
                return {"id": session_id,
                        "status": status,
                        "data": json.dumps(
                            {"step": i,
                             "loss": avg_loss / float(i + 1),
                             "losses": losses,
                             "curr_loss": curr_loss,
                             "percentage": (i / steps) * 100})}

            if np.isnan(curr_loss):
                # unprep model
                self.model.transformer.wte = old_wte
                s_wte.wte = None
                del s_wte.learned_embedding
                del s_wte
                del optimizer
                self.model.eval()
                indexes.reverse()
                yield {"event": "loss_nan",
                       "ok": False,
                       "last_reported_step": last_reported_step,
                       "remaining_steps": indexes,
                       "seed": seed,
                       "data": data_report("failed")}
            avg_loss += curr_loss
            losses.append(curr_loss)
            total_iter_time += time.perf_counter() - curr_time
            if (i % report_steps == 0 and i != 0) or i == steps - 1:
                encoded_embedding = prefix.encode_prefix(
                    {"embs": self.model.transformer.wte.learned_embedding.data,
                     "model_version": model_version})
                last_reported_embedding = encoded_embedding
                last_reported_step = i
                yield {"event": "training_update",
                       "encoded_embedding": encoded_embedding,
                       "data": data_report()}
                losses.clear()

            if i % 50 == 0 and i != 0:
                self.logger.info("step: " + str(i) + ", step_iter: " + str(
                    1 / (float(total_iter_time) / i)) + ", contexts/s: " + str(
                    (1 / (float(total_iter_time) / i)) * bs))

        encoded_embedding = prefix.encode_prefix(
            {"embs": self.model.transformer.wte.learned_embedding.data,
             "model_version": model_version})

        for i in range(n_batches):
            del batches[n_batches - i - 1]
        del batches

        # unprep model
        self.model.transformer.wte = old_wte
        s_wte.wte = None
        del s_wte.learned_embedding
        del s_wte
        del optimizer
        self.model.eval()
        yield {"event": "complete",
               "ok": True,
               "encoded_embedding": encoded_embedding,
               "data": data_report(),
               "step": i,
               "loss": avg_loss / float(steps)}


class VQGANModel(Model):
    def __init__(self, config):
        Model.__init__(self, config)
        if hasattr(models, "VQGAN_CLIP"):
            self.model, self.clip, self.declipper = models.VQGAN_CLIP._init_model(
                self.config)
        else:
            raise RuntimeError("ModelNotAvailable")

    def generate_image(self):
        if hasattr(models, "VQGAN_CLIP"):
            return models.VQGAN_CLIP._generate(self.model, self.tokenizer)
        else:
            raise RuntimeError("VqganModelNotAvailable")
