# find required pip packages in bobross_server.ipynb

#curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_imagenet_f16_1024.yaml
#curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_imagenet_f16_1024.ckpt

import math
import io
from pathlib import Path
import sys

#git clone https://github.com/CompVis/taming-transformers
sys.path.append('./taming-transformers')

from omegaconf import OmegaConf
from PIL import Image
import requests
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
import numpy as np
import collections
import threading
import json
import cv2
import random
import base64
from dotmap import DotMap

#from CLIP import clip
import clip

from lm_node import utils

def _init_model(nodeconf):
    return

args = {
    "prompts": ['air conditioning for dry banana'],
    "image_prompts": [],
    "noise_prompt_weights": [],
    "noise_prompt_seeds": [],
    "init_image": None,
    "init_weight": 0.5,
    "clip_model": 'ViT-B/32',
    "vqgan_config": 'vqgan_imagenet_f16_1024.yaml',
    "vqgan_checkpoint": 'vqgan_imagenet_f16_1024.ckpt',
    "step_size": 0.05,
    "cutn": 32,
    "cut_pow": 1.,
    "display_freq": 50,
    "seed": 0,

    # display
    "use_augs": True,
    "noise_fac": 0.1,
    "ema_val": 0.99,
    "record_generation": True,

    # noise and other constraints
    "use_noise": None,
    "constraint_regions": False,

    # mse settings
    "mse_withzeros": True,
    "mse_decay_rate": 4,
    "mse_epoches": 25,
    "mse_quantize": False,

    # end itteration
    "max_itter": -1,
    
    # image size
    "blocks_width": 54,
    "blocks_height": 28,
    
    # linear layer dims
    "encoder_dims": 1536,

}

args = DotMap(args)
offsetx = 0
offsety = 0
csize = 244
do_random_cuts = False
edims = args.encoder_dims
resw = args.blocks_width # * 8
resh = args.blocks_height
mse_weight = args.init_weight
mse_decay = 0

replace_grad = utils.image.ReplaceGrad.apply
clamp_with_grad = utils.image.ClampWithGrad.apply

class Declipper(nn.Module):    
    def __init__(self, **ignorekwargs):
        super().__init__()
        self.linear = torch.nn.Linear(edims, 256*resw*resh, bias=False)#2048*4*4)


    def forward(self, x):
        x = self.linear(x)
        x = torch.reshape(x, (x.shape[0], 256,resh,resw))
        return x

def _init_model(nodeconf):
    config = OmegaConf.load(nodeconf.config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(nodeconf.model_path)
    del model.loss
    declipper = Declipper().cuda()
    perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)

    return model, perceptor, declipper

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if args.init_weight:
    mse_decay = args.init_weight / args.mse_epoches

lock = threading.Lock()

def ascend_txt(model, perceptor, declipper, i, set_lout=False, constrain_z=False):
    global mse_weight
    global baseZc
    global lout

    z, _, _ = model.quantize(declipper(baseZ.average))
    out = utils.image.synth(model, z)
    iii = perceptor.encode_image(normalize(make_cutouts(out))).float()

    result = []
    
    if lout is not None:

        if not do_random_cuts:
            mask = (torch.randn_like(out) * 0) + 1
            mask[:, :, offsety:offsety + csize, offsetx:offsetx + csize] *= 0
            result.append(F.mse_loss(out * mask, lout * mask) * 80)

    if set_lout or lout is None:
        lout = out.clone().detach()

    result.append(F.mse_loss(baseZ.tensor, baseZc) * mse_weight / 2)

    with torch.no_grad():
        if i > 0 and i%args.mse_decay_rate==0:
            if args.mse_quantize:
                baseZc = vector_quantize(baseZ.average.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)#z.average
            else:
                baseZc = baseZ.average.clone()

            if mse_weight - mse_decay > 0 and mse_weight - mse_decay >= mse_decay:
                mse_weight = mse_weight - mse_decay

            else:
                mse_weight = 0
    
    
    for prompt in pMs:
        rp = prompt(iii)
        result.append(rp)

    return result

def reset(reset_z=True):
    global baseZ
    global baseZc
    global pMs
    global opt
    global mse_weight
    del opt
    
    if reset_z:
        baseZ = torch.randn((1, edims)).cuda()
        baseZ = EMATensor(baseZ, args.ema_val)
        baseZc = torch.zeros_like(baseZ.tensor)
        declipper.linear.reset_parameters()
    mse_weight = args.init_weight

    opt = optim.RMSprop([
        {'params': declipper.parameters(), 'lr': 0.00005},
        {'params': baseZ.parameters(), 'lr': 0.07}
        ], 
        momentum=0.5
        )
    
    pMs = []

def train(i, set_lout=False, constrain_z=False):
    opt.zero_grad()
    lossAll = ascend_txt(i, set_lout=set_lout, constrain_z=constrain_z)

    l = None
    for lo in lossAll:

        if l is None:
            l = lo
        else:
            l = l + lo
            
    l.backward()
    opt.step()
    
i = 0
do_random_cuts = True

def does_whatever(pMs, texts, perceptor, append="") -> None:
    pMs.clear()
    for prompt in texts:
        embed = perceptor.encode_text(clip.tokenize(append+prompt).to(device)).float()
        pMs.append(Prompt(embed, float(1), float('-inf')).to(device))

def _generate(model, perceptor, declipper, texts, constrain_z=False, do_blur_trick=True):
    global opt
    global do_random_cuts
    global offsetx
    global offsety
    global csize

    cut_size = perceptor.visual.input_resolution
    f = 2**(model.decoder.num_resolutions - 1)
    make_cutouts = utils.image.MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
    n_toks = model.quantize.n_e
    toksX, toksY = args.size[0] // f, args.size[1] // f
    sideX, sideY = toksX * f, toksY * f

    if args.seed is not None:
        torch.manual_seed(args.seed)

    model.requires_grad_(False)
    declipper.requires_grad_(True)

    baseZ = torch.randn((1, edims)).cuda() * 0
    baseZ.requires_grad_(False)

    baseZc = baseZ.clone()

    perceptor.requires_grad_(False)

    p = list(model.decoder.parameters())

    print(len(p))
    mul = 1
    newp = []
    i = 0
    for pa in p:
        newp.append({'params': pa, 'lr': args.step_size * mul})
        
        if i > 6*4:
            mul = mul / 1.5
        
        i = i + 1


    opt = optim.RMSprop([
        {'params': declipper.parameters(), 'lr': 0.0000399},
    ])


    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711])


    pMs = []
    baseWeight = None
    baseBias = None
    
    reset(not constrain_z)
    
    do_random_cuts = False
    offsetx = 0
    offsety = 0
    csize = 256

    if not do_blur_trick:
        does_whatever(pMs, texts, perceptor, "")

    if do_blur_trick:
        does_whatever(pMs, texts, perceptor, "Very blurry ")

    del opt
    opt = optim.RMSprop([
        {'params': declipper.parameters(), 'lr': 0.00005},
        {'params': baseZ.parameters(), 'lr': 0.07},
        ],
        momentum=0
        )

    for i in range(5):
        train(i,constrain_z=constrain_z)

    if do_blur_trick:
        does_whatever(pMs, texts, perceptor, append="Blurry ")

    for i in range(15):
        train(i,constrain_z=constrain_z)

    del opt
    opt = optim.RMSprop([
        {'params': declipper.parameters(), 'lr': 0.00005},
        {'params': baseZ.parameters(), 'lr': 0.07},
        ],
        momentum=0.2)

    if do_blur_trick:
        does_whatever(pMs, texts, perceptor)

    for i in range(25):
        train(i,constrain_z=constrain_z)

    if do_blur_trick:
        does_whatever(pMs, texts, perceptor, append="High resolution ")
        
    z, _, _ = model.quantize(declipper(baseZ.average))
    out = synth(z)
    pim = TF.to_pil_image(out[0].cpu())
    ret = image_to_b64(np.array(pim)[:, :, ::-1])
    return json.dumps({"image":ret, "text": texts[0]})