import torch
import re
import contextlib
import gender_guesser.detector as gender
import numpy as np
from hashlib import blake2b

detector = gender.Detector()

sids_male = [2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 16, 25, 26, 27, 28, 29, 30, 31, 32, 36, 38, 39, 40, 41, 43, 46, 53, 55, 56, 59, 60, 61, 63, 68, 69, 71, 72, 77, 81, 82, 85, 86, 88, 90, 96]

sids_female = [0, 1, 3, 12, 15, 17, 18, 19, 20, 21, 22, 23, 24, 33, 34, 35, 37, 42, 44, 45, 47, 48, 49, 50, 51, 52, 54, 57, 58, 62, 64, 65, 66, 67, 70, 73, 74, 75, 76, 78, 79, 80, 83, 84, 87, 89, 91, 92, 93, 94, 95, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107]

sids = [sids_male, sids_female]

base_seed = "NovelAI_SID"

@contextlib.contextmanager
def seeded(seed, purpose):
    seed = blake2b((base_seed + purpose + seed).encode()).hexdigest()
    seed = int(seed, 16)
    bg = np.random.PCG64(abs(seed))
    rng = np.random.Generator(bg)
    yield rng

def get_gender(seed):
    try:
        with seeded(seed, "gender") as rng:
            name = re.split(r"\s", re.sub(r"[0-9]", "", seed).strip())[0].strip().title()
            if len(name) < 1:
                return 1 if rng.random(1).item() < 0.5 else 0
            gender = detector.get_gender(name)
            if gender == "male":
                return 0
            elif gender == "female":
                return 1
            elif gender == "mostly_male":
                return 0 if rng.random(1).item() < 0.9 else 1
            elif gender == "mostly_female":
                return 1 if rng.random(1).item() < 0.9 else 0
            elif name[-1] in ["a", "i", "u"]:
                return 1 if rng.random(1).item() < 0.80 else 0
            else:
                return 1 if rng.random(1).item() < 0.5 else 0
    except:
        return 1

def get_props(seed):
    gender = get_gender(seed)
    with seeded(seed, "count") as rng:
        voice_count = 2 if rng.random(1).item() < 0.4 else 3
        if rng.random(1).item() < 0.00001:
            voice_count = 1
        elif rng.random(1).item() < 0.00005:
            voice_count = 4
    voices = []
    full_mix = False
    with seeded(seed, "selection"):
        choices = sids[gender]
        if rng.random(1).item() < 0.00005:
            choices = sids[0] + sids[1]
            full_mix = True
        high = len(choices)
        while len(voices) != voice_count:
            pick = choices[rng.integers(0, high, 1).item()]
            if pick not in voices:
                voices.append(pick)
    tries = 0
    good = False
    early = False
    with seeded(seed, "weights") as rng:
        factor = 1.0
        if rng.random(1).item() < 0.0005:
            factor = 1.0 + (rng.random(1).item() * 2.0 - 1.0) * 0.25
        while not good and tries < 5:
            tries += 1
            weights = []
            total_weight = 0.
            for i in range(voice_count):
                weight = rng.random(1).item()
                weights.append(weight)
                total_weight += weight
            good = True
            for i in range(voice_count):
                weights[i] /= total_weight * factor
                if weights[i] < 0.05:
                    good = False
            if rng.random(1).item() < 0.001:
                if not good:
                    early = True
                good = True
    voices = list(zip(voices, weights))
    props = {"gender": "male" if gender == 0 else "female", "voices": voices, "factor": factor, "full_mix": full_mix, "early": early}
    return props

def get_emb(seed, net_g):
    props = get_props(seed)
    try:
        if seed[0:2] == 'no':
            sid = int(seed[2:])
            if sid >= 0 and sid <= 108:
                props = {"voices": [(sid, 1.0)]}
    except:
        pass
    mixed_emb = None
    device = net_g.emb_g.weight.data.device
    for sid, weight in props["voices"]:
        emb = net_g.emb_g(torch.LongTensor([sid]).to(device)) * weight
        if mixed_emb is None:
            mixed_emb = emb.clone()
        else:
            mixed_emb += emb
    return mixed_emb.unsqueeze(-1)
