Skip to content
Snippets Groups Projects
Commit 1ea710a6 authored by Riko Corwin Uphoff's avatar Riko Corwin Uphoff
Browse files

Merged main (scheduler fix)

parents 06c069aa 2687e71a
No related branches found
No related tags found
No related merge requests found
Pipeline #25397 passed
......@@ -17,9 +17,9 @@ def get_scheduler(
warm_up_scheduler = ConstantLR(optimizer, 1.0, warm_up_steps)
if scheduler_type == "constant":
annealing_scheduler = ConstantLR(optimizer, max_lr, annealing_steps)
annealing_scheduler = ConstantLR(optimizer, 1.0, annealing_steps)
elif scheduler_type == "linear":
annealing_scheduler = LinearLR(optimizer, max_lr, min_lr, annealing_steps)
annealing_scheduler = LinearLR(optimizer, 1.0, min_lr / max_lr, annealing_steps)
elif scheduler_type == "cosine":
annealing_scheduler = CosineAnnealingLR(optimizer, annealing_steps, min_lr)
else:
......
......@@ -45,7 +45,7 @@ def load_galore_config(args):
def get_optimizer(args, model):
"""Creates optimizer (GaLore, LoRa, or baseline AdamW)"""
default_lr = 1.0 # Will be scheduled by LRScheduler
default_lr = args.lr
if args.optimizer == "baseline":
return AdamW(model.parameters(), lr=default_lr, weight_decay=args.weight_decay), model
......@@ -66,7 +66,7 @@ def get_optimizer(args, model):
if args.optimizer == "lora":
return AdamW(model.parameters(), lr=args.lr), model
else:
galore_config = load_galore_config()
galore_config = load_galore_config(args)
trainable_params = [p for p in model.parameters() if p.requires_grad and p.dim() > 1]
param_groups = [
{"params": trainable_params, **galore_config}
......
@echo off
python main.py ^
--mode pretraining ^
--mode finetuning ^
--optimizer galore ^
--model llama_60m ^
--batch_size 8 ^
--model roberta ^
--dataset glue_cola ^
--batch_size 32 ^
--num_epochs 30 ^
--max_length 512 ^
--num_training_tokens 1000000 ^
--shuffle false ^
--dtype bf16 ^
--lr 4e-4 ^
--weight_decay 0.01 ^
--tmax 30 ^
--test true
\ No newline at end of file
--lr_scheduler constant ^
--lr 1e-5 ^
--lr_min 1e-8 ^
--warm_up_fraction 0 ^
--weight_decay 0 ^
--rank 8 ^
--galore_alpha 2 ^
--galore_T 200 ^
--lora_alpha 8 ^
--lora_dropout 0.1 ^
--test false
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment