from dotmap import DotMap
import base64
import torch
import os
import functools
import lm_node.prefix
import logging
from icecream import ic

logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
    "%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)

# not enforced, only applied if not set.
default_values = {
    "top_p": 1.0,
    "top_k": 0,
    "temperature": 1.0,
    "repetition_penalty": 1.0,
    "use_cache": True,
    "do_sample": True,
    "pad_token_id": 50256,
    "max_length": 20,
    "min_length": 1,
    "prefix": "vanilla",
    "generate_until_sentence": False,
}

# all other parameters are blacklisted.
whitelist = [
    "top_p",
    "top_k",
    "top_a",
    "typical_p",
    "temperature",
    "min_length",
    "max_length",
    "do_sample",
    "repetition_penalty",
    "eos_token_id",
    "pad_token_id",
    "use_cache",
    "tail_free_sampling",
    "tfs",
    "repetition_penalty_slope",
    "repetition_penalty_range",
    "bad_words_ids",
    "repetition_penalty_frequency",
    "repetition_penalty_presence",
    "prefix",
    "hypernetwork",
    "logit_bias",
    "order",
    "request_id",
    "repetition_penalty_whitelist",
    "output_nonzero_probs",
    "generate_until_sentence",
    "embs",
    "logit_bias_exp",
    "num_logprobs",
    "stop_sequences"
]

default_values = DotMap(default_values)


def tokenize(req_dict, use_string,
             tokenizer):  # just forward input here to make it more modular
    if isinstance(req_dict.input, torch.Tensor):
        ids = req_dict.input
    elif use_string:
        ids = torch.tensor([tokenizer(req_dict.input).input_ids])
    else:
        out = base64_to_list(tokenizer, req_dict.input)
        if out[0]:
            detokenlist = out[1]

        else:
            return out

        ids = torch.tensor([detokenlist])
    return (ids,)


def list_to_base64(input):
    try:
        bytestring = b''
        for x in input:
            bytestring += bytes([x % 256, (x - (x % 256)) // 256])

        return base64.b64encode(bytestring)

    except Exception as e:
        return (False, str(e))


def base64_to_list(tokenizer, input):
    try:
        entokenlist = base64.b64decode(input)

    except Exception as e:
        return (False, str(e))

    detokenlist = []
    x = 0
    while x < len(entokenlist) - 1:
        token = entokenlist[x] + entokenlist[x + 1] * 256
        detokenlist.append(token)
        x += 2

    if any(x > len(tokenizer) for x in detokenlist):
        return (False, "Wrong tokenization.")

    return (True, detokenlist)


def load_user_module(req_dict, default_modules, default_hypernets,
                     config, model_obj):
    req_params = req_dict.parameters
    req_params.hypernetwork = None
    if "prefix" in req_dict.parameters:
        prefixstr = req_params.prefix
        # ic(prefixstr)
        prefixlist = prefixstr.split(':')
        if len(prefixlist) == 3:
            model, prefix_id, prefix_key = prefixlist
            if os.path.isfile(config.user_module_path + "/" + prefix_id):
                emb = cache_module(config.user_module_path, prefix_id,
                                   prefix_key)
                model_obj.logger.info(
                    "Loaded module: {0}/{1} - {2}".format(
                        config.user_module_path, prefix_id,
                        model_obj.getGPUram()))
                # ic(emb)
                req_params.embs = [(0, emb)]
            else:
                return (False, f"ERROR: User module {prefix_id} doesn't exist.")

        if len(prefixlist) == 1:
            if (default_hypernets is not None and
                    prefixstr in default_hypernets):
                req_params.hypernetwork = default_hypernets[prefixstr]
            elif (default_modules is not None
                  and prefixstr not in default_modules):
                if prefixstr != "vanilla":
                    return (False, f"ERROR: ${prefixstr} doesn't exist.")
            elif default_modules is not None:
                emb = default_modules[req_params.prefix]
                req_params.embs = [(0, emb)]

    n_embs = req_params.embs[0][1].shape[0] if "embs" in req_params else 0
    req_dict.n_embs = n_embs
    req_dict.n_hypers = 1 if req_params.hypernetwork is not None else 0
    return (True,)


def cache_module(path, idx, key):
    if type(key) == str:
        key = bytes.fromhex(key)

    with open(f"{path}/{idx}", "rb") as fh:
        encoded = lm_node.prefix.self_decrypt_prefix(fh.read(), key)
        module, mod_version = lm_node.prefix.decode_prefix(encoded)
        return module.cuda()


def resolve_order(req_dict):
    default_order = {
        0: ["temperature", 1.0],
        1: ["top_k", 0],
        2: ["top_p", 1.0],
        3: ["tfs", 1.0],
        4: ["top_a", 1.0],
        5: ["typical_p", 1.0]
    }

    # No order value was provided, so construct one from parameters that are
    # set according to the default order.
    order = req_dict.get('order', None)
    if order is None or len(order) == 0:
        order = []
        for logit_id in default_order.keys():
            param, disabled_value = default_order[logit_id]
            if param in req_dict and \
                    req_dict[param] != disabled_value and \
                    req_dict[param] is not None:
                order.append(logit_id)

    # We have at least one logit id in our order vector, so go through them
    # and filter out those that have default values set or are invalid.
    new_order = []
    for item in order:
        param, disabled_value = default_order.get(item, [None, None])
        if param is None:
            raise Exception(f"InvalidLogitId: %d is not a valid logit id",
                            item)
        if item in new_order:
            raise Exception(f"RepeatedLogitId: '%s' was repeated in 'order'",
                            param)
        if param in req_dict and \
                req_dict[param] != disabled_value and \
                req_dict[param] is not None:
            new_order.append(item)

    # Then we make sure that the logits that are not in the order sequence
    # are disabled.
    for to_disable in set(default_order.keys()).difference(set(new_order)):
        param, disabled_value = default_order.get(to_disable)
        req_dict[param] = None
        new_order.append(to_disable)

    req_dict['order'] = new_order
    return req_dict


# TODO: Enforce and mark types.
def sanitize(tokenizer, req_dict, config):
    warning = None
    ids = req_dict.input
    req_params = req_dict.parameters.copy()

    model_name = config['model_name']
    model_max_tokens = config.get('model_max_tokens', 2028)
    if req_dict.input is None or req_dict.parameters is None or \
            req_dict.input.shape[1] == 0:
        return (False, "ERROR: Empty request.")

    if "tail_free_sampling" in req_params:
        req_params["tfs"] = req_params["tail_free_sampling"]

    subscription_tier = req_params.subscription_tier
    req_params = req_params.toDict()

    for key in req_dict:
        if key in whitelist:
            req_params[key] = req_dict[key]

    for x in default_values:
        if x not in req_params or req_params[x] is None:
            req_params[x] = default_values[x]

    for key in list(req_params):
        if key not in whitelist:
            del req_params[key]

    req_params = DotMap(req_params)
    max_length = req_params.max_length
    min_length = req_params.min_length

    if type(req_params.repetition_penalty) == int:
        req_params.repetition_penalty = float(req_params.repetition_penalty)

    if subscription_tier == 3:
        max_length_factor = 150
    elif subscription_tier in [1, 2]:
        max_length_factor = 100
    else:
        max_length_factor = model_max_tokens

    logit_biases = req_params.get('logit_bias_exp', [])
    if logit_biases is None:
        logit_biases = []
    num_logit_biases = len(logit_biases)
    if num_logit_biases > 1024:
        raise Exception(
            f"ExceedMaxBiases: {num_logit_biases} biases > 1024 maximum")

    bans = req_params.get('bad_words_ids', [])
    if bans is None:
        bans = []
    num_bans = len(bans)
    if num_bans > 2048:
        raise Exception(
            f"ExceedMaxBans: {num_bans} bans > 2048 maximum")

    def tensorize_stop(s):
        if type(s) != torch.Tensor:
            s = [s]
        return torch.LongTensor(s).cuda()

    if "stop_sequences" in req_params:
        req_params.stop_sequences = [tensorize_stop(s) for s in
                                     req_params.stop_sequences]

    if "repetition_penalty_range" in req_params:
        if req_params.repetition_penalty_range <= 3:
            req_params.repetition_penalty_range = 3

    if "tail_free_sampling" in req_dict:
        req_params.tfs = req_dict["tail_free_sampling"]

    if "repetition_penalty" in req_params and req_params[
        "repetition_penalty"] == 1:
        req_params.repetition_penalty = None

    max_length, min_length = list(map(lambda x: min(max_length_factor,
                                                    max(0, x)),
                                      [max_length, min_length]))
    min_length = min(min_length, max_length)
    req_params.temperature = float(req_params.temperature)
    n_embs = req_params.embs[0][1].shape[0] if "embs" in req_params else 0
    req_dict.n_embs = n_embs
    req_params.max_length = min(max_length + len(ids[0]) + req_dict.n_embs,
                                model_max_tokens)
    req_params.min_length = min(min_length + len(ids[0]) + req_dict.n_embs,
                                model_max_tokens)
    # req_dict.input = ids[:, -(2028 - max_length - req_dict.n_embs):]
    if req_dict.n_embs > 0:
        eot = tokenizer.encode("<|endoftext|>")[0]
        req_dict.input = torch.cat((torch.full((req_dict.input.shape[0],
                                                req_dict.n_embs),
                                               eot).long(),
                                    req_dict.input), dim=1)

    curr_used = max_length + len(ids[0]) + req_dict.n_embs
    if curr_used > model_max_tokens:
        raise Exception(
            f"ModelMaxLength: model '{model_name}' has {model_max_tokens} " +
            f"tokens maximum, {curr_used} tokens context requested")
        # tokens_text = tokenizer.decode(ids[0])
        # new_tokens = tokenizer.encode(tokens_text[int(-model_max_tokens*3.5):])
        # req_dict.input = torch.tensor([tokenizer(new_tokens).input_ids])

    req_params = resolve_order(req_params)
    req_params.use_cache = True

    req_params["tokenizer"] = tokenizer

    for key in list(req_params):
        req_dict[key] = req_params[key]

    # del req_params.eos_token_id
    return (req_params, warning)


def process_payload(req_dict, config,
                    use_string, model):
    error_msg = None
    warning = None
    output = tokenize(req_dict, use_string, model.tokenizer)
    if output[0] is not False:
        req_dict.input = output[0]

    else:
        return (False, output[1])

    if model.modules or model.hypernets:
        output = load_user_module(req_dict, model.modules, model.hypernets,
                                  config, model)
        if output[0] is not False:
            pass

        else:
            return (False, output[1])

    output = sanitize(model.tokenizer, req_dict, config)
    if output[0] is not False:
        req_dict.parameters = output[0]

    else:
        return (False, output[1])

    return (req_dict, warning)
