package main

import (
	"context"
	"errors"
	"fmt"
	"github.com/NovelAI/interfaces/anlatan/auth"
	"github.com/NovelAI/interfaces/anlatan/completion"
	"github.com/cockroachdb/cmux"
	"github.com/improbable-eng/grpc-web/go/grpcweb"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
	"io"
	"log"
	"net"
	"net/http"
	"os"
	"time"
)

type anlatanServer struct {
	completion.CompletionServiceServer
	auth.AuthServiceServer
	authenticated bool
	nodeSpec      string
	ComputeNodes  map[string]*completion.CompletionServiceClient
}

var lameStaticToken string

func (anls *anlatanServer) Authenticate(ctx context.Context,
	request *auth.AuthRequest) (*auth.AuthResponse, error) {
	_ = grpc.SetHeader(ctx, metadata.Pairs("Access-Control-Allow-Origin",
		"*"))
	if request.StaticBearer != nil {
		if *request.StaticBearer == lameStaticToken {
			anls.authenticated = true
			log.Println("AUTH: Authentication successful")
			return &auth.AuthResponse{
				Authorized: true,
			}, nil
		} else {
			log.Println("AUTH: Authentication failed")
			return nil, errors.New("Authentication failed")
		}
	} else {
		log.Println("AUTH: Empty auth")
		return nil, errors.New("Empty authentication request")
	}
}

func countTokens(ans *completion.Answer) int {
	choices := ans.GetChoices()
	total := 0
	for choiceIdx := range choices {
		choice := choices[choiceIdx]
		total += len(choice.Logprobs.Tokens.GetLogprobs())
	}
	return total
}

func (anls *anlatanServer) Completion(request *completion.Request,
	server completion.CompletionService_CompletionServer) error {
	server.SetHeader(metadata.Pairs("Access-Control-Allow-Origin",
		"*"))
	if !anls.authenticated {
		log.Println("COMPLETION: Unauthenticated attempt")
		return errors.New("Not authenticated!")
	}
	resolved := fmt.Sprintf(anls.nodeSpec, request.EngineId)
	log.Printf("COMPLETION: got request for engine %s, resolved to: %s",
		request.EngineId, resolved)
	computeNode, err := anls.GetComputeNode(resolved)
	if err != nil {
		log.Printf("ERROR ON NODE ACQUISITON: %v", err)
		return err
	}
	rqBegin := time.Now()
	ctx := context.Background()
	stream, rqErr := (*computeNode).Completion(ctx, request)
	if rqErr != nil {
		log.Println(rqErr)
		return rqErr
	}
	ctr := 0
	var streamBegin time.Time
	for {
		if answer, ansErr := stream.Recv(); ansErr != nil {
			streamEnd := time.Now()
			firstResp := streamBegin.Sub(rqBegin).Milliseconds()
			duration := streamEnd.Sub(rqBegin).Milliseconds()
			tokensPerSecond := float64(ctr) / float64(duration) * 1000
			log.Println(fmt.Sprintf("COMPLETION: %d ms for %d tokens ("+
				"%0.2f tokens/s), %d ms to first token, status: %v",
				duration, ctr, tokensPerSecond, firstResp, ansErr))
			if ansErr != io.EOF {
				return ansErr
			} else {
				return nil
			}
		} else {
			if ctr == 0 {
				streamBegin = time.Now()
			}
			ctr += countTokens(answer)
			sndErr := server.Send(answer)
			if sndErr != nil {
				return sndErr
			}
		}
	}
}

func (anls *anlatanServer) GetComputeNode(addr string) (*completion.
	CompletionServiceClient, error) {
	if nodeConn, ok := anls.ComputeNodes[addr]; !ok {
		conn, grpcErr := grpc.Dial(addr, grpc.WithInsecure())
		if grpcErr != nil {
			return nil, grpcErr
		}
		client := completion.NewCompletionServiceClient(conn)
		anls.ComputeNodes[addr] = &client
		return &client, nil
	} else {
		return nodeConn, nil
	}
}

func newAnlatanServer() *anlatanServer {
	anls := anlatanServer{}
	return &anls
}

func main() {
	var port string
	var nodeSpec string
	var bindHost string
	var exist bool
	if lameStaticToken, exist = os.LookupEnv("STATIC_TOKEN"); !exist {
		lameStaticToken = "gooseHonkHonk"
	}
	if port, exist = os.LookupEnv("PORT"); exist {
		bindHost = fmt.Sprintf("0.0.0.0:%s", port)
	} else if bindHost, exist = os.LookupEnv("BIND_HOST"); !exist {
		bindHost = "localhost:8443"
	}
	if nodeSpec, exist = os.LookupEnv("NODE_SPEC"); !exist {
		nodeSpec = "localhost:50051"
	}
	log.Printf("GRPC server listening on: %s", bindHost)
	log.Printf("Node Specs: '%s'", nodeSpec)

	var opts []grpc.ServerOption
	grpcServer := grpc.NewServer(opts...)
	lis, err := net.Listen("tcp", bindHost)
	if err != nil {
		log.Fatal(err)
	}
	m := cmux.New(lis)
	// Match gRPC requests here
	grpcL := m.Match(cmux.HTTP2HeaderField("content-type",
		"application/grpc"))
	// Otherwise match regular http requests.
	httpL := m.Match(cmux.Any())

	wrappedGrpc := grpcweb.WrapServer(grpcServer,
		grpcweb.WithAllowedRequestHeaders([]string{"*"}))
	server := http.Server{
		Addr: bindHost,
	}

	server.Handler = http.HandlerFunc(func(resp http.ResponseWriter,
		req *http.Request) {
		resp.Header().Set("Access-Control-Allow-Origin", "*")
		resp.Header().Set("Access-Control-Allow-Headers", "*")
		if req.Method == "OPTIONS" {
			resp.WriteHeader(200)
			return
		}
		if wrappedGrpc.IsGrpcWebRequest(req) {
			wrappedGrpc.ServeHTTP(resp, req)
			return
		}
		// Fall back to other servers.
		http.DefaultServeMux.ServeHTTP(resp, req)
	})

	anls := newAnlatanServer()
	anls.nodeSpec = nodeSpec
	anls.ComputeNodes = make(map[string]*completion.CompletionServiceClient, 0)

	completion.RegisterCompletionServiceServer(grpcServer, anls)
	auth.RegisterAuthServiceServer(grpcServer, anls)

	go server.Serve(httpL)
	go m.Serve()

	if srvErr := grpcServer.Serve(grpcL); srvErr != nil {
		log.Fatal(srvErr)
	}
}
