import { HttpException, HttpStatus, Inject, Injectable } from "@nestjs/common";
import { InjectConnection, InjectRepository } from "@nestjs/typeorm";
import { User } from "../user/user.entity";
import {
  Connection,
  getConnection,
  LessThanOrEqual,
  Repository,
} from "typeorm";
import { AiModule } from "./ai-module.entity";
import {
  AiModuleTrainRequest,
  AiModuleDto,
  AiTrainingUpdate,
} from "./dto/ai-modules.dto";
import AWS = require("aws-sdk");
import { HydraService } from "./hydra.service";
import { paddleCreateSubscriptionCharge } from "../user/subscription/paddle.payment-processor";
import { PaymentProcessors } from "../user/subscription/payment-processors";
import { Cron, CronExpression } from "@nestjs/schedule";
import { StepCharge } from "./step-charge.entity";

function round(num: number, decimalPlaces = 0): number {
  num = Math.round((num + "e" + decimalPlaces) as unknown as number);
  return Number(num + "e" + -decimalPlaces);
}

@Injectable()
export class AIModuleService {
  private _s3: AWS.S3;

  constructor(
    @InjectRepository(AiModule)
    private trainingModulesRepository: Repository<AiModule>,

    @InjectRepository(User)
    private usersRepository: Repository<User>,

    @InjectRepository(StepCharge)
    private stepChargeRepository: Repository<StepCharge>,

    @InjectConnection()
    private connection: Connection,

    @Inject("GENERATION_SERVICE")
    private readonly hydraService: HydraService,
  ) {
    if (process.env.AWS_ENDPOINT_ADDR) {
      this._s3 = new AWS.S3({
        endpoint: new AWS.Endpoint(process.env.AWS_ENDPOINT_ADDR),
        credentials: {
          accessKeyId: process.env.AWS_ACCESS_KEY,
          secretAccessKey: process.env.AWS_SECRET_KEY,
        },
      });
    }
  }

  @Cron(CronExpression.EVERY_MINUTE)
  async purgeStuckJobs() {
    const MAX_TIMEOUT_MINUTES_TRAINING = 10;

    const ERROR_REASON = "training node timeout";

    const MS_PER_MINUTE = 60000;

    const dateBeforeTraining = new Date(
      +new Date() - MAX_TIMEOUT_MINUTES_TRAINING * MS_PER_MINUTE,
    );

    await this.trainingModulesRepository.update(
      {
        status: "training",
        lastUpdatedAt: LessThanOrEqual(dateBeforeTraining),
      },
      {
        status: "error",
        data: Buffer.from(ERROR_REASON, "utf-8"),
      },
    );

    // const dateBeforePending = new Date(
    //   +new Date() - MAX_TIMEOUT_MINUTES_PENDING * MS_PER_MINUTE,
    // );

    // await this.trainingModulesRepository.update(
    //   {
    //     status: "pending",
    //     lastUpdatedAt: LessThanOrEqual(dateBeforePending),
    //   },
    //   {
    //     status: "error",
    //     data: Buffer.from(ERROR_REASON, "utf-8"),
    //   },
    // );
  }

  async getUserModule(user: User, id: string) {
    return this.trainingModulesRepository.findOneBy({
      owner: { id: user.id },
      id,
      deleted: false,
    });
  }

  getStorageKey(aiModule: AiModule): string {
    const jobEnvironment = process.env.NODE_ENV === "production" ? "" : "dev_";
    return (
      jobEnvironment +
      aiModule.model +
      "_" +
      aiModule.id +
      "_" +
      Math.floor(aiModule.createdAt.getTime() / 1000)
    );
  }

  async trainModule(
    user: User,
    aiTrainingRequest: AiModuleTrainRequest,
  ): Promise<AiModuleDto> {
    const aiModule = this.trainingModulesRepository.create();

    aiModule.name = aiTrainingRequest.name;
    aiModule.description = aiTrainingRequest.description;
    aiModule.lr = aiTrainingRequest.lr;
    aiModule.steps = aiTrainingRequest.steps;
    aiModule.status = "pending";
    aiModule.owner = user;
    aiModule.lossHistory = [];
    aiModule.model = aiTrainingRequest.model;

    await this.trainingModulesRepository.save(aiModule);

    const trainingData = {
      data: aiTrainingRequest.data,
      lr: aiTrainingRequest.lr,
      steps: aiTrainingRequest.steps,
    };
    const trainingDataKey = this.getStorageKey(aiModule);

    try {
      await this._s3
        .putObject({
          Bucket: process.env.AWS_BUCKET_NAME,
          Key: trainingDataKey,
          Body: JSON.stringify(trainingData),
          ContentType: "application/json; charset=utf-8",
        })
        .promise();
    } catch {
      throw new HttpException(
        "Data upload error.",
        HttpStatus.INTERNAL_SERVER_ERROR,
      );
    }

    /*if (!this.hydraService.isDummy)
      this.hydraService.sendTrainingJob("6B-v3", {
        id: aiModule.id,
        key: trainingDataKey,
        lr: aiModule.lr,
        steps: aiModule.steps,
      });*/

    return this.moduleToModuleDto(aiModule);
  }

  async removeTrainingDataFromBucket(module: AiModule) {
    await this._s3
      .deleteObjects({
        Bucket: process.env.AWS_BUCKET_NAME,
        Delete: {
          Objects: [
            {
              Key: this.getStorageKey(module),
            },
          ],
        },
      })
      .promise();
  }

  async deleteModule(module: AiModule) {
    await this.removeTrainingDataFromBucket(module);
    await this.trainingModulesRepository.update(module.id, { deleted: true });
  }

  async getAllUserModules(user: User): Promise<AiModuleDto[]> {
    return (
      await this.trainingModulesRepository.findBy({
        owner: { id: user.id },
        deleted: false,
      })
    ).map((item) => this.moduleToModuleDto(item));
  }

  async onTrainingUpdate(data: AiTrainingUpdate) {
    const id = data.id;
    const trainingModule = await this.trainingModulesRepository.findOne({
      relations: ["owner"],
      where: {
        id,
        deleted: false,
      },
    });
    if (!trainingModule) return;

    const oldStatus = trainingModule.status;
    const oldData = trainingModule.data;

    if (oldStatus === "ready") {
      throw new HttpException(
        "Completed module training can't be updated again",
        HttpStatus.CONFLICT,
      );
    }

    trainingModule.status = data.status;
    if (data.data.length > 0)
      trainingModule.data = Buffer.from(data.data, "utf-8");

    if (data.status === "ready")
      await this.tryChargeTrainingSteps(
        trainingModule.owner,
        trainingModule.steps,
        "training_ready",
        trainingModule.id,
      );

    if (data.status === "training" && data.data.length > 0) {
      try {
        const parsedData = JSON.parse(data.data);

        if (parsedData.loss !== undefined) {
          if (
            typeof parsedData.percentage === "number" &&
            oldData &&
            oldData.length > 0
          ) {
            try {
              const previousData = JSON.parse(oldData.toString("utf-8"));

              if (typeof previousData.percentage === "number")
                if (previousData.percentage > parsedData.percentage)
                  trainingModule.lossHistory = [];
            } catch {}
          }

          trainingModule.lossHistory.push(parsedData.loss);
        }
      } catch {}
    }

    await this.trainingModulesRepository.save(trainingModule);
  }

  moduleToModuleDto(module: AiModule): AiModuleDto {
    return {
      id: module.id,
      lastUpdatedAt: Math.floor(+module.lastUpdatedAt / 1000),
      name: module.name,
      description: module.description,
      status: module.status,
      lr: module.lr,
      steps: module.steps,
      data: module.data?.toString("utf-8"),
      lossHistory: module.lossHistory || [],
      model: module.model,
    };
  }

  async tryChargeTrainingSteps(
    user: User,
    steps: number,
    type: string,
    description: string,
    addLog = true,
  ): Promise<boolean> {
    if (steps <= 0) return true;

    const didCharge = await this.connection.transaction(async (manager) => {
      const _user = await manager.findOne(User, {
        where: {
          id: user.id,
        },
      });

      let stepBudget = steps;

      // Try to subtract from the fixed steps first
      if (_user.availableModuleTrainingSteps > 0) {
        // Subtract either the module step amount or the rest of the steps there are
        const amountToSubtractFromFixedSteps = Math.min(
          _user.availableModuleTrainingSteps,
          steps,
        );

        _user.availableModuleTrainingSteps -= amountToSubtractFromFixedSteps;
        stepBudget -= amountToSubtractFromFixedSteps;
      }

      // Still have some leftover steps? Subtract them from the purchased step amount
      if (stepBudget > 0) {
        _user.purchasedModuleTrainingSteps -= stepBudget;
        stepBudget = 0;
      }

      if (stepBudget > 0) return false;

      await manager.update(
        User,
        { id: user.id },
        {
          purchasedModuleTrainingSteps: _user.purchasedModuleTrainingSteps,
          availableModuleTrainingSteps: _user.availableModuleTrainingSteps,
        },
      );
      user.purchasedModuleTrainingSteps = _user.purchasedModuleTrainingSteps;
      user.availableModuleTrainingSteps = _user.availableModuleTrainingSteps;

      return true;
    });

    if (didCharge && addLog) {
      await this.stepChargeRepository.save(
        this.stepChargeRepository.create({
          amount: steps,
          owner: user,
          type,
          description,
        }),
      );
    }
    if (didCharge) {
      await user.flushCache();
    }

    return didCharge;
  }

  async tryBuyTrainingSteps(user: User, steps: number) {
    if (!user.hasSubscription())
      throw new HttpException(
        "Can't purchase module training steps while not subscribed.",
        HttpStatus.CONFLICT,
      );

    if (
      !user.subscriptionData ||
      user.subscriptionData.paymentProcessor != PaymentProcessors.PADDLE
    )
      throw new HttpException(
        "Incorrect subscription payment processor (A recurrent subscription is required).",
        HttpStatus.CONFLICT,
      );

    const amount = round(2 + steps / 1111 - 0.01, 2);

    try {
      const data = await paddleCreateSubscriptionCharge(
        +user.subscriptionData.id,
        amount,
        `${steps} Module Training Steps`,
      );

      if (!data.success)
        throw new HttpException(
          data.error
            ? data.error.message
            : "The request was not successful yet no error was provided.",
          HttpStatus.CONFLICT,
        );

      user.purchasedModuleTrainingSteps += steps;
      await this.usersRepository.save(user);
    } catch (ex) {
      console.error(ex); // for troubleshooting

      throw new HttpException("Payment processor error.", HttpStatus.CONFLICT);
    }
  }
}
