import random
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlparse, parse_qs
from lame import Lame
import sys
import socket
import signal

import numpy as np
import time
import os
import json
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader

import multiprocessing

import vits.commons as commons
import vits.utils as utils
from vits.models import SynthesizerTrn
from vits.text.symbols import symbols
from vits.text import text_to_sequence
import nltk
import voices
import logging
import sentry_sdk
from sentry_sdk import capture_exception
from sentry_sdk import capture_message
from sentry_sdk import start_transaction
from sentry_sdk.integrations.threading import ThreadingIntegration

logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
    "%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)

environment = os.getenv("DEV")
secret = os.getenv("SECRET")
port = os.getenv("PORT", "80")
if environment == "True":
    environment = "staging"

elif environment == "False":
    environment = "production"

else: # if it's null
    environment = "development"
    

logger.info("Starting server on port %s" % port)
sentry_sdk.init(
    "https://0dcc09626e2b405d8f006aa1a54800fa@o846434.ingest.sentry.io/6253166",
    server_name=socket.gethostname(),
    traces_sample_rate=0.002,
    environment=environment,
    integrations=[ThreadingIntegration(propagate_hub=True)],
)

max_timeout = 3
used_tickets = {}

named_voices = {
    -1: "seed",
    17: "Cyllene",
    95: "Leucosia",
    44: "Crina",
    80: "Hespe",
    106: "Ida",
    6: "Alseid",
    10: "Daphnis",
    16: "Echo",
    41: "Thel",
    77: "Nomios",
}

max_text_len = 1000
phonemizer_procs = 6
rosa_procs = 6
sample_rate = 22050
default_text = "Text to speech"
default_sid = 6

def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

short_pause = np.zeros((int(sample_rate * 0.115)), dtype=np.int16)
long_pause = np.zeros((int(sample_rate * 0.305)), dtype=np.int16)

def gpu_thread(queue, proc_name):
    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model).cuda()
    _ = net_g.eval()
    try:
        t = time.perf_counter()
        _ = utils.load_checkpoint("pretrained_vctk.pth", net_g, None)

    except Exception as e:
        capture_exception(e)
        logger.error("Failed to load model, restarting.")
        sys.exit(1)

    logger.info(f"Loaded pretrained model on {proc_name} in {time.perf_counter() - t:.2f} seconds.")
    while True:
        req = queue.get()
        if req is None:
            break
        stn_tst, sid, ret_queue, seed = req
        if stn_tst == "":
            ret_queue.put(np.array([0], dtype=np.int16))
            continue
        with torch.no_grad(): # save VRAM by not including gradients
            x_tst = stn_tst.cuda().unsqueeze(0)
            x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
            sid_t = torch.LongTensor([sid]).cuda()
            emb = None
            if sid < 0:
                emb = voices.get_emb(seed, net_g)
            audio = net_g.infer(x_tst, x_tst_lengths, sid=sid_t, noise_scale=.667, emb=emb, noise_scale_w=0.8, length_scale=1)[0][0,0].data.detach().float()
            audio = audio.cpu().numpy()
            del emb
        ret_queue.put(audio)

def infer_to_mp3(text, sid, fh, seed, opus):
    artist=None
    if sid >= 0 and sid <= 108:
        if sid in named_voices:
            artist = named_voices[sid]
        else:
            artist = f"seed.no{sid}"
    if sid == -1:
        artist = f"seed.{seed}"
    lame = Lame(ofile=fh, chunked=False, opus=opus, album="NovelAI", artist=artist, title=text)
    lame.start()
    for line in [x for x in text.split("\n") if len(x)]:
        for big_sent in nltk.sent_tokenize(line):
            for sent in big_sent.split(";"):
                if lame.finished:
                    #print("Early")
                    break
                stn_tst = phonemizer_pool.apply(get_text, [sent.replace('"', ''), hps])
                ret_queue = m.Queue()
                gpu_queue.put((stn_tst, sid, ret_queue, seed))
                audio = ret_queue.get()
                audio *= hps.data.max_wav_value
                lame.add_pcm(audio.astype(np.int16))
                lame.add_pcm(short_pause)
        lame.add_pcm(long_pause)
    lame.finish()
    print("Finished")

class TTSServer(BaseHTTPRequestHandler):
    def log_message(self, format, *args):
        return

    def do_GET(self):
        self.send_response(200)
        parsed = urlparse(self.path)
        logger.info(f"GET {parsed.path}")
        authtoken = self.headers.get('Authorization')
        if secret is not None and authtoken != "Bearer " + secret:
            self.send_header('Content-type', 'text/html')
            self.end_headers()
            self.wfile.write(bytes("Invalid authtoken", "utf-8"))
            return

        if parsed.path == "/audio":
            try:
                query = parse_qs(parsed.query)
                text = query['text'][0][:max_text_len]
                sid = int(query['sid'][0])
                try:
                    seed = query['seed'][0][:256]
                except:
                    seed = "none"
                try:
                    opus = query["opus"][0] == "true"
                except:
                    opus = False
                assert sid >= -1 and sid <= 108
                if opus:
                    self.send_header("Content-type", "audio/webm")
                else:
                    self.send_header("Content-type", "audio/mpeg")
                self.send_header("Cache-control", "no-cache")
                self.send_header("Access-Control-Allow-Origin", "*")
                self.end_headers()
                s = time.perf_counter()
                infer_to_mp3(text, sid, self.wfile, seed, opus)
                logger.info(f"Request took {str(time.perf_counter() - s)} seconds for {str(len(text))} characters.")

            except Exception as e:
                self.send_header("Content-type", "text/html")
                self.end_headers()
                self.wfile.write(b"Error")
                capture_exception(e)
                logger.error(e)

        elif parsed.path == "/decode":
            try:
                query = parse_qs(parsed.query)
                seed = query['seed'][0][:256]
                props = json.dumps(voices.get_props(seed))
                self.send_header("Content-type", "text/html")
                self.end_headers()
                self.wfile.write(props.encode())

            except Exception as e:
                self.send_header("Content-type", "text/html")
                self.end_headers()
                self.wfile.write(b"Error")
                capture_exception(e)
                logger.error(e)

        elif parsed.path == "/":
            self.send_header("Content-type", "text/html")
            self.end_headers()
            self.wfile.write(f"OK".encode())

def simple_handle_exit(signal, frame):
    print("Termination signal (%d) received, dying instantly" % signal)
    sys.exit(0)

if __name__ == "__main__":
    # need these before forking the process
    signal.signal(signal.SIGINT, simple_handle_exit)
    signal.signal(signal.SIGTERM, simple_handle_exit)

    phonemizer_pool = multiprocessing.Pool(phonemizer_procs)
    rosa_pool = multiprocessing.Pool(rosa_procs)

    hps = utils.get_hparams_from_file("vits/configs/vctk_base.json")
    m = multiprocessing.Manager()
    gpu_queue = multiprocessing.Queue()
    gpu_proc = multiprocessing.Process(target=gpu_thread, args=(gpu_queue, "proc1"))
    gpu_proc.start()
    wait_queue = m.Queue()
    gpu_queue.put(("", 0, wait_queue, "none"))
    wait_queue.get()
    gpu_proc2 = multiprocessing.Process(target=gpu_thread, args=(gpu_queue, "proc2"))
    gpu_proc2.start()

    cudadev = torch.cuda.current_device()
    gb_gpu = int(torch.cuda.get_device_properties(0).total_memory /
                (1000 * 1000 * 1000))
    logger.info("GPU: " + torch.cuda.get_device_name(cudadev))
    logger.info("GPU RAM: " + str(gb_gpu) + "gb")

    webserver = ThreadingHTTPServer(("0.0.0.0", int(port)), TTSServer)
    try:
        webserver.serve_forever()
    except Exception as e:
        capture_exception(e)
        logger.error(e)

    webserver.server_close()
    logger.info("Exiting")

    """
    gpu_queue.close()
    gpu_queue.join_thread()

    phonemizer_pool.close()
    rosa_pool.close()
    phonemizer_pool.join()
    rosa_pool.join()
    logger.info("Pools closed.")
    gpu_proc.terminate()
    gpu_proc2.terminate()
    gpu_proc.close()
    gpu_proc2.close()
    logger.info("Everything closed, exiting.")
    """

