from datetime import datetime, timezone
from types import NoneType
from typing import List, Union, Optional

import tiktoken
from pydantic import BaseModel, Field, field_validator

from scraper.card_types.item_extensions import ItemExtensions, CharArchiveExtension
from scraper.card_types.st_timestamp import parse_st_create_date_timestamp


class LorebookV2Entry(BaseModel):
    keys: List[str]
    content: str
    enabled: bool
    id: Optional[int] = None
    extensions: dict = Field(default_factory=dict)
    selective: Optional[bool] = False
    insertion_order: Optional[int] = None
    case_sensitive: Optional[bool] = False
    name: Optional[str] = ""
    priority: Optional[int] = 100
    comment: Optional[str] = ""
    secondary_keys: List[str] = Field(default_factory=list)
    constant: Optional[bool] = False
    position: Union[str, int, NoneType] = None


class LorebookV2(BaseModel):
    name: Optional[str] = ""
    description: Optional[str] = ""
    scan_depth: Optional[int] = None
    token_budget: Optional[int] = None
    recursive_scanning: Optional[bool] = False
    extensions: Optional[ItemExtensions] = Field(default_factory=ItemExtensions)
    entries: List[LorebookV2Entry] = Field(default_factory=dict)
    create_date: Optional[datetime] = datetime.fromtimestamp(0)  # Not a normal lorebook field, but adding it to follow the character format

    class Config:
        json_encoders = {
            datetime: lambda v: v.astimezone(timezone.utc).isoformat()
        }
        validate_assignment = True

    @field_validator('create_date', mode='before')
    @classmethod
    def parse_date(cls, v):
        return parse_st_create_date_timestamp(v)


def normalize_lorebook_entries(entries: Union[list, dict]) -> list:
    if isinstance(entries, dict):
        new_entries = []
        for k, v in entries.items():
            new_entries.append(v)
        entries = new_entries
    for i, item in enumerate(entries):
        if not item.get('position'):
            entries[i]['position'] = None
    return entries


def normalize_lorebook(lore_data: dict, source: str) -> LorebookV2:
    lore_data['entries'] = normalize_lorebook_entries(lore_data['entries'])
    if not lore_data.get('extensions'):
        lore_data['extensions'] = {}

    # Create the char archive extension while preserving existing data char archive extension data.
    if not lore_data.get('extensions'):
        lore_data['extensions'] = {}
    if not lore_data['extensions'].get('char_archive'):
        lore_data['extensions']['char_archive'] = {}
    lore_data['extensions']['char_archive']['source'] = source
    if lore_data.get('create_date'):
        lore_data['extensions']['char_archive']['created'] = parse_st_create_date_timestamp(lore_data['create_date'])
    lore_data['extensions']['char_archive'] = CharArchiveExtension(**lore_data['extensions']['char_archive']).model_dump()

    return LorebookV2(**lore_data)


def count_v2lorebook_tokens(card: LorebookV2):
    text_stuff = ''.join([x.content for x in card.entries])
    return len(tiktoken.get_encoding("cl100k_base").encode(text_stuff, disallowed_special=()))
