import signal
import sys
import pathlib

from lm_node.config import init_config

thisPath = str(pathlib.Path(__file__).parent.resolve())
sys.path.append(thisPath + "/interfaces/gooseai/completion")
sys.path.append(thisPath + "/transformers/src")
import grpc_server
import socket
from multiprocessing import Process, Queue

from prometheus_client import CollectorRegistry, Gauge, push_to_gateway
import pika
import websockets

from lm_node.utils.text import *
from lm_node.sanitize import *
from lm_node.base import *

from sentry_sdk import capture_exception
from sentry_sdk import capture_message
from sentry_sdk import start_transaction

global q
q = Queue()
global logger
logger = None

def main():
    global logger
    model, config = init_config()
    logger = config.logger

    grpcServer = grpc_server.serve(model, config)
    channel = DotMap({'value': None, 'queue': None})
    running = DotMap({'value': True})

    def addStopRunningHook(sig):
        fn = signal.getsignal(sig)

        def handleSignal(sig_param, frame):
            logger.warning(f"Received signal {sig_param}!")
            running.value = False
            if channel.value is not None:
                channel.value.close(reply_text='Terminating on signal')
            if fn is not None:
                fn(sig_param, frame)

        signal.signal(sig, handleSignal)

    addStopRunningHook(signal.SIGTERM)
    addStopRunningHook(signal.SIGINT)

    # Reconnect if the rabbitmq connection fails
    if config.rmq_host is not None and config.rmq_host != "":
        credentials = pika.credentials.PlainCredentials(config.rmq_username,
                                                        config.rmq_password)
        connection_params = pika.ConnectionParameters(config.rmq_host,
                                                      credentials=credentials)

        pid = os.getpid()
        th2 = Process(target=spy, args=(pid, config, running, ))
        th2.start()

        while running.value:
            try:
                connection = pika.BlockingConnection(connection_params)
                channel.value = connection.channel()

                logger.info("RabbitMQ queue: " + config.queue_name)
                channel.queue = channel.value.queue_declare(
                    queue=config.queue_name,
                    durable=True)
                channel.value.basic_qos(prefetch_count=1)
                threads = []
                on_message_callback = functools.partial(on_message,
                                                        args=(connection,
                                                              threads,
                                                              config,
                                                              model))
                # maybe kill the threads if connection dies?
                channel.value.basic_consume(
                    queue=config.queue_name,
                    on_message_callback=on_message_callback,
                    arguments={"x-max-priority": 11}
                )
                logger.info("Started consuming")
                channel.value.start_consuming()
                break
            except Exception as e:
                logger.error(e)
                capture_exception(e)
                continue
        logger.warning("RabbitMQ broker exiting.")
        os.kill(th2.pid, signal.SIGINT)
    else:
        grpcServer.wait_for_termination()
        logger.info("GRPC server exited!")
        sys.exit(0)
    grpcServer.stop(30)
    sys.exit(0)


def on_request(ch, method, props, body, connection, config, model):
    global q
    global sent_first_message
    global total_sent
    global args

    hostname = socket.gethostname()
    curr_time = time.time()
    ctx = [ch, method, props, connection]
    stream = None
    generation_id = None
    next_word = None
    get_hidden_states = None
    priority = None
    prep_time = 0
    try:
        req_dict = json.loads(body.decode("utf-8"))
        req_dict = DotMap(req_dict)
        if method.delivery_tag is None:
            logger.warning("deliverytag none, returning")
            return

        elif props.reply_to is None and req_dict.parameters.generation_id is None:
            logger.warning("reply_to none, returning")
            cb = functools.partial(nack_message, ch, method.delivery_tag)
            connection.add_callback_threadsafe(cb)
            return

        if config.model_type == "GPT":
            logger.info("Got request, processing...")
            # make whomst send this shit outside of parameters so we won't need
            # to do this mess
            get_hidden_states = req_dict.parameters.get("get_hidden_states",
                                                        False)
            next_word = req_dict.parameters.get("next_word", False)
            use_string = req_dict.parameters.get("use_string", True)
            priority = req_dict.parameters.get("priority", 10)
            stream = True if "generation_id" in req_dict.parameters else False
            generation_id = req_dict.parameters.generation_id if stream else None
            if not next_word and not get_hidden_states:
                sample = True
            else:
                sample = False

            output = process_payload(req_dict, config, use_string, model)
            if output[0]:  # success
                req_dict, warning = output

            else:
                ic()
                error_msg = output[1]
                logger.error(error_msg)
                capture_message(error_msg)
                # error_msg = "Server error, please try again."
                if not stream:
                    send_message(ctx, {"error": error_msg})

                else:
                    send_message(ctx, {"event": "token",
                                       "data": {"uuid": generation_id,
                                                "error": error_msg,
                                                "final": True}},
                                 stream=True)
                return

            req_params = req_dict.parameters.toDict()
            ids = req_dict.input

            task_name = "generate-normal"
            if generation_id:
                task_name = "generate-stream"

            prep_time = time.time() - curr_time
            logger.info(f"Request ready in {prep_time:0.3f}s")

            if sample:
                with start_transaction(op=task_name, name=hostname):
                    output = model.generate(ids, req_params, use_string,
                                            generation_id=generation_id)

            elif get_hidden_states:
                with start_transaction(op="get_hidden_states", name=hostname):
                    output = model.get_hidden_states(ids)
                logger.info(f"get_hidden_states call done in {time.time()-curr_time}")

            elif next_word:
                with start_transaction(op="next_word", name=hostname):
                    output = model.next_token_probabilities(ids)
                logger.info(f"next_word call done in {time.time()-curr_time}")

    except Exception as e:
        ic()
        if not stream:
            send_message(ctx,
                         {"error": str(e)})

        else:
            send_message(ctx,
                         {"event": "token",
                          "data": {"uuid": generation_id,
                                   "error": str(e),
                                   "final": True}}, stream=True)
        logger.error(str(e))
        capture_exception(e)
        e_s = str(e)
        gc.collect()
        if "CUDA out of memory" in e_s or \
                "an illegal memory access" in e_s or "CUDA" in e_s:
            logger.error("GPU error, committing seppuku.")
            nack_message(ch, method.delivery_tag)
            os.kill(os.getpid(), signal.SIGTERM)
        return

    if not get_hidden_states and not next_word:
        if priority > 0:
            q.put((time.time() - curr_time, prep_time, priority))

    logger.info(f"Request took {str(time.time() - curr_time)} seconds")
    if stream:
        cb = functools.partial(ack_message, ch, method.delivery_tag)
        connection.add_callback_threadsafe(cb)
    else:
        payload = {'output': output[0]}
        try:
            if len(output[1]) > 0:
                payload['logprobs'] = output[1]
        except:
            logger.error("Not logprobs: {0}".format(output))
        send_message(ctx, payload)
        del payload
    del output

    sent_first_message = True

    f = open("/tmp/health_readiness", "w")
    f.close()

    if os.environ['DEV'] == "False":
        f = open("/tmp/healthy", "w")
        f.close()


def send_message(ctx, message, stream=False):
    ch, method, props, connection = ctx
    if not stream:
        safepublish = lambda: ch.basic_publish(
            exchange="",
            routing_key=props.reply_to,
            properties=pika.BasicProperties(
                correlation_id=props.correlation_id,
                content_type="application/json"),
            body=json.dumps(message).encode("utf-8")
        )
        cb = functools.partial(ack_message, ch, method.delivery_tag)
        connection.add_callback_threadsafe(safepublish)
        connection.add_callback_threadsafe(cb)

    else:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        loop.run_until_complete(send_message_async(ctx, message))
        cb = functools.partial(ack_message, ch, method.delivery_tag)
        connection.add_callback_threadsafe(cb)


async def send_message_async(ctx, message):
    ch, method, props, connection = ctx
    if os.environ['DEV'] == "True":
        uri = "wss://staging.novelai.net/ai/internal/node-pipe"
    else:
        uri = "wss://api.novelai.net/ai/internal/node-pipe"

    async with websockets.connect(uri, ping_interval=None) as websocket:
        sender = json.dumps(message)
        # {"event":"token", "data":{"uuid": generation_id, "error": message, "final": True}}
        print(sender)
        await websocket.send(sender)


def on_message(channel, method_frame, header_frame, body, args):
    (connection, threads, config, model) = args
    delivery_tag = method_frame.delivery_tag
    on_request(channel, method_frame, header_frame, body, connection, config,
               model)


def ack_message(channel, delivery_tag):
    if channel.is_open:
        channel.basic_ack(delivery_tag)

    else:
        pass


def nack_message(channel, delivery_tag):
    if channel.is_open:
        channel.basic_nack(delivery_tag, requeue=False)

    else:
        pass


def spy(pid, config, running):
    if config.prom_addr is None:
        logger.info("Prometheus address not set, will only report to terminal.")

    machine_id = socket.gethostname()
    job_id = config.model_name
    if os.getenv('DEV').lower() == "true":
        job_id = "dev-" + job_id
    starttime = time.time()
    total_sent = 0
    processing_time = 0
    priority_time = 0
    prep_time = 0

    while running.value:
        try:
            time.sleep(60.0 - ((time.time() - starttime) % 60.0))
            for x in range(0, q.qsize()):
                total_sent += 1
                proc_time, prep_t, priority_value = q.get()
                prep_time += prep_t
                processing_time += proc_time
                priority_time += proc_time * priority_value / 10
            time_waited = time.time() - starttime
            starttime = time.time()

            registry = CollectorRegistry()
            ga = Gauge('compute_time', 'Compute time used in a minute',
                       registry=registry)
            gb = Gauge('answered_per_min', 'Answered requests in a minute',
                       registry=registry)
            gc = Gauge('priority_time', 'Derived priority time in a minute',
                       registry=registry)
            gd = Gauge('prep_time', 'Preparation time in a minute',
                       registry=registry)
            ge = Gauge('queue_depth', 'The depth of the queue',
                       registry=registry)
            ga.set(processing_time)
            gb.set(total_sent)
            gc.set(priority_time)
            gd.set(prep_time)

            logger.info(
                f"{job_id}: answered {total_sent} requests in "
                f"{time_waited:.3f}s, "
                f"{prep_time:.3f}s in preparation "
                f"using {processing_time:.3f}s processing seconds;")

            if config.prom_addr:
                push_to_gateway(config.prom_addr,
                                grouping_key={"instance": machine_id},
                                job=job_id, registry=registry,
                                timeout=55)
            total_sent = 0
            processing_time = 0
            priority_time = 0
            prep_time = 0


        except Exception as ex:
            logger.error(f"spy exception: {str(ex)}")
            capture_exception(ex)
            total_sent = 0
            processing_time = 0
            priority_time = 0
            prep_time = 0
            continue


if __name__ == "__main__":

    main()
