from basedformer import optimizer
import torch
from tqdm import tqdm
import wandb
import os
from pathlib import Path
train_config = {
    "lr": 5e-4,
    "end_lr": 1e-4,
    "warmup_steps": 100,
    "anneal_steps": 90,
}

model = torch.nn.Linear(10, 100)
save_folder = "models/test_optimizer6"
if not os.path.isdir(save_folder + "/opt"):
    opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
else:
    opt = optimizer.BasedOptimizer.load(model.parameters(), Path(save_folder))

wandb.init(project="opt-test", name="test")

for x in tqdm(range(opt.curr_step, 100)):
    opt.step(dry_run=True)
    # current step gets iterated before the logging, so negate 1.
    print(f"Step {opt.curr_step - 1}: LR {opt.curr_lr}")
    wandb.log({"lr": opt.curr_lr})
    if x == 60:
        opt.save(Path(save_folder))
