import unittest
import pathlib
import sys
import grpc
from typing import List

thisPath = str(pathlib.Path(__file__).parent.resolve())
sys.path.append(thisPath + "/../interfaces/gooseai/completion")
sys.path.append(thisPath + "/../interfaces/gooseai/engines")

import interfaces.gooseai.engines.engines_pb2 as engine
import interfaces.gooseai.engines.engines_pb2_grpc as engines_grpc
import interfaces.gooseai.completion.completion_pb2 as completion
import interfaces.gooseai.completion.completion_pb2_grpc as completions_grpc

# endpoint = "gpt-neo-20b-predictor-default.tenant-gooseprod-1.knative.chi.coreweave.com:80"
endpoint = "localhost:50051"
channel = grpc.insecure_channel(endpoint)
completion_grpc = completions_grpc.CompletionServiceStub(channel)
engine_grpc = engines_grpc.EnginesServiceStub(channel)

def toTokens(tokens: List[int]) -> completion.Tokens:
    wrapped = []
    for id in tokens:
        wrapped.append(completion.Token(id=id))
    return completion.Tokens(tokens=wrapped)

prompts: List[completion.Prompt] = \
    [completion.Prompt(text="The mercurial and beautiful witch laughed"),
     completion.Prompt(text="The warlock cackled and cast a spell")]

token_prompts: List[completion.Prompt] = \
    [completion.Prompt(tokens=toTokens(
        [464, 11991, 333, 498, 290, 4950, 16365, 13818]))]

promptStr = open(pathlib.Path(thisPath,"mary_shelley_2.txt")).read()

largePrompt: List[completion.Prompt] = [completion.Prompt(text=promptStr)]


class TestGRPCBasicRequest(unittest.TestCase):
    def test_basic_request(self):
        rq = completion.Request(prompt=[prompts[0]])
        acc = ""
        for resp in completion_grpc.Completion(rq):
            acc += resp.choices[0].text
        assert acc != ""
        assert resp.choices[0].finish_reason == completion.FinishReason.LENGTH

class TestEnginesInfo(unittest.TestCase):
    def test_engines_info(self):
        rq = engine.ListEnginesRequest()
        engines = engine_grpc.ListEngines(rq)
        assert len(engines.engine) == 1
        assert engines.engine[0].id != ""
        assert engines.engine[0].owner != ""
        assert engines.engine[0].name != ""
        assert engines.engine[0].description != ""
        assert engines.engine[0].ready == True
        assert engines.engine[0].type == engine.TEXT
        assert engines.engine[0].tokenizer == engine.GPT2 or \
            engines.engine[0].tokenizer == engine.PILE

class TestGRPCTokenRequest(unittest.TestCase):
    def test_token_request(self):
        rq = completion.Request(prompt=token_prompts)
        acc = ""
        for resp in completion_grpc.Completion(rq):
            acc += resp.choices[0].text
        assert acc != ""
        assert resp.choices[0].finish_reason == completion.FinishReason.LENGTH

class TestGRPCMultiPrompt(unittest.TestCase):
    def test_multi_request(self):
        rq = completion.Request(prompt=prompts)
        completions = {}
        for resp in completion_grpc.Completion(rq):
            completions[resp.answer_id] = completions.get(resp.answer_id, "") + \
                                          resp.choices[0].text
        assert len(completions) == 2
        assert len(list(completions.items())[0]) != 0
        assert len(list(completions.items())[1]) != 0


class TestGRPCEcho(unittest.TestCase):
    def test_echo(self):
        rq = completion.Request(prompt=[prompts[0]],
                                engine_params=completion.EngineParams(
                                    echo=completion.Echo(index=0)))
        acc = ""
        for resp in completion_grpc.Completion(rq):
            acc += resp.choices[0].text
        assert acc.startswith(prompts[0].text)

class TestGRPCMetadata(unittest.TestCase):
    def test_grpc_metadata(self):
        rq = completion.Request(prompt=[prompts[0]],
                                engine_params=completion.EngineParams(
                                    echo=completion.Echo(index=0)))
        for resp in completion_grpc.Completion(rq):
            assert resp.inference_received > 0
            assert resp.choices[0].started > 0
            assert resp.meta.cpu_id is not None
            assert resp.meta.cpu_id != ""
            assert resp.meta.gpu_id is not None
            assert resp.meta.gpu_id != ""
            assert resp.meta.node_id is not None
            assert resp.meta.node_id != ""


class TestGRPCEchoNoIndex(unittest.TestCase):
    def test_echo_noindex(self):
        rq = completion.Request(prompt=[prompts[0]],
                                engine_params=completion.EngineParams(
                                    echo=completion.Echo()))
        acc = ""
        for resp in completion_grpc.Completion(rq):
            acc += resp.choices[0].text
        assert acc.startswith(prompts[0].text)


class TestGRPCEchoOnly(unittest.TestCase):
    def test_echo_only(self):
        rq = completion.Request(prompt=[prompts[0]],
                                engine_params=completion.EngineParams(
                                    echo=completion.Echo(index=0),
                                    max_tokens=0))
        acc = ""
        resp_ct = 0
        for resp in completion_grpc.Completion(rq):
            resp_ct += 1
            acc += resp.choices[0].text
            assert resp.choices[0].finish_reason == \
                   completion.FinishReason.LENGTH
        assert acc.startswith(prompts[0].text)
        assert resp_ct == 1


class TestGRPCEchoOnlyLogprobs(unittest.TestCase):
    def test_echo_only(self):
        num_logprobs = 10
        rq = completion.Request(prompt=[prompts[0]],
                                engine_params=completion.EngineParams(
                                    echo=completion.Echo(index=0),
                                    max_tokens=0,
                                    logprobs=num_logprobs))
        acc = ""
        resp_ct = 0
        for resp in completion_grpc.Completion(rq):
            resp_ct += 1
            acc += resp.choices[0].text
        assert acc.startswith(prompts[0].text)
        assert resp_ct == 1


class TestGRPCEchoIndex(unittest.TestCase):
    def test_echo_index(self):
        rq = completion.Request(prompt=[completion.Prompt(text=" a b c d e f")],
                                engine_params=completion.EngineParams(
                                    echo=completion.Echo(index=3)))
        acc = ""
        resp_ct = 0
        for resp in completion_grpc.Completion(rq):
            resp_ct += 1
            acc += resp.choices[0].text
        assert acc.startswith(" c d e f")


class TestGRPCEchoIndexLogprobs(unittest.TestCase):
    def test_echo_index_logprobs(self):
        rq = completion.Request(prompt=[completion.Prompt(text=" a b c d e f")],
                                engine_params=completion.EngineParams(
                                    echo=completion.Echo(index=3),
                                    logprobs=10))
        acc = ""
        resp_ct = 0
        for resp in completion_grpc.Completion(rq):
            resp_ct += 1
            acc += resp.choices[0].text
        assert acc.startswith(" c d e f")


class TestGRPCLogprobs(unittest.TestCase):
    def test_logprobs(self):
        num_logprobs = 10
        rq = completion.Request(prompt=[prompts[0]],
                                engine_params=completion.EngineParams(
                                    logprobs=num_logprobs))
        for resp in completion_grpc.Completion(rq):
            logprobs = resp.choices[0].logprobs
            tokens_ct = len(logprobs.tokens.logprobs)
            for chosen_logprob in logprobs.tokens.logprobs:
                assert chosen_logprob.HasField("logprob")
                assert chosen_logprob.HasField("logprob_before")
            after = logprobs.top
            before = logprobs.top_before
            assert len(after) == tokens_ct
            assert len(before) == tokens_ct
            for before_logprob in before:
                assert len(before_logprob.logprobs) == num_logprobs


class TestGRPCLogprobsEcho(unittest.TestCase):
    def test_logprobs_echo(self):
        num_logprobs = 10
        rq = completion.Request(prompt=[prompts[0]],
                                engine_params=completion.EngineParams(
                                    logprobs=num_logprobs,
                                    echo=completion.Echo(index=0)))
        firstToken = True
        for resp in completion_grpc.Completion(rq):
            logprobs = resp.choices[0].logprobs
            tokens_ct = len(logprobs.tokens.logprobs)
            for chosen_logprob in logprobs.tokens.logprobs:
                if firstToken:
                    firstToken = False
                    continue
                assert chosen_logprob.HasField("logprob_before")
                before = logprobs.top_before
                assert len(logprobs.top_before) == tokens_ct
                firstBefore = True
                for before_logprob in before:
                    if firstBefore:
                        firstBefore = False
                        continue
                    assert len(before_logprob.logprobs) == num_logprobs


class TestGRPCUnicodeEcho(unittest.TestCase):
    def test_unicode_echo(self):
        rq = completion.Request(
            prompt=[completion.Prompt(text="This is a ⁂ handy asterism ⁂")],
            engine_params=completion.EngineParams(
                echo=completion.Echo(index=0)))
        for resp in completion_grpc.Completion(rq):
            choice = resp.choices[0]
            assert choice.logprobs.text_offset[4] == \
                   choice.logprobs.text_offset[5]
            assert choice.logprobs.text_offset[5] == \
                   choice.logprobs.text_offset[6]


class TestGRPCUnicodeResponse(unittest.TestCase):
    def test_unicode_response(self):
        asterismPrompt = "⁂ foo ⁂ foo ⁂ foo"
        rq = completion.Request(
            prompt=[completion.Prompt(text=asterismPrompt)],
            engine_params=completion.EngineParams(
                echo=completion.Echo(index=0)))
        acc = ""
        for resp in completion_grpc.Completion(rq):
            acc += resp.choices[0].text
        assert acc.startswith(asterismPrompt)


class TestGRPCLargePrompts(unittest.TestCase):
    def test_largeprompt(self):
        rq = completion.Request(prompt=largePrompt,
                                engine_params=completion.EngineParams(
                                    logprobs=10))
        acc = ""
        for resp in completion_grpc.Completion(rq):
            acc += resp.choices[0].text


class TestGRPCTypicalP(unittest.TestCase):
    def test_typical_p(self):
        rq = completion.Request(
            prompt=prompts,
            engine_params=completion.EngineParams(
                echo=completion.Echo(index=0),
                max_tokens=500),
            model_params=completion.ModelParams(
                sampling_params=completion.SamplingParams(
                    typical_p=0.1),
                frequency_params=completion.FrequencyParams(
                    frequency_penalty=1.4)))
        completions = {}
        for resp in completion_grpc.Completion(rq):
            completions[resp.answer_id] = completions.get(resp.answer_id, "") + \
                                          resp.choices[0].text
        for k in completions:
            print("\n" + k, "-", completions[k])
        assert len(completions) == 2
        assert len(list(completions.items())[0]) != 0
        assert len(list(completions.items())[1]) != 0


class TestGRPCLogitOrder(unittest.TestCase):
    def test_logit_order(self):
        rq = completion.Request(
            prompt=prompts,
            engine_params=completion.EngineParams(
                logprobs=10),
            model_params=completion.ModelParams(
                sampling_params=completion.SamplingParams(
                    order=[completion.TEMPERATURE, completion.TOP_K],
                    top_k=1)))
        for resp in completion_grpc.Completion(rq):
            print(resp)


class TestGRPCLogitBiasBan(unittest.TestCase):
    def test_logit_ban(self):
        rq = completion.Request(
            engine_params=completion.EngineParams(logprobs=10, max_tokens=1),
            prompt=[completion.Prompt(text="goose goose goose goose goose "
                                           "goose goose goose goose goose")],
            model_params=completion.ModelParams(
                logit_bias=completion.LogitBiases(
                    biases=[completion.LogitBias(
                        tokens=completion.Tokens(
                            tokens=[completion.Token(id=37246)]),
                        bias=-100.0)])))
        for answer in completion_grpc.Completion(rq):
            top_before = answer.choices[0].logprobs.top_before[0].logprobs
            for top in top_before:
                if top.token.text == " goose":
                    assert top.logprob == float('-inf')


class TestGRPCLogitBias(unittest.TestCase):
    def test_logit_bias(self):
        rq = completion.Request(
            engine_params=completion.EngineParams(logprobs=10, max_tokens=1),
            prompt=[completion.Prompt(text="woof woof woof woof woof the ")],
            model_params=completion.ModelParams(
                logit_bias=completion.LogitBiases(
                    biases=[completion.LogitBias(
                        tokens=completion.Tokens(
                            tokens=[completion.Token(id=37246)]),
                        bias=100.0)])))
        for answer in completion_grpc.Completion(rq):
            top_before = answer.choices[0].logprobs.top_before[0].logprobs
            assert answer.choices[0].text == " goose"
            for top in top_before:
                if top.token.text == " goose":
                    assert top.logprob == 0.0


class TestGRPCStopSequences(unittest.TestCase):
    def test_stop_sequences(self):
        sampling_params = completion.SamplingParams(
            temperature=1,
            top_k=1)
        logit_bias = completion.LogitBiases(
            biases=[completion.LogitBias(
                tokens=completion.Tokens(tokens=[completion.Token(id=46256)]),
                bias=5.0)])
        rq = completion.Request(
            prompt=[completion.Prompt(text="⁂⁂⁂⁂⁂⁂⁂⁂⁂⁂")],
            model_params=completion.ModelParams(
                sampling_params=sampling_params,
                logit_bias=logit_bias),
            engine_params=completion.EngineParams(
                best_of=0,
                logprobs=8,
                completions=1,
                max_tokens=25,
                stop=[completion.Prompt(text="⁂"),
                      completion.Prompt(text="\n")]))
        last_token = False
        for answer in completion_grpc.Completion(rq):
            assert last_token is False
            if answer.choices[0].text == "⁂":
                assert answer.choices[0].finish_reason == \
                       completion.FinishReason.STOP
                last_token = True


class TestBMKCase(unittest.TestCase):
    def test_bmk_case(self):
        rq = completion.Request(
            prompt=[completion.Prompt(
                # text=" Problem: Tara rolls three standard dice once.   What is the probability that the sum of the numbers rolled will be three or more?   Express your answer as a percent.\nAnswer:")],
                text=" Problem: Tara rolls three standard dice once. What is the probability that the sum of the numbers rolled will be three or more? Express your answer as a percent.\nAnswer:")],
            engine_params=completion.EngineParams(
                logprobs=10,
                max_tokens=32,
                stop=[completion.Prompt(text="\n\n"),
                      completion.Prompt(text="\n")]),
            model_params=completion.ModelParams(
                sampling_params=completion.SamplingParams(
                    temperature=0.0)))
        text = ""
        for resp in completion_grpc.Completion(rq):
            choice = resp.choices[0]
            tokens = resp.choices[0].logprobs.tokens
            print("CHOICE: [" + str(tokens.logprobs[0].token.id) + "] |" +
                  resp.choices[0].text + "|")
            text += resp.choices[0].text
            for logprobs in resp.choices[0].logprobs.top_before:
                for logprob in logprobs.logprobs:
                    print("LOGPROB: [" + str(
                        logprob.token.id) + "] |" + logprob.token.text.replace(
                        "\n", "\\n") + "|",
                          logprob.logprob_before)
            print("====")
        print(text.replace("\n", "\\n"))

# if __name__ == '__main__':
#    suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestGRPCMethods)
#    unittest.TextTestRunner().run(suite)
