Skip to content
Snippets Groups Projects
Commit f60b182c authored by Riko Corwin Uphoff's avatar Riko Corwin Uphoff
Browse files

Fixed(?) load_data for glue tasks

parent 4e857330
No related branches found
No related tags found
No related merge requests found
......@@ -38,14 +38,31 @@ def load_data_finetune(args, tokenizer):
"glue_qqp": ("glue", "qqp"),
"glue_rte": ("glue", "rte"),
"glue_sts-b": ("glue", "stsb"),
"glue_wnli": ("glue", "wnli")
}
task_to_keys = {
"glue_cola": ("sentence", None),
"glue_mnli": ("premise", "hypothesis"),
"glue_mrpc": ("sentence1", "sentence2"),
"glue_qnli": ("question", "sentence"),
"glue_qqp": ("question1", "question2"),
"glue_rte": ("sentence1", "sentence2"),
"glue_sst-2": ("sentence", None),
"glue_sts-b": ("sentence1", "sentence2"),
"glue_wnli": ("sentence1", "sentence2"),
}
if args.dataset not in arg_map:
raise ValueError(f"Data set '{args.dataset}' not supported for mode 'finetuning'!")
dataset = load_dataset(*arg_map[args.dataset])
# Extract useful text
sentence1_key, sentence2_key = task_to_keys[args.dataset]
def tokenize_function_finetune(batch):
# FIXME This fails for GLUE MNLI
return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=args.max_length)
texts = (
(batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key])
)
return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
dataset = dataset.map(tokenize_function_finetune)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment