import gc
import glob
import hashlib
import random
from typing import Callable, Union, Generator, Sequence, List

import numpy as np
from dotmap import DotMap
from torch import nn
from typing_extensions import Self

import transformers.modeling_utils
from transformers import (
    AutoModelForCausalLM,
    GPTNeoForCausalLM,
    AutoConfig, Adafactor,
)
import os
import torch
from lm_node import utils, prefix
import time
from mmappickle import mmapdict
import lm_node.unitrim
import struct
import base64
import json
import websockets
import math
from icecream import ic
from transformers.modeling_utils import SplitCheckpoint


def gelu_new(x):
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi)
                                       * (x + 0.044715 * torch.pow(x, 3.0))))

def _init_weights(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()

    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)

    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

class HyperNetworkSingle(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config["hidden_dim"]
        self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
        self.activation = gelu_new
        # self.linear.weight.data.normal_(mean=0.0, std=0.02)
        for module in self.modules():
            _init_weights(module)

        for param in self.linear.parameters():
            param.data.normal_(mean=0.0,
                               std=(0.02 / math.sqrt(2 * config["n_layer"])))

    def forward(self, x):
        x = x.half()
        x = self.linear(x)
        x = x.mul(torch.sigmoid(x))
        return x.half()

def no_init(
        loading_code: Callable[[], transformers.modeling_utils.PreTrainedModel]) \
        -> transformers.modeling_utils.PreTrainedModel:
    def dummy(self):
        return

    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy

    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]

    return result


def _init_model(nodeconf) -> transformers.modeling_utils.PreTrainedModel:
    model: transformers.modeling_utils.PreTrainedModel

    model_name = nodeconf.get("model_name", nodeconf.model_path)
    model = no_init(lambda: AutoModelForCausalLM.from_pretrained(
        nodeconf.model_path, cache_dir="./cache"))
    model = model.half().cuda().eval()

    if nodeconf.deepspeed_enabled:
        import deepspeed
        ds_engine = deepspeed.init_inference(model,
                                             mp_size=1,
                                             dtype=torch.half,
                                             replace_method='auto')
        return ds_engine.module

    return model


def _init_modules(model_device, model_dtype, nodeconf):
    default_modules = {}
    if nodeconf.prefix_path is None:
        return {}
    try:
        prefixes = os.listdir(nodeconf.prefix_path)
    except FileNotFoundError:
        nodeconf.logger.info(
            f"WARNING: No modules loaded from {nodeconf.prefix_path}!")
        return None
    for module in prefixes:
        with open(f"{nodeconf.prefix_path}/{module}", "r",
                  encoding='utf-8') as fh:
            if not module.endswith(".json"):
                continue
            key = module.replace(".json", "")
            contents = fh.read()
            default_modules[key] = torch.tensor(json.loads(contents)).to(
                model_device).to(model_dtype)
            module_size = default_modules[key].shape
            module_hash = hashlib.md5(contents.encode('utf-8')).hexdigest()
            nodeconf.logger.info(
                f"MODULE: loaded [{module_hash}] {key}, {module_size}")
    return default_modules

def _init_hypernet(model_device, model_dtype, nodeconf: DotMap):
    default_hypernets = {}
    if nodeconf.hyper_path is None:
        return {}
    try:
        hypers = os.listdir(nodeconf.hyper_path)
    except FileNotFoundError:
        nodeconf.logger.warn(
            f"WARNING: No hypernets loaded from {nodeconf.hyper_path}!")
        return None
    for hyper in hypers:
        hyper_path = f"{nodeconf.hyper_path}/{hyper}"
        if not hyper.endswith(".hyper"):
            continue
        key = hyper.replace(".hyper", "")
        hyper_conf = {"hidden_dim": nodeconf.hidden_dim,
                      "n_layer": nodeconf.n_layer}
        hypernetwork = HyperNetworkSingle(hyper_conf).to(model_device).to(
            model_dtype)
        hypernetwork.load_state_dict(torch.load(hyper_path))
        default_hypernets[key] = hypernetwork
        hyper_size = os.path.getsize(hyper_path)
        hyper_hash = hashlib.md5(
            open(hyper_path, 'rb').read()).hexdigest()
        nodeconf.logger.info(
            f"HYPERNET: loaded [{hyper_hash}] {key}, {hyper_size}")
    return default_hypernets



def _get_hidden_states(model, ids):
    with torch.no_grad():
        output = utils.text.build_engram(model.forward, ids).tolist()
        return output, []


def _next_token_probabilities(model, tokenizer, ids):
    with torch.no_grad():
        output = utils.text.get_next_words(model.forward, tokenizer, ids)
        return output, []


def _token_probabilities(model, ids, n=0, idx=0):
    with torch.no_grad():
        return utils.text.get_token_probabilities(model, ids, n=n, idx=idx)


def params_report(self, ids, req_params):
    logit_bias = req_params.get("logit_bias", [])
    biases = req_params.get("logit_bias_exp", [])
    bans = req_params.get("bad_words_ids", [])
    stops = req_params.get("stop_sequences", [])
    if biases is None:
        biases = []
    if bans is None:
        bans = []
    num_bans = len(bans)
    num_biases = len(biases)
    num_embeddings = len(req_params.get("embs", []))
    num_logprobs = req_params.get("num_logprobs", None)
    num_stops = len(stops)
    flatten_bans = [item for sublist in bans for item in sublist]
    flatten_bias = [item for sublist in biases for item in sublist]
    cpu_stops = [item.cpu() for item in stops]

    return "{0} tokens, {1} desired, {2} total -- ".format(
        len(ids[0]), req_params['max_length'] - len(ids[0]),
        req_params['max_length'], ) + \
           "{0} bans (flat: {1}), {2} phrase biases (flat: {3}), ".format(
               num_bans, len(flatten_bans), num_biases, len(flatten_bias),
               num_logprobs) + \
           "{0} logit biases, {1} stop_sequences (seqs: {2}), ".format(
               len(logit_bias), num_stops, cpu_stops) + \
           "{0} num_logprobs, {1} embeddings -- order={2}, ".format(
               num_logprobs, num_embeddings, req_params.get('order', [])) + \
           "temp={0}, top_p={1}, top_k={2}, top_a={3}, typical_p={4}, tfs={5}".format(
               req_params.get('temperature'), req_params.get('top_p'),
               req_params.get('top_k'), req_params.get('top_a'),
               req_params.get('typical_p'),
               req_params.get('tail_free_sampling'))


def _generate(self, model, tokenizer, ids, req_params, use_string):
    request_id_repr = req_params.get("request_id", "")
    if request_id_repr:
        request_id_repr = f"[{request_id_repr}] "
    with torch.no_grad():
        self.logger.info("{0}generate() params: {1}".format(
                         request_id_repr, params_report(self, ids, req_params)))
        self.logger.info("{0}generate() starting: {1}".format(
                         request_id_repr, self.getGPUram()))
        try:
            req_params["tokenizer"] = tokenizer
            generated_tokens = []
            scores = []
            for tokens, is_finished, scores_before, scores_after in \
                    model.generate(ids.long().cuda(), **req_params):
                generated_tokens.append(int(tokens[0]))
                if req_params.get("num_logprobs", None) is not None:
                    tok_probs, before_probs, after_probs = \
                        utils.text.process_scores(
                            tokens[0],
                            scores_before,
                            scores_after,
                            num_logprobs=req_params['num_logprobs'],
                            filter_inf=True)
                    scores.append({"chosen": tok_probs,
                                   "before": before_probs,
                                   "after": after_probs})
                del is_finished
                del tokens
                del scores_before
                del scores_after

            generated_tokens = self.unitrim.trim(generated_tokens)
            if use_string:
                output = tokenizer.decode(generated_tokens)
            else:
                packed_tokens = b''
                for token in generated_tokens:
                    packed_tokens = packed_tokens + struct.pack("<H", token)
                output = base64.b64encode(packed_tokens).decode("utf-8")
            return output, scores
        except Exception as e:
            raise e
        finally:
            self.logger.info(
                "{0}generate() done: {1}".format(request_id_repr,
                                                 self.getGPUram()))
            utils.memory.cleanup_variables(locals(),
                                           "ids",
                                           "tokens",
                                           "scores_before",
                                           "scores_after")
            torch.cuda.synchronize()
            self.logger.info(
                "{0}generate() cleanup: {1}".format(request_id_repr,
                                                    self.getGPUram()))


async def _generate_stream(self, model, tokenizer, ids, req_params, use_string,
                           generation_id):
    with torch.no_grad():
        if os.environ['DEV'] == "True":
            uri = "wss://staging.novelai.net/ai/internal/node-pipe"
        else:
            uri = "wss://api.novelai.net/ai/internal/node-pipe"

        params_str = params_report(self, ids, req_params)
        gpu_str = self.getGPUram()
        self.logger.info(f"generate_stream() params: {params_str}")
        self.logger.info(f"generate_stream() starting: {gpu_str}")

        async with websockets.connect(uri, ping_interval=None) as websocket:
            req_params["use_cache"] = True
            req_params["tokenizer"] = tokenizer
            ptr = 0
            try:
                for tokens, is_finished, scores_before, scores_after in \
                        model.generate(ids.long().cuda(), **req_params):
                    for y in range(tokens.shape[0]):  # batching support
                        # If `num_logprobs` is not None, we process the weights
                        # returned by the `generate` call into datastructures
                        # suitable for consumption by clients.
                        logprobs = None
                        if req_params.get("num_logprobs", None) is not None:
                            tok_probs, before_probs, after_probs = \
                                utils.text.process_scores(
                                    tokens,
                                    scores_before,
                                    scores_after,
                                    req_params['num_logprobs'],
                                    filter_inf=True)
                            logprobs = {'chosen': tok_probs,
                                        'before': before_probs,
                                        'after': after_probs}
                        tokens = tokens.cpu().detach().numpy()
                        if not use_string:
                            output = self.packB64(tokens[y])
                        else:
                            output = tokenizer.decode([tokens[y]])
                        is_finished = is_finished is not False
                        msg = {"event": "token", "data":
                            {"uuid": generation_id,
                             "token": output,
                             "ptr": ptr,
                             "final": is_finished,
                             "logprobs": logprobs}}

                        message = json.dumps(msg)
                        del msg
                        del output
                        del tokens
                        del scores_before
                        del scores_after
                        torch.cuda.synchronize()
                        await websocket.send(message)
                    ptr += 1
            except Exception as e:
                self.logger.error("failure at token ct {0}".format(ptr))
                raise e
            finally:
                self.logger.info(
                    "generate_stream() done: {0}".format(self.getGPUram()))
                utils.memory.cleanup_variables(locals(),
                                               "ids",
                                               "msg",
                                               "output",
                                               "tokens",
                                               "scores_before",
                                               "scores_after")
                gc.collect()
                torch.cuda.synchronize()
                #torch.cuda.empty_cache()
                self.logger.info(
                    "generate_stream() cleanup: {0}".format(self.getGPUram()))


class SoftEmbedding(nn.Module):
    def __init__(self,
                 wte: Union[nn.Embedding, Self],
                 n_tokens: int = 10,
                 random_range: float = 0.5,
                 initialize_from_vocab: bool = True, init_tokens=None) -> Self:
        """appends learned embedding to
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not
            initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default
            vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(
            self.initialize_embedding(wte,
                                      n_tokens,
                                      random_range,
                                      initialize_from_vocab,
                                      init_tokens))

    def initialize_embedding(self,
                             wte: nn.Embedding,
                             n_tokens: int = 10,
                             random_range: float = 0.5,
                             initialize_from_vocab: bool = True,
                             init_tokens=None):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab and init_tokens is not None:
            return self.wte.weight[init_tokens].clone().detach()
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(wte.weight.size(1), n_tokens).uniform_(
            -random_range, random_range)

    def forward(self, tokens):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
        Returns:
            torch.float: encoding of text concatenated with learned task
            specifc embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(
            input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)


