diff --git a/main.py b/main.py index 7342f3dc53d25b3cde0bdd4f9f7b51648cf597e6..2cc320355402aab72809d73cbe631a2f69d9cc69 100644 --- a/main.py +++ b/main.py @@ -74,7 +74,7 @@ if __name__ == "__main__": optimizer, model = get_optimizer(args, model) - num_steps = ceil(args.epochs * len(dataset) / args.batch_size) + num_steps = ceil(args.num_epochs * len(dataset) / args.batch_size) scheduler = get_scheduler( optimizer, args.lr_scheduler, args.warm_up_fraction, num_steps, args.lr, args.lr_min )