import torch
import torch.functional as F
from torch import nn
import kornia.augmentation as K


def resample(input):
    if True:
        return F.interpolate(input, (224, 224), mode='nearest')


def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)


class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)


class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None


replace_grad = ReplaceGrad.apply
clamp_with_grad = ClampWithGrad.apply


def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)


def synth(model, z):
    z_q = z
    return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)


class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        #self.embed = embed.unsqueeze(0)#F.normalize(embed.unsqueeze(0), dim=2)#embed
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))

    def _forward(self, input):
        input_normed = input.unsqueeze(1)#F.normalize(input.unsqueeze(1), dim=2)
        #embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = F.mse_loss(input_normed, self.embed) #input_normed.sub(self.embed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        #dists = dists * self.weight.sign()
        return dists#self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
    
    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()


class EMATensor(nn.Module):
    """implmeneted by Katherine Crowson"""
    def __init__(self, tensor, decay):
        super().__init__()
        self.tensor = nn.Parameter(tensor)
        self.register_buffer('biased', torch.zeros_like(tensor))
        self.register_buffer('average', torch.zeros_like(tensor))
        self.decay = decay
        self.register_buffer('accum', torch.tensor(1.))
        self.update()
    
    @torch.no_grad()
    def update(self):
        if not self.training:
            raise RuntimeError('update() should only be called during training')

        self.accum *= self.decay
        self.biased.mul_(self.decay)
        self.biased.add_((1 - self.decay) * self.tensor)
        self.average.copy_(self.biased)
        self.average.div_(1 - self.accum)

    def forward(self):
        if self.training:
            return self.tensor
        return self.average


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1., offsetx=0, offsety=0, do_random_cuts=False, csize=244):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.offsetx = offsetx
        self.offsety = offsety
        self.do_random_cuts = do_random_cuts
        self.csize = csize

        self.augs = nn.Sequential(
            K.RandomAffine(degrees=25, translate=0.1, padding_mode='reflection'),
            K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7),
        )
        
    def forward(self, input):
        if self.do_random_cuts:
            return self._forward(input)
        cutout = input[:, :, self.offsety:self.offsety + self.csize, self.offsetx:self.offsetx + self.csize]
        
        return clamp_with_grad(self.augs(torch.cat([resample(cutout, (self.cut_size, self.cut_size)), ] * (self.csize // 8), dim=0)), 0, 1)
        
    def _forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)
