diff --git a/load_data.py b/load_data.py index 24311a1e44fd2a0f4044214718623c75c8672afb..6c7981c7a5ec65f52bdf85f83b625a2be139c6e4 100644 --- a/load_data.py +++ b/load_data.py @@ -38,14 +38,31 @@ def load_data_finetune(args, tokenizer): "glue_qqp": ("glue", "qqp"), "glue_rte": ("glue", "rte"), "glue_sts-b": ("glue", "stsb"), + "glue_wnli": ("glue", "wnli") + } + task_to_keys = { + "glue_cola": ("sentence", None), + "glue_mnli": ("premise", "hypothesis"), + "glue_mrpc": ("sentence1", "sentence2"), + "glue_qnli": ("question", "sentence"), + "glue_qqp": ("question1", "question2"), + "glue_rte": ("sentence1", "sentence2"), + "glue_sst-2": ("sentence", None), + "glue_sts-b": ("sentence1", "sentence2"), + "glue_wnli": ("sentence1", "sentence2"), } if args.dataset not in arg_map: raise ValueError(f"Data set '{args.dataset}' not supported for mode 'finetuning'!") dataset = load_dataset(*arg_map[args.dataset]) + # Extract useful text + sentence1_key, sentence2_key = task_to_keys[args.dataset] + def tokenize_function_finetune(batch): - # FIXME This fails for GLUE MNLI - return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=args.max_length) + texts = ( + (batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key]) + ) + return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length) dataset = dataset.map(tokenize_function_finetune) dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])