# requirements: gpt-neo-localattention3-rp-b branch of transformers and pynacl==1.4.0
# map files are just a flat array of tokens with eot token between files, then
# with open("literature.map", "wb") as fh:
#     fh.write(np.array(tokens, dtype=np.uint16).tobytes())
# optional: use this with -r and prepend a single fake token that is actually the number of steps
import random
import json
import os
import numpy as np
import traceback
import torch
from transformers import GPTNeoForCausalLM, AutoConfig, GPT2TokenizerFast
from transformers.optimization import Adafactor
import transformers
import prefix # https://gist.github.com/finetuneanon/99eb04b35f4b4776399d3a4a1fbfcef2 (requires the usual and pynacl==1.4.0)
import torch
import torch.nn as nn
import os
try:
    from collections.abc import MutableMapping
except ImportError:
    from collections import MutableMapping
from pathlib import Path
from tqdm import tqdm

class SoftEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                n_tokens: int = 10, 
                random_range: float = 0.5,
                initialize_from_vocab: bool = True, init_tokens = None):
        """appends learned embedding to 
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens, 
                                                                               random_range, 
                                                                               initialize_from_vocab, init_tokens))
            
    def initialize_embedding(self, 
                             wte: nn.Embedding,
                             n_tokens: int = 10, 
                             random_range: float = 0.5, 
                             initialize_from_vocab: bool = True, init_tokens=None):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab and init_tokens is not None:
            return self.wte.weight[init_tokens].clone().detach()
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(wte.weight.size(1), n_tokens).uniform_(-random_range, random_range)
            
    def forward(self, tokens):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)

def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    
    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

class Checkpoint(MutableMapping):
    def __init__(self, chkpt_dir, device="cpu"):
        self.device = device
        self.chkpt_dir = Path(chkpt_dir)
        self.checkpoint = torch.load(str(chkpt_dir / Path("m.pt")))
    def __len__(self):
        return len(self.checkpoint)
    def __getitem__(self, key):
        path = self.chkpt_dir / Path(self.checkpoint[key]).name
        return torch.load(str(path), map_location=self.device)
    def __setitem__(self, key, value):
        return
    def __delitem__(self, key, value):
        return
    def keys(self):
        return self.checkpoint.keys()
    def __iter__(self):
        for key in self.checkpoint:
            yield (key, self.__getitem__(key))
    def __copy__(self):
        return Checkpoint(self.chkpt_dir, device=self.device)
    def copy(self):
        return Checkpoint(self.chkpt_dir, device=self.device)

model = None
tokenizer = None
def init(model_folder):
    global model
    global tokenizer
    config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
    config.attention_layers = ["global"] * 28
    config.attention_types = [["global"], 28]
    config.num_layers = 28
    config.num_heads = 16
    config.hidden_size = 256 * config.num_heads
    config.vocab_size = 50400
    config.rotary = True
    config.rotary_dim = 64
    config.jax = True
    config.model_dtype = "fp16"#"bf16"
    config.model_device = "cuda"
    config.full_bf16 = False #True
    model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=Checkpoint(model_folder)))
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

def train(tokens, model_version, steps=3000, bs=1, tokens_per_sample=256, prefix_len=20, gen_samples=5, gen_len=200, prompts=["The", "He", "She"]):
    if not (isinstance(tokens, np.ndarray) and tokens.dtype in [np.uint16] and len(tokens.shape) == 1 and tokens.shape[0] < 2048 * 1024 * 50 and tokens.shape[0] > 1):
        return {"ok": False, "reason": "invalid input"}

    init_tokens = get_init_tokens(tokens, prefix_len)

    # take care of small token chunks
    tiny_size = tokens.shape[0] % tokens_per_sample
    if tiny_size > 0:
        batches = [torch.tensor(tokens[-tiny_size:].astype(np.int32)).long().unsqueeze(0)]
        tokens = tokens[:-tiny_size]
    else:
        batches = []

    # shuffle samples
    n_samples = tokens.shape[0] // tokens_per_sample
    if n_samples > 0:
        indexes = list(range(n_samples))
        random.shuffle(indexes)

        n_big_batches = n_samples // bs
        n_partial_batch = n_samples % bs

        for i in range(n_big_batches):
            batch = []
            for j in range(bs):
                k = indexes.pop()
                batch.append(torch.tensor(tokens[k * tokens_per_sample:(k + 1) * tokens_per_sample].astype(np.int32)).long())
            batches.append(torch.stack(batch))

        batch = []
        for i in range(n_partial_batch):
            k = indexes.pop()
            batch.append(torch.tensor(tokens[k * tokens_per_sample:(k + 1) * tokens_per_sample].astype(np.int32)).long())
        if n_partial_batch > 0:
            batches.append(torch.stack(batch))
    del tokens

    # prep model
    model.train()
    for param in model.parameters():
        param.requires_grad = False
    old_wte = model.transformer.wte
    s_wte = SoftEmbedding(old_wte, n_tokens=prefix_len, initialize_from_vocab=True, init_tokens=init_tokens).to("cuda")
    model.transformer.wte = s_wte
    params = [model.transformer.wte.learned_embedding]
    optimizer = Adafactor(params=params)

    indexes = []
    n_batches = len(batches)
    avg_loss = 0.0
    for i in tqdm(range(steps)):
        # shuffle batches
        if len(indexes) < 1:
            indexes = list(range(n_batches))
            random.shuffle(indexes)
        j = indexes.pop()
        batch = batches[j]
        # train
        optimizer.zero_grad()

        inputs = {}
        inputs['input_ids'] = torch.cat([torch.full((batch.shape[0], prefix_len), 50256), batch],1).cuda()
        inputs['attention_mask'] = torch.full((batch.shape[0], prefix_len + batch.shape[1]), 1).cuda()
        labels = torch.cat([torch.full((batch.shape[0], prefix_len), -100), batch], 1).cuda()

        output = model(**inputs, labels=labels)
        del labels
        del inputs['input_ids']
        del inputs['attention_mask']

        loss = output.loss
        loss.backward()
        optimizer.step()
        avg_loss += loss.detach().cpu().item()
        if i%10 == 0:
          print(f"{i}: Loss: {loss} Avg: {avg_loss / float(i+1)}")

    encoded_embedding = prefix.encode_prefix({"embs": model.transformer.wte.learned_embedding.data, "model_version": model_version})

    for i in range(n_batches):
        del batches[n_batches - i - 1]
    del batches

    # unprep model
    model.transformer.wte = old_wte
    s_wte.wte = None
    del s_wte.learned_embedding
    del s_wte
    del optimizer
    model.eval()

    # gen samples
    embs = prefix.decode_prefix(encoded_embedding)["embs"].cuda()
    samples = ""

    print("generating samples")
    for prompt in prompts:
        ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cpu")
        n_ids = ids.shape[1]
        if n_ids < 1:
          n_ids = 1
          ids = torch.tensor([[tokenizer.eos_token_id]])
        n_embs = embs.shape[0]
        if n_embs > 0:
          ids = torch.cat((torch.full((ids.shape[0], n_embs), 50256), ids), dim=1)
        n_ids = ids.shape[1]
        max_length = n_ids + gen_len

        ids = ids.long().cuda()
        for i in range(gen_samples):
            samples += ("-" * 30) + f" SAMPLE {i+1} " + ("-" * 30) + f"\nPrompt: {prompt}\n\n"
            basic_output = model.generate(
                ids,
                do_sample=True,
                min_length=max_length,
                max_length=max_length,
                temperature=0.7,
                tfs = None,
                top_k = 50,
                top_p = 0.9,
                repetition_penalty = 1.08,
                repetition_penalty_range = 2048,
                repetition_penalty_slope = 3.33,
                repetition_penalty_frequency = 0.0,
                repetition_penalty_presence = 0.0,
                use_cache=True,
                pad_token_id=tokenizer.eos_token_id,
                embs=[(0, embs)],
            ).long().to("cpu")
            samples += tokenizer.decode(basic_output[0][n_embs:])
            samples += "\n\n"
        del ids

    del embs
    return {"ok": True, "encoded_embedding": encoded_embedding, "samples": samples, "loss": avg_loss / float(steps)}

def get_init_tokens(tokens, n_tokens):
    hist = np.bincount(tokens)
    hist_s = np.argsort(hist)[::-1]

    import json
    with open("wordtokens.json", "r") as fh:
        word_tokens = json.load(fh)

    got = 0
    i = 0
    relevant = []
    while got < n_tokens and i < 50256:
        t = hist_s[i]
        if word_tokens[t] > 0:
            #print(t, tokenizer.decode([t]))
            got += 1
            relevant.append(t)
        i += 1

    if got < n_tokens:
        return hist_s[0:n_tokens][::-1]

    return relevant[::-1]

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model-folder", help="folder with split checkpoint", type=str, required=True)
    parser.add_argument("-i", "--input-folder", help="folder with input numpy np.uint16 memmaps", type=str, required=True)
    parser.add_argument("-t", "--tokens-per-sample", help="context size per training sample", type=int, default=256)
    parser.add_argument("-v", "--model-version", help="version id of the model (sigurdv3 = 3)", type=int, required=True)
    parser.add_argument("-s", "--steps", help="number of training steps per map", type=int, default=3000)
    parser.add_argument("-p", "--prefix-len", help="prefix length", type=int, default=20)
    parser.add_argument("-g", "--gen-samples", help="number of samples to generate for each prompt", type=int, default=5)
    parser.add_argument("-l", "--gen-len", help="length of each sample", type=int, default=200)
    parser.add_argument("-r", "--read-steps", help="read steps from first map element", action='store_true')
    parser.add_argument("-b", "--batch-size", help="training batch size", type=int, default=1)
    parser.add_argument("-j", "--json-embeddings", help="output embeddings as json", action='store_true')
    args = parser.parse_args()

    init(args.model_folder)

    maps = []
    for root, subdirs, files in os.walk(args.input_folder):
        for file in files:
            file = str(Path(root) / Path(file))
            if not file.lower().endswith('map'):
                continue
            maps.append(file)

    for mmap in maps:
        print(f"processing {mmap}")
        map_path = Path(mmap)
        base_path = str(map_path.parent / map_path.stem)
        tokens = np.memmap(mmap, mode="r", dtype="uint16")
        result = {}
        try:
            steps = args.steps
            if args.read_steps:
                steps = int(tokens[0])
                tokens = tokens[1:]
            result = train(tokens, args.model_version, steps=steps, bs=args.batch_size, prefix_len=args.prefix_len, gen_samples=args.gen_samples, gen_len=args.gen_len, tokens_per_sample=args.tokens_per_sample)
            if result["ok"]:
                if args.json_embeddings:
                    decoded = prefix.decode_prefix(result["encoded_embedding"])
                    prefix.write_json_embs(decoded["embs"], f"{base_path}.emb")
                else:
                    with open(f"{base_path}.emb", "w") as fh:
                        fh.write(result["encoded_embedding"])
                with open(f"{base_path}.txt", "wb") as fh:
                    fh.write(result["samples"].encode('utf-8', 'surrogateescape'))
                print("success")
                os.unlink(mmap)
                continue
        except Exception as e:
            print(e)
            print(traceback.format_exc())
            result["exception"] = str(e)
        print("failure")
        with open(f"{base_path}.fail", "w") as fh:
            fh.write(json.dumps(result))
        os.unlink(mmap)
{"mode":"full","isActive":false}
