from main import *
import time
import gc

from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
    precision = 'ns'
    r_arr = np.empty([2, r]) # [0] = mean, [1] = std
    if function:
        func.__name__ = function.__name__

    for i in tqdm(range(r)) if do_tqdm else range(r):
        n_arr = np.empty(n)
        for k in range(n):
            start = perf_counter_ns()
            func()
            n_arr[k] = perf_counter_ns() - start
        
        if not first:
            # delete the first element from n_arr numpy array
            n_arr = np.delete(n_arr, 0)

        r_arr[0, i] = np.mean(n_arr)
        r_arr[1, i] = np.std(n_arr)
    
    best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
    #check if best[0] bigger than 1ms in numpy
    if best[0] < 1e3:
        precision = 'ns'

    elif best[0] >= 1e9:
        best[0] = best[0] * 1e-9
        best[1] = best[1] * 1e-9
        precision = 's'

    elif best[0] >= 1e6:
        best[0] = best[0] * 1e-6
        best[1] = best[1] * 1e-6
        precision = 'ms'

    elif best[0] >= 1e3:
        precision = 'μs'
        best[0] = best[0] * 1e-3
        best[1] = best[1] * 1e-3

    if not quiet:
        if precision == 'ns':
            print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        if precision == 'μs':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 'ms':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 's':
            print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")


def rndinput(shape):
    return torch.randint(0, 50256, shape).long().cuda()

@torch.no_grad()
def forward(model, x, hypernetwork=None):
    out = model.get_logits(x, hypernetwork=hypernetwork, act_ck=True)
    print(out.shape)
    #print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    #loss = torch.nn.CrossEntropyLoss()(out, out)
    #loss.backward()
    #model.zero_grad()
    #print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))

class HyperNetwork(nn.Module):
    def __init__(self, hidden_size, num_layers):
        super().__init__()
        embed_dim = hidden_size
        self.linear = nn.Linear(embed_dim, embed_dim, bias=True)

        state = self.state_dict()
        for k in state:
            state[k] = state[k] * 1 / math.sqrt(2 * num_layers)
        self.load_state_dict(state)

    def forward(self, hidden_states):
        hidden_states = self.linear(hidden_states)
        hidden_states = hidden_states.mul(torch.sigmoid(hidden_states))
        return hidden_states

def main():
    model = init_6b().cuda().half()

    for param in model.parameters():
        param.requires_grad = False

    for param in model.vocab_embed.parameters():
        param.requires_grad = True

    for x in model.layers:
        for param in x.ln_preattn.parameters():
            param.requires_grad = True

    hypernetwork = HyperNetwork(4096, 28).cuda().half()
    hypernetwork.train()

    shape = (1, 1)
    #print(model(x).shape)
    print("PyTorch Eager")
    timeit(r=1, n=2, func=lambda: forward(model, rndinput(shape), hypernetwork), do_tqdm=False, first=False)

if __name__ == "__main__":
    main()
