/* eslint-disable array-callback-return */
// This file includes code which was modified from https://github.com/openai/gpt-2
// ... and was then further modified from https://github.com/latitudegames/GPT-3-Encoder

import { normalizeModel, TextGenerationModel } from '../../shared/data/request/model'
import { fetchWithTimeout } from '../../shared/util/general'
import pile_tokenizer from './files/pile_tokenizer.json'
import genji_tokenizer from './files/genji_tokenizer.json'
import gpt2_tokenizer from './files/gpt2_tokenizer.json'

export enum EncoderType {
    GPT2,
    PileNAI,
    Genji,
    Pile,
    NAIInline,
}

export function getModelEncoderType(model: TextGenerationModel): EncoderType {
    switch (normalizeModel(model)) {
        case TextGenerationModel.genjijp6bv2:
            return EncoderType.Genji
        case TextGenerationModel.krakev2:
            return EncoderType.Pile
        case TextGenerationModel.infill:
            return EncoderType.NAIInline
        default:
            return EncoderType.GPT2
    }
}

const textEncoder = new TextEncoder()
const encodeStr = (str: string) => {
    return [...textEncoder.encode(str)].map((x) => x.toString())
}

const textDecoder = new TextDecoder('utf8')
const decodeStr = (arr: Iterable<number>) => {
    return textDecoder.decode(new Uint8Array(arr))
}

const dictZip = (x: any, y: any) => {
    const result: any = {}
    x.map((_: any, i: any) => {
        result[x[i]] = y[i]
    })
    return result
}

const range = (x: number | undefined, y: any) => {
    const res = [...Array.from({ length: y }).keys()].slice(x)
    return res
}

const ord = (x: string) => {
    // eslint-disable-next-line unicorn/prefer-code-point
    return x.charCodeAt(0)
}

const chr = (x: number) => {
    // eslint-disable-next-line unicorn/prefer-code-point
    return String.fromCharCode(x)
}

function get_pairs(word: any[]) {
    const pairs = new Set<any>()
    let prev_char = word[0]
    for (let i = 1; i < word.length; i++) {
        const char = word[i]
        pairs.add([prev_char, char])
        prev_char = char
    }
    return pairs
}

const bytes_to_unicode = () => {
    const bs = [
        ...range(ord('!'), ord('~') + 1),
        ...range(ord('¡'), ord('¬') + 1),
        ...range(ord('®'), ord('ÿ') + 1),
    ]

    let cs: any = [...bs]
    let n = 0
    for (let b = 0; b < 2 ** 8; b++) {
        if (!bs.includes(b)) {
            bs.push(b)
            cs.push(2 ** 8 + n)
            n = n + 1
        }
    }

    cs = cs.map((x: number) => chr(x))

    const result: any = {}
    bs.map((_, i) => {
        result[bs[i]] = cs[i]
    })
    return result
}

const byte_encoder = bytes_to_unicode()
const byte_decoder: any = {}
Object.keys(byte_encoder).map((x) => {
    byte_decoder[byte_encoder[x]] = x
})

//const pat = /'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/gu
// prettier-ignore
const pat = new RegExp("s|'t|'re|'ve|'m|'ll|'d| ?\pL+| ?\pN+| ?[^\s\pL\pN]+|\s+(?!\S)|\s+", "gu")

export default class Encoder {
    encoder: any
    bpe_ranks: any
    decoder: any
    addedTokens: any
    private cache = new Map()
    // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
    constructor(encoder: any, bpeArr: string[], addedTokens?: any) {
        this.addedTokens = addedTokens
        this.encoder = encoder
        const lines = bpeArr

        // bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
        const bpe_merges = lines.map((x) => {
            return x.split(/(\s+)/).filter(function (e) {
                return e.trim().length > 0
            })
        })
        this.bpe_ranks = dictZip(bpe_merges, range(0, bpe_merges.length))

        this.decoder = {}
        Object.keys(encoder).map((x) => {
            this.decoder[encoder[x]] = x
        })
        Object.keys(addedTokens).map((x) => {
            this.decoder[addedTokens[x]] = x
        })
    }

    private bpe(token: string) {
        if (this.cache.has(token)) {
            return this.cache.get(token)
        }

        let word: string[] = [...token]

        let pairs = get_pairs(word)

        if (!pairs) {
            return token
        }

        for (;;) {
            const minPairs: any = {}
            ;[...pairs].map((pair) => {
                const rank = this.bpe_ranks[pair]
                // eslint-disable-next-line unicorn/prefer-number-properties
                minPairs[isNaN(rank) ? 10e10 : rank] = pair
            })

            const bigram =
                minPairs[
                    Math.min(
                        ...Object.keys(minPairs).map((x) => {
                            return Number.parseInt(x)
                        })
                    )
                ]

            if (!(bigram in this.bpe_ranks)) {
                break
            }

            const first = bigram[0]
            const second = bigram[1]
            let new_word: any[] = []
            let i = 0

            while (i < word.length) {
                const j = word.indexOf(first, i)
                if (j === -1) {
                    new_word = [...new_word, ...word.slice(i)]
                    break
                }
                new_word = [...new_word, ...word.slice(i, j)]
                i = j

                if (word[i] === first && i < word.length - 1 && word[i + 1] === second) {
                    new_word.push(first + second)
                    i = i + 2
                } else {
                    new_word.push(word[i])
                    i = i + 1
                }
            }
            word = new_word
            if (word.length === 1) {
                break
            } else {
                pairs = get_pairs(word)
            }
        }

        const joined = word.join(' ')
        this.cache.set(token, joined)

        return joined
    }
    encode = (text: string): number[] => {
        let parts: any[] = []
        parts.push(text)
        for (const token of Object.keys(this.addedTokens)) {
            for (let i = 0; i < parts.length; ++i) {
                const part = parts[i]
                if (typeof part === 'string') {
                    const split = part.split(token)
                    const result: (string | number)[] = []
                    for (const [j, s] of split.entries()) {
                        result.push(s)
                        if (j < split.length - 1) {
                            result.push(this.addedTokens[token])
                        }
                    }
                    parts = [...parts.slice(0, i), ...result, ...parts.slice(i + 1)]
                }
            }
        }

        const tokens: number[] = []
        for (const part of parts) {
            if (typeof part === 'string') {
                let bpe_tokens: number[] = []
                const matches = [...part.matchAll(pat)].map((x) => x[0])

                for (let token of matches) {
                    token = encodeStr(token)
                        .map((x) => {
                            return byte_encoder[x]
                        })
                        .join('')
                    const bpe = this.bpe(token)
                    const new_tokens = bpe.split(' ').map((x: any) => this.encoder[x])
                    bpe_tokens = [...bpe_tokens, ...new_tokens]
                }
                tokens.push(...bpe_tokens)
            } else {
                tokens.push(part)
            }
        }

        return tokens
    }

    decode = (tokens: any[]): string => {
        let text = tokens.map((x) => this.decoder[x]).join('')
        text = decodeStr(
            [...text].flatMap((x) => {
                const converted = byte_decoder[x] ?? [...textEncoder.encode(x)]
                return converted
            })
        )
        return text
    }

    tokensContaining = (str: string): { token: string; id: number }[] => {
        const keys = Object.keys(this.encoder)
        const arr = []
        for (const key of keys) {
            if (key.includes(str)) arr.push({ token: key, id: this.encoder[key] })
        }
        return arr
    }
}

type Tokenizer = {
    vocab: any
    merges: string[]
    addedTokens?: any
}

//not sure if this is a good idea
export const loadEncoder = async (url: string, extraTokens: any): Promise<Encoder> => {
    const tokenizer: Tokenizer = (() => {
        switch (url) {
            case 'pile_tokenizer.json':
                return pile_tokenizer
            case 'genji_tokenizer.json':
                return genji_tokenizer
            default:
                return gpt2_tokenizer
        }
    })()
    return new Encoder(tokenizer.vocab, tokenizer.merges, { ...tokenizer.addedTokens, ...extraTokens })
}
