From a5a951273182c32904f5504c1a2fc15f495e4830 Mon Sep 17 00:00:00 2001
From: Riko Uphoff <riko.uphoff@student.uni-halle.de>
Date: Mon, 7 Apr 2025 17:33:04 +0200
Subject: [PATCH] Tried to fix the unexpected argument "label" during
 finetuning

---
 load_data.py   |  4 +++-
 load_models.py | 27 ++++++++++++++++++---------
 2 files changed, 21 insertions(+), 10 deletions(-)

diff --git a/load_data.py b/load_data.py
index 34eef83..241f2a7 100644
--- a/load_data.py
+++ b/load_data.py
@@ -82,7 +82,9 @@ def load_data_finetune(args, tokenizer):
         texts = (
             (batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key])
         )
-        return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
+        result = tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
+        result["labels"] = batch["label"]
+        return result
     
     dataset = dataset.map(tokenize_function_finetune)
     dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
diff --git a/load_models.py b/load_models.py
index 4ae5193..42975c3 100644
--- a/load_models.py
+++ b/load_models.py
@@ -1,6 +1,8 @@
 from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
 import torch
 
+from load_data import arg_map
+
 
 def get_model(args):
     """ Creates model for Pretraining or Fine-Tuning """
@@ -24,18 +26,25 @@ def get_model(args):
 
     elif args.mode == "finetuning":
         if args.model == "roberta":
+            num_labels = 1 if args.dataset == "glue_stsb" else 2  # TODO might be wrong
+            config = AutoConfig.from_pretrained(
+                args.model_name_or_path,
+                num_labels=num_labels,
+                finetuning_task=arg_map[args.dataset][1],
+            )
             if args.dtype == "bf16":
-                model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2,
-                                                                           torch_dtype=torch.bfloat16)
+                model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels,
+                                                                           torch_dtype=torch.bfloat16, config=config)
             else:
-                model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2)
+                model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels,
+                                                                           config=config)
             tokenizer = AutoTokenizer.from_pretrained("roberta-base")
-        elif args.model == "gpt2":
-            model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
-            tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
-
-            tokenizer.pad_token = tokenizer.eos_token
-            model.config.pad_token_id = tokenizer.pad_token_id
+        # elif args.model == "gpt2":
+        #     model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
+        #     tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
+        #
+        #     tokenizer.pad_token = tokenizer.eos_token
+        #     model.config.pad_token_id = tokenizer.pad_token_id
         else:
             raise ValueError("Invalid model name. Choose 'roberta' or 'gpt2'")
     else:
-- 
GitLab