diff --git a/main.py b/main.py
index 3972801d6776c456ba62ec87fbe729113ab6e691..de77c26658971ea8b41536f2db76d8ce3dd8db7c 100644
--- a/main.py
+++ b/main.py
@@ -1,3 +1,5 @@
+from torch.optim.lr_scheduler import LinearLR
+
 from load_data import load_data
 from load_models import get_model
 from load_optimizers import get_optimizer
@@ -85,9 +87,10 @@ if __name__ == "__main__":
 
     num_batches = len(dataloader) if args.mode == "finetuning" else ceil(args.num_training_tokens / args.batch_size)
     num_steps = args.num_epochs * num_batches
-    scheduler = get_scheduler(
-        optimizer, args.lr_scheduler, args.warm_up_fraction, num_steps, args.lr, args.lr_min
-    )
+    # scheduler = get_scheduler(
+    #     optimizer, args.lr_scheduler, args.warm_up_fraction, num_steps, args.lr, args.lr_min
+    # )
+    scheduler = LinearLR(optimizer, args.lr, args.lr_min, num_steps)  # TODO
 
     trained_model = train(device, accelerator, scheduler, model, optimizer, dataloader, num_epochs=args.num_epochs)