import io

import json
import os
from typing import Sequence

curr_path = os.path.dirname(__file__)

# check if these tokens are ready to be sent to frontend
class Unitrimmer():
    def __init__(self, path_or_file=curr_path + "/unitrim.json"):
        if isinstance(path_or_file, io.IOBase):
            self.unicode_req = json.load(path_or_file)
        else:
            with open(path_or_file, "r") as fh:
                self.unicode_req = json.load(fh)

    def serialize(self) -> str:
        return json.dumps(self.unicode_req)

    def get_token(self, tkn: int) -> int:
        try:
            return self.unicode_req[tkn]
        except:
            return 0

    def send_ready(self, tokens) -> bool:
        n_tokens = len(tokens)
        good = 0
        need = 0
        for i in range(n_tokens):
            req = self.get_token(tokens[i])
            if not (need + req < 0):
                need += req
            if req == 0:
                # reset need to 0 to avoid being stuck when we have invalid
                # unicode being generated.
                need = 0
            if need == 0:
                good = i + 1
        return good == n_tokens

    # trim_front should be off for nodes
    def trim(self, tkns: Sequence[int], trim_front=False) -> Sequence[int]:
        if trim_front:
            i = 0
            n_tokens = len(tkns)
            while i < n_tokens and self.get_token(tkns[i]) < 0:
                i += 1
            tkns = tkns[i:]
        n_tokens = len(tkns)
        good = 0
        need = 0
        for i in range(n_tokens):
            req = self.get_token(tkns[i])
            if not (need + req < 0):
                need += req
            if need == 0:
                good = i + 1
        return tkns[0:good]

if __name__ == "__main__":
    trimmer = Unitrimmer()

    print(trimmer.trim([251,447,251,447], trim_front=True))
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    print("reference", tokenizer.encode("猫です。"))
    tokens = trimmer.trim(tokenizer.encode("猫です。")[:-1], trim_front=True)
    print("corrupt end, true", tokens, tokenizer.decode(tokens))
    tokens = trimmer.trim(tokenizer.encode("猫です。")[1:-1], trim_front=True)
    print("corrupt both, true", tokens, tokenizer.decode(tokens))
    tokens = trimmer.trim(tokenizer.encode("猫です。")[:-1], trim_front=False)
    print("corrupt end, false", tokens, tokenizer.decode(tokens))
    tokens = trimmer.trim(tokenizer.encode("猫です。")[1:-1], trim_front=False)
    print("corrupt both, false", tokens, tokenizer.decode(tokens))