Skip to content
Snippets Groups Projects
Commit b0c3c3cc authored by Riko Corwin Uphoff's avatar Riko Corwin Uphoff
Browse files

Fixed UnexpectedKeywordError during training

parent 03dc8efa
Branches
No related tags found
No related merge requests found
Pipeline #25424 passed
......@@ -86,8 +86,13 @@ def load_data_finetune(args, tokenizer):
result["labels"] = batch["label"]
return result
dataset = dataset.map(tokenize_function_finetune)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
dataset = dataset.map(
tokenize_function_finetune,
batched=True,
remove_columns=dataset["train"].column_names,
desc="Running tokenizer on dataset",
)
dataset.set_format(type="torch")
eval_dataset = dataset["validation_matched" if args.dataset == "glue_mnli" else "validation"]
train_dataset = dataset["train"]
......
......@@ -33,11 +33,10 @@ def get_model(args):
finetuning_task=arg_map[args.dataset][1],
)
if args.dtype == "bf16":
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels,
torch_dtype=torch.bfloat16, config=config)
else:
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels,
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", torch_dtype=torch.bfloat16,
config=config)
else:
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", config=config)
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
# elif args.model == "gpt2":
# model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment