import torch
import argparse
from transformers import AutoConfig

parser = argparse.ArgumentParser()
parser.add_argument("-l", "--layers", help="layers in model", type=int, required=True)
parser.add_argument("-d", "--checkpoints_dir", help="folder with mp2 neox model training output pt files", type=str, required=True)
parser.add_argument("-o", "--output", help="output folder with converted hf mp1 model", type=str, required=True)
args = parser.parse_args()

m = torch.load(f"{args.output}/m.pt", map_location="cpu")
torch.save(m, f"{args.output}/m_mp1.pt")

config = AutoConfig.from_pretrained(args.output)
config.split_layernorm = True
config.save_pretrained(args.output)

for layer in range(args.layers):
    print(f"layer {layer}")
    base = f"transformer.h.{layer}.ln_"
    del m[f"{base}1.weight"]
    del m[f"{base}1.bias"]
    del m[f"{base}2.weight"]
    del m[f"{base}2.bias"]
    a = torch.load(f"{args.checkpoints_dir}/layer_{layer+2:02d}-model_00-model_states.pt", map_location="cpu")
    b = torch.load(f"{args.checkpoints_dir}/layer_{layer+2:02d}-model_01-model_states.pt", map_location="cpu")

    torch.save(a["input_layernorm.weight"], f"{args.output}/ln_1_weight_a_{layer}.pt")
    m[f"{base}1.a.weight"] = f"ln_1_weight_a_{layer}.pt"
    torch.save(b["input_layernorm.weight"], f"{args.output}/ln_1_weight_b_{layer}.pt")
    m[f"{base}1.b.weight"] = f"ln_1_weight_b_{layer}.pt"

    torch.save(a["input_layernorm.bias"], f"{args.output}/ln_1_bias_a_{layer}.pt")
    m[f"{base}1.a.bias"] = f"ln_1_bias_a_{layer}.pt"
    torch.save(b["input_layernorm.bias"], f"{args.output}/ln_1_bias_b_{layer}.pt")
    m[f"{base}1.b.bias"] = f"ln_1_bias_b_{layer}.pt"

    torch.save(a["post_attention_layernorm.weight"], f"{args.output}/ln_2_weight_a_{layer}.pt")
    m[f"{base}2.a.weight"] = f"ln_2_weight_a_{layer}.pt"
    torch.save(b["post_attention_layernorm.weight"], f"{args.output}/ln_2_weight_b_{layer}.pt")
    m[f"{base}2.b.weight"] = f"ln_2_weight_b_{layer}.pt"

    torch.save(a["post_attention_layernorm.bias"], f"{args.output}/ln_2_bias_a_{layer}.pt")
    m[f"{base}2.a.bias"] = f"ln_2_bias_a_{layer}.pt"
    torch.save(b["post_attention_layernorm.bias"], f"{args.output}/ln_2_bias_b_{layer}.pt")
    m[f"{base}2.b.bias"] = f"ln_2_bias_b_{layer}.pt"

del m["transformer.ln_f.weight"]
del m["transformer.ln_f.bias"]
a = torch.load(f"{args.checkpoints_dir}/layer_{args.layers+3:02d}-model_00-model_states.pt", map_location="cpu")
b = torch.load(f"{args.checkpoints_dir}/layer_{args.layers+3:02d}-model_01-model_states.pt", map_location="cpu")

torch.save(a["norm.weight"], f"{args.output}/ln_f_weight_a.pt")
torch.save(b["norm.weight"], f"{args.output}/ln_f_weight_b.pt")
m["transformer.ln_f.a.weight"] = "ln_f_weight_a.pt"
m["transformer.ln_f.b.weight"] = "ln_f_weight_b.pt"

torch.save(a["norm.bias"], f"{args.output}/ln_f_bias_a.pt")
torch.save(b["norm.bias"], f"{args.output}/ln_f_bias_b.pt")
m["transformer.ln_f.a.bias"] = "ln_f_bias_a.pt"
m["transformer.ln_f.b.bias"] = "ln_f_bias_b.pt"

torch.save(m, f"{args.output}/m.pt")
