import { ApiProperty } from "@nestjs/swagger";
import { Type } from "class-transformer";
import {
  IsArray,
  IsBoolean,
  IsEnum,
  IsInt,
  IsNotEmpty,
  IsNumber,
  IsObject,
  IsOptional,
  IsString,
  Max,
  MaxLength,
  Min,
  MinLength,
  ValidateNested,
} from "class-validator";
import { SubscriptionTiers } from "src/shared/subscription-tiers";
import {
  GenerationModels,
  DEFAULT_GENERATION_MODEL,
  ClassificationModels,
  ImageGenerationModels,
} from "../../shared/generation-models";

type LogitData = number[];
type LogitDataAdvanced = (number[] | number | boolean)[];

export class AiGenerateParameters {
  @IsOptional()
  @IsNumber()
  @Min(0.1)
  @Max(100.0)
  temperature?: number;

  @IsInt()
  @Min(1)
  @Max(2048)
  min_length: number;

  @IsInt()
  @Min(1)
  @Max(2048)
  max_length: number;

  @IsOptional()
  @IsBoolean()
  do_sample?: boolean;

  @IsOptional()
  @IsBoolean()
  early_stopping?: boolean;

  @IsOptional()
  @IsInt()
  num_beams?: number;

  @IsOptional()
  @IsInt()
  top_k?: number;

  @IsOptional()
  @IsNumber()
  top_a?: number;

  @IsOptional()
  @IsNumber()
  top_p?: number;

  @IsOptional()
  @IsNumber()
  typical_p?: number;

  @IsOptional()
  @IsNumber()
  repetition_penalty?: number;

  @IsOptional()
  @IsInt()
  pad_token_id?: number;

  @IsOptional()
  @IsInt()
  bos_token_id?: number;

  @IsOptional()
  @IsInt()
  eos_token_id?: number;

  @IsOptional()
  @IsNumber()
  length_penalty?: number;

  @IsOptional()
  @IsInt()
  no_repeat_ngram_size?: number;

  @IsOptional()
  @IsInt()
  encoder_no_repeat_ngram_size?: number;

  @IsOptional()
  @IsInt()
  num_return_sequences?: number;

  @IsOptional()
  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "array",
      items: {
        type: "integer",
        format: "int32",
        minimum: 0,
      },
    },
  })
  stop_sequences?: LogitData[];

  @IsOptional()
  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "array",
      items: {
        type: "integer",
        format: "int32",
        minimum: 0,
      },
    },
  })
  bad_words_ids?: LogitData[];

  @IsOptional()
  @IsNumber()
  max_time?: number;

  @IsOptional()
  @IsBoolean()
  use_cache?: boolean;

  @IsOptional()
  @IsInt()
  num_beam_groups?: number;

  @IsOptional()
  @IsNumber()
  diversity_penalty?: number;

  @IsOptional()
  @IsNumber()
  @Min(0.0)
  @Max(1.0)
  tail_free_sampling?: number;

  @IsOptional()
  @IsInt()
  @Min(0)
  @Max(2048)
  repetition_penalty_range?: number;

  @IsOptional()
  @IsNumber()
  @Min(0.0)
  @Max(10.0)
  repetition_penalty_slope?: number;

  @IsOptional()
  @IsBoolean()
  @ApiProperty({
    description:
      "If false, input and output strings should be Base64-encoded uint16 numbers representing tokens",
    default: false,
  })
  use_string?: boolean;

  @IsOptional()
  @IsBoolean()
  get_hidden_states?: boolean;

  @IsOptional()
  @IsNumber()
  @Min(-2.0)
  @Max(2.0)
  repetition_penalty_frequency?: number;

  @IsOptional()
  @IsNumber()
  @Min(-2.0)
  @Max(2.0)
  repetition_penalty_presence?: number;

  @IsOptional()
  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "array",
      items: {
        type: "integer",
        format: "int32",
        minimum: 0,
      },
      minItems: 2,
      maxItems: 2,
    },
  })
  logit_bias?: LogitData[];

  @IsOptional()
  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "object",
      properties: {
        sequence: {
          type: "array",
          items: {
            type: "integer",
            format: "int32",
            minimum: 0,
          },
          minItems: 1,
        },
        bias: {
          type: "number",
          format: "float",
        },
        ensure_sequence_finish: {
          type: "boolean",
        },
        generate_once: {
          type: "boolean",
        },
      },
      required: ["sequence", "bias"],
    },
    example: [
      {
        sequence: [9288, 286, 10690],
        bias: 4,
        ensure_sequence_finish: true,
        generate_once: true,
      },
      { sequence: [9288, 286], bias: 2 },
    ],
  })
  logit_bias_exp?: LogitDataAdvanced[];

  @IsOptional()
  @IsBoolean()
  next_word?: boolean;

  @IsOptional()
  @IsString()
  @IsNotEmpty()
  prefix?: string;

  @IsOptional()
  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "integer",
      format: "int32",
      minimum: 0,
    },
    minItems: 0,
    maxItems: 6,
  })
  order?: number[];

  @IsOptional()
  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "integer",
      format: "int32",
      minimum: 0,
    },
  })
  repetition_penalty_whitelist?: number[];

  @IsOptional()
  @IsBoolean()
  output_nonzero_probs?: boolean;

  @IsOptional()
  @IsBoolean()
  generate_until_sentence?: boolean;

  @IsOptional()
  @IsInt()
  @Min(0)
  @Max(30)
  num_logprobs?: number;
}

export class AiGenerateParametersInternal extends AiGenerateParameters {
  generation_id: string;
  used_model: string;
  priority: number;
}

export class AiGenerateRequest {
  @IsString()
  @IsNotEmpty()
  @MinLength(1)
  @MaxLength(14000)
  @ApiProperty({
    description: "Input for the text generation model",
    required: true,
    example: "Text generation example.",
  })
  input: string;

  @IsOptional()
  @IsString()
  @IsNotEmpty()
  @IsEnum(GenerationModels)
  @ApiProperty({
    description: "Used text generation model",
    default: DEFAULT_GENERATION_MODEL,
    required: true,
  })
  model: GenerationModels = DEFAULT_GENERATION_MODEL as unknown as GenerationModels;

  @IsObject()
  @Type(() => AiGenerateParameters)
  @ValidateNested()
  @ApiProperty({
    description: "Generation parameters",
    required: true,
    example: {
      use_string: true,
      temperature: 1.0,
      min_length: 10,
      max_length: 30,
    },
  })
  parameters: AiGenerateParameters;
}

export class AiGenerateResponse {
  @ApiProperty({
    description: "Output from the text generation model, if defined",
  })
  output?: string;

  @ApiProperty({
    description: "Error from the generation node, if defined",
  })
  error?: string;
}

export class AiGenerateStreamableResponse {
  @ApiProperty({
    description: "Incrementing token pointer",
  })
  ptr?: number;

  @ApiProperty({
    description: "Generated token",
  })
  token?: string;

  @ApiProperty({
    description: "Set to true if the token is final and the generation ended",
  })
  final?: boolean;

  @ApiProperty({
    description:
      "Error from the generation node, if defined. Usually means the end of stream",
  })
  error?: string;
}

export type OnNewTokenCallback = (data: AiGenerateStreamableResponse) => void;

export class AiGenerateImageResponse {
  @ApiProperty({
    description: "Incrementing version pointer",
  })
  ptr?: number;

  @ApiProperty({
    description: "Generated image in base64",
  })
  image?: string;

  @ApiProperty({
    description: "Set to true if the image is final and the generation ended",
  })
  final?: boolean;

  @ApiProperty({
    description:
      "Error from the generation node, if defined. Usually means the end of stream",
  })
  error?: string;
}

export type OnNewImageDataCallback = (data: AiGenerateImageResponse) => void;

export class AiGenerateImageRequest {
  @IsString()
  @IsNotEmpty()
  @MinLength(1)
  @MaxLength(14000)
  @ApiProperty({
    description: "Input for the text generation model",
    required: true,
    example: "Image generation example.",
  })
  input: string;

  @IsString()
  @IsNotEmpty()
  @IsEnum(ImageGenerationModels)
  @ApiProperty({
    description: "Used image generation model",
    required: true,
  })
  model: ImageGenerationModels;

  @IsObject()
  @ApiProperty({
    description: "Generation parameters (model specific)",
    required: true,
  })
  parameters: Record<string, unknown>;
}

export class AiRequestImageGenerationPriceRequest {
  @IsString()
  @IsNotEmpty()
  @IsEnum(SubscriptionTiers)
  @ApiProperty({
    description: "Tier to check against",
    required: true,
  })
  tier: SubscriptionTiers;

  @IsObject()
  @ApiProperty({
    description: "Request object",
    required: true,
  })
  request: AiGenerateImageRequest;
}

export class AiRequestImageGenerationTagsRequest {
  @IsString()
  @IsNotEmpty()
  @IsEnum(ImageGenerationModels)
  @ApiProperty({
    description: "Used image generation model",
    required: true,
  })
  model: ImageGenerationModels;

  @IsString()
  @IsNotEmpty()
  @ApiProperty({
    description: "Tag search prompt string",
    required: true,
  })
  prompt: string;
}

export class AiRequestImageGenerationTag {
  tag: string;
  count: number;
  confidence: number;
}

export class AiRequestImageGenerationTagsResponse {
  tags: AiRequestImageGenerationTag[];
}

export class AiSequenceClassificationDebertaRequest {
  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "array",
      items: {
        type: "string",
      },
      minItems: 2,
      maxItems: 2,
    },
    example: [
      ["sentence1", "sentence2"],
      ["sentence3", "sentence4"],
    ],
  })
  sentences: string[][];

  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "array",
      items: {
        type: "string",
      },
      minItems: 2,
      maxItems: 2,
    },
    example: [
      ["label1", "label2"],
      ["label3", "label4"],
    ],
  })
  labels: string[][];

  @IsString()
  @IsNotEmpty()
  @IsEnum(ClassificationModels)
  @ApiProperty({
    description: "Used classification model",
    required: true,
  })
  model: "debertaxxl";
}

export class AiSequenceClassificationBartRequest {
  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "string",
    },
  })
  sequence: string[];

  @IsArray()
  @ApiProperty({
    type: "array",
    items: {
      type: "array",
      items: {
        type: "string",
      },
    },
    example: [["mobile", "account access", "billing", "website"]],
  })
  labels: string[][];

  @IsString()
  @IsNotEmpty()
  @IsEnum(ClassificationModels)
  @ApiProperty({
    description: "Used classification model",
    required: true,
  })
  model: "bartlarge";
}

class ClassificationData {
  label: string;
  score: number;
}

export class AiSequenceClassificationResponse {
  @ApiProperty({
    description: "Output, if defined",
    anyOf: [
      {
        type: "array",
        items: {
          type: "object",
          properties: {
            label: {
              type: "string",
            },
            score: {
              type: "number",
            },
          },
        },
      },
      {
        type: "object",
        properties: {
          scores: {
            type: "array",
            items: {
              type: "number",
            },
          },
        },
      },
    ],
    example: [
      [
        {
          label: "label",
          score: 0.0,
        },
      ],
      {
        scores: [
          [
            0.9600785970687866, 0.01683211140334606, 0.014393973164260387,
            0.008695312775671482,
          ],
        ],
      },
    ],
  })
  output?:
    | ClassificationData[]
    | {
        sequence: string[];
        labels: string[][];
      };

  @ApiProperty({
    description: "Error, if defined",
  })
  error?: string;
}
