diff --git a/load_data.py b/load_data.py
index 24311a1e44fd2a0f4044214718623c75c8672afb..6c7981c7a5ec65f52bdf85f83b625a2be139c6e4 100644
--- a/load_data.py
+++ b/load_data.py
@@ -38,14 +38,31 @@ def load_data_finetune(args, tokenizer):
         "glue_qqp": ("glue", "qqp"),
         "glue_rte": ("glue", "rte"),
         "glue_sts-b": ("glue", "stsb"),
+        "glue_wnli": ("glue", "wnli")
+    }
+    task_to_keys = {
+        "glue_cola": ("sentence", None),
+        "glue_mnli": ("premise", "hypothesis"),
+        "glue_mrpc": ("sentence1", "sentence2"),
+        "glue_qnli": ("question", "sentence"),
+        "glue_qqp": ("question1", "question2"),
+        "glue_rte": ("sentence1", "sentence2"),
+        "glue_sst-2": ("sentence", None),
+        "glue_sts-b": ("sentence1", "sentence2"),
+        "glue_wnli": ("sentence1", "sentence2"),
     }
     if args.dataset not in arg_map:
         raise ValueError(f"Data set '{args.dataset}' not supported for mode 'finetuning'!")
     dataset = load_dataset(*arg_map[args.dataset])
 
+    # Extract useful text
+    sentence1_key, sentence2_key = task_to_keys[args.dataset]
+
     def tokenize_function_finetune(batch):
-        # FIXME This fails for GLUE MNLI
-        return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=args.max_length)
+        texts = (
+            (batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key])
+        )
+        return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
     
     dataset = dataset.map(tokenize_function_finetune)
     dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])