import numpy as np
import os
import time  
import concurrent.futures
from tqdm import tqdm
import sys
import cProfile

def chunkit(text, n_chunk=8):
    #chunk by specified n_chunk and override for space.
    print("Text size: " + str(len(text)))
    chunk_size = int(len(text) // n_chunk)
    chunkindex = []
    chunks = []
    startindex = 0
    endindex = 0
    for x in range(n_chunk):
        #print(chunkindex)
        #print(chunkindex)

        if x == 0:
            idealendindex = chunk_size - 1
            endindex = chunk_size - 1

        elif x != n_chunk - 1:
            idealendindex = chunkindex[x-1][1] + chunk_size
            endindex = chunkindex[x-1][1] + chunk_size

        elif x == n_chunk - 1:
            idealendindex = len(text) - 1
            endindex = len(text) - 1

        while text[endindex] != " " and len(text) - 1 != endindex and not endindex + 100 > idealendindex:
            endindex += 1

        if x == 0:
            chunkindex.append((0, endindex))
        
        else:
            chunkindex.append((chunkindex[x-1][1], endindex))
        
    for x in range(len(chunkindex)):
        chunks.append(text[chunkindex[x][0]:chunkindex[x][1]])
        #print(text[chunkindex[x][0]:chunkindex[x][1]])
        #print("-----------------------")
    
    #print(chunkindex)
    return chunks
    
def tokenizePart(chunks, iterx):
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("tokenizer/gpt2", local_files_only=True)
    tokenizer.model_max_length = 99999999
    ids = tokenizer(chunks, return_tensors="pt").input_ids.numpy()
    ids = ids.astype('uint16')
    return ids

def run(f, iterx, chunks):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(tokenizePart, chunks, iterx), total=len(iterx)))
    return results

def main():
    f = open("10mb.txt", "r", encoding="utf8")
    text = f.read()
    n_chunk = 8
    chunks = chunkit(text, n_chunk)
    iterx = range(n_chunk)
    s = time.perf_counter()
    tokens_array = run(tokenizePart, iterx, chunks)
    tokens_out = np.hstack(tokens_array)
    print("Took " + str(time.perf_counter() - s) + " seconds at " + str(n_chunk) + " threads")
    print("Train complete with map array size of:")
    print(tokens_out.shape)

    with open("test.gay", "wb") as fh:
        fh.write(tokens_out.tobytes())
    sys.exit(0)

if __name__ == '__main__':
    main()