diff --git a/load_data.py b/load_data.py index 34eef833f857537762cec9b33e4261f40f59a3bd..241f2a73b294c145fe2f02f0db78d1d01059186a 100644 --- a/load_data.py +++ b/load_data.py @@ -82,7 +82,9 @@ def load_data_finetune(args, tokenizer): texts = ( (batch[sentence1_key],) if sentence2_key is None else (batch[sentence1_key], batch[sentence2_key]) ) - return tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length) + result = tokenizer(*texts, truncation=True, padding="max_length", max_length=args.max_length) + result["labels"] = batch["label"] + return result dataset = dataset.map(tokenize_function_finetune) dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) diff --git a/load_models.py b/load_models.py index 4ae519380916818030caa017d450f0d7e456d725..42975c32114b9ea25d242f4d7d89f32121f2a936 100644 --- a/load_models.py +++ b/load_models.py @@ -1,6 +1,8 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer import torch +from load_data import arg_map + def get_model(args): """ Creates model for Pretraining or Fine-Tuning """ @@ -24,18 +26,25 @@ def get_model(args): elif args.mode == "finetuning": if args.model == "roberta": + num_labels = 1 if args.dataset == "glue_stsb" else 2 # TODO might be wrong + config = AutoConfig.from_pretrained( + args.model_name_or_path, + num_labels=num_labels, + finetuning_task=arg_map[args.dataset][1], + ) if args.dtype == "bf16": - model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2, - torch_dtype=torch.bfloat16) + model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels, + torch_dtype=torch.bfloat16, config=config) else: - model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2) + model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels, + config=config) tokenizer = AutoTokenizer.from_pretrained("roberta-base") - elif args.model == "gpt2": - model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2) - tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") - - tokenizer.pad_token = tokenizer.eos_token - model.config.pad_token_id = tokenizer.pad_token_id + # elif args.model == "gpt2": + # model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2) + # tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") + # + # tokenizer.pad_token = tokenizer.eos_token + # model.config.pad_token_id = tokenizer.pad_token_id else: raise ValueError("Invalid model name. Choose 'roberta' or 'gpt2'") else: