import pathlib
import signal
import sys
import traceback
import platform
import os
import dotmap

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from lm_node import utils
from lm_node.base import GPTModel
from lm_node.config import init_config
from lm_node.prefix import bytes_to_prefix
from lm_node.utils.text import process_scores

thisPath = str(pathlib.Path(__file__).parent.resolve())
sys.path.append(thisPath + "/interfaces/gooseai/completion")
sys.path.append(thisPath + "/interfaces/gooseai/engines")
sys.path.append(thisPath + "/transformers/src")
import uuid
from concurrent import futures
from typing import List, Dict, Tuple, Any
import gc

import grpc
from lm_node.sanitize import *
from lm_node.base import *
from lm_node.models.GPT import params_report
import interfaces.gooseai.completion.completion_pb2 as completion
import interfaces.gooseai.engines.engines_pb2 as engine
from sentry_sdk import capture_exception

import interfaces.gooseai.completion.completion_pb2_grpc as completion_grpc
import interfaces.gooseai.engines.engines_pb2_grpc as engine_grpc


class EngineServicer(engine_grpc.EnginesServiceServicer):
    def __init__(self, model: GPTModel, config: dotmap.DotMap):
        super().__init__()
        tokenizer_id = engine.GPT2
        if model.tokenizer_id == "pile":
            tokenizer_id = engine.PILE

        self.engine_data = engine.Engines(
            engine=[engine.EngineInfo(
                id=config.model_name,
                name=config.model_name,
                owner=config.model_name,
                description=config.model_name,
                ready=True,
                type=engine.TEXT,
                tokenizer=tokenizer_id)])

    def ListEngines(self, request, context):
        return self.engine_data


class CompletionServicer(completion_grpc.CompletionServiceServicer):
    def __init__(self, model: GPTModel, config: dotmap.DotMap,
                 server: Any):
        super().__init__()
        self.answer_meta = completion.AnswerMeta(
            cpu_id=config.cpu_id,
            gpu_id=config.gpu_id,
            node_id=config.node_id
        )
        self.model: GPTModel = model
        self.config = config
        self.server = server
        self.pause = os.getenv("PAUSE")
        self.dummy = os.getenv("DUMMY", "False").lower() == "true"
        self.dumpRequests = os.getenv("DUMP_REQUESTS",
                                      "False").lower() == "true"
        signal.signal(signal.SIGTERM, self.handle_exit)

    def handle_exit(self, sig, frame):
        logger.warning(f"Received signal {sig} in {frame}! Preparing to exit.")
        self.server.stop(60)

    def Completion(self, request, context):
        received = int(time.time() * 1000)
        if self.pause is not None:
            logger.info(f"Pausing for {self.pause} seconds")
            time.sleep(int(self.pause))
        if self.dummy is True:
            for x in range(20):
                yield completion.Answer(
                    answer_id=str(uuid.uuid4()),
                    created=int(time.time() * 1000),
                    model=self.config.model_name,
                    choices=[
                        completion.Completion(
                            text=request.prompt[0].text,
                            index=0,
                            token_index=0,
                            logprobs=completion.LogProbs(),
                            started=int(time.time() * 1000))],
                    meta=self.answer_meta,
                    inference_received=received)
                time.sleep(1)
            return

        try:
            with torch.no_grad():
                begin = time.time()
                req_params = self.fill_params(request)
                request_id_repr = f"[{req_params['request_id']}]"

                target_completion = 0
                for prompt in request.prompt:
                    started = int(time.time() * 1000)
                    answer_id = str(uuid.uuid4())
                    req_dict = DotMap(req_params)
                    if prompt.text:
                        tkns = \
                            [self.model.tokenizer(prompt.text).input_ids]
                    else:
                        tkns = [[tkn.id for tkn in prompt.tokens.tokens]]
                    req_dict.input = torch.tensor(tkns)
                    output = process_payload(req_dict,
                                             self.config,
                                             True,
                                             self.model)
                    logger.info("{0} generate() request: {1}".format(
                        request_id_repr,
                        params_report(self.model, req_dict.input, req_dict)))
                    if self.dumpRequests:
                        req_dict.pprint()
                    logger.info("{0} generate() start: {1}".format(
                        request_id_repr,
                        self.model.getGPUram()))
                    if output[0]:  # success
                        req_dict, warning = output
                    else:
                        ic(output[1])
                    last_offset = 0
                    last_idx = 0
                    if req_dict.echo is not None and \
                            req_dict.echo.index <= len(req_dict.input[0]):
                        # Echo back the token probabilities of the prompt(s) if
                        # requested.
                        echo_response = self.generate_echoes(
                            req_dict, req_dict.num_logprobs,
                            target_completion,
                            answer_id, started)
                        echo_response.inference_received = received
                        first_choice = echo_response.choices[0]
                        last_offset = len(first_choice.text)
                        last_idx = len(first_choice.logprobs.tokens.logprobs)
                        yield echo_response
                    if req_dict.max_length > len(req_dict.input[0]):
                        # Our requested length is more than the length of the
                        # input, so we begin generating tokens.
                        for answer in self.generate(req_dict,
                                                    last_offset,
                                                    answer_id,
                                                    req_params[
                                                        "num_logprobs"],
                                                    target_completion,
                                                    started,
                                                    last_idx):
                            answer.inference_received = received
                            yield answer
                    target_completion += 1
                    logger.info(
                        "{0} generate() done: {1}".format(
                            request_id_repr,
                            self.model.getGPUram()))
        except Exception as e:
            print(''.join(traceback.format_tb(e.__traceback__)))
            ic(e)
            capture_exception(e)
            e_s = str(e)
            if "CUDA out of memory" in e_s or \
                    "an illegal memory access" in e_s:
                logger.error("GPU OOM or memory error, committing seppuku.")
                logger.error(self.model.getGPUram())
                self.handle_exit(None, None)
            raise e
        finally:
            utils.memory.cleanup_variables(locals(),
                                           "output",
                                           "req_dict",
                                           "answer")
            gc.collect()
            torch.cuda.synchronize()
        return

    def fill_params(self, request: completion.Request) -> Dict:
        # ic(request)
        req_params = {
            "use_cache": True,
            "pad_token_id": self.model.tokenizer.eos_token_id,
            "top_k": 0,
            "repetition_penalty": 1.0,
            "max_length": 16,
            "min_length": 0,
            "num_logprobs": -1,
            "echo": None,
            "output_scores": True,
            "do_sample": True,
            "prefix": None,
        }

        if request.HasField("request_id"):
            req_params['request_id'] = request.request_id
        else:
            req_params['request_id'] = str(uuid.uuid4())

        if request.HasField("model_params"):
            mp: completion.ModelParams = request.model_params
            if mp.HasField("sampling_params"):
                sp: completion.SamplingParams = mp.sampling_params
                if sp.HasField("top_k"):
                    req_params["top_k"] = sp.top_k
                if sp.HasField("top_p") and sp.top_p:
                    req_params["top_p"] = sp.top_p
                if sp.HasField("temperature") and sp.temperature is not None:
                    if sp.temperature == 0:
                        req_params["temperature"] = None
                        req_params["top_k"] = 1
                    else:
                        req_params["temperature"] = sp.temperature
                if sp.HasField("tail_free_sampling") and sp.tail_free_sampling:
                    req_params["tail_free_sampling"] = sp.tail_free_sampling
                if sp.HasField("typical_p") and sp.typical_p:
                    req_params["typical_p"] = sp.typical_p
                if sp.HasField("top_a") and sp.top_a:
                    req_params["top_a"] = sp.top_a
                if len(sp.order):
                    req_params["order"] = []
                    for processor_id in sp.order:
                        if processor_id != 0:
                            req_params["order"].append(processor_id - 1)
            if mp.HasField("frequency_params"):
                fp: completion.FrequencyParams = mp.frequency_params
                if fp.HasField("presence_penalty") and fp.presence_penalty:
                    req_params["repetition_penalty_presence"] = \
                        fp.presence_penalty
                if fp.HasField("frequency_penalty") and fp.frequency_penalty:
                    req_params["repetition_penalty_frequency"] = \
                        fp.frequency_penalty
                if fp.HasField("repetition_penalty") and fp.repetition_penalty:
                    req_params["repetition_penalty"] = \
                        fp.repetition_penalty
                if fp.HasField("repetition_penalty_range") and \
                        fp.repetition_penalty_range:
                    req_params["repetition_penalty_range"] = \
                        fp.repetition_penalty_range
                if fp.HasField("repetition_penalty_slope") and \
                        fp.repetition_penalty_slope:
                    req_params["repetition_penalty_slope"] = \
                        fp.repetition_penalty_slope
            if mp.HasField("logit_bias"):
                lbs: completion.LogitBiases = mp.logit_bias
                logit_bias = []
                for lb in lbs.biases:
                    if len(lb.tokens.tokens) == 0:
                        raise Exception("InvalidLogitBias: No token provided "
                                        "for bias.")
                    elif len(lb.tokens.tokens) != 1:
                        raise Exception("InvalidLogitBias: Only one token per "
                                        "bias is permitted at this time.")
                    logit_bias.append([lb.tokens.tokens[0].id, lb.bias])
                if len(logit_bias) > 0:
                    req_params["logit_bias"] = logit_bias

        if request.HasField("engine_params"):
            ep: completion.EngineParams = request.engine_params
            if ep.HasField("max_tokens"):
                req_params["max_length"] = ep.max_tokens
            if ep.HasField("min_tokens"):
                req_params["min_length"] = ep.min_tokens
            if ep.HasField("logprobs"):
                req_params["num_logprobs"] = ep.logprobs
            if len(ep.stop) > 0:
                req_params["stop_sequences"] = []
                for stop in ep.stop:
                    req_params["stop_sequences"].append(
                        self.model.tokenizer.encode(stop.text,
                                                    return_tensors="pt"))
            if ep.HasField("echo"):
                if ep.echo.HasField("index"):
                    req_params["echo"] = completion.Echo(index=ep.echo.index)
                else:
                    req_params["echo"] = completion.Echo(index=0)

        if len(request.embeddings):
            ems: List[completion.Embedding] = request.embeddings
            embs = []
            for em in ems:
                if em.HasField("raw"):
                    tensor: completion.Tensor = em.raw
                    dims: List[int] = tensor.dims
                    tensor_typ: completion.NumType = tensor.typ
                    typ: np.ScalarType = np.float32
                    if tensor_typ == completion.FP16:
                        typ = np.float16
                    emb = (em.pos, bytes_to_prefix(dims, tensor.data, typ))
                    embs.append(emb)
            if embs:
                req_params['embs'] = embs

        req_params['parameters'] = {
            'prefix': req_params.get('prefix', None)
        }

        return req_params

    def to_logprobs(self, arr):
        if arr is None:
            return None
        return [completion.LogProb(
            token=completion.Token(
                text=self.model.tokenizer.decode(tpl[0]),
                id=tpl[0][0]),
            logprob=tpl[1][1],
            logprob_before=tpl[1][0]) for tpl in arr]

    def process_scores(self, chosen, before: torch.Tensor, after: torch.Tensor,
                       num_logprobs: int) -> \
            Tuple[completion.TokenLogProbs,
                  Union[Tuple[completion.TokenLogProbs], None],
                  Union[Tuple[completion.TokenLogProbs], None]]:

        chosen_logprobs, before_logprobs, after_logprobs = \
            process_scores(chosen, before, after, num_logprobs)

        if before_logprobs is not None:
            before_logprobs = (completion.TokenLogProbs(
                logprobs=self.to_logprobs(before_logprobs)),)
        if after_logprobs is not None:
            after_logprobs = (completion.TokenLogProbs(
                logprobs=self.to_logprobs(after_logprobs)),)

        return (completion.TokenLogProbs(
            logprobs=self.to_logprobs(chosen_logprobs)),
                after_logprobs,
                before_logprobs)

    def generate(self, req_dict: DotMap, begin_offset: int,
                 answer_id: str, num_logprobs: int,
                 target_completion: int,
                 started: int,
                 start_token_idx: Union[int, None] = None) \
            -> completion.Answer:
        offset = begin_offset
        if start_token_idx is None:
            token_idx: int = len(req_dict.input[0])
        else:
            token_idx: int = start_token_idx

        tokens = []
        text_offsets = []
        top_logprobs = []
        before_logprobs = []
        token_logprobs = []

        def serialize_answer(tkn_idx, tkn_offset, reason):
            acc = []
            text = ""

            for tkn in tokens:
                acc.append(tkn)
                if self.model.unitrim.send_ready(acc):
                    curr_offset = len(text) + tkn_offset
                    text += self.model.tokenizer.decode(acc)
                    text_offsets.extend([curr_offset] * len(acc))
                    acc.clear()
            if len(acc):
                curr_offset = len(text) + tkn_offset
                text += self.model.tokenizer.decode(acc)
                text_offsets.extend([curr_offset] * len(acc))

            logprobs = completion.LogProbs(
                tokens=completion.TokenLogProbs(logprobs=token_logprobs),
                text_offset=text_offsets)
            if len(top_logprobs):
                logprobs.top_before.extend(top_logprobs)
            if len(before_logprobs):
                logprobs.top.extend(before_logprobs)

            compl = completion.Completion(
                text=text,
                index=target_completion,
                token_index=tkn_idx,
                logprobs=logprobs,
                started=started)

            if reason:
                if reason == "MaxLengthCriteria":
                    compl.finish_reason = completion.FinishReason.LENGTH
                elif reason == "SequenceStoppingCriteria":
                    compl.finish_reason = completion.FinishReason.STOP
                else:
                    compl.finish_reason = completion.FinishReason.NULL

            ser_answer = completion.Answer(
                request_id=req_dict['request_id'],
                answer_id=answer_id,
                created=int(time.time() * 1000),
                model=self.config.model_name,
                choices=[compl],
                meta=self.answer_meta)

            tkn_idx += len(tokens)
            tkn_offset += len(text)

            tokens.clear()
            text_offsets.clear()
            top_logprobs.clear()
            before_logprobs.clear()
            token_logprobs.clear()

            return ser_answer, tkn_idx, tkn_offset

        with torch.no_grad(), self.model.mutex:
            generated = self.model.model.generate(req_dict.input.long().cuda(),
                                                  **req_dict.parameters)
            for token, is_finished, scores_before, scores_after in generated:
                chosen, logprobs_before, logprobs_after = \
                    self.process_scores(token, scores_before, scores_after,
                                        num_logprobs)
                token_logprobs.append(chosen.logprobs[0])
                tokens.append(chosen.logprobs[0].token.id)
                if logprobs_before is not None:
                    before_logprobs.append(logprobs_before[0])
                if logprobs_after is not None:
                    top_logprobs.append(logprobs_after[0])
                if self.model.unitrim.send_ready(tokens):
                    answer, token_idx, offset = serialize_answer(token_idx,
                                                                 offset,
                                                                 is_finished)
                    yield answer
                del scores_before, scores_after, logprobs_before, logprobs_after
            if len(tokens):
                answer, token_idx, offset = serialize_answer(token_idx,
                                                             offset,
                                                             is_finished)
                yield answer

    def generate_echoes(self, req_dict: DotMap, num_logprobs: int,
                        target_completion: int, answer_id: str,
                        started: int) -> \
            completion.Answer:
        token_logprobs: Union[List[completion.LogProb], None] = []
        before_logprobs: Union[List[completion.LogProbs], None] = []
        target_idx = 0
        if req_dict.echo.index > 0:
            target_idx = req_dict.echo.index
        elif req_dict.echo.index < 0:
            target_idx = len(req_dict.input[0]) - target_idx
        target_idx = max(0, min(target_idx, len(req_dict.input[0])))
        text_offsets: List[int] = []

        def serialize_answer():
            acc = []
            text = ""

            for tkn in tokens:
                acc.append(tkn)
                if self.model.unitrim.send_ready(acc):
                    curr_offset = len(text)
                    text += self.model.tokenizer.decode(acc)
                    text_offsets.extend([curr_offset] * len(acc))
                    acc.clear()
            if len(acc):
                curr_offset = len(text)
                text += self.model.tokenizer.decode(acc)
                text_offsets.extend([curr_offset] * len(acc))

            compl = completion.Completion(
                text=text,
                index=target_completion,
                token_index=target_idx,
                logprobs=completion.LogProbs(
                    tokens=completion.TokenLogProbs(
                        logprobs=token_logprobs),
                    top_before=before_logprobs,
                    text_offset=text_offsets),
                started=started)

            if req_dict.max_length == len(req_dict.input[0]):
                compl.finish_reason = completion.FinishReason.LENGTH

            answer = completion.Answer(
                request_id=req_dict['request_id'],
                answer_id=answer_id,
                created=int(time.time() * 1000),
                model=self.config.model_name,
                choices=[compl],
                meta=self.answer_meta)

            tokens.clear()
            text_offsets.clear()
            before_logprobs.clear()
            token_logprobs.clear()

            return answer

        tokens = []
        if num_logprobs == -1:
            input_begin = target_idx
            if input_begin > 0:
                input_begin -= 1
            for token in req_dict.input[0][input_begin:]:
                tokens.append(token)
                token_logprob = self.to_logprobs([[[token],
                                                   [None, None]]])[0]
                token_logprobs.append(token_logprob)
        else:
            for prob in self.model.get_token_probabilities(req_dict.input,
                                                           n=num_logprobs,
                                                           idx=target_idx):
                # Get the data for the token in the prompt.
                chosen_prob = prob["chosen"]
                tokens.append(chosen_prob[0][0])
                token_logprob = self.to_logprobs([chosen_prob])[0]
                token_logprobs.append(token_logprob)
                top_prob = prob.get("choices", [])
                if num_logprobs > 0:
                    before_logprobs.append(completion.TokenLogProbs(
                        logprobs=self.to_logprobs(top_prob)))

        if len(tokens):
            return serialize_answer()


def serve(model: GPTModel, config: dotmap.DotMap) -> grpc.Server:
    threads = int(os.getenv("GRPC_MAX_WORKERS", "16"))
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=threads))
    completion_grpc.add_CompletionServiceServicer_to_server(
        CompletionServicer(model, config, server), server)
    engine_grpc.add_EnginesServiceServicer_to_server(
        EngineServicer(model, config), server)
    port = os.getenv("PORT", "50051")
    bind = os.getenv("GRPC_BIND", "0.0.0.0:" + port)
    server.add_insecure_port(bind)
    server.start()
    logger.info("GRPC server started on " + bind)
    return server
    # server.wait_for_termination()


def simple_handle_exit(signal, frame):
    print("Termination signal (%d) received, dying instantly" % signal)
    sys.exit(0)


def main():
    # Setup pre-serve signal handlers
    signal.signal(signal.SIGTERM, simple_handle_exit)
    signal.signal(signal.SIGINT, simple_handle_exit)

    model, config = init_config()

    # And serve it up via GRPC.
    logger.info("Starting GRPC server")
    completion_server = serve(model, config)
    completion_server.wait_for_termination()
    logger.info("Exiting")


if __name__ == "__main__":
    main()
