From f60b182c0884d919503a47840401ed0ceb3fda1a Mon Sep 17 00:00:00 2001 From: Riko Uphoff <riko.uphoff@student.uni-halle.de> Date: Fri, 4 Apr 2025 13:35:31 +0200 Subject: [PATCH] Fixed(?) load_data for glue tasks --- load_data.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/load_data.py b/load_data.py index 24311a1..6c7981c 100644 --- a/load_data.py +++ b/load_data.py @@ -38,14 +38,31 @@ def load_data_finetune(args, tokenizer): "glue_qqp": ("glue", "qqp"), "glue_rte": ("glue", "rte"), "glue_sts-b": ("glue", "stsb"), + "glue_wnli": ("glue", "wnli") + } + task_to_keys = { + "glue_cola": ("sentence", None), + "glue_mnli": ("premise", "hypothesis"), + "glue_mrpc": ("sentence1", "sentence2"), + "glue_qnli": ("question", "sentence"), + "glue_qqp": ("question1", "question2"), + "glue_rte": ("sentence1", "sentence2"), + "glue_sst-2": ("sentence", None), + "glue_sts-b": ("sentence1", "sentence2"), + "glue_wnli": ("sentence1", "sentence2"), } if args.dataset not in arg_map: raise ValueError(f"Data set '{args.dataset}' not supported for mode 'finetuning'!") dataset = load_dataset(*arg_map[args.dataset]) + # Extract useful text + sentence1_key, sentence2_key = task_to_keys[args.dataset] + def tokenize_function_finetune(batch): - # FIXME This fails for GLUE MNLI - return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=args.max_length) + texts = ( + (batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key]) + ) + return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length) dataset = dataset.map(tokenize_function_finetune) dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) -- GitLab