import argparse
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GPTNeoForCausalLM,
    AutoConfig
)
import time

parser = argparse.ArgumentParser(description="Inference tests for GPT models.")
parser.add_argument("--model", help="Model you want to test. Use neo-2.7b or neo-6b", type=str, required=True)
parser.add_argument(
    "--batch",
    dest="batch",
    type=bool,
    default=False,
    help="Run the batching benchmark",
)
args = parser.parse_args()
print("loading model: " + args.model)

tokenizer = AutoTokenizer.from_pretrained("gpt2")

load_time = time.perf_counter()

if args.model == "neo-2.7b":
    model_path = "/root/models/gpt-neo-lit-2.7B-fp16"
    model = AutoModelForCausalLM.from_pretrained(model_path).half().cuda().eval()

elif args.model == "neo-6b":
    model_path = "/root/models/gpt-neo-lit-6B-fp16/j6b_ckpt"
    config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
    config.attention_layers = ["global"] * 28
    config.attention_types = [["global"], 28]
    config.num_layers = 28
    config.num_heads = 16
    config.hidden_size = 256 * config.num_heads
    config.vocab_size = 50400
    config.rotary = True
    config.rotary_dim = 64
    config.jax = True
    try:
        from collections.abc import MutableMapping
    except ImportError:
        from collections import MutableMapping
    from pathlib import Path

    class Checkpoint(MutableMapping):
        def __init__(self, chkpt_dir, device="cpu"):
            self.device = device
            self.chkpt_dir = Path(chkpt_dir)
            self.checkpoint = torch.load(str(chkpt_dir / Path("m.pt")))
        def __len__(self):
            return len(self.checkpoint)
        def __getitem__(self, key):
            path = self.chkpt_dir / Path(self.checkpoint[key]).name
            return torch.load(str(path), map_location=self.device)
        def __setitem__(self, key, value):
            return
        def __delitem__(self, key, value):
            return
        def keys(self):
            return self.checkpoint.keys()
        def __iter__(self):
            for key in self.checkpoint:
                yield (key, self.__getitem__(key))
        def __copy__(self):
            return Checkpoint(self.chkpt_dir, device=self.device)
        def copy(self):
            return Checkpoint(self.chkpt_dir, device=self.device)

    model = GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=Checkpoint(model_path))
    model = model.half().to("cuda").eval()

print("Models loaded in " + str(time.perf_counter()-load_time) + "seconds")

if not args.batch:
    print(model.dtype, model.device)
    print("seq_len\tmax_len\truntime")
    for seq_len in range(128,2049,128):
        ids = torch.zeros(1, seq_len).long().cuda()
        runtime = 0.
        max_length = min(2049, seq_len + 40)
        for i in range(10):
            s = time.perf_counter()
            #outputs = model.generate(ids, use_cache=True, do_sample=True, min_length=max_length, remove_invalid_values=True, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
            outputs = model.generate(ids, use_cache=True, do_sample=True, min_length=max_length, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
            runtime += time.perf_counter() - s
            del outputs
        print(f"{seq_len}\t{max_length}\t{runtime/10.}s")
        del ids

elif args.batch:
    print(model.dtype, model.device)
    print("bs\tseq_len\tmax_len\truntime")
    for bs in range(1,11):
        for seq_len in [950, 1920]:
            ids = torch.zeros(bs, seq_len).long().cuda()
            runtime = 0.
            max_length = min(2049, seq_len + 40)
            for i in range(10):
                s = time.perf_counter()
                #outputs = model.generate(ids, use_cache=True, do_sample=True, min_length=max_length, remove_invalid_values=True, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
                outputs = model.generate(ids, use_cache=True, do_sample=True, min_length=max_length, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
                runtime += time.perf_counter() - s
                del outputs
            print(f"{bs}\t{seq_len}\t{max_length}\t{runtime/10.}s")
            del ids

