From eccaeef87c5da12aa8fa7af8b4243e825f08134f Mon Sep 17 00:00:00 2001 From: Riko Uphoff <riko.uphoff@student.uni-halle.de> Date: Fri, 4 Apr 2025 16:15:03 +0200 Subject: [PATCH] Added scheduling --- main.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 3972801..de77c26 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,5 @@ +from torch.optim.lr_scheduler import LinearLR + from load_data import load_data from load_models import get_model from load_optimizers import get_optimizer @@ -85,9 +87,10 @@ if __name__ == "__main__": num_batches = len(dataloader) if args.mode == "finetuning" else ceil(args.num_training_tokens / args.batch_size) num_steps = args.num_epochs * num_batches - scheduler = get_scheduler( - optimizer, args.lr_scheduler, args.warm_up_fraction, num_steps, args.lr, args.lr_min - ) + # scheduler = get_scheduler( + # optimizer, args.lr_scheduler, args.warm_up_fraction, num_steps, args.lr, args.lr_min + # ) + scheduler = LinearLR(optimizer, args.lr, args.lr_min, num_steps) # TODO trained_model = train(device, accelerator, scheduler, model, optimizer, dataloader, num_epochs=args.num_epochs) -- GitLab