Skip to content
Snippets Groups Projects
Commit a245efdc authored by Riko Corwin Uphoff's avatar Riko Corwin Uphoff
Browse files

Fixed naming bugs

parent 23462f92
No related branches found
No related tags found
No related merge requests found
Pipeline #25408 passed
......@@ -87,7 +87,7 @@ def load_data_finetune(args, tokenizer):
dataset = dataset.map(tokenize_function_finetune)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
eval_dataset = dataset["validation_matched" if args.task_name == "glue_mnli" else "validation"]
eval_dataset = dataset["validation_matched" if args.dataset == "glue_mnli" else "validation"]
train_dataset = dataset["train"]
return train_dataset, eval_dataset
......@@ -10,14 +10,13 @@ from math import ceil
from logger import init_csv, log_to_csv
from accelerate import Accelerator
import torch
import evaluate
from evaluate import load
from args import args
from time import perf_counter
def train(accelerator, scheduler, model, optimizer, train_dataloader, eval_dataloader, num_epochs):
def train(scheduler, model, optimizer, train_dataloader, eval_dataloader, num_epochs):
""" training model """
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) # To GPU
model.train()
for epoch in range(num_epochs):
......@@ -43,7 +42,7 @@ def train(accelerator, scheduler, model, optimizer, train_dataloader, eval_datal
if batch_cnt % args.eval_every == 0:
train_loss = loss.item()
eval_loss, eval_mertic = evaluate(accelerator, model, tokenizer, eval_dataloader)
eval_loss, eval_mertic = evaluate(accelerator, model, eval_dataloader)
log_to_csv(epoch + 1, batch_cnt, compute_time, train_loss, eval_loss, eval_mertic)
reset_peak_memory_stats()
......@@ -57,7 +56,7 @@ def train(accelerator, scheduler, model, optimizer, train_dataloader, eval_datal
def evaluate(accelerator, model, eval_dataloader):
if args.mode == "finetuning":
metric = evaluate.load(*arg_map[args.task_name])
metric = load(*arg_map[args.task_name])
else:
metric = None # TODO Metric fro pretraining
......@@ -73,7 +72,7 @@ def evaluate(accelerator, model, eval_dataloader):
break
if args.mode == "finetuning":
is_regression = args.task_name == "glue_stsb"
is_regression = args.dataset == "glue_stsb"
labels = batch["labels"]
with torch.no_grad():
......@@ -137,7 +136,9 @@ if __name__ == "__main__":
)
# Train
trained_model = train(accelerator, lr_scheduler, model, optimizer, train_dataloader, num_epochs=args.num_epochs)
trained_model = train(
lr_scheduler, model, optimizer, train_dataloader, eval_dataloader, args.num_epochs
)
# Evaluate
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment