# -*- coding: utf-8 -*-
"""TFDNE Editor v1 - TF (2).ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1O5XbpMRU9i83mLAiTrMufCqmImgTRI7A

# This Fursona Does Not Exist - Fursona Editor (Tensorflow Version)

For a faster, easier to use editor (albeit with fewer customization options), try [Artbreeder](https://artbreeder.com) and select a furry image to start customizing.

### To run:
**Click "Open in Playground" above (if you see that option) and then Runtime > Run All and wait a few minutes for everything to load. (Takes about 4 minutes)**

Feel free to watch the demo video or read the code (by clicking the triangle in the next section) while you wait.

---


### Credits

Customizable furries using [@arfa](https://twitter.com/arfafax)'s furry StyleGAN2 model from [TFDNE](https://thisfursonadoesnotexist.com/).

Latent directions were discovered using [@harskish](https://twitter.com/harskish)'s [Ganspace](https://github.com/harskish/ganspace) [notebook](https://colab.research.google.com/github/harskish/ganspace/blob/master/notebooks/Ganspace_colab.ipynb) put together by [@realmeatyhuman](https://twitter.com/realmeatyhuman), with help from [@KeyKitsune](https://twitter.com/KitsuneKey) for providing code to load the saved latent directions.



---



### Shameless Plug
If you enjoyed This Fursona Does Not Exist and want to see more projects like it in the future, consider supporting me on Ko-fi or Patreon.

-arfa


<div>
<a href="https://www.twitter.com/arfafax">
<img src="https://thisfursonadoesnotexist.com/arfa_sig.png" width="350"/>
</a>
</div>
<div>
<a href="https://ko-fi.com/arfafax">
<img src="https://cdn.ko-fi.com/cdn/kofi3.png?v=2" width="220"/>
</a>
<a href="https://www.patreon.com/arfafax">
<img src="https://c5.patreon.com/external/logo/become_a_patron_button.png" width="235"/>
</a>
</div>
"""

#@title Demo video
from IPython.display import YouTubeVideo
YouTubeVideo('dyV_URuA0yE', width=1280, height=720)

"""## <- Click the triangle to view the code while you wait for it to load

"""

!git clone https://github.com/shawwn/stylegan2 -b estimator /content/stylegan2

import gdown
!wget  -O /content/network-e621.pkl https://thisfursonadoesnotexist.com/model/network-e621-r-512-3194880.pkl
!wget  -O /content/directions.zip https://thisfursonadoesnotexist.com/directions.zip

# Commented out IPython magic to ensure Python compatibility.
# %tensorflow_version 1.x
# %cd /content/stylegan2

import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import scipy
import tensorflow as tf
import tflex

if 'COLAB_TPU_ADDR' in os.environ:
    os.environ['TPU_NAME'] = 'grpc://' + os.environ['COLAB_TPU_ADDR']
    os.environ['NOISY'] = '1'

tflib.init_tf()
sess = tf.get_default_session()
sess.list_devices()
cores = tflex.get_cores()
tflex.set_override_cores(cores)
_G, _D, Gs = pickle.load(open("/content/network-e621.pkl", "rb"))
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.

def generate_image_from_w(w, truncation_psi):
    with tflex.device('/gpu:0'):
        #_G, _D, Gs = pickle.load(open("/content/network-e621.pkl", "rb"))
        noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
        Gs_kwargs = dnnlib.EasyDict()
        Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
        Gs_kwargs.randomize_noise = False
        if truncation_psi is not None:
            Gs_kwargs.truncation_psi = truncation_psi
        synthesis_kwargs = dict(output_transform=Gs_kwargs.output_transform, truncation_psi=truncation_psi, minibatch_size=1)
        images = Gs.components.synthesis.run(w, randomize_noise=False, **synthesis_kwargs)
        #PIL.Image.fromarray(images[0], 'RGB').save('seed%04d.png' % seed)
        display(PIL.Image.fromarray(images[0], 'RGB'))

# Commented out IPython magic to ensure Python compatibility.
# %cd "/content"

!unzip /content/directions.zip

!rm /content/directions/StyleGAN2-Light_direction-ffhq-ipca-w-style-comp15-range8-9.pkl
!rm -r /content/directions/.ipynb_checkpoints
#!mv /content/directions/directions/* /content/ganspace/directions
!rm -r /content/directions/directions

import os
named_directions = {}
latent_dirs = []
starts = []
ends = []
# 
path_to_directions = "/content/directions"

# This loads the directions in a dictionary in this format:
# {'name' : [direction_num, start, end]}
# and you load the direction by:
# directions[direction_num]

for i,file in enumerate(sorted(os.listdir(path_to_directions))):
    np_file = np.load(f'{path_to_directions}/{file}', allow_pickle=True)
    name = file.split("_layers_")[0].split("/")[-1]
    file = file.split('_')
    
    named_directions[f'{name}'] = [i, int(file[-2]), int(file[-1].split('.')[0])]
    latent_dirs.append(np_file)

named_directions

"""# UI"""

#@title Run UI (make sure you've done Runtime > Run All first or it won't work)
from ipywidgets import fixed
import PIL
import numpy as np
import ipywidgets as widgets
from PIL import Image
from IPython.display import clear_output

def display_sample(seed, truncation, direction, distance, scale, start, end, update, disp=True, save=None, noise_spec=None, **args):
    if update == False:
        print("False")
    # blockPrint()
    rng = np.random.RandomState(seed)
    z = rng.standard_normal(*Gs.input_shape[1:]).reshape(1, *Gs.input_shape[1:])

    noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.randomize_noise = False

    all_w = Gs.components.mapping.run(z, None, **Gs_kwargs) # [minibatch, layer, component]

    for i, item in enumerate(args):
        value = args[item]
        start_l = named_directions[item][1]
        end_l = named_directions[item][2]
        direction_l = latent_dirs[named_directions[item][0]]
        for l in range(start_l, end_l):
            all_w[0][l] = all_w[0][l] + direction_l * value * scale

    if truncation != 1:
        w_avg = Gs.get_var('dlatent_avg')
        all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]

    generate_image_from_w(all_w, truncation)

seed = np.random.randint(0,100000)
style = {'description_width': '110px'}
row_length = 5

seed = widgets.IntSlider(min=0, max=100000, step=1, value=seed, description='Seed: ', continuous_update=False)
truncation = widgets.FloatSlider(min=0, max=2, step=0.1, value=0.4, description='Truncation: ', continuous_update=False)
distance = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description='Distance: ', continuous_update=False, style=style)
scale = widgets.FloatSlider(min=0, max=10, step=0.05, value=1, description='Scale: ', continuous_update=False)
start_layer = widgets.IntSlider(min=0, max=18, step=1, value=18, description='start layer: ', continuous_update=False)
end_layer = widgets.IntSlider(min=0, max=18, step=1, value=18, description='end layer: ', continuous_update=False)


update = widgets.Checkbox(value=True, description="update")

directions_list = []
params = {'seed': seed, 'truncation': truncation, 'direction': fixed(0), 'distance': distance, 'scale': scale, 'start': start_layer, 'end': end_layer, 'update' : update}

for i, item in enumerate(named_directions):
    widget = widgets.FloatSlider(min=-20, max=20, step=0.1, value=0, description=item + ': ', continuous_update=False, style=style, layout={'width' : '350px'})
    directions_list.append(widget)
    params[item] = widget

bot_box = widgets.HBox([seed, truncation])

ui = widgets.VBox([bot_box])

row_list = []
foo = []
for i, item in enumerate(directions_list):
    row_list.append(item)
    if len(row_list) == row_length:
        bar = widgets.HBox(row_list)
        foo.append(bar)
        row_list = []
bar = widgets.HBox(row_list)
foo.append(bar)

ui2 = widgets.VBox(foo)

random = widgets.Button(description="Randomize Sliders")
reset = widgets.Button(description="Reset Sliders")
mobile = widgets.Button(description="Mobile Mode")
desktop = widgets.Button(description="Desktop Mode")
def reset_sliders(b):
    directions_list = []
    params_new = {'seed': seed, 'truncation': truncation, 'direction': fixed(0), 'distance': distance, 'scale': scale, 'start': start_layer, 'end': end_layer, 'update' : update}
    for i, item in enumerate(named_directions):
        widget = widgets.FloatSlider(min=-20, max=20, step=0.1, value=0, description=item + ': ', continuous_update=False, style=style, layout={'width' : '350px'})
        directions_list.append(widget)
        params_new[item] = widget
    params = params_new
    row_list = []
    foo = []
    for i, item in enumerate(directions_list):
        row_list.append(item)
        if len(row_list) == row_length:
            bar = widgets.HBox(row_list)
            foo.append(bar)
            row_list = []
    bar = widgets.HBox(row_list)
    foo.append(bar)

    ui2 = widgets.VBox(foo)
    clear_output()
    out = widgets.interactive_output(display_sample, params)
    last_button = mobile
    if row_length == 1:
        last_button = desktop
    display(ui, out, ui2, reset, random, last_button)

def random_sliders(b):
    directions_list = []
    params_new = {'seed': seed, 'truncation': truncation, 'direction': fixed(0), 'distance': distance, 'scale': scale, 'start': start_layer, 'end': end_layer, 'update' : update}
    for i, item in enumerate(named_directions):
        widget = widgets.FloatSlider(min=-20, max=20, step=0.1, value=np.random.normal(scale=2.5), description=item + ': ', continuous_update=False, style=style, layout={'width' : '350px'})
        directions_list.append(widget)
        params_new[item] = widget
    params = params_new
    row_list = []
    foo = []
    for i, item in enumerate(directions_list):
        row_list.append(item)
        if len(row_list) == row_length:
            bar = widgets.HBox(row_list)
            foo.append(bar)
            row_list = []
    bar = widgets.HBox(row_list)
    foo.append(bar)

    ui2 = widgets.VBox(foo)
    clear_output()
    out = widgets.interactive_output(display_sample, params)
    last_button = mobile
    if row_length == 1:
        last_button = desktop
    display(ui, out, ui2, reset, random, last_button)

def mobile_mode(b):
    global row_length
    global params
    row_length = 1
    directions_list = []
    params_new = {'seed': seed, 'truncation': truncation, 'direction': fixed(0), 'distance': distance, 'scale': scale, 'start': start_layer, 'end': end_layer, 'update' : update}
    for i, item in enumerate(named_directions):
        widget = widgets.FloatSlider(min=-20, max=20, step=0.1, value=params[item].value, description=item + ': ', continuous_update=False, style=style, layout={'width' : '350px'})
        directions_list.append(widget)
        params_new[item] = widget
    params = params_new
    row_list = []
    foo = []
    for i, item in enumerate(directions_list):
        row_list.append(item)
        if len(row_list) == row_length:
            bar = widgets.HBox(row_list)
            foo.append(bar)
            row_list = []
    bar = widgets.HBox(row_list)
    foo.append(bar)

    ui2 = widgets.VBox(foo)
    clear_output()
    out = widgets.interactive_output(display_sample, params)
    display(ui, out, ui2, reset, random, desktop)

def desktop_mode(b):
    global row_length
    global params
    row_length = 5
    directions_list = []
    params_new = {'seed': seed, 'truncation': truncation, 'direction': fixed(0), 'distance': distance, 'scale': scale, 'start': start_layer, 'end': end_layer, 'update' : update}
    for i, item in enumerate(named_directions):
        widget = widgets.FloatSlider(min=-20, max=20, step=0.1, value=params[item].value, description=item + ': ', continuous_update=False, style=style, layout={'width' : '350px'})
        directions_list.append(widget)
        params_new[item] = widget
    params = params_new
    row_list = []
    foo = []
    for i, item in enumerate(directions_list):
        row_list.append(item)
        if len(row_list) == row_length:
            bar = widgets.HBox(row_list)
            foo.append(bar)
            row_list = []
    bar = widgets.HBox(row_list)
    foo.append(bar)

    ui2 = widgets.VBox(foo)
    clear_output()
    out = widgets.interactive_output(display_sample, params)
    display(ui, out, ui2, reset, random, mobile)

random.on_click(random_sliders)
reset.on_click(reset_sliders)
mobile.on_click(mobile_mode)
desktop.on_click(desktop_mode)
out = widgets.interactive_output(display_sample, params)

display(ui, out, ui2, reset, random, mobile)