import torch
from torch.utils import data
import numpy as np
import sys

class FbDataset(data.Dataset):
    def __init__(self, block_size, map_file, max_samples=None):
        self.half_blocks = False
        if block_size is not None and int(block_size) < 2048:
            self.half_blocks = True
        self.npz = np.memmap(map_file, mode="r", dtype="uint16")
        if self.npz.shape[0] % 2048 > 0:
            self.npz = self.npz[:-(self.npz.shape[0]%2048)].reshape((-1, 2048))
        else:
            self.npz = self.npz.reshape((-1, 2048))
        self.samples = self.npz.shape[0]
        if self.half_blocks:
            self.samples *= 2
        if not max_samples is None:
            self.samples = int(max_samples)
        self.skip = 0
        self.mapping = {}
        self.proc = 0
        self.max_done = 0
    def __len__(self):
        return self.samples
    def __getitem__(self, _id):
        if not _id in self.mapping:
            self.mapping[_id] = self.proc
            self.proc += 1
            self.max_done = self.proc + self.skip + 1
        if _id > self.samples:
            nth = _id + self.skip
        else:
            nth = self.mapping[_id] + self.skip
        offset = 0
        length = 2048
        if self.half_blocks:
            nth = _id // 2
            offset = 1024 * (_id % 2)
            length = 1024
        data = torch.tensor(self.npz[nth][offset:offset+length].astype(np.int64))
        #data[0] = 198
        return {'input_ids': data, 'labels': data}

dataset = FbDataset(2048, sys.argv[1])

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GPTNeoForCausalLM
from transformers.modeling_utils import SplitCheckpoint

def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    
    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

print("loading model")
model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(sys.argv[2]))
print("loaded")

from transformers import TrainingArguments, Trainer, default_data_collator, set_seed

tokenizer = AutoTokenizer.from_pretrained("gpt2")
training_args = TrainingArguments(output_dir="out")
training_args.per_device_eval_batch_size = 1
set_seed(5)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=None,
    eval_dataset=dataset,
    tokenizer=tokenizer,
    # Data collator will default to DataCollatorWithPadding, so we change it.
    data_collator=default_data_collator,
)

import math
with torch.no_grad():
        print("*** Evaluate ***")

        metrics = trainer.evaluate()

        max_val_samples = len(dataset)
        metrics["eval_samples"] = len(dataset)
        perplexity = math.exp(metrics["eval_loss"])
        metrics["perplexity"] = perplexity

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
        print(metrics)
