import pytest
from dotmap import DotMap
from icecream import ic

import lm_node.base
from lm_node.sanitize import *

req_params = {
    "top_k": 1,
    "top_p": 0.1,
    "temperature": 0.1,
    "min_length": 1,
    'max_length': 1,
    'repetition_penalty': 1.105,
    "repetition_penalty_range": 2,
    'tail_free_sampling': 1,
    "pad_token_id": 50256,
    "use_cache": True,
    "do_sample": True,
    "generate_until_sentence": False,
}

config = DotMap()
config.model_type = "GPT"
config.model_path =  "/models/j6b_ckpt_14001"
config.model_name = "6B-v3" #2.7B, 6B, 6B-v3
config.prefix_path = "/models/defaultmodules"
config.user_module_path = "/usermodules"
config.deepspeed_enabled = False

model = lm_node.base.GPTModel(config)

text2 = "9AR8dRkAPQGgEU0RdQEfARwFRgA2Cw0AxgAuClIcAw8ZAHgCdQEBAQkXDwUNADhtdhQwAYMjQAMhBocRHgHUKWZMCwAfAiwBsBgiAeUF4MAfAQEBOChAJKeOoAGAAg0APQF5BigJ+wTmCgYBVBKHHhMPCwAGAU0RpRkNAHgCdQEVMOIDGgABAQQGLgL+AwYBDwVGAXUBIgFGAR8C4QEzAQsAAQEdCh4B3wNUBB8BBgE1Ph4BfBgNAH8BTRF1AZcMCwDbAZcBCwJxA1QBdQHcFBwB/gINAOmOCAP2AXUBBgEhJwsAeFMiAWQBviFfAa8JCwDVA6McKRuAd30BxwFSAJoKAQEmCZED5WINAD0BDAioSDwYY4E8AekFBAQGAdsGbhoLAJcCHRVcDQAFHAFWHgEBegUDGw0Auw5IEh4FGQi8K+IDFAJJAVsOBgGXd3UBAQEVFwsA3zUcEkoADQB/AbMC7QUAHXUBBgETDyoJCwAiAQYBegMLAD4xkyBGAaZKCgULAAoFWw4NAMYA4h/GAPQEfHUZADoBv2cfAaoE4wMNAMYALgpSHAMPGQDDBdozxJagAR8BXAHpXtVvCwCyFyIBUYUcGz019gEfAQEBV3NYBPpOiwENANsH/zELAGgCTQOTEg0DflT+AxkBkgOwKz8EDwMiAXa6tTwNAJwBrALkEpAClAOjBAsA+AWQAjaXCwBJB1QB1AGFAiAODQAgKdQBIA4NACYFPgGPAp4eDQCsDAsAcw0LAEwCAQFxGotVYgkZAScZVzsLAN4FPgGNKg0AfwEUDB4Bji4+AQUCe21tAs0HHAECBVkMDQCsn4UBaRALALJGPgEZGw0AeAI+AY26ewHNAgsAHgZoAtAHPgHaBRwBfDMGAUsnjDEcAVAYDQA5BGgCshehBdoDJSAeAOEhCwCIEDoBzAILAAkDhQFoAs0CFBwLAEkBYwhJmGgCdiZbDoQCOgHyC/UaDQDqVgsAZAIPAwoIAQGeHiIBAQGyRmMBAQFXApUhuxFAA/YBCwBuA2gC20StCKMdKQUfAYhaDQDGAOIfxgD0BHx1GQCZAzUDAQG+KR8BAQEQDdUDAQG2g1Up7AHGAC4KUhwDDxkAmQOFAcoatjYeAVJC0BINAH8BUy54Dx4BAQF5J1wBORkSAUVCCwC9TyMDHwEBAaJsFSFGAQAy7Bg5f7UHBgGoNA0AfwFQCIp2QAMWAmgsYwFZAS6EYgFtCLoLCwDXCPwtpgEGAWGDDQBODZ4Eex9fARYDcAsaAFQBUgFMAlII1AEdCAEBwjwcAQYBUQoeARYCrzgNAG8B4lDaCHkGxRy8AwYBSHhWKNQGFAIRIyIBYx0KAXYCwAJfAXpADQBZJVQB3QE+AQEB5BwoCRoAAQG/Ak8FNAHoA0YB6U4cAWotWQGmARYCKQ4NAIcXvU9aAl029wENAAIDgAILAAYBxy4+AUMWDQBZJQEBKwTrAgwA7EwWAYi49QEeAesB2C3sIxwBBgHLHQsAWQH+AgEBygGWCR4BYx3kHAsAVAMMRSsNBgFmCfpOiwENAD0BmgRpECAGCwAGARMJB3UAZQ0AeAJSASgGHAHfM1kB4AIcARYC20QNAMYA4h/GAPQEfHUZAHgCUgEBAYoiVQYiAhkBFC6mAaoE4wMNAH8BKSiFAZIilQL8AYURHAEGAUsRDQDGAC4KUhwDDxkAfwFuGmITHAF6Iw0AHASeCYWaGgH6TosBK3K1BwYBCjsLACIBBgErB6EBeo+tCBwBSQgNAK0RGQRbHgSc5AHdcpACBgG2EQsA/wFwRymPbwZDOcUDdQVjAeQBAwUcARILBgFNA+MD5A5eAg0A9g1aBhcBNhIZAcI3Jg3PCRoOCwAWCmQPoAHoiu8BrR8eAXojIgF6QA0A9w7CYwd1BgHjA2MBBgE6DR4Bmh3vdvsJHAF/Bg0AziWmAQYBWHJfitQBKiQcAaofQAMaAAYBwhLmHx4BQTwLABapqIseAbAaIgHoSwsAIgEGAbeEhREeAbVo82cfAREQKgNkAnUBhQImJB4BdAPbBG4D2gMeGWQC0Q5fCw0AGQljAR8BBgGECNkbCwCsAYGb1AEVB4URHAEBAYkClhQeASkoFAKqBSIBugX8AacBHB0GAWYJeiMLACIBniBfAd0B/wE0ChwBcwK3AlQBXwEcBcMPeQYNACA/eQcLAEkBVAE+AbMCXwEGASkXahAeAaICCT1GAWQB5wJJB7oerAHVDp2eCwAiAdIEGQF2AxwBmh3vdlIB8QKFXmIJogLjAw0AxgDiH8YA9AR8dRkAgwsMAPMBs25VBiIClQIfAaAp/AGnARkBDm9JAbwaDQDGAC4KUhwDDxkAfwGVAh4BoCk7CDwCBgHQrR4BBgEyCeMDDQCtEWMDPgdqDgYBvLgLAKAp1AFLAuUILwWgAQYBJg+0Dg0AKhIeAXMD/AEbLsws/RwcAf4ZBgHvjx4BSkP+B6OWDQCTMSMCMh/ZBaABgwYtDxYBIVgiAaZ10yqIMg0Alx4+BwoeHAFcCtIECwAtHkYBmQ5pAoizHgGfDbQDMQKhBRwB/wEEUDIWDQDeGCENRAQWAQkD8lOFAf8BvBoLABkE4b6tHx4B+6emSA0ACAmqVZoGUxJjATEJVxILACIBIA1jAat4SQGgKVIBXVINAD4d/wFqECIBThGNHXAapgHaDBwB2gwLAOQBNQMBAfQjfAIeAQYBBwgUAuAVCwBxAx4BoClSAYAPVJVABF4CDQDGAOIfxgD0BHx1GQBWAy4vXQLXBlIBVQYiAsnBowfyA7cGAQGDiw0AxgAuClIcAw8ZAHgCdQEBAVYTVgQfASEg7g5QXgsAKgPdAQYB/B5XEjECAgUcAccCDQCTCWW4ugFXAaQjsQ/ocEVCCwAiAQEBGA7Zb9MZFAFAAwYB7hsNAHIjDQAkfEEQPnD3AaYBBgElkgsAXwLWHgoBFkABRw0AbgECBD4BAQGaA1YECwB1BykCoQKmYeN9CwAcAR8CAQF6BeEOCCIZAG4BAQZNCNYBVAG4BaUGciMNACR8dQH/GRwB/gLqBg0AeKZlAkJJ9wGmAS45AQFobw0AbgFnBjUDCgWmAQYBU0ELAOoGDQB4pqoJciMNACR8cxkNAOoGDQB4puU0CwAiAVsYQAMGAe4bDQBuAWcGLhHWAaYLHgCfBQYBVxKFAbcGAQGDi6wBVgSqCYkqKAYLAGoCciMNACR8VAtfAp4ECwChAsEJAQEZBA9IqB/3AbUHBgGYQw0AnwUGAVcSIwLcK2QCCwAiAaECsQt1BzECMwEgCpoDAADGAOIfxgD0BHx1GQAuKAwAMykNAH8BRwGLAU0BmgL2RAsAAQHQD4wKHgHfQgsApwEBASwaDACmMukBGwEeAQYBeiz7HlwgcwxJAQEB8LkNAH8B+x5cIIUBoAijFgsAIgH4ArMCBgEBBwwAXD1SARoLDQDGAC4KUhwDDxkAWClgBlEFPBh7AQYBXA2AHKMaCwAiAWQB3QE8GFkTDQB/AWEGdQGSGV8B9l0UAuwSHgECAiYCcgQSAh4BBgH7HlwgbgMLAB0FbA0BAdkzLBoMAKYy6QEbAQ0AVzgyAe4CVQKnAQYBLwscAQQDtApfAV4CCwAiAUkHuWfSA04coQEmAq8d1ySsARICHAFzDAICSQEBAfC5DQBeBaEC1TKsAToB5wLXBPgCDQDGACAAFksmAYaEw1amAVsOBgFcDQ0AfRUeAQICuCsUAf4BCwAiAVgpYAZRBWkLAgISAgUCAmckA2MBYwIcAY4DbgFGBUY3hGl3Bz0BHQogBgsABgFcDXESDQAmBVMBdQELAAEBDg4MAAoXHxgeAQEBRgINABIeDAC3AdYCFAGIICqV0QFUJhAFCwAVxB8BAQFKAxoBSxA1Xg0AXwkCZ4cLT8R7AWMEHgECAh8BugQLABQDZQIsBCICBgFHCAsAYwFjApUYegRrXgYBxQcNAKwMCwBYKWAGUQUmFA0AxgABANEoWQGTBgsCcQNJAc8lAgILAGURLBoMAKYy6QEbAa4dmgIjEqIBKQKhAogLEwJfAQEBWyULACIBYwQeAQIC7gILAmMBcAMNAG4B0AFHAYsBTQGaAvZEhQEWUlsOogJNERwBpwFZAV8BAgKsAdsGDgLGAOIfxgD0BHx1GQA9AR8YRrBqDQEBvjALBVYA2gYNAMYALgpSHAMPGQB/AeoPUgGeKcWHehgcAdEBVhULACIB5AGTjxwBiwIGAfK8YgMNAH8BHxhGsAsAbggLAD4BkGMNANsY6KB3t/cBCwB/X7UoHwEGAVxcIgEwHCx1tBwNAFkHCKnSOxwBsBXRBh4BAQGDD3doCwDbAQYBRrDJKqsCYxUNAA=="

honkers = [
    [True, "Nestled between the winding dunes", " of", [' of', 17.65625]],
    [False, text2, 'fwGeKT4BKRAcAdca6g8LACIB', [' The', 17.890625]]
]

def test_gptj_6b_generate():
    req_dict = DotMap()
    req_dict.parameters = DotMap(req_params)

    for gen in honkers:
        req_dict.input = gen[1]
        output = process_payload(req_dict, config, gen[0], model)
        assert output[0]

        gen_text = model.generate(req_dict.input, req_params=req_dict.parameters, use_string=gen[0])
        ic(gen_text)
        assert gen_text == gen[2]


def test_gptj_6b_hidden_states():
    req_dict = DotMap()
    req_dict.parameters = DotMap(req_params)

    for gen in honkers:
        req_dict.input = gen[1]
        output = process_payload(req_dict, config, gen[0], model)
        assert output[0]

        gen_text = model.get_hidden_states(req_dict.input)
        assert gen_text


def test_gptj_next_token():
    req_dict = DotMap()
    req_dict.parameters = DotMap(req_params)

    for gen in honkers:
        req_dict.input = gen[1]
        output = process_payload(req_dict, config, gen[0], model)
        assert output[0]

        gen_text = model.next_token_probabilities(req_dict.input)
        assert gen_text[0][0] == gen[3][0]
        assert gen_text[0][1] == pytest.approx(gen[3][1])


def test_tokenizer():
    target_text = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque pellentesque dui ut mauris vehicula, sed eleifend elit porttitor. Vivamus facilisis sapien et massa tempor, sed euismod erat lacinia. Quisque iaculis pulvinar blandit. Sed quis nunc erat. Aliquam erat erat, pretium at rhoncus convallis, faucibus vitae arcu. Duis at rutrum ligula. Donec augue neque, sollicitudin non vehicula vel, eleifend eget augue. Aenean id enim aliquam, dictum mauris suscipit, egestas elit. Etiam gravida, enim ut convallis sagittis, sem urna pellentesque mi, sit amet hendrerit nisl magna vitae neque. Nunc varius purus sed risus vestibulum laoreet. Aliquam vitae ante lectus. Mauris imperdiet a justo non finibus. Pellentesque congue mi a ligula condimentum bibendum. Quisque vitae turpis interdum lorem dignissim bibendum congue sed justo. Proin ligula est, eleifend eu tellus id, tempus molestie purus. Aliquam imperdiet neque quis purus euismod euismod. Curabitur molestie tempus felis, in suscipit est tristique ac. Fusce euismod, ipsum non fermentum lacinia, mi dolor viverra lectus, id rhoncus augue leo nec lacus. Aenean pharetra fringilla est eu suscipit. Nulla sed turpis convallis, vestibulum nisl et, posuere tellus. Pellentesque nec eleifend neque. Aenean est sem, tempor quis purus et, varius interdum justo. Fusce sit amet ex pulvinar, imperdiet urna a, hendrerit massa."""
    target_tokens = [43,29625,220,2419,388,288,45621,1650,716,316,11,369,8831,316,333,31659,271,2259,1288,270,13,2264,271,4188,613,297,298,28939,7043,72,3384,285,2899,271,2844,291,4712,11,10081,9766,361,437,1288,270,2493,83,2072,13,25313,25509,1777,346,271,271,31841,2013,2123,2347,64,10042,11,10081,304,84,1042,375,1931,265,31123,43168,13,2264,271,4188,1312,330,377,271,17472,7114,283,34377,270,13,22710,627,271,299,19524,1931,265,13,978,1557,321,1931,265,1931,265,11,2181,1505,379,9529,261,9042,3063,439,271,11,277,14272,26333,9090,3609,10389,84,13,10343,271,379,374,315,6582,26106,4712,13,24429,66,16339,518,497,4188,11,523,297,3628,463,259,1729,2844,291,4712,11555,11,9766,361,437,304,1136,16339,518,13,317,1734,272,4686,551,320,435,1557,321,11,8633,388,285,2899,271,2341,66,541,270,11,304,3495,292,1288,270,13,17906,1789,9067,3755,11,551,320,3384,3063,439,271,45229,715,271,11,5026,220,700,64,613,297,298,28939,21504,11,1650,716,316,339,358,260,799,299,3044,2153,2616,9090,3609,497,4188,13,399,19524,1401,3754,1308,385,10081,6106,385,19750,571,14452,8591,382,316,13,978,1557,321,9090,3609,29692,11042,385,13,18867,271,11071,67,1155,257,655,78,1729,957,26333,13,43113,298,28939,369,18701,21504,257,26106,4712,1779,3681,388,275,571,43755,13,2264,271,4188,9090,3609,7858,79,271,987,67,388,24044,76,13469,747,320,275,571,43755,369,18701,10081,655,78,13,1041,259,26106,4712,1556,11,9766,361,437,304,84,1560,385,4686,11,20218,385,18605,395,494,1308,385,13,978,1557,321,11071,67,1155,497,4188,627,271,1308,385,304,84,1042,375,304,84,1042,375,13,4424,29968,333,18605,395,494,20218,385,10756,271,11,287,2341,66,541,270,1556,491,396,2350,936,13,376,385,344,304,84,1042,375,11,220,2419,388,1729,38797,388,31123,43168,11,21504,288,45621,410,1428,430,11042,385,11,4686,9529,261,9042,16339,518,443,78,27576,31123,385,13,317,1734,272,872,8984,430,1216,278,5049,1556,304,84,2341,66,541,270,13,35886,64,10081,7858,79,271,3063,439,271,11,19750,571,14452,299,3044,2123,11,1426,518,260,1560,385,13,43113,298,28939,27576,9766,361,437,497,4188,13,317,1734,272,1556,5026,11,10042,627,271,1308,385,2123,11,1401,3754,987,67,388,655,78,13,376,385,344,1650,716,316,409,17472,7114,283,11,11071,67,1155,220,700,64,257,11,339,358,260,799,2347,64,13]
    tokens = model.tokenizer(target_text).input_ids
    assert tokens == target_tokens

    text = model.tokenizer.decode(target_tokens)
    assert text == target_text

if __name__ == '__main__':
    test_tokenizer()