diff --git a/main.py b/main.py index 3972801d6776c456ba62ec87fbe729113ab6e691..de77c26658971ea8b41536f2db76d8ce3dd8db7c 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,5 @@ +from torch.optim.lr_scheduler import LinearLR + from load_data import load_data from load_models import get_model from load_optimizers import get_optimizer @@ -85,9 +87,10 @@ if __name__ == "__main__": num_batches = len(dataloader) if args.mode == "finetuning" else ceil(args.num_training_tokens / args.batch_size) num_steps = args.num_epochs * num_batches - scheduler = get_scheduler( - optimizer, args.lr_scheduler, args.warm_up_fraction, num_steps, args.lr, args.lr_min - ) + # scheduler = get_scheduler( + # optimizer, args.lr_scheduler, args.warm_up_fraction, num_steps, args.lr, args.lr_min + # ) + scheduler = LinearLR(optimizer, args.lr, args.lr_min, num_steps) # TODO trained_model = train(device, accelerator, scheduler, model, optimizer, dataloader, num_epochs=args.num_epochs)