Skip to content
Snippets Groups Projects
Commit a5a95127 authored by Riko Corwin Uphoff's avatar Riko Corwin Uphoff
Browse files

Tried to fix the unexpected argument "label" during finetuning

parent a245efdc
No related branches found
No related tags found
No related merge requests found
Pipeline #25412 passed
......@@ -82,7 +82,9 @@ def load_data_finetune(args, tokenizer):
texts = (
(batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key])
)
return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
result = tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
result["labels"] = batch["label"]
return result
dataset = dataset.map(tokenize_function_finetune)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
......
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
import torch
from load_data import arg_map
def get_model(args):
""" Creates model for Pretraining or Fine-Tuning """
......@@ -24,18 +26,25 @@ def get_model(args):
elif args.mode == "finetuning":
if args.model == "roberta":
num_labels = 1 if args.dataset == "glue_stsb" else 2 # TODO might be wrong
config = AutoConfig.from_pretrained(
args.model_name_or_path,
num_labels=num_labels,
finetuning_task=arg_map[args.dataset][1],
)
if args.dtype == "bf16":
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2,
torch_dtype=torch.bfloat16)
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels,
torch_dtype=torch.bfloat16, config=config)
else:
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2)
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels,
config=config)
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
elif args.model == "gpt2":
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
# elif args.model == "gpt2":
# model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
# tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
#
# tokenizer.pad_token = tokenizer.eos_token
# model.config.pad_token_id = tokenizer.pad_token_id
else:
raise ValueError("Invalid model name. Choose 'roberta' or 'gpt2'")
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment