package gobot

import (
	"bytes"
	"encoding/base64"
	"errors"
	"fmt"
	"io"
	"log"
	"math/rand"
	gobotHttp "nai-gobot/http"
	"nai-gobot/includes"
	"net/http"
	"strconv"
	"strings"

	"github.com/bwmarrin/discordgo"
)

func MakeImageGenerationRequest(conf *Config, data *discordgo.ApplicationCommandInteractionData) *includes.ImageGenerationRequest {
	genReq := includes.ImageGenerationRequest{
		Sampler:        conf.DefaultSampler,
		Model:          conf.ImageModels[conf.DefaultModel],
		Module:         conf.DefaultModule,
		Steps:          conf.DefaultSteps,
		Height:         conf.DefaultHeight,
		Width:          conf.DefaultWidth,
		NumGenerations: conf.DefaultNumGen,
		Scale:          conf.DefaultScale,
		Strength:       conf.DefaultStrength,
		Noise:          conf.DefaultNoise,
		Seed:           rand.Intn(4294967295),
	}
	for _, command := range data.Options {
		switch command.Name {
		case "prompt":
			genReq.Prompt = command.StringValue()
		case "model":
			genReq.Model = conf.ImageModels[command.StringValue()]
		case "sampler":
			genReq.Sampler = command.StringValue()
		case "height":
			genReq.Height = int(command.IntValue())
		case "width":
			genReq.Width = int(command.IntValue())
		case "num_gen":
			genReq.NumGenerations = int(command.IntValue())
		case "scale":
			genReq.Scale = int(command.IntValue())
		case "steps":
			genReq.Steps = int(command.IntValue())
		case "seed":
			genReq.Seed = int(command.IntValue())
		case "module":
			genReq.Module = command.StringValue()
		case "uc":
			genReq.Uc = command.StringValue()
		case "floatArgs":
			genReq.Strength = command.FloatValue()
		case "noise":
			genReq.Noise = command.FloatValue()
		}
	}
	return &genReq
}

func SendImage(conf *Config, s *discordgo.Session, i *discordgo.InteractionCreate, genReq *includes.ImageGenerationRequest, apiObject *includes.APIEndpoint, allowAdvancedMode bool) error {
	var user = new(discordgo.User)
	if i.Member != nil {
		user = i.Member.User
	} else {
		user = i.User
	}
	var discordFiles = make([]*discordgo.File, 0)
	generations, err := gobotHttp.GenerateImage(*genReq, apiObject, conf.LogRequests)
	if err != nil {
		return err
	}

	for index, generation := range generations.Images {
		img, err := base64.StdEncoding.DecodeString(generation)
		if err != nil {
			return fmt.Errorf("an error occurred decoding the image: %v", err.Error())
		}
		discordFile := discordgo.File{
			Name:   fmt.Sprintf("generation_%d.png", index),
			Reader: bytes.NewReader(img),
		}
		discordFiles = append(discordFiles, &discordFile)
	}

	components := make([]discordgo.MessageComponent, 0)
	if allowAdvancedMode {
		components = append(components, discordgo.ActionsRow{
			Components: []discordgo.MessageComponent{conf.Buttons["strength"]},
		})
		components = append(components, discordgo.ActionsRow{
			Components: []discordgo.MessageComponent{conf.Buttons["noise"]},
		})
		components = append(components, discordgo.ActionsRow{
			Components: []discordgo.MessageComponent{conf.Buttons["enhance"]},
		})
	}

	_, err = s.FollowupMessageCreate(i.Interaction, true, &discordgo.WebhookParams{
		Content:    fmt.Sprintf("Generated for %v --- `/draw %v`", user.Mention(), genReq.GetStr()),
		Files:      discordFiles,
		Components: components,
	})
	return err
}

func StrToImageGenerationRequest(str string, imgUrl string) (*includes.ImageGenerationRequest, *includes.APIEndpoint, error) {
	var req includes.ImageGenerationRequest
	var prompt string
	// remove ` from the string
	str = strings.ReplaceAll(str, "`", "")
	removeSlash := str[5:]
	promptEnd := strings.Index(removeSlash, "prompt: ") + 8
	samplerBegin := strings.Index(removeSlash, "sampler: ")
	if strings.Contains(removeSlash, "uc: ") {
		ucBegin := strings.Index(removeSlash, "uc: ")
		ucEnd := ucBegin + 4
		req.Uc = removeSlash[ucEnd : samplerBegin-1]
		prompt = removeSlash[promptEnd : ucBegin-1]
	} else {
		prompt = removeSlash[promptEnd : samplerBegin-1]
	}
	removeSlash = removeSlash[samplerBegin:]
	strList := strings.SplitN(removeSlash, " ", 22)
	if len(strList) != 22 {
		return nil, nil, errors.New("invalid format")
	}
	req = includes.ImageGenerationRequest{
		Prompt:         prompt,
		Sampler:        strList[1],
		NumGenerations: StrToInt(strList[5]),
		Steps:          StrToInt(strList[7]),
		Height:         StrToInt(strList[9]),
		Width:          StrToInt(strList[11]),
		Scale:          StrToInt(strList[13]),
		Seed:           StrToInt(strList[15]),
		Module:         strList[17],
		Strength:       StrToF64(strList[19]),
		Noise:          StrToF64(strList[21]),
	}
	endpoint := includes.APIEndpoint{
		URL:  strList[3],
		Type: "imagegen",
	}
	if imgUrl != "" {
		imagePng, err := http.Get(imgUrl)
		if imagePng.StatusCode != http.StatusOK {
			return nil, nil, errors.New("Getting image failed, status code: " + strconv.Itoa(imagePng.StatusCode))
		}
		if err != nil {
			return nil, nil, errors.New("Error getting image: " + err.Error())
		}
		defer func(Body io.ReadCloser) {
			err := Body.Close()
			if err != nil {
				log.Printf("error closing image: %v", err)
			}
		}(imagePng.Body)
		imageBytes, err := io.ReadAll(imagePng.Body)
		if err != nil {
			return nil, nil, errors.New("Error reading image: " + err.Error())
		}
		req.Image = base64.StdEncoding.EncodeToString(imageBytes)
	}
	return &req, &endpoint, nil
}

func MessageComponentHandler(conf *Config, i *discordgo.InteractionCreate, genReq *includes.ImageGenerationRequest) error {
	if i.Type != discordgo.InteractionMessageComponent {
		return errors.New("not a message component")
	}
	data := i.MessageComponentData()
	if _, ok := conf.Buttons[data.CustomID]; ok {
		if data.CustomID == "strength" {
			genReq.Strength = StrToF64(i.MessageComponentData().Values[0])
		} else if data.CustomID == "noise" {
			genReq.Noise = StrToF64(i.MessageComponentData().Values[0])
		}
		if genReq.Height <= 512 && genReq.Width <= 512 {
			genReq.Height = genReq.Height * 2
			genReq.Width = genReq.Width * 2
		}
		return nil
	} else {
		return errors.New("invalid button")
	}

}
