diff --git a/load_data.py b/load_data.py
index 34eef833f857537762cec9b33e4261f40f59a3bd..241f2a73b294c145fe2f02f0db78d1d01059186a 100644
--- a/load_data.py
+++ b/load_data.py
@@ -82,7 +82,9 @@ def load_data_finetune(args, tokenizer):
         texts = (
             (batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key])
         )
-        return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
+        result = tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length)
+        result["labels"] = batch["label"]
+        return result
     
     dataset = dataset.map(tokenize_function_finetune)
     dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
diff --git a/load_models.py b/load_models.py
index 4ae519380916818030caa017d450f0d7e456d725..42975c32114b9ea25d242f4d7d89f32121f2a936 100644
--- a/load_models.py
+++ b/load_models.py
@@ -1,6 +1,8 @@
 from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
 import torch
 
+from load_data import arg_map
+
 
 def get_model(args):
     """ Creates model for Pretraining or Fine-Tuning """
@@ -24,18 +26,25 @@ def get_model(args):
 
     elif args.mode == "finetuning":
         if args.model == "roberta":
+            num_labels = 1 if args.dataset == "glue_stsb" else 2  # TODO might be wrong
+            config = AutoConfig.from_pretrained(
+                args.model_name_or_path,
+                num_labels=num_labels,
+                finetuning_task=arg_map[args.dataset][1],
+            )
             if args.dtype == "bf16":
-                model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2,
-                                                                           torch_dtype=torch.bfloat16)
+                model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels,
+                                                                           torch_dtype=torch.bfloat16, config=config)
             else:
-                model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2)
+                model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels,
+                                                                           config=config)
             tokenizer = AutoTokenizer.from_pretrained("roberta-base")
-        elif args.model == "gpt2":
-            model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
-            tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
-
-            tokenizer.pad_token = tokenizer.eos_token
-            model.config.pad_token_id = tokenizer.pad_token_id
+        # elif args.model == "gpt2":
+        #     model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
+        #     tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
+        #
+        #     tokenizer.pad_token = tokenizer.eos_token
+        #     model.config.pad_token_id = tokenizer.pad_token_id
         else:
             raise ValueError("Invalid model name. Choose 'roberta' or 'gpt2'")
     else: