import json
import torch
from fairseq.models.transformer_lm import TransformerLanguageModel
import sys
import os

copy_eot_to_newline = True
copy_newline_to_eot = True
model_dir = 'en_dense_lm_125m' # path to smol model weights to fix tokenizer shuffle

checkpoint = {}
ckmap = {}
ckid = 0

def save(params, name):
    global ckid
    ckmap[name] = f"b{ckid}.pt"
    ckid += 1
    torch.save(params, f"{sys.argv[2]}/" + ckmap[name])
    torch.save(ckmap, f"{sys.argv[2]}/m.pt")
    print(name + ": " + str(params.shape))
    del params

def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    
    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

lm = no_init(lambda: TransformerLanguageModel.from_pretrained(model_dir, bpe='gpt2').eval().cpu())
fairdict = torch.load(f"{sys.argv[1]}", map_location="cpu")

try:
    os.mkdir(sys.argv[2])
except:
    pass

hidden_dim = fairdict["cfg"]["model"]["decoder_embed_dim"]
num_heads = fairdict["cfg"]["model"]["decoder_attention_heads"]
num_layers = fairdict["cfg"]["model"]["decoder_layers"]

fairdict = fairdict["model"]

config = {
    "activation_function": "gelu",
    "architectures": ["GPTNeoForCausalLM"],
    "attention_dropout": 0,
    "attention_layers": ["global"] * num_layers,
    "attention_types": [[["global"], num_layers]],
    "bos_token_id": 50256,
    "embed_dropout": 0,
    "eos_token_id": 50256,
    "gradient_checkpointing": False,
    "hidden_size": hidden_dim,
    "initializer_range": 0.02,
    "intermediate_size": None,
    "fair": True,
    "layer_norm_epsilon": 1e-05,
    "max_position_embeddings": 2048,
    "model_type": "gpt_neo",
    "num_heads": num_heads,
    "num_layers": num_layers,
    "resid_dropout": 0,
    "rotary": False,
    "summary_activation": None,
    "summary_first_dropout": 0.1,
    "summary_proj_to_labels": True,
    "summary_type": "cls_index",
    "summary_use_proj": True,
    "model_dtype": "fp16",
    "model_device": "cuda",
    "transformers_version": "4.6.0.dev0",
    "use_cache": True,
    "vocab_size": 51200,
    "window_size": 256,
    "tokenizer_class": "GPT2Tokenizer",
    "task_specific_params": {"text-generation": {"do_sample": True,"temperature": 1.0,"max_length": 50}}
  }

with open(f"{sys.argv[2]}/config.json", "w") as fh:
    fh.write(json.dumps(config))

#print(lm)

def hack_embs(embs):
    eot = embs[50256].clone()
    newline = embs[198].clone()
    if copy_eot_to_newline:
        embs[198] = eot
    if copy_newline_to_eot:
        embs[50256] = newline

# gpt2 compatible input/output embedding layers
l1 = []
l2 = []

check = {}
for i in range(50256):
    check[i] = True

for i, s in enumerate(lm.tgt_dict.symbols):
    try:
        if str(int(s)) == s and s != '50256':
            l2.append(int(s))
            l1.append(i)
            del check[int(s)]
    except:
        pass

for i, s in enumerate([lm.tgt_dict.eos_word, lm.tgt_dict.pad_word, lm.tgt_dict.bos_word, lm.tgt_dict.unk_word]):
    l2.append(50256 + i)
    l1.append(lm.tgt_dict.indices[s])

mapping = {}
for i in range(50260):
    mapping[l1[i]] = l2[i]

with torch.no_grad():
    wte = fairdict["decoder.embed_tokens.weight"].clone()
    for i in range(50260):
        wte[mapping[i]] = fairdict["decoder.embed_tokens.weight"][i]
    hack_embs(wte)
    save(wte.half(), "transformer.wte.weight")
    lm_head = fairdict["decoder.output_projection.weight"].clone()
    for i in range(50260):
        lm_head[mapping[i]] = fairdict["decoder.output_projection.weight"][i]
    hack_embs(lm_head)
    save(lm_head.half(), "lm_head.weight")

save(torch.FloatTensor(1), "transformer.wpe_sin._float_tensor")

new_state_dict = {}
for y in fairdict:
    dotlist = y.split(".")

    if y == "decoder.version":
        trans_to = "Passed"
        pass

    elif y == "decoder.embed_tokens.weight":
        continue

    elif len(dotlist) >= 2 and dotlist[1] == "layers":
        layer_id = dotlist[2]

        if dotlist[-2] in ["k_proj", "v_proj", "q_proj", "out_proj"]:
            trans_to = f"transformer.h.{layer_id}.attn.attention.{dotlist[-2]}.{dotlist[-1]}"

        if dotlist[-2] == "self_attn_layer_norm":
            trans_to = f"transformer.h.{layer_id}.ln_1.{dotlist[-1]}"

        if dotlist[3] == "fc1":
            trans_to = f"transformer.h.{layer_id}.mlp.c_fc.{dotlist[-1]}"

        if dotlist[3] == "fc2":
            trans_to = f"transformer.h.{layer_id}.mlp.c_proj.{dotlist[-1]}"
        
        if dotlist[3] == "final_layer_norm":
            trans_to = f"transformer.h.{layer_id}.ln_2.{dotlist[-1]}"

    elif len(dotlist) >= 2 and dotlist[1] == "layer_norm":
        trans_to = f"transformer.ln_f.{dotlist[-1]}"

    elif y == "decoder.output_projection.weight":
        continue

    if trans_to != "Passed":
        save(fairdict[y].half(), trans_to)
    print(f"{trans_to} < {y}")
