import os
import struct
import sys
from typing import Generator, Any

from multiprocessing import Pool

import numpy as np

from lm_node import unitrim
from lm_node.base import split_chunks
from lm_node.utils.text import read_data_files
from transformers import AutoTokenizer


def mmap_serializer(path: str, token_seq: Generator[np.array, Any, None],
                    size=2048):
    struct_format = '%sH' % size
    total = 0
    with open(path, 'wb') as f:
        for tokens in token_seq:
            num_tokens = len(tokens)
            assert num_tokens == size
            total += num_tokens
            f.write(struct.pack(struct_format, *tokens))
    print(f"{total} tokens serialized")


root_dir = sys.argv[1]
model_path = sys.argv[2]
chunk_size = int(sys.argv[3])
try:
    tokenizer = AutoTokenizer.from_pretrained(
        model_path, cache_dir="./cache", local_only=True)
    print(f"Using tokenizer data from {model_path}.")
except:
    print("Falling back to default `gpt2` tokenizer.")
    tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir="./cache")

unitrimModelPath = os.path.join(model_path,
                                "unitrim.json")
if os.path.exists(unitrimModelPath):
    print(f"Loading unitrim data from {unitrimModelPath}")
    unitrim = unitrim.Unitrimmer(unitrimModelPath)
else:
    print("Loading unitrim data from default (gpt2)")
    unitrim = unitrim.Unitrimmer()

trans = "".maketrans({' ': '', '-': '_', ",": '', '.': ''})


def do_dir(dir):
    dirname = os.path.dirname(dir)
    print(f"==> {os.path.basename(dir)}")
    outname = os.path.basename(dir).lower().translate(trans) + ".map"
    output_path = os.path.join(dirname, outname)
    mmap_serializer(output_path,
                    split_chunks(read_data_files(dir),
                                 size=chunk_size,
                                 tokenizer=tokenizer,
                                 unitrim=unitrim,
                                 boundary="\n",
                                 preamble="<|endoftext|>"),
                    size=chunk_size)

todo = list(filter(os.path.isdir,
                   map(lambda x: os.path.join(root_dir, x),
                       os.listdir(root_dir))))


p = Pool(16)
with p:
    p.map(do_dir, todo)
