from torch import optim
import numpy as np
#Based Optimizer

def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
    warmup_percent = np.clip(step, 0, warmup_steps) / warmup_steps
    anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps
    #cosine schedule for annealing
    return lr * warmup_percent - (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2


class BasedOptimizer:
    def __init__(self, parameters, config, optimizer):
        self.lr = config["lr"]
        self.end_lr = config["end_lr"] if "end_lr" in config else self.lr
        self.warmup_steps =  config["warmup_steps"] if "warmup_steps" in config else 1
        self.anneal_steps = config["anneal_steps"] if "anneal_steps" in config else 1
        self.total_steps = config["total_steps"] if "total_steps" in config else None
        self.weight_decay = config["weight_decay"] if "weight_decay" in config else 0
        self.tokens = config["tokens"] if "tokens" in config else None
        self.epochs = config["epochs"] if "epochs" in config else None
        # tokens and epochs should not be here. calculate it somewhere else and find how many steps, then pass to the BasedOptimizer
        self.beta1 = config["beta1"] if "beta1" in config else 0.9
        self.beta2 = config["beta2"] if "beta2" in config else 0.95
        self.eps = config["eps"] if "eps" in config else 1e-4
        self.max_lr = False
        self.curr_step = 0
        self.curr_lr = 0
            
        if optimizer == "adamw":
            self.optimizer = optim.AdamW(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
        
    def step(self):
        self.optimizer.step()
        self.curr_step = self.curr_step + 1
        self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)

        if not self.max_lr:
            if self.curr_lr == self.end_lr:
                print("max lr reached.")
                self.max_lr = True
                
            for paramx in self.optimizer.param_groups:
                paramx['lr'] = self.curr_lr
        
    def zero_grad(self):
        self.optimizer.zero_grad()
        
    def print_info(self):
        print(f"end_lr: {str(self.end_lr)}")
        print(f"warmup_steps: {str(self.warmup_steps)}")
        print(f"total_steps: {str(self.total_steps)}")
        print(f"weight_decay: {str(self.weight_decay)}")
        print(f"step: {str(self.curr_step)}")
        if self.curr_step != 0:
            print(f"curr_lr: {str(self.get_current_lr())}")