From f60b182c0884d919503a47840401ed0ceb3fda1a Mon Sep 17 00:00:00 2001
From: Riko Uphoff <riko.uphoff@student.uni-halle.de>
Date: Fri, 4 Apr 2025 13:35:31 +0200
Subject: [PATCH] Fixed(?) load_data for glue tasks

---
 load_data.py | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/load_data.py b/load_data.py
index 24311a1..6c7981c 100644
--- a/load_data.py
+++ b/load_data.py
@@ -38,14 +38,31 @@ def load_data_finetune(args, tokenizer):
         "glue_qqp": ("glue", "qqp"),
         "glue_rte": ("glue", "rte"),
         "glue_sts-b": ("glue", "stsb"),
+        "glue_wnli": ("glue", "wnli")
+    }
+    task_to_keys = {
+        "glue_cola": ("sentence", None),
+        "glue_mnli": ("premise", "hypothesis"),
+        "glue_mrpc": ("sentence1", "sentence2"),
+        "glue_qnli": ("question", "sentence"),
+        "glue_qqp": ("question1", "question2"),
+        "glue_rte": ("sentence1", "sentence2"),
+        "glue_sst-2": ("sentence", None),
+        "glue_sts-b": ("sentence1", "sentence2"),
+        "glue_wnli": ("sentence1", "sentence2"),
     }
     if args.dataset not in arg_map:
         raise ValueError(f"Data set '{args.dataset}' not supported for mode 'finetuning'!")
     dataset = load_dataset(*arg_map[args.dataset])
 
+    # Extract useful text
+    sentence1_key, sentence2_key = task_to_keys[args.dataset]
+
     def tokenize_function_finetune(batch):
-        # FIXME This fails for GLUE MNLI
-        return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=args.max_length)
+        texts = (
+            (batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key])
+        )
+        return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
     
     dataset = dataset.map(tokenize_function_finetune)
     dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
-- 
GitLab