import { Injectable } from "@nestjs/common";
import * as amqp from "amqplib";
import * as events from "events";
import { nanoid } from "nanoid";
import {
  AiGenerateParametersInternal,
  AiSequenceClassificationResponse,
} from "./dto/ai-generate.dto";

const REPLY_QUEUE = "amq.rabbitmq.reply-to";

const TEXT_GENERATION_PREFIX = "generation_jobs";
// const IMAGE_GENERATION_PREFIX = "generation_jobs_images";
const CLASSIFICATION_PREFIX = "generation_jobs_B2B";

type AMQPChannelWithResponseEmitter = amqp.Channel & {
  responseEmitter: events.EventEmitter;
};

@Injectable()
export class HydraService {
  isDummy = false;

  _client: amqp.Connection | null = null;
  _channel: AMQPChannelWithResponseEmitter | null = null;

  constructor() {
    this.isDummy = process.env.HYDRA_DUMMY !== undefined;
  }

  async _asyncCreate() {
    if (this.isDummy) return;

    this._client = await amqp.connect(process.env.AI_KEY);
    this._channel =
      (await this._client.createChannel()) as AMQPChannelWithResponseEmitter;

    this._channel.responseEmitter = new events.EventEmitter();
    this._channel.responseEmitter.setMaxListeners(0);
    this._channel.consume(
      REPLY_QUEUE,
      (msg) => {
        this._channel.responseEmitter.emit(
          msg.properties.correlationId,
          msg.content.toString("utf8"),
        );
      },
      { noAck: true },
    );
  }

  private _sendRMQData(
    queue: string,
    message: string,
    isRpc = true,
    priority: number | undefined = undefined,
  ): Promise<any> | void {
    const correlationId = isRpc ? nanoid() : undefined;

    this._channel.sendToQueue(queue, Buffer.from(message), {
      correlationId,
      replyTo: isRpc ? REPLY_QUEUE : undefined,
      priority,
    });

    if (isRpc)
      return new Promise((resolve) => {
        this._channel.responseEmitter.once(correlationId, resolve);
      });
  }

  async generateText(
    input: string,
    parameters: AiGenerateParametersInternal,
  ): Promise<any | void> {
    if (this.isDummy)
      return {
        output: "dummy",
      };

    const shouldBeStreamed = !!parameters.generation_id;
    const isRpc = !shouldBeStreamed;

    const jobEnvironment = process.env.NODE_ENV === "production" ? "" : "_dev";
    const queue = `${TEXT_GENERATION_PREFIX}_${parameters.used_model}${jobEnvironment}`;

    const promiseOrNothing = this._sendRMQData(
      queue,
      JSON.stringify({
        input,
        parameters,
      }),
      isRpc,
      Math.round(parameters.priority),
    );

    if (isRpc) {
      const responseText = (await promiseOrNothing) as string;
      const responseJSON = JSON.parse(responseText);
      return responseJSON;
    }
  }

  sendTrainingJob(model: string, message: Record<string, unknown>) {
    const jobEnvironment = process.env.NODE_ENV === "production" ? "" : "_dev";
    const queueName = `training_jobs_${model}${jobEnvironment}`;

    const serializedMessage = JSON.stringify(message);
    const messageBuffer = Buffer.from(serializedMessage);

    this._channel.sendToQueue(queueName, messageBuffer);
  }

  async classifyText(
    aiClassificationRequest: any,
    priority: number,
  ): Promise<AiSequenceClassificationResponse> {
    if (this.isDummy)
      return {
        output: [{ label: "dummy", score: 0 }],
      };

    const jobEnvironment = process.env.NODE_ENV === "production" ? "" : "_dev";
    const queue = `${CLASSIFICATION_PREFIX}_${aiClassificationRequest.model}${jobEnvironment}`;

    const responseText = await this._sendRMQData(
      queue,
      JSON.stringify(aiClassificationRequest),
      true,
      Math.round(priority),
    );

    return JSON.parse(responseText);
  }
}

export async function createHydraService(): Promise<HydraService> {
  const newHydraServiceInstance = new HydraService();
  await newHydraServiceInstance._asyncCreate();
  return newHydraServiceInstance;
}
