import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from torch.utils import data
from basedformer import optimizer, utils, models, lm_utils
import yaml
import sys
from tqdm import tqdm
import time
import wandb
import numpy as np
import os

def softmax_activation(x):
    return F.log_softmax(x, dim=-1)

model_config = {
    "n_layer": 12,
    "n_head": 12,
    "hidden_dim": 768,
    "vocab_dim": 50400,
    "eps": 1e-5,
    "q_only": True,
    "activation": torch.nn.GELU(),
}

# we need 250 batch size to train the small GPT.
train_config = {
    #"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
    "data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
    #"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
    "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshift-superhighlr-residualgate-3L-16A-1024H",
    "do_save": False,
    "run_name": "gptj-owt2-512ctx-12L-12H-768H-16bs-1e-4lr-q-only-smallattneveryotherlayer",
    "lr": 1e-4,
    "end_lr": 1e-4,
    "warmup_steps": 100,
    "bs": 16,
    "gas": 1,
    "seed": 69,
    "save_every": 500,
    "amp": True,
    "loss_scale": True,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]

Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)

#model = GPTModel.gpt2_init(model_config).cuda().float()
model = lm_utils.init(models.fast.GPTJModel, model_config).cuda().float()
utils.print_parameters(model)
model.train()

cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None
print(last_cp)

if last_cp:
    print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
    model.load(model_config, last_cp / "lm", strict=True)
    opt = optimizer.BasedOptimizer.load(model.parameters(), last_cp / "opt")

else:
    opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")

# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.FbDataset(2049, train_config["data_path"])
if last_cp:
    train_dataset.skip = opt.curr_step * bs * gas
    
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model_config})

if last_cp:
    curr_step = opt.curr_step
else:
    curr_step = 0

t = tqdm(train_loader, initial=curr_step)

scaler = torch.cuda.amp.GradScaler()

for input_ids, labels in t:
    timex = time.perf_counter()
    input_ids = input_ids.cuda()
    labels = labels.cuda()
    loss = 0
    for x in range(train_config["gas"]):
        with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
            logits = model(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False)
            #print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
            #roll down the sequence
            logits = logits.view(-1, logits.shape[-1])
            gas_labels = labels[x*bs:(x+1)*bs, :512].contiguous()
            gas_labels = gas_labels.view(-1)
            gas_loss = F.cross_entropy(logits, gas_labels)

        if train_config["loss_scale"]:
            scaler.scale(gas_loss).backward()
        else:
            gas_loss.backward()

        loss += gas_loss.item()

    loss = loss / gas
    if train_config["loss_scale"]:
        scaler.unscale_(opt.optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
    if train_config["loss_scale"]:
        opt.step(scaler=scaler)
    else:
        opt.step()

    if train_config["loss_scale"]:
        scaler.update()

    opt.zero_grad()
    sec_per_step = (time.perf_counter() - timex)
    step_per_sec = (1. / sec_per_step)
    tokens_per_sec = (step_per_sec * 512) * bs * gas
    t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
    wandb.log(
        {
            "train/loss": loss,
            "train/tokens_per_sec": tokens_per_sec,
            "train/sec_per_step": sec_per_step,
            "train/step_per_sec": step_per_sec, 
            "train/lr": opt.curr_lr, 
            "train/loss_scale": scaler.get_scale()
        },
        step=curr_step)

    if train_config["do_save"]:
        if curr_step % train_config["save_every"] == 0:
            save_folder = Path(train_config["save_path"]) / f"step_{curr_step}"
            save_folder.mkdir(parents=True, exist_ok=True)
            model.save(save_folder / "lm")
            opt.save(save_folder / "opt")
            print(f"Saved model at step {curr_step}")

    curr_step += 1