package gpu

import (
	"context"
	"github.com/NovelAI/interfaces/anlatan/completion"
	"github.com/wbrown/novelai-research-tool/gpt-bpe"
	"google.golang.org/grpc"
	"io"
	"log"
)

type GenerationParams struct {
	Temperature      float64
	OutputLength     uint32
	PresencePenalty  float64
	FrequencyPenalty float64
	LogProbs         uint32
	TopP             float64
	TopK             uint32
	TFS              float64
}

func (genParams *GenerationParams) buildRequest(prompt string,
	echo *int) *completion.Request {
	tokens := gpt_bpe.Encoder.Encode(&prompt)
	numTokensPrompt := len(*tokens)

	maxTokens := genParams.OutputLength + uint32(numTokensPrompt)
	if maxTokens > 2048 {
		maxTokens = 2048
	}

	completions := uint32(1)

	var echoParam *completion.Echo
	if echo != nil {
		echoVal := int32(*echo)
		echoParam = &completion.Echo{
			Index: &echoVal,
		}
	}

	return &completion.Request{
		Prompt: []*completion.Prompt{
			{Prompt: &completion.Prompt_Text{
				Text: prompt,
			}}},
		ModelParams: &completion.ModelParams{
			SamplingParams: &completion.SamplingParams{
				Order:            nil,
				Temperature:      &genParams.Temperature,
				TopP:             &genParams.TopP,
				TopK:             &genParams.TopK,
				TailFreeSampling: &genParams.TFS,
			},
			FrequencyParams: &completion.FrequencyParams{
				PresencePenalty:  &genParams.PresencePenalty,
				FrequencyPenalty: &genParams.FrequencyPenalty,
			},
			LogitBias: nil,
		},
		EngineParams: &completion.EngineParams{
			MaxTokens:   &maxTokens,
			Completions: &completions,
			Logprobs:    &genParams.LogProbs,
			Echo:        echoParam,
		},
	}
}

func TestCompletion() {
	// GRPC stuff
	serverAddr := "localhost:50051"
	conn, grpcErr := grpc.Dial(serverAddr, grpc.WithInsecure())
	if grpcErr != nil {
		log.Fatal(grpcErr)
	}
	defer conn.Close()
	client := completion.NewCompletionServiceClient(conn)

	genSettings := GenerationParams{
		Temperature:      1.0,
		OutputLength:     40,
		PresencePenalty:  0,
		FrequencyPenalty: 0,
		LogProbs:         30,
		TopP:             0,
		TopK:             0,
		TFS:              0.8,
	}

	rq := genSettings.buildRequest("This is a test prompt", nil)

	stream, rqErr := client.Completion(context.Background(), rq)
	if rqErr != nil {
		log.Fatal(rqErr)
	}
	for {
		answer, err := stream.Recv()
		if err == io.EOF {
			break
		}
		log.Printf("%v", answer)
	}
}
