import { HttpException, HttpStatus, Inject, Injectable } from "@nestjs/common";
import { nanoid } from "nanoid";
import {
  ImageGnerationLimit,
  MAX_TIER,
  SubscriptionTierData,
  SubscriptionTiers,
} from "../shared/subscription-tiers";
import { User, UserType } from "../user/user.entity";
import {
  AiGenerateParametersInternal,
  AiGenerateRequest,
  AiGenerateStreamableResponse,
  AiSequenceClassificationDebertaRequest,
  AiSequenceClassificationBartRequest,
  OnNewTokenCallback,
  AiGenerateImageRequest,
} from "./dto/ai-generate.dto";
import { HydraService } from "./hydra.service";
import { TaskPriorityService } from "../user/task-priority.service";
import Redis from "ioredis";
import { Cron, CronExpression } from "@nestjs/schedule";
import { InjectRepository } from "@nestjs/typeorm";
import { Repository } from "typeorm";
import { TokenMetrics } from "./analytics/token-metrics.entity";
import {
  GenerationModelAccessRightsData,
  GenerationModels,
  ImageGenerationModelSettingsData,
} from "../shared/generation-models";
import { SubscriptionTierPerks } from "../shared/subscription-tiers";
import { ApiProperty } from "@nestjs/swagger";

const sleep = (time: number) => {
  return new Promise((r) => setTimeout(r, time));
};

const PUBSUB_GENERATION_CHANNEL = "streamed_generation_events";

class StreamCallbackData {
  lastUpdatedTime: number;
  cb: OnNewTokenCallback;
}

const LEFTOVER_CALLBACK_CHECK_PERIOD = 1 * 60 * 1000;
const LEFTOVER_CALLBACK_NOT_UPDATED_MAX_LIFETIME = 1 * 60;

const ANONYMOUS_SUBSCRIPTION_TIER = SubscriptionTiers.SCROLL;
const ANONYMOUS_QUEUE_PRIORITY = 5;
const ANONYMOUS_ALLOWED_GENERATIONS_AMOUNT = 50 * 2;
const ANONYMOUS_ALLOWED_GENERATIONS_RECORD_TTL = 1 * 7 * 24 * 60 * 60;

export const ANONYMOUS_ALLOWED_VOICE_GENERATIONS_AMOUNT = 50 * 2;
export const ANONYMOUS_ALLOWED_VOICE_GENERATIONS_RECORD_TTL =
  1 * 7 * 24 * 60 * 60;

export const ANONYMOUS_ALLOWED_IMAGE_GENERATIONS_AMOUNT = 20;
export const ANONYMOUS_ALLOWED_IMAGE_GENERATIONS_RECORD_TTL =
  1 * 7 * 24 * 60 * 60;

const TRIAL_SUBSCRIPTION_TIER = SubscriptionTiers.SCROLL;
const TRIAL_QUEUE_PRIORITY = 6;

export class StepCostData {
  @ApiProperty({
    description: "Cost for one image",
  })
  costPerPrompt: number;

  @ApiProperty({
    description: "Amount of images requested",
  })
  numPrompts: number;

  @ApiProperty({
    description: "If true, steps are not subtracted for all prompts",
  })
  requestEligibleForUnlimitedGeneration: boolean;

  @ApiProperty({
    description: "Amount of free prompts in this generation",
  })
  freePrompts: number;
}

function calculatePromptPrice(
  width: number,
  height: number,
  steps: number,
): number {
  const resolution = width * height;
  return (
    ((15.266497014243718 *
      Math.exp(0.6326248927474729 * (resolution / 1048576)) +
      -15.225164493059737) /
      28) *
    steps
  );
}

@Injectable()
export class AIService {
  private _stream_callbacks: Map<string, StreamCallbackData>;
  private _redis_client_subscriber: Redis | null;
  private _redis_client_publisher: Redis | null;

  private _token_count: number;

  constructor(
    @Inject("GENERATION_SERVICE")
    private readonly hydraService: HydraService,

    @InjectRepository(TokenMetrics)
    private tokenMetricsRepository: Repository<TokenMetrics>,

    @InjectRepository(User)
    private usersRepository: Repository<User>,

    private readonly taskPriorityService: TaskPriorityService,
  ) {
    this._stream_callbacks = new Map();
    this._token_count = 0;

    if (process.env.REDIS_HOST !== undefined) {
      this._redis_client_subscriber = new Redis(6379, process.env.REDIS_HOST);
      this._redis_client_publisher = new Redis(6379, process.env.REDIS_HOST);

      this._redis_client_subscriber.subscribe(PUBSUB_GENERATION_CHANNEL);
      this._redis_client_subscriber.on(
        "message",
        (channel: string, messageText: string) => {
          if (channel == PUBSUB_GENERATION_CHANNEL) {
            try {
              const message = JSON.parse(messageText);
              if (message.uuid) this.onStreamedGenerationData(message);
            } catch (ex) {
              console.error("Failed to parse streamed channel data!", ex);
            }
          }
        },
      );
    }

    setInterval(
      this.checkLeftoverCallbacks.bind(this),
      LEFTOVER_CALLBACK_CHECK_PERIOD,
    );
  }

  async checkIpBasedLimit(
    requestIdentifier: any,
    maxAmount: number,
    recordTTL: number,
  ): Promise<boolean> {
    if (!this._redis_client_publisher)
      throw new HttpException(
        "Redis not initialized.",
        HttpStatus.INTERNAL_SERVER_ERROR,
      );

    const identifierKey = "amu:" + requestIdentifier;
    const identifierExistingData = await this._redis_client_publisher.get(
      identifierKey,
    );

    const identifierCounter = identifierExistingData
      ? parseInt(identifierExistingData) || 0
      : 0;

    if (identifierCounter > maxAmount) return false;

    const currentUnixTime = Math.floor(Date.now() / 1000);
    const limitExpiry = currentUnixTime + recordTTL;

    await this._redis_client_publisher
      .multi()
      .incr(identifierKey)
      .expire(identifierKey, limitExpiry)
      .exec();

    return true;
  }

  async checkSubscription(
    user?: User,
    model?: GenerationModels | null,
    requestIdentifier: any = null,
    b2bMetric = -1,
  ): Promise<{
    tier: SubscriptionTiers;
    contextTokens: number;
    priority: number;
  }> {
    if (!user) {
      if (!requestIdentifier)
        throw new HttpException(
          "Request identifier not provided.",
          HttpStatus.INTERNAL_SERVER_ERROR,
        );

      if (
        !(await this.checkIpBasedLimit(
          "txt:" + requestIdentifier,
          ANONYMOUS_ALLOWED_GENERATIONS_AMOUNT,
          ANONYMOUS_ALLOWED_GENERATIONS_RECORD_TTL,
        ))
      )
        throw new HttpException(
          "Anonymous quota reached.",
          HttpStatus.PAYMENT_REQUIRED,
        );

      return {
        tier: ANONYMOUS_SUBSCRIPTION_TIER,
        contextTokens:
          SubscriptionTierData[ANONYMOUS_SUBSCRIPTION_TIER].contextTokens,
        priority: ANONYMOUS_QUEUE_PRIORITY,
      };
    }

    if (model) {
      const modelAccessSettings = GenerationModelAccessRightsData[model];

      if (
        modelAccessSettings.trialCanAccess &&
        user.trialActivated &&
        user.trialActions > 0 &&
        !user.hasSubscription()
      ) {
        if (modelAccessSettings.trialDecreaseActions) {
          user.trialActions--;
          this.usersRepository.save(user);
        }

        return {
          tier: TRIAL_SUBSCRIPTION_TIER,
          contextTokens:
            SubscriptionTierData[TRIAL_SUBSCRIPTION_TIER].contextTokens,
          priority: TRIAL_QUEUE_PRIORITY,
        };
      }
    }

    if (user.hasSubscription(true)) {
      let consumed = await this.taskPriorityService.tryConsumeMPA(user, 1);
      if (!consumed.success)
        consumed = {
          success: true,
          priority: 10,
        };

      return {
        tier: user.subscriptionTier,
        contextTokens: user.getSubscriptionPerks().contextTokens,
        priority: consumed.priority,
      };
    } else if (user.userType != UserType.RETAIL) {
      const success = await this.taskPriorityService.tryConsumeCharacters(
        user,
        b2bMetric,
      );
      if (!success)
        throw new HttpException(
          `Insufficent characters left (need ${b2bMetric}, have ${user.classificationCharactersLeft}).`,
          HttpStatus.CONFLICT,
        );

      return {
        tier: MAX_TIER,
        contextTokens: 2048,
        priority: 11,
      };
    } else
      throw new HttpException(
        "Active subscription required.",
        HttpStatus.PAYMENT_REQUIRED,
      );
  }

  /* Request-based generation */
  async predictText(aiGenerateRequest: AiGenerateRequest) {
    const internalParameters =
      aiGenerateRequest.parameters as AiGenerateParametersInternal;

    internalParameters.used_model =
      aiGenerateRequest.model as unknown as string;

    this._token_count += aiGenerateRequest.parameters.min_length;

    return this.hydraService.generateText(
      aiGenerateRequest.input,
      internalParameters,
    );
  }

  async checkLeftoverCallbacks() {
    for (const [key, value] of this._stream_callbacks.entries()) {
      const currentUnixTime = Date.now() / 1000;
      if (
        value.lastUpdatedTime - currentUnixTime >=
        LEFTOVER_CALLBACK_NOT_UPDATED_MAX_LIFETIME
      ) {
        value.cb({
          error: "Node timeout",
          final: true,
        });

        this._stream_callbacks.delete(key);
      }
    }
  }

  /* Stream-based generation */
  async onStreamedGenerationData(data: AiGenerateStreamableResponse) {
    const generationId = (data as any).uuid;
    delete (data as any).uuid;

    if (this._stream_callbacks.has(generationId)) {
      this._stream_callbacks.get(generationId).lastUpdatedTime =
        Date.now() / 1000;

      this._stream_callbacks.get(generationId).cb(data);
    }

    if (data.final) this._stream_callbacks.delete(generationId);
  }

  async onNewNodeData(data: AiGenerateStreamableResponse) {
    const generationId = (data as any).uuid;
    if (this._stream_callbacks.has(generationId))
      // Should be parsed locally, no need to reroute it to other subscribers
      return this.onStreamedGenerationData(data);

    if (!this._redis_client_publisher) {
      console.error(
        "Receiving node data before Redis connection has been established!",
      );

      return;
    }

    // Do not have a callback for that stream, publish it to other subscribers
    this._redis_client_publisher.publish(
      PUBSUB_GENERATION_CHANNEL,
      JSON.stringify(data),
    );
  }

  async predictTextStream(
    aiGenerateRequest: AiGenerateRequest,
    onNewTokenCallback: OnNewTokenCallback,
  ) {
    if (this.hydraService.isDummy) {
      onNewTokenCallback({
        final: true,
        ptr: 0,
        token: "dummy",
      });
    } else {
      const internalParameters =
        aiGenerateRequest.parameters as AiGenerateParametersInternal;

      internalParameters.used_model =
        aiGenerateRequest.model as unknown as string;

      this._token_count += aiGenerateRequest.parameters.min_length;

      const generationId = nanoid();
      this._stream_callbacks.set(generationId, {
        cb: onNewTokenCallback,
        lastUpdatedTime: Date.now() / 1000,
      });

      internalParameters.generation_id = generationId;

      this.hydraService.generateText(
        aiGenerateRequest.input,
        internalParameters,
      );

      while (this._stream_callbacks.has(generationId)) await sleep(500);
    }
  }

  /* BERT classification */
  async classifyText(
    aiClassificationRequest:
      | AiSequenceClassificationDebertaRequest
      | AiSequenceClassificationBartRequest,
    priority: number,
  ) {
    return this.hydraService.classifyText(aiClassificationRequest, priority);
  }

  calculateImageStepCost(
    aiGenerateRequest: AiGenerateImageRequest,
    perks: SubscriptionTierPerks,
  ): StepCostData {
    const modelConfig =
      ImageGenerationModelSettingsData[aiGenerateRequest.model];
    if (!modelConfig) {
      throw new HttpException(
        "Invalid request parameters.",
        HttpStatus.BAD_REQUEST,
      );
    }
    const width = modelConfig.customResolutionSupported
      ? (+(
          aiGenerateRequest.parameters["width"] ?? modelConfig.defaultWidth
        ) as number)
      : modelConfig.defaultWidth;
    const height = modelConfig.customResolutionSupported
      ? (+(
          aiGenerateRequest.parameters["height"] ?? modelConfig.defaultHeight
        ) as number)
      : modelConfig.defaultHeight;
    let numPrompts = +(
      aiGenerateRequest.parameters["n_samples"] ?? 1
    ) as number;
    const steps = +(aiGenerateRequest.parameters["steps"] ?? 28) as number;
    const strengthMultiplier = (aiGenerateRequest.parameters["image"] as string)
      ? (aiGenerateRequest.parameters["strength"] as number) ?? 1
      : 1;

    if (
      isNaN(width) ||
      isNaN(height) ||
      isNaN(numPrompts) ||
      numPrompts > 10 ||
      numPrompts < 1
    )
      throw new HttpException(
        "Invalid request parameters.",
        HttpStatus.BAD_REQUEST,
      );

    let pixels = width * height;
    if (pixels < 256 * 256) {
      pixels = 256 * 256;
    }
    if (pixels > 1024 * 1024)
      throw new HttpException(
        "Invalid request resolution.",
        HttpStatus.BAD_REQUEST,
      );

    if (pixels == 1024 * 1024 && numPrompts > 1)
      throw new HttpException(
        "1024x1024 does not support n_samples > 1.",
        HttpStatus.BAD_REQUEST,
      );

    let requestEligibleForUnlimitedGeneration =
      perks.unlimitedImageGeneration && steps <= 28;
    let overFreeLimit = 0;
    let freePrompts = 0;
    if (
      requestEligibleForUnlimitedGeneration &&
      perks.unlimitedImageGenerationLimits.length > 0
    ) {
      let effectiveLimit: ImageGnerationLimit | null = null;

      for (const limit of perks.unlimitedImageGenerationLimits)
        if (limit.resolution >= pixels) effectiveLimit = limit;

      if (effectiveLimit && numPrompts > effectiveLimit.maxPrompts) {
        requestEligibleForUnlimitedGeneration = false;
        overFreeLimit = numPrompts - effectiveLimit.maxPrompts;
        freePrompts = effectiveLimit.maxPrompts;
      }
    }

    const costPerPrompt = Math.max(
      Math.ceil(
        calculatePromptPrice(width, height, steps) * strengthMultiplier,
      ),
      2,
    );

    if (overFreeLimit) {
      numPrompts = overFreeLimit;
    }

    return {
      costPerPrompt,
      numPrompts,
      requestEligibleForUnlimitedGeneration,
      freePrompts,
    };
  }

  /* Token metric */
  @Cron(CronExpression.EVERY_HOUR)
  async flushTokenMetric() {
    if (this._token_count == 0) return; // No tokens to save the metric of.

    const newMetric = await this.tokenMetricsRepository.create({
      tokenCount: this._token_count,
    });

    await this.tokenMetricsRepository.save(newMetric);

    this._token_count = 0;
  }
}
