import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
    from collections.abc import MutableMapping
except ImportError:
    from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer.models import base_lm
from typing import Optional, Any
from icecream import ic


def _attn(query, key, value, causal_mask, masked_bias,
            attention_mask=None, scale_attn=None, fp32_attn=True):

    if fp32_attn:
        attn_weights = torch.matmul(query.float(), key.transpose(-1, -2).float())
    else:
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

    attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
    if scale_attn:
        attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = F.softmax(attn_weights, dim=-1)
    attn_weights = attn_weights.to(value.dtype)

    attn_output = torch.matmul(attn_weights, value).to(value.dtype)

    return attn_output

class SelfAttention(nn.Module):
    # Code copied from HF, might want to sanity check later.
    def __init__(self, config, attn_type):
        ic(attn_type)
        nn.Module.__init__(self)
        self.config = config
        max_positions = 2049
        bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
            1, 1, max_positions, max_positions).bool()

        if attn_type == "local":
            self.register_buffer(
                "bias",
                bias ^ torch.tril(bias, -config.window_size),
            )
        else:
            self.register_buffer(
                "bias",
                bias,
            )

        self.head_dim = config.hidden_dim // config.n_head
        self.rotary_dim = self.head_dim // 4
        self.hidden_dim = config.hidden_dim
        self.n_head = config.n_head
        device = config.device
        dtype = config.dtype

        self.scale_attn = None
        self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
        attn_bias = False #fairseq has attn_bias
        self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=True, device=device, dtype=dtype)
        
    def forward(self, x, kv=None, cache=False):
        B, S, H = x.shape # batch, sequence, hidden_dim
        # split heads into: [batch, head, sequence, head_dim]
        query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
        key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
        value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)

        if kv:
            k, v = kv
            # cat key and value (get the whole sequence, other than the last added token all are cached),
            # so query can attend to it.
            torch.cat([k, key], dim=-2) # cat key
            torch.cat([v, value], dim=-2) # cat value
            
        query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
        causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]

        x = _attn(
            query, key, value, causal_mask, self.masked_bias, None, self.scale_attn, self.config.fp32_attn
        )

        x = x.transpose(1, 2).contiguous().view(B, S, H)
        x = self.out_proj(x)
        if cache:
            return x, (key, value)
        else:
            return x, None

class FeedForward(nn.Module):
    def __init__(self, config):
        nn.Module.__init__(self)
        self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim * 4, device=config.device, dtype=config.dtype)
        self.ff2 = nn.Linear(config.hidden_dim * 4, config.hidden_dim, device=config.device, dtype=config.dtype)
        self.activation = config.activation

    def forward(self, x, act_ck=False):
        x = self.ff1(x)
        if act_ck:
            x = ck(self.activation, x)
        else:
            x = self.activation(x)
        x = self.ff2(x)
        return x

class GPTNeoLayer(nn.Module):
    def __init__(self, attn, ff, config):
        nn.Module.__init__(self)
        self.hidden_dim = config.hidden_dim
        self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
        self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
        self.ff = ff(config)
        if config.layer_idx % 2 == 0:
            attn_type = "global"
        else:
            attn_type = "local"
        self.attn = attn(config, attn_type)

    def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, cache=False, kv=None):
        residual = x
        
        if act_ck:
            x = ck(self.ln_preattn, x)
            attn_out, kv = ck(self.attn, x, kv=kv, cache=cache)

        else:
            x = self.ln_preattn(x)
            attn_out, kv = self.attn(x, kv=kv, cache=cache)

        x = residual + attn_out
        residual = x
        x = self.ln_postattn(x)
        ff_out = self.ff(x, act_ck)
        x = residual + ff_out
            
        return x, kv

class GPTNeoModel(base_lm.BaseModel):
    def __init__(self, user_config, **kwargs):
        self.default_config = {
            'n_layer': 6,
            'n_head': 8,
            'n_tokens': 2049,
            'hidden_dim': 512,
            'vocab_dim': 50257,
            'fp32_attn': False,
            'eps': 1e-5,
            'device': torch.device('cuda'),
            'dtype': torch.float16,
            'Layer': GPTNeoLayer,
            'activation': gelu_new,
            'SelfAttention': SelfAttention,
            'FeedForward': FeedForward,
            'window_size': 256,
        }
        base_lm.BaseModel.__init__(self, user_config, **kwargs)
        self.pos_embed = nn.Embedding(self.config.n_tokens, self.config.hidden_dim)
        self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False)
        #bias=False for neo models

    def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
        if kv is None:
            kv = [None] * self.n_layer
            past_length = 0

        else:
            past_length = kv[0][0].size(-2) #get sequence dim of key

        kv_new = []

        position_ids = torch.arange(past_length,
                                        x.shape[-1] + past_length,
                                        dtype=torch.long, device=x.device)
        position_ids = position_ids.unsqueeze(0).view(-1, x.shape[-1])

        x = self.vocab_embed(x)
        x = x + self.pos_embed(position_ids)

        for layer_id, layer in enumerate(self.layers):
            x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
            kv_new.append(kvi)

        x = self.ln_final(x)
        if cache:
            return x, kv_new
        else:
            return x, None