# coding=utf-8
# Copyright 2021 The rinna Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import yaml
import os
import argparse

import torch

from transformers import GPTNeoXForCausalLM
from transformers import GPTNeoXConfig, AutoConfig
from transformers.modeling_utils import no_init_weights
def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy

    with no_init_weights():    
        result = loading_code()

    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result


checkpoint = {}
ckmap = {}
ckid = 0

def save(name, params):
    global ckid
    ckmap[name] = f"b{ckid}.pt"
    ckid += 1
    torch.save(params, f"{args.hf_save_dir}/" + ckmap[name])
    torch.save(ckmap, f"{args.hf_save_dir}/m.pt")
    print(name + ": " + str(params.shape))
    del params

def get_state_dict_from_checkpoint_dir(checkpoint_dir, num_layers):
    tgt_state_dict = {}

    # word embedding
    src_state_dict = torch.load(os.path.join(checkpoint_dir, "layer_00-model_00-model_states.pt"), map_location="cpu")
    save("transformer.wte.weight", src_state_dict["word_embeddings.weight"])
    del src_state_dict

    # layers
    for layer_idx in range(1, num_layers+1):
        src_state_dict = torch.load(os.path.join(checkpoint_dir, f"layer_{layer_idx+1:02}-model_00-model_states.pt"), map_location="cpu")
        
        # ln_1
        save(f"transformer.h.{layer_idx-1}.ln_1.weight", src_state_dict["input_layernorm.weight"])
        save(f"transformer.h.{layer_idx-1}.ln_1.bias", src_state_dict["input_layernorm.bias"])
        
        # attn.bias, attn.masked_bias: ignored

        # qkv_proj
        save(f"transformer.h.{layer_idx-1}.attn.qkv_proj.weight", src_state_dict["attention.query_key_value.weight"])
        save(f"transformer.h.{layer_idx-1}.attn.qkv_proj.bias", src_state_dict["attention.query_key_value.bias"])

        # out_proj
        save(f"transformer.h.{layer_idx-1}.attn.out_proj.weight", src_state_dict["attention.dense.weight"])
        save(f"transformer.h.{layer_idx-1}.attn.out_proj.bias", src_state_dict["attention.dense.bias"])

        # ln_2
        save(f"transformer.h.{layer_idx-1}.ln_2.weight", src_state_dict["post_attention_layernorm.weight"])
        save(f"transformer.h.{layer_idx-1}.ln_2.bias", src_state_dict["post_attention_layernorm.bias"])

        # mlp
        save(f"transformer.h.{layer_idx-1}.mlp.fc_in.weight", src_state_dict["mlp.dense_h_to_4h.weight"])
        save(f"transformer.h.{layer_idx-1}.mlp.fc_in.bias", src_state_dict["mlp.dense_h_to_4h.bias"])
        save(f"transformer.h.{layer_idx-1}.mlp.fc_out.weight", src_state_dict["mlp.dense_4h_to_h.weight"])
        save(f"transformer.h.{layer_idx-1}.mlp.fc_out.bias", src_state_dict["mlp.dense_4h_to_h.bias"])
        del src_state_dict

    # final norm
    src_state_dict = torch.load(os.path.join(checkpoint_dir, f"layer_{num_layers+3:02}-model_00-model_states.pt"), map_location="cpu")
    save("transformer.ln_f.weight", src_state_dict["norm.weight"])
    save("transformer.ln_f.bias", src_state_dict["norm.bias"])
    del src_state_dict

    # output layer
    src_state_dict = torch.load(os.path.join(checkpoint_dir, f"layer_{num_layers+4:02}-model_00-model_states.pt"), map_location="cpu")
    save("lm_head.weight", src_state_dict["final_linear.weight"])
    del src_state_dict

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint_dir", type=str, required=True, help="directory that contains state dict pt files and a config directory generated by gpt-neox")
    parser.add_argument("--hf_config_path", type=str, required=True, help="path to Huggingface GPT-NeoX configuration file")
    parser.add_argument("--hf_save_dir", type=str, required=True, help="directory to save Huggingface GPT-NeoX model weights and configuration")
    args = parser.parse_args()

    try: os.mkdir(args.hf_save_dir)
    except: pass

    with open(args.hf_config_path, "r") as fh:
        config = json.load(fh)
    get_state_dict_from_checkpoint_dir(args.checkpoint_dir, config["n_layer"])
