diff --git a/src/__pycache__/chat_model.cpython-311.pyc b/src/__pycache__/chat_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..764d7a49cfb63dc62d3cf97cf1bf069e9d4db011 Binary files /dev/null and b/src/__pycache__/chat_model.cpython-311.pyc differ diff --git a/src/__pycache__/chat_model.cpython-38.pyc b/src/__pycache__/chat_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5aab31cc3c0b2c6ce3cd7ff98c556c2e973b0fc Binary files /dev/null and b/src/__pycache__/chat_model.cpython-38.pyc differ diff --git a/src/__pycache__/config.cpython-311.pyc b/src/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d125297421cb0f7952328d784e7d0400b3be776d Binary files /dev/null and b/src/__pycache__/config.cpython-311.pyc differ diff --git a/src/__pycache__/config.cpython-38.pyc b/src/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4df068d804bbb079a78bb531b4dd463bd61b8f32 Binary files /dev/null and b/src/__pycache__/config.cpython-38.pyc differ diff --git a/src/__pycache__/config_parser.cpython-311.pyc b/src/__pycache__/config_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7455b902a8791a4ab81308af64bbe9c37bff1b29 Binary files /dev/null and b/src/__pycache__/config_parser.cpython-311.pyc differ diff --git a/src/__pycache__/config_parser.cpython-38.pyc b/src/__pycache__/config_parser.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4f1d617238228fdb418833c389f34e2e519633d Binary files /dev/null and b/src/__pycache__/config_parser.cpython-38.pyc differ diff --git a/src/__pycache__/data_args.cpython-311.pyc b/src/__pycache__/data_args.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..994546d53bdbc0e1bf6f8c3e1428decf9144a4bd Binary files /dev/null and b/src/__pycache__/data_args.cpython-311.pyc differ diff --git a/src/__pycache__/data_args.cpython-38.pyc b/src/__pycache__/data_args.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ca8d054fe8fff5956093c94394b8b82b39177e9 Binary files /dev/null and b/src/__pycache__/data_args.cpython-38.pyc differ diff --git a/src/__pycache__/data_utils.cpython-311.pyc b/src/__pycache__/data_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f3d08c3e07cb8bcfe4c97c690d1a1daedfd2c22 Binary files /dev/null and b/src/__pycache__/data_utils.cpython-311.pyc differ diff --git a/src/__pycache__/export.cpython-311.pyc b/src/__pycache__/export.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e2946a1d7e1d6f4f9be7eb60215e4d46e3313fe Binary files /dev/null and b/src/__pycache__/export.cpython-311.pyc differ diff --git a/src/__pycache__/load.cpython-311.pyc b/src/__pycache__/load.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32bfaacf3260b8ed660090214b9010898d43f80c Binary files /dev/null and b/src/__pycache__/load.cpython-311.pyc differ diff --git a/src/__pycache__/load.cpython-38.pyc b/src/__pycache__/load.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9b67eccfcb396562c79f2b9476105f02d7a091e Binary files /dev/null and b/src/__pycache__/load.cpython-38.pyc differ diff --git a/src/__pycache__/loggings.cpython-311.pyc b/src/__pycache__/loggings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61b3de8b7140675c666684f1d11c1f51707c8a0 Binary files /dev/null and b/src/__pycache__/loggings.cpython-311.pyc differ diff --git a/src/__pycache__/loggings.cpython-38.pyc b/src/__pycache__/loggings.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca6ea0e3c239ce8f2c79b747d3fc7da6223b3db7 Binary files /dev/null and b/src/__pycache__/loggings.cpython-38.pyc differ diff --git a/src/__pycache__/model_args.cpython-311.pyc b/src/__pycache__/model_args.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a26a1b60ec68cb256dcca764f4d44fa7fba4ddc2 Binary files /dev/null and b/src/__pycache__/model_args.cpython-311.pyc differ diff --git a/src/__pycache__/model_args.cpython-38.pyc b/src/__pycache__/model_args.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d90def21ee0898713d76861ba13ea6a3499ebc9 Binary files /dev/null and b/src/__pycache__/model_args.cpython-38.pyc differ diff --git a/src/__pycache__/model_trainer.cpython-311.pyc b/src/__pycache__/model_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4bc316ef4f15d71afb99922ab1346a943bc9928 Binary files /dev/null and b/src/__pycache__/model_trainer.cpython-311.pyc differ diff --git a/src/__pycache__/model_trainer.cpython-38.pyc b/src/__pycache__/model_trainer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6788ca6d0fd24e16dbd584bfffb2a8279a68a02e Binary files /dev/null and b/src/__pycache__/model_trainer.cpython-38.pyc differ diff --git a/src/__pycache__/predict.cpython-311.pyc b/src/__pycache__/predict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de7d66b72097f53bdcd681e05132b7ffc177edee Binary files /dev/null and b/src/__pycache__/predict.cpython-311.pyc differ diff --git a/src/__pycache__/predict.cpython-38.pyc b/src/__pycache__/predict.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bacdadaf6da97c66156ed0b3d953b438b8433bff Binary files /dev/null and b/src/__pycache__/predict.cpython-38.pyc differ diff --git a/src/__pycache__/sft_train.cpython-311.pyc b/src/__pycache__/sft_train.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..375ae6837294188886ea1e902fcd01ebaade88e5 Binary files /dev/null and b/src/__pycache__/sft_train.cpython-311.pyc differ diff --git a/src/__pycache__/sql_data_process.cpython-311.pyc b/src/__pycache__/sql_data_process.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..875b2274cfdbeb5c10dce5325b35261def2f571c Binary files /dev/null and b/src/__pycache__/sql_data_process.cpython-311.pyc differ diff --git a/src/chat_model.py b/src/chat_model.py new file mode 100644 index 0000000000000000000000000000000000000000..87b41c089995303f3e046dd931a5127e2bed6efd --- /dev/null +++ b/src/chat_model.py @@ -0,0 +1,561 @@ +import torch +import json +from typing import Any, Union, Dict, Generator, List, Optional, Tuple +from threading import Thread +from transformers import GenerationConfig, TextIteratorStreamer + +from .config_parser import get_infer_args +from .load import dispatch_model, load_model_and_tokenizer +from .model_trainer import get_logits_processor +from .data_args import ( + DEFAULT_PROMPT_DICT, + ALPACA_PROMPT_DICT, + SQL_PROMPT_DICT, + Template, + Llama2Template, +) +from .loggings import get_logger + + +logger = get_logger(__name__) + +templates: Dict[str, Template] = {} + +def get_template_and_fix_tokenizer( + name: str, tokenizer: "PreTrainedTokenizer" +) -> Template: + template = templates.get(name, None) + assert template is not None, "Template {} does not exist.".format(name) + + additional_special_tokens = template.stop_words + + if tokenizer.eos_token_id is None: + tokenizer.eos_token = "<|endoftext|>" + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + + if tokenizer.pad_token_id is None: + if tokenizer.unk_token_id is not None: + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) + + if name is None: + return None + + tokenizer.add_special_tokens( + dict(additional_special_tokens=additional_special_tokens), + replace_additional_special_tokens=False, + ) + return template + + +class ChatModel: + def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + ( + model_args, + self.data_args, + finetuning_args, + self.generating_args, + ) = get_infer_args(args) + self.model, self.tokenizer = load_model_and_tokenizer( + model_args, finetuning_args + ) + self.tokenizer.padding_side = "left" + self.model = dispatch_model(self.model) + self.template = get_template_and_fix_tokenizer( + self.data_args.template, self.tokenizer + ) + self.system_prompt = self.data_args.system_prompt + + def process_args( + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + **input_kwargs + ) -> Tuple[Dict[str, Any], int]: + system = system or self.system_prompt + + prompt, _ = self.template.encode_oneturn( + tokenizer=self.tokenizer, + query=query, + resp="", + history=history, + system=system, + ) + input_ids = torch.tensor([prompt], device=self.model.device) + prompt_length = len(input_ids[0]) + + do_sample = input_kwargs.pop("do_sample", None) + temperature = input_kwargs.pop("temperature", None) + top_p = input_kwargs.pop("top_p", None) + top_k = input_kwargs.pop("top_k", None) + repetition_penalty = input_kwargs.pop("repetition_penalty", None) + max_length = input_kwargs.pop("max_length", None) + max_new_tokens = input_kwargs.pop("max_new_tokens", None) + + generating_args = self.generating_args.to_dict() + generating_args.update( + dict( + do_sample=do_sample + if do_sample is not None + else generating_args["do_sample"], + temperature=temperature or generating_args["temperature"], + top_p=top_p or generating_args["top_p"], + top_k=top_k or generating_args["top_k"], + repetition_penalty=repetition_penalty + or generating_args["repetition_penalty"], + eos_token_id=[self.tokenizer.eos_token_id] + + self.tokenizer.additional_special_tokens_ids, + pad_token_id=self.tokenizer.pad_token_id, + ) + ) + + if max_length: + generating_args.pop("max_new_tokens", None) + generating_args["max_length"] = max_length + + if max_new_tokens: + generating_args.pop("max_length", None) + generating_args["max_new_tokens"] = max_new_tokens + + gen_kwargs = dict( + inputs=input_ids, + generation_config=GenerationConfig(**generating_args), + logits_processor=get_logits_processor(), + ) + + return gen_kwargs, prompt_length + + @torch.inference_mode() + def chat( + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + **input_kwargs + ) -> Tuple[str, Tuple[int, int]]: + gen_kwargs, prompt_length = self.process_args( + query, history, system, **input_kwargs + ) + generation_output = self.model.generate(**gen_kwargs) + outputs = generation_output.tolist()[0][prompt_length:] + response = self.tokenizer.decode(outputs, skip_special_tokens=True) + response_length = len(outputs) + return response, (prompt_length, response_length) + + @torch.inference_mode() + def stream_chat( + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + **input_kwargs + ) -> Generator[str, None, None]: + gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) + streamer = TextIteratorStreamer( + self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True + ) + gen_kwargs["streamer"] = streamer + + thread = Thread(target=self.model.generate, kwargs=gen_kwargs) + thread.start() + + yield from streamer + + +def register_template( + name: str, + prefix: List[Union[str, Dict[str, str]]], + prompt: List[Union[str, Dict[str, str]]], + system: str, + sep: List[Union[str, Dict[str, str]]], + stop_words: Optional[List[str]] = [], + use_history: Optional[bool] = True, +) -> None: + template_class = Llama2Template if "llama2" in name else Template + templates[name] = template_class( + prefix=prefix, + prompt=prompt, + system=system, + sep=sep, + stop_words=stop_words, + use_history=use_history, + ) + + +r""" +Supports language model inference without histories. +""" +register_template( + name="vanilla", + prefix=[], + prompt=["{{query}}"], + system="", + sep=[], + use_history=False, +) + +r""" +Supports language model for mistral sqlcoder-7b +""" +register_template( + name="mistral", + prefix=["{{system}}"], + prompt=["[INST] {{query}} [/INST]"], + system="", + sep=[], +) + + +r""" +Default template. +""" +register_template( + name="default", + prefix=["{{system}}"], + prompt=["Human: {{query}}\nAssistant: "], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=["\n"], +) + + +r""" +Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf + https://huggingface.co/meta-llama/Llama-2-13b-chat-hf + https://huggingface.co/meta-llama/Llama-2-70b-chat-hf +""" +register_template( + name="llama2", + prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"], + prompt=["[INST] {{query}} [/INST] "], + system=( + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + sep=[], +) + +register_template( + name="llama3", + prefix=["<|start_header_id|>system<|end_header_id|>\n\n{{system}}<|eot_id|>\n"], + prompt=["<|start_header_id|>user<|end_header_id|>\n\n{{query}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"], + system=( + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + sep=[], +) + +r""" +Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 + https://huggingface.co/ziqingyang/chinese-alpaca-2-7b +""" +register_template( + name="llama2_zh", + prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"], + prompt=["[INST] {{query}} [/INST] "], + system="You are a helpful assistant. 你是一个乐于助人的助手。", + sep=[], +) + + +r""" +Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff + https://github.com/ymcui/Chinese-LLaMA-Alpaca +""" +register_template( + name="alpaca", + prefix=["{{system}}"], + prompt=["### Instruction:\n{{query}}\n\n### Response:\n"], + system=( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ), + sep=["\n\n"], +) + + +r""" +Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 + https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 +""" +register_template( + name="vicuna", + prefix=["{{system}}"], + prompt=["USER: {{query}} ASSISTANT: "], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=[], +) + + +r""" +Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B +""" +register_template( + name="belle", + prefix=["{{system}}"], + prompt=["Human: {{query}}\n\nBelle: "], + system="", + sep=["\n\n"], +) + + +r""" +Supports: https://github.com/CVI-SZU/Linly +""" +register_template( + name="linly", + prefix=["{{system}}"], + prompt=["User: {{query}}\nBot: "], + system="", + sep=["\n"], +) + + +r""" +Supports: https://github.com/Neutralzz/BiLLa +""" +register_template( + name="billa", + prefix=["{{system}}"], + prompt=["Human: {{query}}\nAssistant: "], + system="", + sep=["\n"], +) + + +r""" +Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 +""" +register_template( + name="ziya", + prefix=["{{system}}"], + prompt=[{"token": "<human>"}, ":{{query}}\n", {"token": "<bot>"}, ":"], + system="", + sep=["\n"], +) + + +r""" +Supports: https://huggingface.co/qhduan/aquilachat-7b +""" +register_template( + name="aquila", + prefix=["{{system}}"], + prompt=["Human: {{query}}###Assistant: "], + system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), + sep=["###"], +) + + +r""" +Supports: https://huggingface.co/internlm/internlm-chat-7b +""" +register_template( + name="intern", + prefix=["{{system}}"], + prompt=["<|User|>:{{query}}", {"token": "<eoh>"}, "\n<|Bot|>:"], + system="", + sep=["\n"], + stop_words=["</s>", "<eoa>"], # internlm cannot replace eos token +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat +Used for training and inference of the fine-tuned models. +""" +register_template( + name="baichuan", + prefix=["{{system}}"], + prompt=[ + {"token": "<reserved_102>"}, # user token + "{{query}}", + {"token": "<reserved_103>"}, # assistant token + ], + system="", + sep=[], + stop_words=[], +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat +Used for inference of the original model. +""" +register_template( + name="baichuan_eval", + prefix=["{{system}}", {"token": "<reserved_102>"}], # user token + prompt=["{{query}}", {"token": "<reserved_103>"}], # assistant token + system="", + sep=[], + stop_words=["<reserved_102>"], # user token +) + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat + https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat +Used for training and inference of the fine-tuned models. +""" +register_template( + name="baichuan2", + prefix=["{{system}}"], + prompt=[ + {"token": "<reserved_106>"}, # user token + "{{query}}", + {"token": "<reserved_107>"}, # assistant token + ], + system="", + sep=[], +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat + https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat +Used for inference of the original model. +""" +register_template( + name="baichuan2_eval", + prefix=["{{system}}", {"token": "<reserved_106>"}], # user token + prompt=["{{query}}", {"token": "<reserved_107>"}], # assistant token + system="", + sep=[], + stop_words=["<reserved_106>"], # user token +) + + +r""" +Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha + https://huggingface.co/HuggingFaceH4/starchat-beta + +""" +register_template( + name="starchat", + prefix=[{"token": "<|system|>"}, "\n{{system}}", {"token": "<|end|>"}], + prompt=[ + {"token": "<|user|>"}, + "\n{{query}}", + {"token": "<|end|>"}, + "\n", + {"token": "<|assistant|>"}, + ], + system="", + sep=["\n"], + stop_words=["<|end|>"], +) + + +r""" +Supports: https://huggingface.co/Qwen/Qwen-7B-Chat +""" +register_template( + name="chatml", + prefix=[{"token": "<|im_start|>"}, "system\n{{system}}", {"token": "<|im_end|>"}], + prompt=[ + {"token": "<|im_start|>"}, + "user\n{{query}}", + {"token": "<|im_end|>"}, + "\n", + {"token": "<|im_start|>"}, + "assistant\n", + ], + system="You are a helpful assistant.", + sep=["\n"], + stop_words=["<|im_end|>"], +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm2-6b +""" +register_template( + name="chatglm2", + prefix=[{"token": "[gMASK]"}, {"token": "sop"}, "{{system}}"], + prompt=["[Round {{idx}}]\n\n问:{{query}}\n\n答:"], + system="", + sep=["\n\n"], +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm3-6b +""" +register_template( + name="chatglm3", + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + {"token": "<|system|>"}, + "\n", + "{{system}}", + ], + prompt=[ + {"token": "<|user|>"}, + "\n", + "{{query}}", + {"token": "<|assistant|>"}, + "\n", # add an extra newline to avoid error in ChatGLM's process_response method + ], + system=( + "You are ChatGLM3, a large language model trained by Zhipu.AI. " + "Follow the user's instructions carefully. Respond using markdown." + ), + sep=[], + stop_words=["<|user|>", "<|observation|>"], +) + +register_template( + name="chatglm3_raw", # the raw template for tool tuning + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + {"token": "<|system|>"}, + "\n", + "{{system}}", + ], + prompt=[{"token": "<|user|>"}, "\n", "{{query}}", {"token": "<|assistant|>"}], + system=( + "You are ChatGLM3, a large language model trained by Zhipu.AI. " + "Follow the user's instructions carefully. Respond using markdown." + ), + sep=[], + stop_words=["<|user|>", "<|observation|>"], +) + + +r""" +Supports: https://huggingface.co/xverse/XVERSE-13B-Chat +""" +register_template( + name="xverse", + prefix=["{{system}}"], + prompt=["Human: {{query}}\n\nAssistant: "], + system="", + sep=[], +) + + diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..64df5416ced8a9c8174a5560b2dd3167e43f5c53 --- /dev/null +++ b/src/config.py @@ -0,0 +1,225 @@ + +import os + +### path config +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# ROOT_PATH = "/root/autodl-tmp" +# MODELS_PARENT_PATH = "/home/model_files/codellama/" +# DEFAULT_FT_MODEL_NAME = "CodeLlama-7b-Instruct-hf" +MODELS_PARENT_PATH = "/home/ambcj/BA/text2sql-sft/models" +DEFAULT_FT_MODEL_NAME = "Baichuan2-13B-Chat" +MODEL_PATH = os.path.join(MODELS_PARENT_PATH, DEFAULT_FT_MODEL_NAME) + +# MODEL_PATH = os.path.join(ROOT_PATH, "model") +ADAPTER_PATH = os.path.join(ROOT_PATH, "text2sql-sft/adapter") +MERGED_MODELS = os.path.join(ROOT_PATH, "text2sql-sft/merged_models") + +# DATA_PATH = "/root/autodl-tmp/data/spider/pre_processed_data" +# OUT_DIR= "/root/autodl-tmp/codellama" + +DATA_PATH = os.path.join(ROOT_PATH, "text2sql-sft/data") +PREDICTED_DATA_PATH = os.path.join(ROOT_PATH, "text2sql-sft/data/eval_data/dev_sql.json") +PREDICTED_OUT_FILENAME = "pred_sql.sql" +# OUT_DIR = os.path.join(DATA_PATH, "out_pred") +OUT_DIR = os.path.join(ROOT_PATH, "text2sql-sft/output/") + +## model constants +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "</s>" +DEFAULT_BOS_TOKEN = "<s>" +DEFAULT_UNK_TOKEN = "<unk>" + + +LOG_FILE_NAME = "trainer_log.jsonl" + +# head_state_dict,model save name +VALUE_HEAD_FILE_NAME = "value_head.bin" + +# output ,finetuning_args save_to_json name +FINETUNING_ARGS_NAME = "finetuning_args.json" + +# when prepare_model_for_training ,layer_norm_names +LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] +EXT2TYPE = {"csv": "csv", "json": "json", "jsonl": "json", "txt": "text"} + +# text2sql dataset information for processing sql data +# TODO: BIRD \ WiKiSQL \ ... +SQL_DATA_INFO = [ + { + "data_source": "spider", + "train_file": ["train_spider.json", "train_others.json"], + "dev_file": ["dev.json"], + "train_tables_file": "tables.json", + "dev_tables_file": "tables.json", + "db_id_name": "db_id", + "output_name": "query", + "is_multiple_turn": False, + } + # { + # "data_source": "bird", + # "train_file": ["train/train.json"], + # "dev_file": ["dev/dev.json"], + # "train_tables_file": "train/train_tables.json", + # "dev_tables_file": "dev/dev_tables.json", + # "db_id_name": "db_id", + # "output_name": "SQL", + # "is_multiple_turn": False, + # } + # , + # { + # "data_source": "chase", + # "train_file": ["Chase/chase_train.json"], + # "dev_file": ["Chase/chase_dev.json"], + # "tables_file": "Chase/chase_tables.json", + # "db_id_name": "database_id", + # "is_multiple_turn": True, + # } + # , + # { + # "data_source": "cosql_dataset", + # "train_file": ["sql_state_tracking/cosql_train.json"], + # "dev_file": ["sql_state_tracking/cosql_dev.json"], + # "tables_file": "tables.json", + # "db_id_name": "database_id", + # "is_multiple_turn": True, + # } + # , + # { + # { + # "data_source": "sparc", + # "train_file": ["train.json"], + # "train_tables_file": "tables.json", + # "dev_tables_file": "tables.json", + # "dev_file": ["dev.json"], + # "db_id_name": "database_id", + # "is_multiple_turn": True, + # "output_name": "query", + # } +] +CODE_REPRESENTATION_PROMPT = """\ +/* Given the following database schema: */\n{}\n\n""" +CR_INPUT_PROMPT = """\ +/* Answer the following: {}\n*/ +""" +INSTRUCTION_PROMPT = """\ +I want you to act as a SQL terminal in front of an example database, \ +you need only to return the sql command to me.Below is an instruction that describes a task, \ +Write a response that appropriately completes the request.\n" +##Instruction:\n{}\n""" +INPUT_PROMPT = "###Input:\n{}\n\n###Response:" + +ALPACA_PROMPT = """\ +Below is an instruction that describes a task , paired with an input that provides further context . Write a response that appropriately completes the request.\n \ +\n###Instruction:\nWrite a sql to answer the question "{}"\n\n""" +ALPACA_INPUT_PROMPT = "###Input:\n{}\n\n###Response:\n" + + +INSTRUCTION_ONE_SHOT_PROMPT = """\ +I want you to act as a SQL terminal in front of an example database. \ +You need only to return the sql command to me. \ +First, I will show you few examples of an instruction followed by the correct SQL response. \ +Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\ +\n### Example1 Instruction: +The database contains tables such as employee, salary, and position. \ +Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \ +Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \ +Table position has columns such as position_id, title, and department. position_id is the primary key. \ +The employee_id of salary is the foreign key of employee_id of employee. \ +The position_id of employee is the foreign key of position_id of position.\ +\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +\n###New Instruction:\n{}\n""" + +# EXAMPLES =[EXAMPLE1, EXAMPLE1] + +# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +# \n###New Instruction:\n{}\n" + +### test-------------------- + + +# METHODS = ["full", "freeze", "lora"] + +# STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"] + +# DATASET_STAGE_MAP = { +# "SFT": "sft", +# "Pre-Training": "pt", +# "Reward Modeling": "rm", +# "PPO": "sft", +# "DPO": "rm", +# } + +# SUPPORTED_MODELS = { +# "LLaMA-7B": "huggyllama/llama-7b", +# "LLaMA-13B": "huggyllama/llama-13b", +# "LLaMA-30B": "huggyllama/llama-30b", +# "LLaMA-65B": "huggyllama/llama-65b", +# "LLaMA2-7B": "meta-llama/Llama-2-7b-hf", +# "LLaMA2-13B": "meta-llama/Llama-2-13b-hf", +# "LLaMA2-70B": "meta-llama/Llama-2-70b-hf", +# "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", +# "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", +# "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", +# "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", +# "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", +# "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", +# "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b", +# "BLOOM-560M": "bigscience/bloom-560m", +# "BLOOM-3B": "bigscience/bloom-3b", +# "BLOOM-7B1": "bigscience/bloom-7b1", +# "BLOOMZ-560M": "bigscience/bloomz-560m", +# "BLOOMZ-3B": "bigscience/bloomz-3b", +# "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", +# "Falcon-7B": "tiiuae/falcon-7b", +# "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", +# "Falcon-40B": "tiiuae/falcon-40b", +# "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", +# "Baichuan-7B": "baichuan-inc/Baichuan-7B", +# "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", +# "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", +# "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base", +# "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base", +# "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat", +# "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat", +# "InternLM-7B": "internlm/internlm-7b", +# "InternLM-7B-Chat": "internlm/internlm-chat-7b", +# "Qwen-7B": "Qwen/Qwen-7B", +# "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", +# "XVERSE-13B": "xverse/XVERSE-13B", +# "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b", +# "ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base", +# "ChatGLM3-6B-Chat": "THUDM/chatglm3-6b" +# } + +# DEFAULT_MODULE = { +# "LLaMA": "q_proj,v_proj", +# "LLaMA2": "q_proj,v_proj", +# "ChineseLLaMA2": "q_proj,v_proj", +# "BLOOM": "query_key_value", +# "BLOOMZ": "query_key_value", +# "Falcon": "query_key_value", +# "Baichuan": "W_pack", +# "Baichuan2": "W_pack", +# "InternLM": "q_proj,v_proj", +# "Qwen": "c_attn", +# "XVERSE": "q_proj,v_proj", +# "ChatGLM2": "query_key_value", +# "ChatGLM3": "query_key_value", + +# } + +# DEFAULT_TEMPLATE = { +# "LLaMA2": "llama2", +# "ChineseLLaMA2": "llama2_zh", +# "Baichuan": "baichuan", +# "Baichuan2": "baichuan2", +# "InternLM": "intern", +# "Qwen": "chatml", +# "ChatGLM2": "chatglm2", +# "ChatGLM3": "chatglm3", + +# } diff --git a/src/config_parser.py b/src/config_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef8d04bdfcb1df8db922613b405a1175d28b94d --- /dev/null +++ b/src/config_parser.py @@ -0,0 +1,258 @@ +import os +import sys +import torch +import transformers +import datasets +from transformers.trainer import WEIGHTS_NAME +from transformers.modeling_utils import load_sharded_checkpoint +from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers.trainer_utils import get_last_checkpoint +from typing import Any, Dict, Optional, Tuple +from .loggings import get_logger +from .model_args import ( + ModelArguments, + FinetuningArguments, + GeneratingArguments, +) +from .data_args import DataArguments + + +logger = get_logger(__name__) + + +def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: + state_dict: Dict[str, torch.Tensor] = model.state_dict() + filtered_state_dict = {} + + for k, v in model.named_parameters(): + if v.requires_grad: + filtered_state_dict[k] = state_dict[k].cpu().clone().detach() + + return filtered_state_dict + + +def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: + weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) + if os.path.exists(weights_file): + model_state_dict = torch.load(weights_file, map_location="cpu") + model.load_state_dict(model_state_dict, strict=False) # skip missing keys + elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)): + load_sharded_checkpoint(model, checkpoint_dir, strict=False) + else: + logger.warning( + "Provided path ({}) does not contain pre-trained weights.".format( + checkpoint_dir + ) + ) + return False + return True + + +def _parse_args( + parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None +) -> Tuple[Any]: + if args is not None: + return parser.parse_dict(args) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return parser.parse_json_file(os.path.abspath(sys.argv[1])) + else: + return parser.parse_args_into_dataclasses() + + +def parse_train_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, +]: + parser = HfArgumentParser( + ( + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + ) + ) + return _parse_args(parser, args) + + +def parse_infer_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: + parser = HfArgumentParser( + (ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments) + ) + return _parse_args(parser, args) + + +def get_train_args( + args: Optional[Dict[str, Any]] = None, data_args_init: bool = True +) -> Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, +]: + ( + model_args, + data_args, + training_args, + finetuning_args, + generating_args, + ) = parse_train_args(args) + + # Setup logging + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) + if data_args_init: + data_args.init_for_training() + + if training_args.max_steps == -1 and data_args.streaming: + raise ValueError("Please specify `max_steps` in streaming mode.") + + if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming: + raise ValueError("Streaming mode should have an integer val size.") + + if training_args.do_train and training_args.predict_with_generate: + raise ValueError( + "`predict_with_generate` cannot be set as True while training." + ) + + if ( + training_args.do_train + and finetuning_args.finetuning_type == "lora" + and finetuning_args.lora_target is None + ): + raise ValueError("Please specify `lora_target` in LoRA training.") + + if ( + model_args.quantization_bit is not None + and finetuning_args.finetuning_type != "lora" + ): + raise ValueError("Quantization is only compatible with the LoRA method.") + + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora": + if len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + elif ( + model_args.quantization_bit is not None + and len(model_args.checkpoint_dir) != 1 + ): + raise ValueError("Quantized model only accepts a single checkpoint.") + + if model_args.quantization_bit is not None and (not training_args.do_train): + logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + + if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): + logger.warning("We recommend enable mixed precision training.") + + # postprocess data_args + if data_args.max_samples is not None and data_args.streaming: + logger.warning( + "`max_samples` is incompatible with `streaming`. Disabling max_samples." + ) + data_args.max_samples = None + + # postprocess training_args + if ( + training_args.local_rank != -1 + and training_args.ddp_find_unused_parameters is None + and finetuning_args.finetuning_type == "lora" + ): + logger.warning( + "`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training." + ) + training_args_dict = training_args.to_dict() + training_args_dict.update(dict(ddp_find_unused_parameters=False)) + training_args = Seq2SeqTrainingArguments(**training_args_dict) + + if ( + training_args.resume_from_checkpoint is None + and training_args.do_train + and os.path.isdir(training_args.output_dir) + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + "Output directory already exists and is not empty. Use `overwrite_output_dir`." + ) + + if last_checkpoint is not None: + training_args_dict = training_args.to_dict() + training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint)) + training_args = Seq2SeqTrainingArguments(**training_args_dict) + logger.info( + "Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid." + ) + + # postprocess model_args + if training_args.bf16: + if not torch.cuda.is_bf16_supported(): + raise ValueError("Current device does not support bf16 training.") + model_args.compute_dtype = torch.bfloat16 + else: + model_args.compute_dtype = torch.float16 + + model_args.model_max_length = ( + data_args.max_source_length + data_args.max_target_length + ) + + # Log on each process the small summary: + logger.info( + "Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format( + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + str(model_args.compute_dtype), + ) + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + transformers.set_seed(training_args.seed) + + return model_args, data_args, training_args, finetuning_args, generating_args + + +def get_infer_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: + model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) + + if ( + model_args.quantization_bit is not None + and finetuning_args.finetuning_type != "lora" + ): + raise ValueError("Quantization is only compatible with the LoRA method.") + + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora": + if len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + elif ( + model_args.quantization_bit is not None + and len(model_args.checkpoint_dir) != 1 + ): + raise ValueError("Quantized model only accepts a single checkpoint.") + + return model_args, data_args, finetuning_args, generating_args diff --git a/src/data_args.py b/src/data_args.py new file mode 100644 index 0000000000000000000000000000000000000000..55fc306d25cf9c29e12b73bd52bce3045b7819ba --- /dev/null +++ b/src/data_args.py @@ -0,0 +1,417 @@ + +import os +import json +import tiktoken +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + +DEFAULT_PROMPT_DICT = { + "prompt_input": ("{instruction}\n\n{input}\n\n"), + "prompt_no_input": ("{instruction}\n\n"), +} + + +CR_PROMPT_DICT = { + "prompt_input": ( + "/* Given the following database schema : */\n " + "{instruction}\n\n/* Answer the following: {input}\n*/\nSELECT" + ), + "prompt_no_input": ( + "/* Given the following database schema : */\n " + "{instruction}\n\nSELECT" + ), +} + +ALPACA_PROMPT_DICT = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: " + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response: " + ), +} + +SQL_PROMPT_DICT = { + "prompt_input": ( + "I want you to act as a SQL terminal in front of an example database, \ + you need only to return the sql command to me.Below is an instruction that describes a task, \ + Write a response that appropriately completes the request.\n" + "##Instruction:\n{instruction}\n###Input:\n{input}\n\n###Response:" + ), + "prompt_no_input": ( + "I want you to act as a SQL terminal in front of an example database, \ + you need only to return the sql command to me.Below is an instruction that describes a task, \ + Write a response that appropriately completes the request.\n" + "####Instruction:\n{instruction}\n\###Response: " + ), +} + + +@dataclass +class DatasetAttr: + load_from: str + dataset_name: Optional[str] = None + dataset_sha1: Optional[str] = None + system_prompt: Optional[str] = None + stage: Optional[str] = None + + def __repr__(self) -> str: + return self.dataset_name + + def __post_init__(self): + self.prompt = "instruction" + self.query = "input" + self.response = "output" + self.history = None + + +@dataclass +class DataArguments: + r""" + Arguments pertaining to what data we are going to input our model for training and evaluation. + """ + template: str = field( + metadata={ + "help": "Which template to use for constructing prompts in training and inference." + } + ) + dataset: Optional[str] = field( + default="example_text2sql", + metadata={ + "help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets." + }, + ) + dataset_dir: Optional[str] = field( + default="data/", + metadata={"help": "The name of the folder containing datasets."}, + ) + cutoff_len: Optional[int] = field( + default=1024, + metadata={"help": "The maximum length of the model inputs after tokenization."}, + ) + reserved_label_len: Optional[int] = field( + default=1, + metadata={"help": "The maximum length reserved for label after tokenization."}, + ) + split: Optional[str] = field( + default="train", + metadata={"help": "Which dataset split to use for training and evaluation."}, + ) + streaming: Optional[bool] = field( + default=False, metadata={"help": "Enable streaming mode."} + ) + buffer_size: Optional[int] = field( + default=16384, + metadata={ + "help": "Size of the buffer to randomly sample examples from in streaming mode." + }, + ) + mix_strategy: Optional[ + Literal["concat", "interleave_under", "interleave_over"] + ] = field(default="concat", metadata={"help": "Strategy to use in dataset mixing."}) + interleave_probs: Optional[str] = field( + default=None, + metadata={ + "help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets." + }, + ) + overwrite_cache: Optional[bool] = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=512, + metadata={ + "help": "The maximum total input sequence length after tokenization." + }, + ) + max_target_length: Optional[int] = field( + default=512, + metadata={ + "help": "The maximum total output sequence length after tokenization." + }, + ) + max_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes, truncate the number of examples for each dataset." + }, + ) + eval_num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`" + }, + ) + ignore_pad_token_for_loss: Optional[bool] = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + system_prompt: Optional[str] = field( + default=None, + metadata={ + "help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training." + }, + ) + val_size: Optional[float] = field( + default=0, + metadata={ + "help": "Size of the development set, should be an integer or a float in range `[0,1)`." + }, + ) + predicted_input_filename: Optional[str] = field( + default="dbgpt_hub/data/example_text2sql_dev.json", + metadata={"help": "Predict input filename to do pred "}, + ) + predicted_out_filename: Optional[str] = field( + default="pred_sql.sql", + metadata={"help": "Filename to save predicted outcomes"}, + ) + + def init_for_training(self): # support mixing multiple datasets + dataset_names = [ds.strip() for ds in self.dataset.split(",")] + with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: + dataset_info = json.load(f) + + prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] + prompt_list = prompt_list * (len(dataset_names) // len(prompt_list)) + assert len(prompt_list) == len( + dataset_names + ), "Number of system prompts should be equal to datasets or 1." + + if self.interleave_probs is not None: + self.interleave_probs = [ + float(prob.strip()) for prob in self.interleave_probs.split(",") + ] + + self.dataset_list: List[DatasetAttr] = [] + for i, name in enumerate(dataset_names): + if name not in dataset_info: + raise ValueError( + "Undefined dataset {} in dataset_info.json.".format(name) + ) + + if "hf_hub_url" in dataset_info[name]: + dataset_attr = DatasetAttr( + "hf_hub", + dataset_name=dataset_info[name]["hf_hub_url"], + stage=dataset_info[name].get("stage", None), + ) + elif "script_url" in dataset_info[name]: + dataset_attr = DatasetAttr( + "script", + dataset_name=dataset_info[name]["script_url"], + stage=dataset_info[name].get("stage", None), + ) + else: + dataset_attr = DatasetAttr( + "file", + dataset_name=dataset_info[name]["file_name"], + dataset_sha1=dataset_info[name].get("file_sha1", None), + stage=dataset_info[name].get("stage", None), + ) + + if "columns" in dataset_info[name]: + dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) + dataset_attr.query = dataset_info[name]["columns"].get("query", None) + dataset_attr.response = dataset_info[name]["columns"].get( + "response", None + ) + dataset_attr.history = dataset_info[name]["columns"].get( + "history", None + ) + + dataset_attr.system_prompt = prompt_list[i] + self.dataset_list.append(dataset_attr) + + +@dataclass +class Template: + prefix: List[Union[str, Dict[str, str]]] + prompt: List[Union[str, Dict[str, str]]] + system: str + sep: List[Union[str, Dict[str, str]]] + stop_words: List[str] + use_history: bool + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + ) -> Tuple[List[int], List[int]]: + r""" + Returns a single pair of token ids representing prompt and response respectively. + """ + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) + prompt_ids = [] + for query_ids, resp_ids in encoded_pairs[:-1]: + prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1] + return prompt_ids, answer_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + ) -> List[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) + return encoded_pairs + + def _format( + self, + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + ) -> Tuple[str, List[Tuple[str, str]]]: + r""" + Aligns inputs to the standard format. + """ + system = system or self.system # use system if provided + history = history if (history and self.use_history) else [] + history = history + [(query, resp)] + return system, history + + def _get_special_ids( + self, tokenizer: "PreTrainedTokenizer" + ) -> Tuple[List[int], List[int]]: + if tokenizer.bos_token_id is not None and getattr( + tokenizer, "add_bos_token", True + ): # baichuan-13b has no bos token + bos_ids = [tokenizer.bos_token_id] + else: + bos_ids = [] # bos token is optional + + if tokenizer.eos_token_id is not None: + eos_ids = [tokenizer.eos_token_id] + else: + raise ValueError("EOS token is required.") + + return bos_ids, eos_ids + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + system: str, + history: List[Tuple[str, str]], + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + sep + query resp + eos + Turn t: sep + bos + query resp + eos + """ + bos_ids, eos_ids = self._get_special_ids(tokenizer) + sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = self._convert_inputs_to_ids( + tokenizer, context=self.prefix, system=system + ) + if len(prefix_ids) != 0: # has prefix + prefix_ids = bos_ids + prefix_ids + sep_ids + else: + prefix_ids = bos_ids + else: + prefix_ids = sep_ids + bos_ids + + query_ids = self._convert_inputs_to_ids( + tokenizer, context=self.prompt, query=query, idx=str(turn_idx) + ) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) + return encoded_pairs + + def _convert_inputs_to_ids( + self, + tokenizer: "PreTrainedTokenizer", + context: List[Union[str, Dict[str, str]]], + system: Optional[str] = None, + query: Optional[str] = None, + idx: Optional[str] = None, + ) -> List[int]: + r""" + Converts context to token ids. + """ + if isinstance( + getattr(tokenizer, "tokenizer", None), tiktoken.Encoding + ): # for tiktoken tokenizer (Qwen) + kwargs = dict(allowed_special="all") + else: + kwargs = dict(add_special_tokens=False) + + token_ids = [] + for elem in context: + if isinstance(elem, str): + elem = ( + elem.replace("{{system}}", system, 1) + if system is not None + else elem + ) + elem = ( + elem.replace("{{query}}", query, 1) if query is not None else elem + ) + elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem + token_ids = token_ids + tokenizer.encode(elem, **kwargs) + elif isinstance(elem, dict): + token_ids = token_ids + [ + tokenizer.convert_tokens_to_ids(elem.get("token")) + ] + else: + raise NotImplementedError + + return token_ids + + +@dataclass +class Llama2Template(Template): + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + system: str, + history: List[Tuple[str, str]], + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + query resp + eos + Turn t: bos + query resp + eos + """ + bos_ids, eos_ids = self._get_special_ids(tokenizer) + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: # llama2 template has no sep_ids + query = self.prefix[0].replace("{{system}}", system) + query + query_ids = self._convert_inputs_to_ids( + tokenizer, context=self.prompt, query=query + ) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) + return encoded_pairs + + +templates: Dict[str, Template] = {} diff --git a/src/data_utils.py b/src/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0071434db72b04166daf2d747b810ce9b2223a17 --- /dev/null +++ b/src/data_utils.py @@ -0,0 +1,1030 @@ +import hashlib +import os +import numpy as np +import pandas as pd +import tiktoken +from itertools import chain +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + Union, + TYPE_CHECKING, + Generator, + Literal, +) +from datasets import ( + Dataset, + DatasetDict, + concatenate_datasets, + load_dataset, + interleave_datasets, +) +from transformers.tokenization_utils import PreTrainedTokenizer + +from .config import EXT2TYPE, IGNORE_INDEX +from .data_args import ( + DEFAULT_PROMPT_DICT, + ALPACA_PROMPT_DICT, + SQL_PROMPT_DICT, + Template, + Llama2Template, +) + +if TYPE_CHECKING: + from .model_args import ModelArguments + from .data_args import DataArguments + from datasets import IterableDataset + from transformers import TrainingArguments, Seq2SeqTrainingArguments + +from .loggings import get_logger + + +logger = get_logger(__name__) + + +def extract_default_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]: + # Not random, use pre-defined templates + if example.get("input", "") != "": + prompt_template = DEFAULT_PROMPT_DICT["prompt_input"] + else: + prompt_template = DEFAULT_PROMPT_DICT["prompt_no_input"] + + # Format prompt with example + formated_prompt = prompt_template.format(**example) + + return {"input": formated_prompt} + + +def extract_alpaca_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]: + if example.get("input", "") != "": + prompt_format = ALPACA_PROMPT_DICT["prompt_input"] + else: + prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"] + return {"input": prompt_format.format(**example)} + + +def extract_sql_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]: + if example.get("input", "") != "": + prompt_format = SQL_PROMPT_DICT["prompt_input"] + else: + prompt_format = SQL_PROMPT_DICT["prompt_no_input"] + return {"input": prompt_format.format(**example)} + + +def infer_max_len( + source_len: int, target_len: int, data_args: "DataArguments" +) -> Tuple[int, int]: + max_target_len = int( + data_args.cutoff_len * (target_len / (source_len + target_len)) + ) + max_target_len = max(max_target_len, data_args.reserved_label_len) + max_source_len = data_args.cutoff_len - max_target_len + return max_source_len, max_target_len + + +def local_dataset( + dataset_path: str, eval_dataset_size: float = 0.1 +) -> Tuple[Dataset, Dataset]: + """ + Reads in a dataset from a file and returns it as a split train-test dataset. + + Args: + dataset_path (str): The name of the dataset file to read in. \ + The format is inferred based on the file extension. + + Returns: + A tuple containing two datasets - the training subset and the testing subset. + Raises: + ValueError: If the specified file format is unsupported. + + """ + + # Read in the full dataset from file based on the file format + if dataset_path.endswith(".json"): + full_dataset = load_dataset("json", data_files=dataset_path) + elif dataset_path.endswith(".jsonl"): + full_dataset = load_dataset("json", data_files=dataset_path) + elif dataset_path.endswith(".csv"): + full_dataset = Dataset.from_pandas(pd.read_csv(dataset_path)) + elif dataset_path.endswith(".tsv"): + full_dataset = Dataset.from_pandas(pd.read_csv(dataset_path, delimiter="\t")) + else: + raise ValueError(f"Unsupported dataset format: {dataset_path}") + if "train" not in full_dataset: + split_dataset = full_dataset.train_test_split(test_size=eval_dataset_size) + return split_dataset + else: + return full_dataset + + +def load_data( + dataset_path: str, eval_dataset_size: float = 0.1 +) -> Union[Dict[str, Dataset], None]: + """ + Load a dataset based on its name. + + Args: + dataset_path: A string representing the path to the dataset to be loaded. + + Returns: + A dictionary containing the loaded dataset if the dataset exists. + None if the dataset does not exist. + + Raises: + NotImplementedError: If the dataset name provided is not implemented yet or if + the dataset is not released. + + Examples: + >>> load_data('alpaca') + {'train': Dataset(...), 'validation': Dataset(...), 'test': Dataset(...)} + + """ + if not os.path.exists(dataset_path): + # Download dataset from HuggingFace Datasets + print( + f"Lodding dataset from huggingface, please ref to https://huggingface.co/datasets/{dataset_path}" + ) + dataset = load_dataset(dataset_path, cache_dir="~/.cache/huggingface/datasets") + return dataset + else: + # Load dataset from local file + try: + print(f"Lodding dataset from local path: {dataset_path}") + dataset = local_dataset(dataset_path, eval_dataset_size) + return dataset + except: + raise ValueError(f"Error loading dataset from {dataset_path}") + + +templates: Dict[str, Template] = {} + + +def get_template_and_fix_tokenizer( + name: str, tokenizer: "PreTrainedTokenizer" +) -> Template: + template = templates.get(name, None) + assert template is not None, "Template {} does not exist.".format(name) + + additional_special_tokens = template.stop_words + + if tokenizer.eos_token_id is None: + tokenizer.eos_token = "<|endoftext|>" + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + + if tokenizer.pad_token_id is None: + if tokenizer.unk_token_id is not None: + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) + + if name is None: + return None + + tokenizer.add_special_tokens( + dict(additional_special_tokens=additional_special_tokens), + replace_additional_special_tokens=False, + ) + return template + + +def register_template( + name: str, + prefix: List[Union[str, Dict[str, str]]], + prompt: List[Union[str, Dict[str, str]]], + system: str, + sep: List[Union[str, Dict[str, str]]], + stop_words: Optional[List[str]] = [], + use_history: Optional[bool] = True, +) -> None: + template_class = Llama2Template if "llama2" in name else Template + templates[name] = template_class( + prefix=prefix, + prompt=prompt, + system=system, + sep=sep, + stop_words=stop_words, + use_history=use_history, + ) + + +r""" +Supports language model inference without histories. +""" +register_template( + name="vanilla", + prefix=[], + prompt=["{{query}}"], + system="", + sep=[], + use_history=False, +) + +r""" +Supports language model for mistral sqlcoder-7b +""" +register_template( + name="mistral", + prefix=["{{system}}"], + prompt=["[INST] {{query}} [/INST]"], + system="", + sep=[], +) + + +r""" +Default template. +""" +register_template( + name="default", + prefix=["{{system}}"], + prompt=["Human: {{query}}\nAssistant: "], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=["\n"], +) + + +r""" +Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf + https://huggingface.co/meta-llama/Llama-2-13b-chat-hf + https://huggingface.co/meta-llama/Llama-2-70b-chat-hf +""" +register_template( + name="llama2", + prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"], + prompt=["[INST] {{query}} [/INST] "], + system=( + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + sep=[], +) + +register_template( + name="llama3", + prefix=["<|start_header_id|>system<|end_header_id|>\n\n{{system}}<|eot_id|>\n"], + prompt=["<|start_header_id|>user<|end_header_id|>\n\n{{query}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"], + system=( + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + sep=[], +) + +r""" +Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 + https://huggingface.co/ziqingyang/chinese-alpaca-2-7b +""" +register_template( + name="llama2_zh", + prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"], + prompt=["[INST] {{query}} [/INST] "], + system="You are a helpful assistant. 你是一个乐于助人的助手。", + sep=[], +) + + +r""" +Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff + https://github.com/ymcui/Chinese-LLaMA-Alpaca +""" +register_template( + name="alpaca", + prefix=["{{system}}"], + prompt=["### Instruction:\n{{query}}\n\n### Response:\n"], + system=( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ), + sep=["\n\n"], +) + + +r""" +Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 + https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 +""" +register_template( + name="vicuna", + prefix=["{{system}}"], + prompt=["USER: {{query}} ASSISTANT: "], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=[], +) + + +r""" +Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B +""" +register_template( + name="belle", + prefix=["{{system}}"], + prompt=["Human: {{query}}\n\nBelle: "], + system="", + sep=["\n\n"], +) + + +r""" +Supports: https://github.com/CVI-SZU/Linly +""" +register_template( + name="linly", + prefix=["{{system}}"], + prompt=["User: {{query}}\nBot: "], + system="", + sep=["\n"], +) + + +r""" +Supports: https://github.com/Neutralzz/BiLLa +""" +register_template( + name="billa", + prefix=["{{system}}"], + prompt=["Human: {{query}}\nAssistant: "], + system="", + sep=["\n"], +) + + +r""" +Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 +""" +register_template( + name="ziya", + prefix=["{{system}}"], + prompt=[{"token": "<human>"}, ":{{query}}\n", {"token": "<bot>"}, ":"], + system="", + sep=["\n"], +) + + +r""" +Supports: https://huggingface.co/qhduan/aquilachat-7b +""" +register_template( + name="aquila", + prefix=["{{system}}"], + prompt=["Human: {{query}}###Assistant: "], + system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), + sep=["###"], +) + + +r""" +Supports: https://huggingface.co/internlm/internlm-chat-7b +""" +register_template( + name="intern", + prefix=["{{system}}"], + prompt=["<|User|>:{{query}}", {"token": "<eoh>"}, "\n<|Bot|>:"], + system="", + sep=["\n"], + stop_words=["</s>", "<eoa>"], # internlm cannot replace eos token +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat +Used for training and inference of the fine-tuned models. +""" +register_template( + name="baichuan", + prefix=["{{system}}"], + prompt=[ + {"token": "<reserved_102>"}, # user token + "{{query}}", + {"token": "<reserved_103>"}, # assistant token + ], + system="", + sep=[], + stop_words=[], +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat +Used for inference of the original model. +""" +register_template( + name="baichuan_eval", + prefix=["{{system}}", {"token": "<reserved_102>"}], # user token + prompt=["{{query}}", {"token": "<reserved_103>"}], # assistant token + system="", + sep=[], + stop_words=["<reserved_102>"], # user token +) + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat + https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat +Used for training and inference of the fine-tuned models. +""" +register_template( + name="baichuan2", + prefix=["{{system}}"], + prompt=[ + {"token": "<reserved_106>"}, # user token + "{{query}}", + {"token": "<reserved_107>"}, # assistant token + ], + system="", + sep=[], +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat + https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat +Used for inference of the original model. +""" +register_template( + name="baichuan2_eval", + prefix=["{{system}}", {"token": "<reserved_106>"}], # user token + prompt=["{{query}}", {"token": "<reserved_107>"}], # assistant token + system="", + sep=[], + stop_words=["<reserved_106>"], # user token +) + + +r""" +Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha + https://huggingface.co/HuggingFaceH4/starchat-beta + +""" +register_template( + name="starchat", + prefix=[{"token": "<|system|>"}, "\n{{system}}", {"token": "<|end|>"}], + prompt=[ + {"token": "<|user|>"}, + "\n{{query}}", + {"token": "<|end|>"}, + "\n", + {"token": "<|assistant|>"}, + ], + system="", + sep=["\n"], + stop_words=["<|end|>"], +) + + +r""" +Supports: https://huggingface.co/Qwen/Qwen-7B-Chat +""" +register_template( + name="chatml", + prefix=[{"token": "<|im_start|>"}, "system\n{{system}}", {"token": "<|im_end|>"}], + prompt=[ + {"token": "<|im_start|>"}, + "user\n{{query}}", + {"token": "<|im_end|>"}, + "\n", + {"token": "<|im_start|>"}, + "assistant\n", + ], + system="You are a helpful assistant.", + sep=["\n"], + stop_words=["<|im_end|>"], +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm2-6b +""" +register_template( + name="chatglm2", + prefix=[{"token": "[gMASK]"}, {"token": "sop"}, "{{system}}"], + prompt=["[Round {{idx}}]\n\n问:{{query}}\n\n答:"], + system="", + sep=["\n\n"], +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm3-6b +""" +register_template( + name="chatglm3", + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + {"token": "<|system|>"}, + "\n", + "{{system}}", + ], + prompt=[ + {"token": "<|user|>"}, + "\n", + "{{query}}", + {"token": "<|assistant|>"}, + "\n", # add an extra newline to avoid error in ChatGLM's process_response method + ], + system=( + "You are ChatGLM3, a large language model trained by Zhipu.AI. " + "Follow the user's instructions carefully. Respond using markdown." + ), + sep=[], + stop_words=["<|user|>", "<|observation|>"], +) + +register_template( + name="chatglm3_raw", # the raw template for tool tuning + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + {"token": "<|system|>"}, + "\n", + "{{system}}", + ], + prompt=[{"token": "<|user|>"}, "\n", "{{query}}", {"token": "<|assistant|>"}], + system=( + "You are ChatGLM3, a large language model trained by Zhipu.AI. " + "Follow the user's instructions carefully. Respond using markdown." + ), + sep=[], + stop_words=["<|user|>", "<|observation|>"], +) + + +r""" +Supports: https://huggingface.co/xverse/XVERSE-13B-Chat +""" +register_template( + name="xverse", + prefix=["{{system}}"], + prompt=["Human: {{query}}\n\nAssistant: "], + system="", + sep=[], +) + + +def split_dataset( + dataset: Union["Dataset", "IterableDataset"], + data_args: "DataArguments", + training_args: "TrainingArguments", +) -> Dict[str, "Dataset"]: + if training_args.do_train: + if data_args.val_size > 1e-6: # Split the dataset + if data_args.streaming: + val_set = dataset.take(int(data_args.val_size)) + train_set = dataset.skip(int(data_args.val_size)) + dataset = dataset.shuffle( + buffer_size=data_args.buffer_size, seed=training_args.seed + ) + return {"train_dataset": train_set, "eval_dataset": val_set} + else: + val_size = ( + int(data_args.val_size) + if data_args.val_size > 1 + else data_args.val_size + ) + dataset = dataset.train_test_split( + test_size=val_size, seed=training_args.seed + ) + return { + "train_dataset": dataset["train"], + "eval_dataset": dataset["test"], + } + else: + if data_args.streaming: + dataset = dataset.shuffle( + buffer_size=data_args.buffer_size, seed=training_args.seed + ) + return {"train_dataset": dataset} + else: # do_eval or do_predict + return {"eval_dataset": dataset} + + +def preprocess_dataset( + dataset: Union["Dataset", "IterableDataset"], + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"], +) -> Union["Dataset", "IterableDataset"]: + column_names = list(next(iter(dataset)).keys()) + template = get_template_and_fix_tokenizer(data_args.template, tokenizer) + + def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: + for i in range(len(examples["prompt"])): + query, response = examples["prompt"][i], examples["response"][i] + query = ( + query + "\n" + examples["query"][i] + if "query" in examples and examples["query"][i] + else query + ) + history = examples["history"][i] if "history" in examples else None + system = examples["system"][i] if "system" in examples else None + yield query, response, history, system + + def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: + # build grouped texts with format `X1 X2 X3 ...` (without <eos>) + if isinstance( + getattr(tokenizer, "tokenizer", None), tiktoken.Encoding + ): # for tiktoken tokenizer (Qwen) + kwargs = dict(allowed_special="all") + else: + kwargs = dict(add_special_tokens=False) + + tokenized_examples = tokenizer(examples["prompt"], **kwargs) + concatenated_examples = { + k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys() + } + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = data_args.max_source_length + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of max_source_length + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + return result + + def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: + # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + max_length = data_args.max_source_length + data_args.max_target_length + + for query, response, history, system in construct_example(examples): + input_ids, labels = [], [] + + for source_ids, target_ids in template.encode_multiturn( + tokenizer, query, response, history, system + ): + if len(source_ids) > data_args.max_source_length: + source_ids = source_ids[: data_args.max_source_length] + if len(target_ids) > data_args.max_target_length: + target_ids = target_ids[: data_args.max_target_length] + + if len(input_ids) + len(source_ids) + len(target_ids) > max_length: + break + + input_ids += source_ids + target_ids + labels += [IGNORE_INDEX] * len(source_ids) + target_ids + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + + return model_inputs + + def preprocess_unsupervised_dataset( + examples: Dict[str, List[Any]] + ) -> Dict[str, Any]: + # build inputs with format `<bos> X` and labels with format `Y <eos>` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + + for query, response, history, system in construct_example(examples): + source_ids, target_ids = template.encode_oneturn( + tokenizer, query, response, history, system + ) + + if len(source_ids) > data_args.max_source_length: + source_ids = source_ids[: data_args.max_source_length] + if len(target_ids) > data_args.max_target_length: + target_ids = target_ids[: data_args.max_target_length] + + model_inputs["input_ids"].append(source_ids) + model_inputs["attention_mask"].append([1] * len(source_ids)) + model_inputs["labels"].append(target_ids) + + return model_inputs + + def preprocess_pairwise_dataset( + examples: Dict[str, List[Any]] + ) -> Dict[str, List[List[int]]]: + # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` for rm stage + model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} + for query, response, history, system in construct_example(examples): + if not ( + isinstance(query, str) + and isinstance(response, list) + and query != "" + and len(response) > 1 + ): + continue + + prompt_ids, chosen_ids = template.encode_oneturn( + tokenizer, query, response[0], history, system + ) + _, rejected_ids = template.encode_oneturn( + tokenizer, query, response[1], history, system + ) + + # if template.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + source_len, target_len = len(prompt_ids), max( + len(chosen_ids), len(rejected_ids) + ) + max_source_len, max_target_len = infer_max_len( + source_len, target_len, data_args + ) + if source_len > max_source_len: + prompt_ids = prompt_ids[:max_source_len] + if target_len > max_target_len: + chosen_ids = chosen_ids[:max_target_len] + rejected_ids = rejected_ids[:max_target_len] + + model_inputs["prompt_ids"].append(prompt_ids) + model_inputs["chosen_ids"].append(chosen_ids) + model_inputs["rejected_ids"].append(rejected_ids) + + return model_inputs + + def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None: + print("prompt_ids:\n{}".format(example["prompt_ids"])) + print( + "prompt:\n{}".format( + tokenizer.decode(example["prompt_ids"], skip_special_tokens=False) + ) + ) + print("chosen_ids:\n{}".format(example["chosen_ids"])) + print( + "chosen:\n{}".format( + tokenizer.decode(example["chosen_ids"], skip_special_tokens=False) + ) + ) + print("rejected_ids:\n{}".format(example["rejected_ids"])) + print( + "rejected:\n{}".format( + tokenizer.decode(example["rejected_ids"], skip_special_tokens=False) + ) + ) + + def print_supervised_dataset_example(example): + print("input_ids:\n{}".format(example["input_ids"])) + print( + "inputs:\n{}".format( + tokenizer.decode(example["input_ids"], skip_special_tokens=False) + ) + ) + print("label_ids:\n{}".format(example["labels"])) + print( + "labels:\n{}".format( + tokenizer.decode( + [ + token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id + for token_id in example["labels"] + ], + skip_special_tokens=False, + ) + ) + ) + + if stage == "pt": + pass + elif stage == "sft" and not training_args.predict_with_generate: + preprocess_function = preprocess_supervised_dataset + print_function = print_supervised_dataset_example + elif stage == "rm": + print(111111111111111111) + preprocess_function = preprocess_pairwise_dataset + print_function = print_pairwise_dataset_example + else: + pass + + with training_args.main_process_first(desc="dataset map pre-processing"): + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + dataset = dataset.map( + preprocess_function, batched=True, remove_columns=column_names, **kwargs + ) + + print_function(next(iter(dataset))) + return dataset + + +## used in get_dataset +def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: + if file_sha1 is None: + logger.warning( + "Checksum failed: missing SHA-1 hash value in dataset_info.json." + ) + return + + if len(data_files) != 1: + logger.warning("Checksum failed: too many files.") + return + + with open(data_files[0], "rb") as f: + sha1 = hashlib.sha1(f.read()).hexdigest() + if sha1 != file_sha1: + logger.warning( + "Checksum failed: mismatched SHA-1 hash value at {}.".format( + data_files[0] + ) + ) + + +def get_dataset( + model_args: "ModelArguments", data_args: "DataArguments" +) -> Union["Dataset", "IterableDataset"]: + max_samples = data_args.max_samples + all_datasets: List[ + Union["Dataset", "IterableDataset"] + ] = [] # support multiple datasets + + for dataset_attr in data_args.dataset_list: + logger.info("Loading dataset {}...".format(dataset_attr)) + + if dataset_attr.load_from == "hf_hub": + data_path = dataset_attr.dataset_name + data_files = None + elif dataset_attr.load_from == "script": + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_files = None + elif dataset_attr.load_from == "file": + data_path = None + data_files: List[str] = [] + + if os.path.isdir( + os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + ): # directory + for file_name in os.listdir( + os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + ): + data_files.append( + os.path.join( + data_args.dataset_dir, dataset_attr.dataset_name, file_name + ) + ) + if data_path is None: + data_path = EXT2TYPE.get(file_name.split(".")[-1], None) + else: + assert data_path == EXT2TYPE.get( + file_name.split(".")[-1], None + ), "file type does not match." + elif os.path.isfile( + os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + ): # single file + data_files.append( + os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + ) + data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) + else: + raise ValueError("File not found.") + + assert data_path, "File extension must be txt, csv, json or jsonl." + checksum(data_files, dataset_attr.dataset_sha1) + else: + raise NotImplementedError + + dataset = load_dataset( + data_path, + data_files=data_files, + split=data_args.split, + cache_dir=model_args.cache_dir, + streaming=data_args.streaming, + use_auth_token=True if model_args.use_auth_token else None, + ) + + if max_samples is not None: + max_samples_temp = min(len(dataset), max_samples) + dataset = dataset.select(range(max_samples_temp)) + + for column_name in ["prompt", "query", "response", "history"]: # align datasets + if ( + getattr(dataset_attr, column_name) + and getattr(dataset_attr, column_name) != column_name + ): + dataset = dataset.rename_column( + getattr(dataset_attr, column_name), column_name + ) + + if dataset_attr.system_prompt: # add system prompt + if data_args.streaming: + dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}) + else: + dataset = dataset.add_column( + "system", [dataset_attr.system_prompt] * len(dataset) + ) + + all_datasets.append(dataset) + + if len(data_args.dataset_list) == 1: + return all_datasets[0] + elif data_args.mix_strategy == "concat": + if data_args.streaming: + logger.warning( + "The samples between different datasets will not be mixed in streaming mode." + ) + return concatenate_datasets(all_datasets) + elif data_args.mix_strategy.startswith("interleave"): + if not data_args.streaming: + logger.warning( + "We recommend using `mix_strategy=concat` in non-streaming mode." + ) + stopping_strategy = ( + "first_exhausted" + if data_args.mix_strategy.endswith("under") + else "all_exhausted" + ) + return interleave_datasets( + all_datasets, + data_args.interleave_probs, + stopping_strategy=stopping_strategy, + ) + else: + raise ValueError("Unknown mixing strategy.") + + +def split_train_eval( + dataset: Dataset, + do_eval: bool = False, + eval_dataset_size: float = 0.1, + max_eval_samples: int = None, + do_train: bool = True, + max_train_samples: int = None, +) -> Dict[str, Dataset]: + """ + Prepare the training and evaluation datasets for a machine learning model. + + Args: + dataset (DatasetDict): The complete dataset containing train, validation, and test splits. + do_eval (bool, optional): Whether to use an evaluation dataset or not. Defaults to False. + eval_dataset_size (float, optional): The size of the validation set if splitting from the training data. + Ignored if `do_eval` is False. Defaults to 0.2. + max_eval_samples (int, optional): The maximum number of samples to keep in the evaluation dataset. + Ignored if `do_eval` is False or `None`. Defaults to None. + do_train (bool, optional): Whether to use a training dataset or not. Defaults to True. + max_train_samples (int, optional): The maximum number of samples to keep in the training dataset. + Ignored if `do_train` is False or `None`. Defaults to None. + + Returns: + Dict[str, Dataset]: A dictionary containing the prepared training and evaluation datasets + (if used), where the keys are 'train' and 'eval', respectively. + """ + if not isinstance(dataset, DatasetDict): + raise TypeError("The 'dataset' argument must be a DatasetDict object.") + + train_dataset, eval_dataset = None, None + # Prepare evaluation dataset + if do_eval: + if "eval" in dataset: + eval_dataset = dataset["eval"] + else: + # Split train dataset in train and validation according to `eval_dataset_size` + print( + f"Splitting the dataset into train and validation according to `eval_dataset_size`: {eval_dataset_size}" + ) + dataset = dataset["train"].train_test_split( + test_size=eval_dataset_size, shuffle=True, seed=42 + ) + eval_dataset = dataset["test"] + + # Reduce evaluation dataset size (if specified) + print( + f"You have set the max_eval_samples: {max_eval_samples}, will do sampling ..." + ) + if max_eval_samples is not None and len(eval_dataset) > max_eval_samples: + eval_dataset = eval_dataset.select(np.arange(max_eval_samples)) + + # Prepare training dataset + if do_train: + train_dataset = dataset["train"] + + # Reduce training dataset size (if specified) + print( + f"You have set the max_train_samples: {max_train_samples}, will do sampling ..." + ) + if max_train_samples is not None and len(train_dataset) > max_train_samples: + train_dataset = train_dataset.select(np.arange(max_train_samples)) + + return train_dataset, eval_dataset diff --git a/src/ds_config.json b/src/ds_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e96d4d9b2a886a49eeb7077db31eaa9ef13b0ef0 --- /dev/null +++ b/src/ds_config.json @@ -0,0 +1,23 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "overlap_comm": false, + "contiguous_gradients": true + } + } diff --git a/src/export.py b/src/export.py new file mode 100644 index 0000000000000000000000000000000000000000..0daa4d12915828078153ea444723793311439f0d --- /dev/null +++ b/src/export.py @@ -0,0 +1,14 @@ +import os +import sys + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) +from .model_trainer import export_model + + +def main(): + export_model() + + +if __name__ == "__main__": + main() diff --git a/src/load.py b/src/load.py new file mode 100644 index 0000000000000000000000000000000000000000..8567a84254fd9ba892820d7dd736d9ac1402a0be --- /dev/null +++ b/src/load.py @@ -0,0 +1,418 @@ +import os +import torch +import inspect +import math +from typing import TYPE_CHECKING, Optional, Tuple, Dict, Literal, List +from peft import PeftModel, TaskType, LoraConfig, get_peft_model +from peft.utils import CONFIG_NAME, WEIGHTS_NAME +from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers.utils import check_min_version, cached_file +from transformers.utils.versions import require_version +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.deepspeed import is_deepspeed_zero3_enabled +from types import MethodType +from .config import LAYERNORM_NAMES, VALUE_HEAD_FILE_NAME +from .config_parser import load_trainable_params + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from .model_args import ModelArguments, FinetuningArguments + + +def prepare_model_for_training( + model: "PreTrainedModel", + finetuning_type: str, + output_layer_name: Optional[str] = "lm_head", + use_gradient_checkpointing: Optional[bool] = True, + layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES, +) -> "PreTrainedModel": + for name, param in model.named_parameters(): + if param.ndim == 1 and any( + layer_norm_name in name for layer_norm_name in layer_norm_names + ): + param.data = param.data.to(torch.float32) + + if use_gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model.gradient_checkpointing_enable() + model.config.use_cache = ( + False # turn off when gradient checkpointing is enabled + ) + + if finetuning_type != "full" and hasattr(model, output_layer_name): + output_layer: torch.nn.Linear = getattr(model, output_layer_name) + input_dtype = output_layer.weight.dtype + + class CastOutputToFloat(torch.nn.Sequential): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.to(input_dtype)).to(torch.float32) + + setattr(model, output_layer_name, CastOutputToFloat(output_layer)) + + return model + +def init_adapter( + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + is_mergeable: bool, +) -> "PreTrainedModel": + r""" + Initializes the adapters. + + Support full-parameter, freeze and LoRA ,QLoRA,training. + + Note that the trainable parameters must be cast to float32. + """ + + if finetuning_args.finetuning_type == "none" and is_trainable: + raise ValueError("You cannot use finetuning_type=none while training.") + + if finetuning_args.finetuning_type == "full" and is_trainable: + print("Fine-tuning method: Full") + model = model.float() + + if finetuning_args.finetuning_type == "freeze": + print("Fine-tuning method: Freeze") + + for name, param in model.named_parameters(): + if not any( + trainable_layer in name + for trainable_layer in finetuning_args.trainable_layers + ): + param.requires_grad_(False) + else: + param.data = param.data.to(torch.float32) + + if model_args.checkpoint_dir is not None: + assert load_trainable_params( + model, model_args.checkpoint_dir[0] + ), "Model checkpoint is not correctly loaded." + + if finetuning_args.finetuning_type == "lora": + print("Fine-tuning method: LoRA") + latest_checkpoint = None + + if model_args.checkpoint_dir is not None: + assert os.path.exists( + os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME) + ), "Provided path ({}) does not contain a LoRA weight.".format( + model_args.checkpoint_dir[0] + ) + assert os.path.exists( + os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME) + ), "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." + + if (is_trainable and finetuning_args.resume_lora_training) or ( + not is_mergeable + ): # continually fine-tuning + checkpoints_to_merge, latest_checkpoint = ( + model_args.checkpoint_dir[:-1], + model_args.checkpoint_dir[-1], + ) + else: + checkpoints_to_merge = model_args.checkpoint_dir + + for checkpoint in checkpoints_to_merge: + model = PeftModel.from_pretrained(model, checkpoint) + model = model.merge_and_unload() + + if len(checkpoints_to_merge) > 0: + print( + "Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)) + ) + + if ( + latest_checkpoint is not None + ): # resume lora training or quantized inference + model = PeftModel.from_pretrained( + model, latest_checkpoint, is_trainable=is_trainable + ) + + if ( + is_trainable and latest_checkpoint is None + ): # create new lora weights while training + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=finetuning_args.lora_rank, + lora_alpha=finetuning_args.lora_alpha, + lora_dropout=finetuning_args.lora_dropout, + target_modules=finetuning_args.lora_target, + ) + model = get_peft_model(model, lora_config) + + if model_args.checkpoint_dir is not None: + print( + "Loaded fine-tuned model from checkpoint(s): {}".format( + ",".join(model_args.checkpoint_dir) + ) + ) + + return model + +def load_model_and_tokenizer( + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: Optional[bool] = False, + add_valuehead: Optional[bool] = False, +) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]: + r""" + Loads pretrained model and tokenizer. + + Support both training and inference. + """ + if (not is_trainable) and model_args.checkpoint_dir is None: + print( + "Checkpoint is not found at evaluation, load the original model." + ) + finetuning_args = FinetuningArguments(finetuning_type="none") + + config_kwargs = { + "trust_remote_code": True, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + split_special_tokens=model_args.split_special_tokens, + padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow + **config_kwargs + ) + + if ( + finetuning_args.finetuning_type == "full" + and model_args.checkpoint_dir is not None + ): + model_to_load = model_args.checkpoint_dir[0] + else: + model_to_load = model_args.model_name_or_path + + config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) + + if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config + if model_args.compute_dtype == torch.bfloat16: + setattr(config, "bf16", True) + else: + setattr(config, "fp16", True) + + # Fix config (for Qwen) + #if getattr(config, "model_type", None) == "qwen": + # for dtype_name, dtype in [ + # ("fp16", torch.float16), + # ("bf16", torch.bfloat16), + # ("fp32", torch.float32), + # ]: + # setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) + + # Set RoPE scaling + if model_args.rope_scaling is not None: + if hasattr(config, "use_dynamic_ntk"): # for Qwen models + if is_trainable: + print("Qwen model does not support RoPE scaling in training.") + else: + setattr(config, "use_dynamic_ntk", True) + setattr(config, "use_logn_attn", True) + print("Using dynamic NTK scaling.") + + elif hasattr(config, "rope_scaling"): # for LLaMA models + require_version( + "transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0" + ) + + if is_trainable: + if model_args.rope_scaling == "dynamic": + print( + "Dynamic NTK may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if ( + current_max_length + and model_args.model_max_length > current_max_length + ): + scaling_factor = float( + math.ceil(model_args.model_max_length / current_max_length) + ) + else: + print( + "Input length is smaller than max length. Consider increase input length." + ) + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr( + config, + "rope_scaling", + {"type": model_args.rope_scaling, "factor": scaling_factor}, + ) + print( + "Using {} scaling strategy and setting scaling factor to {}".format( + model_args.rope_scaling, scaling_factor + ) + ) + + else: + print("Current model does not support RoPE scaling.") + + # Quantization configurations (using bitsandbytes library). + is_mergeable = True + if model_args.quantization_bit is not None: + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + + if model_args.quantization_bit == 8: + require_version( + "bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0" + ) + config_kwargs["load_in_8bit"] = True + config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + elif model_args.quantization_bit == 4: + require_version( + "bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0" + ) + config_kwargs["load_in_4bit"] = True + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type, + ) + + is_mergeable = False + config_kwargs["device_map"] = ( + {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto" + ) + print("Quantizing model to {} bit.".format(model_args.quantization_bit)) + + # Load and prepare pre-trained models (without valuehead). + model = AutoModelForCausalLM.from_pretrained( + model_to_load, + config=config, + torch_dtype=model_args.compute_dtype, + low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), + **config_kwargs + ) + + # Disable custom generate method (for Qwen) + #if "GenerationMixin" not in str(model.generate.__func__): + # model.generate = MethodType(PreTrainedModel.generate, model) + + # Fix LM head (for ChatGLM2,ChatGLM3) + #if not hasattr(model, "lm_head") and hasattr(model, "transformer"): + # setattr(model, "lm_head", model.transformer.output_layer) + + # Register auto class to save the custom code files. + if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr( + config, "auto_map", {} + ): + config.__class__.register_for_auto_class() + if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr( + config, "auto_map", {} + ): + model.__class__.register_for_auto_class() + if isinstance( + tokenizer, PreTrainedTokenizerBase + ) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): + tokenizer.__class__.register_for_auto_class() + + # Initialize adapters + model = ( + prepare_model_for_training(model, finetuning_args.finetuning_type) + if is_trainable + else model + ) + model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) + + # Prepare model with valuehead for RLHF + #if add_valuehead: + # model: "AutoModelForCausalLMWithValueHead" = ( + # AutoModelForCausalLMWithValueHead.from_pretrained(model) + # ) + # ignore_modules = [ + # name for name, _ in model.named_parameters() if "pretrained_model" in name + # ] + # setattr(model, "_keys_to_ignore_on_save", ignore_modules) + # setattr( + # model, "tie_weights", MethodType(lambda _: None, model) + # ) # use empty method + # vhead_path = ( + # model_args.checkpoint_dir[-1] + # if model_args.checkpoint_dir is not None + # else model_args.model_name_or_path + # ) + # vhead_params = load_valuehead_params(vhead_path, model_args) + # if vhead_params is not None: + # model.load_state_dict(vhead_params, strict=False) + # logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) + + # Prepare model for inference + if not is_trainable: + model.requires_grad_(False) # fix all model params + infer_dtype = ( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) # detect cuda capability + model = model.to(infer_dtype) if model_args.quantization_bit is None else model + + #trainable_params, all_param = count_parameters(model) + #print( + # "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + # trainable_params, all_param, 100 * trainable_params / all_param + # ) + #) + + return model, tokenizer + +def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": + r""" + Dispatches a pre-trained model to GPUs with balanced memory. + Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 + """ + if getattr(model, "is_loaded_in_8bit", False) or getattr( + model, "is_loaded_in_4bit", False + ): # do nothing + return model + + if torch.cuda.device_count() > 1: + from accelerate import dispatch_model + from accelerate.utils import infer_auto_device_map, get_balanced_memory + + if model._no_split_modules is None: + raise ValueError( + "The model class needs to implement the `_no_split_modules` attribute." + ) + + kwargs = { + "dtype": model.dtype, + "no_split_module_classes": model._no_split_modules, + } + max_memory = get_balanced_memory(model, **kwargs) + # Make sure tied weights are tied before creating the device map. + model.tie_weights() + device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) + return dispatch_model(model, device_map) + else: + return model.cuda() diff --git a/src/loggings.py b/src/loggings.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9b4fa70e0018f10cc17c1a1943dc705c9ced90 --- /dev/null +++ b/src/loggings.py @@ -0,0 +1,227 @@ +import sys +import logging +import os +import json +import time +from typing import TYPE_CHECKING +from datetime import timedelta +from transformers import TrainerCallback +from transformers.trainer_utils import has_length +from .config import LOG_FILE_NAME + +if TYPE_CHECKING: + from transformers import TrainingArguments, TrainerState, TrainerControl + + +def reset_logging(): + r""" + Removes basic config of root logger + """ + root = logging.getLogger() + list(map(root.removeHandler, root.handlers)) + list(map(root.removeFilter, root.filters)) + + +def get_logger(name: str) -> logging.Logger: + formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + ) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + return logger + + +logger = get_logger(__name__) + + +class LoggerHandler(logging.Handler): + def __init__(self): + super().__init__() + self.log = "" + + def reset(self): + self.log = "" + + def emit(self, record): + if record.name == "httpx": + return + log_entry = self.format(record) + self.log += log_entry + self.log += "\n\n" + + +class LogCallback(TrainerCallback): + def __init__(self, runner=None): + self.runner = runner + self.in_training = False + self.start_time = time.time() + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + + def timing(self): + cur_time = time.time() + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 + remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step + self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) + self.remaining_time = str(timedelta(seconds=int(remaining_time))) + + def on_train_begin( + self, + args: "TrainingArguments", + state: "TrainerState", + control: "TrainerControl", + **kwargs + ): + r""" + Event called at the beginning of training. + """ + if state.is_local_process_zero: + self.in_training = True + self.start_time = time.time() + self.max_steps = state.max_steps + if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)): + logger.warning("Previous log file in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) + + def on_train_end( + self, + args: "TrainingArguments", + state: "TrainerState", + control: "TrainerControl", + **kwargs + ): + r""" + Event called at the end of training. + """ + if state.is_local_process_zero: + self.in_training = False + self.cur_steps = 0 + self.max_steps = 0 + + def on_substep_end( + self, + args: "TrainingArguments", + state: "TrainerState", + control: "TrainerControl", + **kwargs + ): + r""" + Event called at the end of an substep during gradient accumulation. + """ + if ( + state.is_local_process_zero + and self.runner is not None + and self.runner.aborted + ): + control.should_epoch_stop = True + control.should_training_stop = True + + def on_step_end( + self, + args: "TrainingArguments", + state: "TrainerState", + control: "TrainerControl", + **kwargs + ): + r""" + Event called at the end of a training step. + """ + if state.is_local_process_zero: + self.cur_steps = state.global_step + self.timing() + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_evaluate( + self, + args: "TrainingArguments", + state: "TrainerState", + control: "TrainerControl", + **kwargs + ): + r""" + Event called after an evaluation phase. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_predict( + self, + args: "TrainingArguments", + state: "TrainerState", + control: "TrainerControl", + *other, + **kwargs + ): + r""" + Event called after a successful prediction. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_log( + self, + args: "TrainingArguments", + state: "TrainerState", + control: "TrainerControl", + **kwargs + ) -> None: + r""" + Event called after logging the last logs. + """ + if not state.is_local_process_zero: + return + + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + loss=state.log_history[-1].get("loss", None), + eval_loss=state.log_history[-1].get("eval_loss", None), + predict_loss=state.log_history[-1].get("predict_loss", None), + reward=state.log_history[-1].get("reward", None), + learning_rate=state.log_history[-1].get("learning_rate", None), + epoch=state.log_history[-1].get("epoch", None), + percentage=round(self.cur_steps / self.max_steps * 100, 2) + if self.max_steps != 0 + else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + ) + os.makedirs(args.output_dir, exist_ok=True) + with open( + os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8" + ) as f: + f.write(json.dumps(logs) + "\n") + + def on_prediction_step( + self, + args: "TrainingArguments", + state: "TrainerState", + control: "TrainerControl", + **kwargs + ): + r""" + Event called after a prediction step. + """ + eval_dataloader = kwargs.pop("eval_dataloader", None) + if ( + state.is_local_process_zero + and has_length(eval_dataloader) + and not self.in_training + ): + if self.max_steps == 0: + self.max_steps = len(eval_dataloader) + self.cur_steps += 1 + self.timing() diff --git a/src/model_args.py b/src/model_args.py new file mode 100644 index 0000000000000000000000000000000000000000..61ad2667a096b185f1c4761725a794409857821b --- /dev/null +++ b/src/model_args.py @@ -0,0 +1,413 @@ +import json +import torch +from dataclasses import dataclass, field, asdict +from typing import Optional, Any, Dict, Literal +from transformers import Seq2SeqTrainingArguments +from .config import ( + MODEL_PATH, + ADAPTER_PATH, +) + + +@dataclass +class ModelArguments: + r""" + Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + """ + model_name_or_path: str = field( + metadata={ + "help": "Path to pretrained model or model identifier from huggingface.co/models." + } + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where to store the pretrained models downloaded from huggingface.co." + }, + ) + use_fast_tokenizer: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + use_auth_token: Optional[bool] = field( + default=False, + metadata={ + "help": "Will use the token generated when running `huggingface-cli login`." + }, + ) + model_revision: Optional[str] = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + padding_side: Optional[Literal["left", "right"]] = field( + default="left", + metadata={"help": "The side on which the model should have padding applied."}, + ) + quantization_bit: Optional[int] = field( + default=None, metadata={"help": "The number of bits to quantize the model."} + ) + quantization_type: Optional[Literal["fp4", "nf4"]] = field( + default="nf4", + metadata={"help": "Quantization data type to use in int4 training."}, + ) + double_quantization: Optional[bool] = field( + default=True, + metadata={ + "help": "Whether to use double quantization in int4 training or not." + }, + ) + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + default=None, metadata={"help": "Adopt scaled rotary positional embeddings."} + ) + checkpoint_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations." + }, + ) + # reward_model: Optional[str] = field( + # default=None, + # metadata={"help": "Path to the directory containing the checkpoints of the reward model."} + # ) + plot_loss: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to plot the training loss after fine-tuning or not." + }, + ) + hf_auth_token: Optional[str] = field( + default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} + ) + compute_dtype: Optional[torch.dtype] = field( + default=None, + metadata={ + "help": "Used in quantization configs. Do not specify this argument manually." + }, + ) + model_max_length: Optional[int] = field( + default=None, + metadata={ + "help": "Used in rope scaling. Do not specify this argument manually." + }, + ) + hf_hub_token: Optional[str] = field( + default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} + ) + split_special_tokens: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether or not the special tokens should be split during the tokenization process." + }, + ) + + def __post_init__(self): + if self.compute_dtype is not None or self.model_max_length is not None: + raise ValueError("These arguments cannot be specified.") + + if self.checkpoint_dir is not None: # support merging multiple lora weights + self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] + + if self.quantization_bit is not None: + assert self.quantization_bit in [ + 4, + 8, + ], "We only accept 4-bit or 8-bit quantization." + + if self.use_auth_token == True and self.hf_auth_token is not None: + from huggingface_hub.hf_api import HfFolder # lazy load + + HfFolder.save_token(self.hf_auth_token) + + +@dataclass +class GeneratingArguments: + r""" + Arguments pertaining to specify the decoding parameters. + """ + do_sample: Optional[bool] = field( + default=True, + metadata={ + "help": "Whether or not to use sampling, use greedy decoding otherwise." + }, + ) + temperature: Optional[float] = field( + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."}, + ) + top_p: Optional[float] = field( + default=0.7, + metadata={ + "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." + }, + ) + top_k: Optional[int] = field( + default=50, + metadata={ + "help": "The number of highest probability vocabulary tokens to keep for top-k filtering." + }, + ) + num_beams: Optional[int] = field( + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."}, + ) + max_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens." + }, + ) + max_new_tokens: Optional[int] = field( + default=512, + metadata={ + "help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt." + }, + ) + repetition_penalty: Optional[float] = field( + default=1.0, + metadata={ + "help": "The parameter for repetition penalty. 1.0 means no penalty." + }, + ) + length_penalty: Optional[float] = field( + default=1.0, + metadata={ + "help": "Exponential penalty to the length that is used with beam-based generation." + }, + ) + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + if args.get("max_new_tokens", None): + args.pop("max_length", None) + return args + + +@dataclass +class FinetuningArguments: + r""" + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + stage: Optional[Literal["sft", "rm"]] = field( + default="sft", metadata={"help": "Which stage will be performed in training."} + ) + finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( + default="lora", metadata={"help": "Which fine-tuning method to use."} + ) + num_hidden_layers: Optional[int] = field( + default=32, + metadata={ + "help": 'Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \ + LLaMA choices: ["32", "40", "60", "80"], \ + LLaMA-2 choices: ["32", "40", "80"], \ + BLOOM choices: ["24", "30", "70"], \ + Falcon choices: ["32", "60"], \ + Baichuan choices: ["32", "40"] \ + Qwen choices: ["32"], \ + XVERSE choices: ["40"], \ + ChatGLM2 choices: ["28"],\ + ChatGLM3 choices: ["28"]' + }, + ) + num_layer_trainable: Optional[int] = field( + default=3, + metadata={ + "help": "Number of trainable layers for partial-parameter (freeze) fine-tuning." + }, + ) + name_module_trainable: Optional[ + Literal["mlp", "self_attn", "self_attention"] + ] = field( + default="mlp", + metadata={ + "help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \ + LLaMA choices: ["mlp", "self_attn"], \ + BLOOM & Falcon & ChatGLM2 & ChatGLM3choices: ["mlp", "self_attention"], \ + Baichuan choices: ["mlp", "self_attn"], \ + Qwen choices: ["mlp", "attn"], \ + LLaMA-2, InternLM, XVERSE choices: the same as LLaMA.' + }, + ) + lora_rank: Optional[int] = field( + default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} + ) + lora_alpha: Optional[float] = field( + default=32.0, + metadata={ + "help": "The scale factor for LoRA fine-tuning (similar with the learning rate)." + }, + ) + lora_dropout: Optional[float] = field( + default=0.1, metadata={"help": "Dropout rate for the LoRA fine-tuning."} + ) + lora_target: Optional[str] = field( + default=None, + metadata={ + "help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ + LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \ + BLOOM & Falcon & ChatGLM2 & ChatGLM3 choices: ["query_key_value", "self_attention.dense", "mlp.dense"], \ + Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \ + Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \ + LLaMA-2, InternLM, XVERSE choices: the same as LLaMA.' + }, + ) + resume_lora_training: Optional[bool] = field( + default=True, + metadata={ + "help": "Whether to resume training from the last LoRA weights or create new weights after merging them." + }, + ) + ppo_score_norm: Optional[bool] = field( + default=False, metadata={"help": "Use score normalization in PPO Training."} + ) + dpo_beta: Optional[float] = field( + default=0.1, metadata={"help": "The beta parameter for the DPO loss."} + ) + + def __post_init__(self): + if isinstance( + self.lora_target, str + ): # support custom target modules/layers of LoRA + self.lora_target = [ + target.strip() for target in self.lora_target.split(",") + ] + + if ( + self.num_layer_trainable > 0 + ): # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = [ + self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable) + ] + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] + + self.trainable_layers = [ + "{:d}.{}".format(idx, self.name_module_trainable) + for idx in trainable_layer_ids + ] + + assert self.finetuning_type in [ + "lora", + "freeze", + "full", + "none", + ], "Invalid fine-tuning method." + + def save_to_json(self, json_path: str): + r"""Saves the content of this instance in JSON format inside `json_path`.""" + json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + r"""Creates an instance from the content of `json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) + + +@dataclass +class TrainingArguments(Seq2SeqTrainingArguments): + cache_dir: Optional[str] = field(default=None) + train_on_source: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to train on the input in addition to the target text." + }, + ) + full_finetune: bool = field( + default=False, metadata={"help": "Finetune the entire model without adapters."} + ) + do_train: bool = field( + default=True, + metadata={"help": "To train or not to train, that is the question?"}, + ) + sample_generate: bool = field( + default=False, metadata={"help": "If do sample generation on evaluation."} + ) + optim: str = field( + default="paged_adamw_32bit", metadata={"help": "The optimizer to be used"} + ) + max_grad_norm: float = field( + default=0.3, + metadata={ + "help": "Gradient clipping max norm. This is tuned and works well for all models tested." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={"help": "Use gradient checkpointing. You want to use this."}, + ) + predict_with_generate: bool = field( + default=False, + metadata={ + "help": "Group sequences into batches with same length. Saves memory and speeds up training considerably." + }, + ) + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + output_dir: str = field( + default=ADAPTER_PATH, + metadata={"help": "The output dir for logs and checkpoints"}, + ) + per_device_train_batch_size: int = field( + default=1, + metadata={ + "help": "The training batch size per GPU. Increase for better speed." + }, + ) + gradient_accumulation_steps: int = field( + default=16, + metadata={ + "help": "How many gradients to accumulate before to perform an optimizer step" + }, + ) + max_steps: int = field( + default=10000, metadata={"help": "How many optimizer update steps to take"} + ) + # use lora dropout instead for regularization if needed + weight_decay: float = field( + default=0.0, metadata={"help": "The L2 weight decay rate of AdamW"} + ) + learning_rate: float = field(default=0.0002, metadata={"help": "The learnign rate"}) + remove_unused_columns: bool = field( + default=False, + metadata={"help": "Removed unused columns. Needed to make this codebase work."}, + ) + lr_scheduler_type: str = field( + default="constant", + metadata={ + "help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis" + }, + ) + warmup_ratio: float = field( + default=0.03, metadata={"help": "Fraction of steps to do a warmup for"} + ) + logging_steps: int = field( + default=10, + metadata={"help": "The frequency of update steps after which to log the loss"}, + ) + group_by_length: bool = field( + default=True, + metadata={ + "help": "Group sequences into batches with same length. Saves memory and speeds up training considerably." + }, + ) + save_strategy: str = field( + default="steps", metadata={"help": "When to save checkpoints"} + ) + save_steps: int = field(default=250, metadata={"help": "How often to save a model"}) + save_total_limit: int = field( + default=40, + metadata={ + "help": "How many checkpoints to save before the oldest is overwritten" + }, + ) diff --git a/src/model_trainer.py b/src/model_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8544d948e3f7672966ec9d99f464c8c89bb9aa --- /dev/null +++ b/src/model_trainer.py @@ -0,0 +1,412 @@ +import os +import json +import torch +import numpy as np +import torch.nn as nn +import jieba +import matplotlib.pyplot as plt +import math +from rouge_chinese import Rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +from dataclasses import dataclass +from .config import IGNORE_INDEX +from .loggings import get_logger +from .config_parser import ( + get_train_args, + get_state_dict, + load_trainable_params, +) +from .load import load_model_and_tokenizer +from .config import VALUE_HEAD_FILE_NAME, FINETUNING_ARGS_NAME +from transformers import Seq2SeqTrainer +from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME +from transformers.modeling_utils import ( + PreTrainedModel, + unwrap_model, + load_sharded_checkpoint, +) +from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TRAINER_STATE_NAME +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList + +from peft import PeftModel +from trl import PreTrainedModelWrapper +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Sequence + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState + from transformers.trainer import PredictionOutput + from .model_args import FinetuningArguments + + +logger = get_logger(__name__) + + +class PeftModelMixin: + r""" + Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead. + """ + + def __init__(self) -> None: # for type checking + self.model: PreTrainedModel = None + self.tokenizer: "PreTrainedTokenizer" = None + self.args: "Seq2SeqTrainingArguments" = None + self.finetuning_args: "FinetuningArguments" = None + self.state: "TrainerState" = None + raise AssertionError("Mixin should not be initialized.") + + def _save( + self, + output_dir: Optional[str] = None, + state_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> None: + r""" + Saves trainable parameters as model checkpoint. + + This function will only be executed at the process zero. + + Subclass and override to inject custom behavior. It should not be directly used by external scripts. + """ + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + model = unwrap_model(self.model) + if isinstance(model, PreTrainedModelWrapper): + # Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200 + model_state_dict = state_dict or model.state_dict() + v_head_state_dict = { + name.replace("v_head.", ""): model_state_dict[name] + .cpu() + .clone() + .detach() + for name in model_state_dict.keys() + if name.startswith("v_head.") + } + + torch.save( + v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME) + ) + model = model.pretrained_model + + state_dict = state_dict or get_state_dict(model) + if isinstance(model, (PeftModel, PreTrainedModel)): + model.config.use_cache = True + model.save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=self.args.save_safetensors, + ) + model.config.use_cache = False + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + + if ( + self.finetuning_args.finetuning_type == "full" + and self.tokenizer is not None + ): + try: + self.tokenizer.save_pretrained(output_dir) + except: + logger.warning("Cannot save tokenizer, copy the files manually.") + + with open( + os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8" + ) as f: + f.write(self.args.to_json_string() + "\n") + + self.finetuning_args.save_to_json( + os.path.join(output_dir, FINETUNING_ARGS_NAME) + ) + + def _load_best_model(self): + r""" + Loads trainable parameters from model checkpoint. + + Subclass and override to inject custom behavior. It should not be directly used by external scripts. + """ + logger.info( + f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." + ) + model = unwrap_model(self.model) + + if isinstance(model, PreTrainedModelWrapper): + model.v_head.load_state_dict( + torch.load( + os.path.join( + self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME + ), + map_location="cpu", + ) + ) + model = model.pretrained_model + + if isinstance(model, PeftModel): + model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) + else: # freeze/full-tuning + load_trainable_params(model, self.state.best_model_checkpoint) + + +class PeftTrainer(PeftModelMixin, Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. + """ + + def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): + Seq2SeqTrainer.__init__(self, **kwargs) + self.finetuning_args = finetuning_args + + +class Seq2SeqPeftTrainer(PeftTrainer): + r""" + Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. + """ + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + r""" + Removes the prompt part in the generated tokens. + + Subclass and override to inject custom behavior. + """ + prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) + if prompt_len > label_len: + inputs["labels"] = self._pad_tensors_to_target_len( + inputs["labels"], inputs["input_ids"] + ) + if label_len > prompt_len: + inputs["input_ids"] = self._pad_tensors_to_target_len( + inputs["input_ids"], inputs["labels"] + ) + if "attention_mask" in inputs: + inputs["attention_mask"] = self._pad_tensors_to_target_len( + inputs["attention_mask"], inputs["labels"], pad_token_id=0 + ) + if "position_ids" in inputs: + inputs["position_ids"] = self._pad_tensors_to_target_len( + inputs["position_ids"], inputs["labels"], pad_token_id=0 + ) + + loss, generated_tokens, labels = super().prediction_step( + model, + inputs, + prediction_loss_only=prediction_loss_only, + ignore_keys=ignore_keys, + ) + if generated_tokens is not None: + generated_tokens[ + :, : max(prompt_len, label_len) + ] = self.tokenizer.pad_token_id * torch.ones_like( + generated_tokens[:, : max(prompt_len, label_len)] + ) + + return loss, generated_tokens, labels + + def _pad_tensors_to_target_len( + self, + src_tensor: torch.Tensor, + tgt_tensor: torch.Tensor, + pad_token_id: Optional[int] = None, + ) -> torch.Tensor: + r""" + Pads the tensor to the same length as the target tensor. + + Should only be called when predict_with_generate=True. + """ + if pad_token_id is None: + if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): + assert ( + self.tokenizer.padding_side == "left" + ), "This method only accepts left-padded tensor." + pad_token_id = self.tokenizer.pad_token_id + else: + raise ValueError("PAD token is required.") + + padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) + padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding + return padded_tensor.contiguous() # in contiguous memory + + def save_predictions(self, predict_results: "PredictionOutput") -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join( + self.args.output_dir, "generated_predictions.jsonl" + ) + logger.info(f"Saving prediction results to {output_prediction_file}") + + preds = np.where( + predict_results.predictions != IGNORE_INDEX, + predict_results.predictions, + self.tokenizer.pad_token_id, + ) + labels = np.where( + predict_results.label_ids != IGNORE_INDEX, + predict_results.label_ids, + self.tokenizer.pad_token_id, + ) + + decoded_preds = self.tokenizer.batch_decode( + preds, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + decoded_labels = self.tokenizer.batch_decode( + labels, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for pred, label in zip(decoded_preds, decoded_labels): + res.append( + json.dumps({"label": label, "predict": pred}, ensure_ascii=False) + ) + writer.write("\n".join(res)) + + +@dataclass +class ComputeMetrics: + r""" + Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. + """ + + tokenizer: "PreTrainedTokenizer" + + def __call__( + self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]] + ) -> Dict[str, float]: + r""" + Uses the model predictions to compute metrics. + """ + preds, labels = eval_preds + score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + + preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) + labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) + + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + + if ( + len(" ".join(hypothesis).split()) == 0 + or len(" ".join(reference).split()) == 0 + ): + result = { + "rouge-1": {"f": 0.0}, + "rouge-2": {"f": 0.0}, + "rouge-l": {"f": 0.0}, + } + else: + rouge = Rouge() + scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + + bleu_score = sentence_bleu( + [list(label)], + list(pred), + smoothing_function=SmoothingFunction().method3, + ) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + return {k: float(np.mean(v)) for k, v in score_dict.items()} + + +# Avoid runtime error in model.generate(do_sample=True). +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 0] = 1.0 + return scores + + +def get_logits_processor() -> LogitsProcessorList: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + return logits_processor + + +# metric used +def smooth(scalars: List[float]) -> List[float]: + r""" + EMA implementation according to TensorBoard. + """ + last = scalars[0] + smoothed = list() + weight = 1.8 * ( + 1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5 + ) # a sigmoid function + for next_val in scalars: + smoothed_val = last * weight + (1 - weight) * next_val + smoothed.append(smoothed_val) + last = smoothed_val + return smoothed + + +def plot_loss( + save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"] +) -> None: + with open( + os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8" + ) as f: + data = json.load(f) + + for key in keys: + steps, metrics = [], [] + for i in range(len(data["log_history"])): + if key in data["log_history"][i]: + steps.append(data["log_history"][i]["step"]) + metrics.append(data["log_history"][i][key]) + + if len(metrics) == 0: + logger.warning(f"No metric {key} to plot.") + continue + + plt.figure() + plt.plot(steps, metrics, alpha=0.4, label="original") + plt.plot(steps, smooth(metrics), label="smoothed") + plt.title("training {} of {}".format(key, save_dictionary)) + plt.xlabel("step") + plt.ylabel(key) + plt.legend() + plt.savefig( + os.path.join(save_dictionary, "training_{}.png".format(key)), + format="png", + dpi=100, + ) + print( + "Figure saved:", + os.path.join(save_dictionary, "training_{}.png".format(key)), + ) + + +def export_model( + args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB" +): + model_args, _, training_args, finetuning_args, _ = get_train_args( + args, data_args_init=False + ) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) + try: + tokenizer.save_pretrained(training_args.output_dir) + except: + logger.warning("Cannot save tokenizer, please copy the files manually.") diff --git a/src/output/logs/pred_test_20240717_1311.log b/src/output/logs/pred_test_20240717_1311.log new file mode 100644 index 0000000000000000000000000000000000000000..c25ce3ba0d648dc5885b0dba3e1fb5ad4e340908 --- /dev/null +++ b/src/output/logs/pred_test_20240717_1311.log @@ -0,0 +1,8 @@ + Pred Start time: 2024-07-17 13:11:06 +############pred end############### +pred End time: Wed Jul 17 01:11:06 PM CEST 2024 +Time elapsed: hour 0 min + Pred Start time: 2024-07-17 13:11:48 +############pred end############### +pred End time: Wed Jul 17 01:11:48 PM CEST 2024 +Time elapsed: hour 0 min diff --git a/src/output/logs/pred_test_20240717_1312.log b/src/output/logs/pred_test_20240717_1312.log new file mode 100644 index 0000000000000000000000000000000000000000..0cd485e8ddf5bb3721005a43ce542747130e8596 --- /dev/null +++ b/src/output/logs/pred_test_20240717_1312.log @@ -0,0 +1,4 @@ + Pred Start time: 2024-07-17 13:12:11 +############pred end############### +pred End time: Wed Jul 17 01:12:11 PM CEST 2024 +Time elapsed: hour 0 min diff --git a/src/output/logs/pred_test_20240717_1313.log b/src/output/logs/pred_test_20240717_1313.log new file mode 100644 index 0000000000000000000000000000000000000000..1e24ed1bdc1cbf084c6e0851dd41cb65e5045a80 --- /dev/null +++ b/src/output/logs/pred_test_20240717_1313.log @@ -0,0 +1,5 @@ + Pred Start time: 2024-07-17 13:13:19 +[2024-07-17 13:13:27,533] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) +############pred end############### +pred End time: Wed Jul 17 01:13:35 PM CEST 2024 +Time elapsed: hour 0 min diff --git a/src/output/logs/pred_test_20240717_1315.log b/src/output/logs/pred_test_20240717_1315.log new file mode 100644 index 0000000000000000000000000000000000000000..1be9d95b11a3d4985d2ca5a70a15eac237d4d43f --- /dev/null +++ b/src/output/logs/pred_test_20240717_1315.log @@ -0,0 +1,5 @@ + Pred Start time: 2024-07-17 13:15:49 +[2024-07-17 13:15:56,605] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) +############pred end############### +pred End time: Wed Jul 17 01:16:03 PM CEST 2024 +Time elapsed: hour 0 min diff --git a/src/output/logs/pred_test_20240717_1316.log b/src/output/logs/pred_test_20240717_1316.log new file mode 100644 index 0000000000000000000000000000000000000000..359ca98c8d82c9f3571529bd845e218a69a1a957 --- /dev/null +++ b/src/output/logs/pred_test_20240717_1316.log @@ -0,0 +1,5 @@ + Pred Start time: 2024-07-17 13:16:58 +[2024-07-17 13:17:03,895] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) +############pred end############### +pred End time: Wed Jul 17 01:17:08 PM CEST 2024 +Time elapsed: hour 0 min diff --git a/src/output/logs/pred_test_20240717_1317.log b/src/output/logs/pred_test_20240717_1317.log new file mode 100644 index 0000000000000000000000000000000000000000..674fe338f693d9b7a34577b7a850f8cb0c549ced --- /dev/null +++ b/src/output/logs/pred_test_20240717_1317.log @@ -0,0 +1,4 @@ + Pred Start time: 2024-07-17 13:17:44 +############pred end############### +pred End time: Wed Jul 17 01:17:44 PM CEST 2024 +Time elapsed: hour 0 min diff --git a/src/predict.py b/src/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..2860262cf13a90762dc08d4483d80844d35cab11 --- /dev/null +++ b/src/predict.py @@ -0,0 +1,111 @@ +import os +#from unsloth import FastLanguageModel +import torch +from transformers import TrainingArguments, logging +from datasets import load_dataset +import json +import re +from tqdm import tqdm +from typing import List, Dict, Optional, Any +from .chat_model import ChatModel +from. data_args import SQL_PROMPT_DICT, CR_PROMPT_DICT, DEFAULT_PROMPT_DICT + + +logging.set_verbosity_error() + +#file = "Spider/alpaca_sft_prompts/dev.jsonl" +#dataset = load_dataset("json", data_files = {"dev" : file}, split = "dev") + +#model, tokenizer = FastLanguageModel.from_pretrained( +# model_name = "Wangzaistone123/CodeLlama-13b-sql-lora", # YOUR MODEL YOU USED FOR TRAINING +# max_seq_length = 2048, +# dtype = None, +# load_in_4bit = False, +# ) +#FastLanguageModel.for_inference(model) + +#def generate_text(text): +# inputs = tokenizer(text, return_tensors="pt").to("cuda:0") +# outputs = model.generate(**inputs, max_new_tokens=20) +#from dbgpt_hub.llm_base.config_parser import load_trainable_params + +def prepare_dataset( + predict_file_path: Optional[str] = None, +) -> List[Dict]: + with open(predict_file_path, "r") as fp: + data = json.load(fp) + predict_data = [extract_default_prompt_dataset(item) for item in data] + return predict_data + +def extract_sql_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]: + if example.get("input", "") != "": + prompt_format = SQL_PROMPT_DICT["prompt_input"] + else: + prompt_format = SQL_PROMPT_DICT["prompt_no_input"] + return {"input": prompt_format.format(**example)} + +def extract_cr_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]: + if example.get("input", "") != "": + prompt_format = CR_PROMPT_DICT["prompt_input"] + else: + prompt_format = CR_PROMPT_DICT["prompt_no_input"] + print({"input": prompt_format.format(**example)}) + return {"input": prompt_format.format(**example)} + +def extract_default_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]: + if example.get("input", "") != "": + prompt_format = DEFAULT_PROMPT_DICT["prompt_input"] + else: + prompt_format = DEFAULT_PROMPT_DICT["prompt_no_input"] + print({"input": prompt_format.format(**example)}) + return {"input": prompt_format.format(**example)} + + +def predict(model: ChatModel, **input_kwargs): + args = model.data_args + res = [] + predict_data = prepare_dataset(args.predicted_input_filename) + + # test + # for item in predict_data[:20]: + for item in tqdm(predict_data, desc="Inference Progress", unit="item"): + print(f"item[input] \n{item['input']}") + response, _ = model.chat(query=item["input"], history=[], **input_kwargs) + res.append(response) + + with open(args.predicted_out_filename, "w") as f: + for p in res: + try: + f.write(p.replace("\n", " ") + "\n") + except: + f.write("Invalid Output!\n") + +if __name__ == "__main__": + model = ChatModel() + predict(model) + + +# Extract the part after "### Response:" and remove newlines +# if "### Response:" in decoded_output: +# response = decoded_output.split("### Response:")[1].strip() +# response = re.sub(r'\s+', ' ', response) # Replace multiple spaces/newlines with a single space +# else: +# response = re.sub(r'\s+', ' ', decoded_output) # Replace multiple spaces/newlines with a single space +# +# return response +# +#results = [] + +#for example in dataset: +# generated_output = generate_text(example['instruction']) +# expected_output = example['output'] +# results.append({ +# "generated_output": generated_output, +# "expected_output": expected_output +# }) + +#output_file = 'dbgpt_hub/eval_output/codellama13b_2.json' +#with open(output_file, 'w') as f: +# json.dump(results, f, indent=4) + +#print(f"Output written to {output_file}") diff --git a/src/sft_train.py b/src/sft_train.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe0fb28fda49aa0597e9457f2009cb7a41c2a25 --- /dev/null +++ b/src/sft_train.py @@ -0,0 +1,165 @@ +import os +import sys + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) +from typing import TYPE_CHECKING, Any, Dict, List, Optional +from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments + +from .loggings import LogCallback, get_logger +from .config_parser import get_train_args +from .load import load_model_and_tokenizer +from .data_utils import ( + get_dataset, + preprocess_dataset, + split_dataset, +) +from .config import IGNORE_INDEX +from .model_trainer import ( + Seq2SeqPeftTrainer, + ComputeMetrics, + get_logits_processor, + plot_loss, +) + + +if TYPE_CHECKING: + from transformers import TrainerCallback + from .model_args import ( + ModelArguments, + FinetuningArguments, + GeneratingArguments, + ) + from .data_args import DataArguments + + +logger = get_logger(__name__) + + +def run_sft( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer( + model_args, finetuning_args, training_args.do_train + ) + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, "sft") + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + label_pad_token_id=IGNORE_INDEX + if data_args.ignore_pad_token_for_loss + else tokenizer.pad_token_id, + ) + + # Override the decoding parameters of Seq2SeqTrainer + training_args_dict = training_args.to_dict() + training_args_dict.update( + dict( + generation_max_length=training_args.generation_max_length + or data_args.max_target_length, + generation_num_beams=data_args.eval_num_beams + or training_args.generation_num_beams, + ) + ) + training_args = Seq2SeqTrainingArguments(**training_args_dict) + + # Initialize our Trainer + trainer = Seq2SeqPeftTrainer( + finetuning_args=finetuning_args, + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + compute_metrics=ComputeMetrics(tokenizer) + if training_args.predict_with_generate + else None, + **split_dataset(dataset, data_args, training_args) + ) + + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = list( + set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids) + ) + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + # Training + if training_args.do_train: + train_result = trainer.train( + resume_from_checkpoint=training_args.resume_from_checkpoint + ) + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + if ( + training_args.predict_with_generate + ): # eval_loss will be wrong if predict_with_generate is enabled + metrics.pop("eval_loss", None) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + predict_results = trainer.predict( + dataset, metric_key_prefix="predict", **gen_kwargs + ) + if ( + training_args.predict_with_generate + ): # predict_loss will be wrong if predict_with_generate is enabled + predict_results.metrics.pop("predict_loss", None) + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(predict_results) + + +def train( + args: Optional[Dict[str, Any]] = None, + callbacks: Optional[List["TrainerCallback"]] = None, +): + ( + model_args, + data_args, + training_args, + finetuning_args, + generating_args, + ) = get_train_args(args) + callbacks = [LogCallback()] if callbacks is None else callbacks + + run_sft( + model_args, + data_args, + training_args, + finetuning_args, + generating_args, + callbacks, + ) + + +def export_model( + args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB" +): + model_args, _, training_args, finetuning_args, _ = get_train_args(args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) + try: + tokenizer.save_pretrained(training_args.output_dir) + except: + logger.warning("Cannot save tokenizer, please copy the files manually.") + + +if __name__ == "__main__": + train() diff --git a/src/sql_data_process.py b/src/sql_data_process.py new file mode 100644 index 0000000000000000000000000000000000000000..f45eb19d275ada3e50380fa189a38848ebe310f3 --- /dev/null +++ b/src/sql_data_process.py @@ -0,0 +1,281 @@ +import os +import json +import jsonlines +import sys +import re +import argparse + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +from tqdm import tqdm + +from .config import ( + SQL_DATA_INFO, + DATA_PATH, + INPUT_PROMPT, + INSTRUCTION_PROMPT, + INSTRUCTION_ONE_SHOT_PROMPT, + CODE_REPRESENTATION_PROMPT, + CR_INPUT_PROMPT, + ALPACA_PROMPT, + ALPACA_INPUT_PROMPT +) + + +class ProcessSqlData: + def __init__( + self, train_file=None, dev_file=None, num_shot=0, code_representation=False + ) -> None: + self.train_file = train_file + self.dev_file = dev_file + self.num_shot = num_shot + self.code_representation = code_representation + + def decode_json_file( + self, + data_file_list, + table_file, + db_folder_path, + db_id_name, + output_name, + is_multiple_turn=False, + ): + """ + TO DO: + 1.将相关prompt放入config中 + 2.将不同数据来源的字段信息放入config中 + """ + + if table_file.endswith(".jsonl"): + tables = jsonlines.open(table_file) + datas = [] + for data_file in data_file_list: + datas.extend(jsonlines.open(data_file)) + + elif table_file.endswith(".json"): + tables = json.load(open(table_file)) + datas = [] + for data_file in data_file_list: + datas.extend(json.load(open(data_file))) + else: + print("Unsupported file types") + raise + + # 先将db_id 的table和coloumns处理好 + db_dict = {} + for item in tables: + tables = item["table_names_original"] + coloumns = item["column_names_original"][1:] + primary_key = item["primary_keys"] + foreign_keys = item["foreign_keys"] + #source = ( + # item["db_id"] + " contains tables such as " + ", ".join(tables) + ". " + #) + source = "" + for i, name in enumerate(tables): + data = [coloumn[1] for coloumn in coloumns if coloumn[0] == i] + source += ( + name + "(" + ", ".join(data) + ")\n" + ) + + # get primary key info + for j in range(len(primary_key)): + if type(primary_key[j]) == int: + if coloumns[primary_key[j] - 1][0] == i: + source += ( + coloumns[primary_key[j] - 1][1] + + " is the primary key." + + "\n" + ) + # combination primary key + elif type(primary_key[j]) == list: + combine_p = "The combination of (" + keys = [] + for k in range(len(primary_key[j])): + if coloumns[primary_key[j][k] - 1][0] == i: + keys.append(coloumns[primary_key[j][k] - 1][1]) + source += ( + combine_p + + ", ".join(keys) + + ") are the primary key." + + "\n" + ) + else: + print("not support type", type(primary_key[j])) + continue + + # get foreign key info + for key in foreign_keys: + source += ( + "The " + + coloumns[key[0] - 1][1] + + " of " + + tables[coloumns[key[0] - 1][0]] + + " is the foreign key of " + + coloumns[key[1] - 1][1] + + " of " + + tables[coloumns[key[1] - 1][0]] + + ".\n" + ) + + db_dict[item["db_id"]] = source + + res = [] + base_instruction = ALPACA_PROMPT + if self.num_shot == 1: + base_instruction = INSTRUCTION_ONE_SHOT_PROMPT + + count = 0 + for data in tqdm(datas): + if data[db_id_name] in db_dict.keys(): + if is_multiple_turn: # 多轮 + history = [] + for interaction in data["interaction"]: + input = { + "db_id": data[db_id_name], + "instruction": base_instruction.format( + db_dict[data[db_id_name]] + ), + "input": INPUT_PROMPT.format(interaction["utterance"]), + "output": interaction[output_name], + "history": history, + } + res.append(input) + history.append( + ( + INPUT_PROMPT.format(interaction["utterance"]), + interaction[output_name], + ) + ) + else: # 单轮 + if self.code_representation: + db_path = os.path.join(db_folder_path, data[db_id_name]) + sql_file_path = next( + ( + file + for file in os.listdir(db_path) + if file.endswith(".sql") + ), + None, + ) + if sql_file_path is None: + print(f"Skipping {data[db_id_name]} due to missing .sql file") + continue # 提前结束迭代 + schema_file_path = os.path.join(db_path, sql_file_path) + with open(schema_file_path, "r", errors="ignore") as file: + schema_content = file.read() + create_statements = re.findall( + r"CREATE\s.*?;", schema_content, re.DOTALL|re.IGNORECASE + ) + input = { + "db_id": data[db_id_name], + "instruction": CODE_REPRESENTATION_PROMPT.format(create_statements), + "input": CR_INPUT_PROMPT.format(data["question"]), + "output": data[output_name], + "history": [], + } + res.append(input) + count += 1 + print(f"Generated {count} inputs") + else: + input = { + "db_id": data[db_id_name], + "instruction": base_instruction.format( + data["question"] + ), + "input": ALPACA_INPUT_PROMPT.format(db_dict[data[db_id_name]]), + "output": data[output_name], + "history": [], + } + print(db_dict[data[db_id_name]]) + res.append(input) + #count += 1 + #print(f"Generated {count} inputs") + return res + + def create_sft_raw_data(self): + train_data = [] + dev_data = [] + for data_info in SQL_DATA_INFO: + train_data_file_list = [ + os.path.join(DATA_PATH, data_info["data_source"], file) + for file in data_info["train_file"] + ] + train_data.extend( + self.decode_json_file( + data_file_list=train_data_file_list, + table_file=os.path.join( + DATA_PATH, + data_info["data_source"], + data_info["train_tables_file"], + ), + db_folder_path=os.path.join( + DATA_PATH, + data_info["data_source"], + "database", + ), + db_id_name=data_info["db_id_name"], + output_name=data_info["output_name"], + is_multiple_turn=data_info["is_multiple_turn"], + ) + ) + + dev_data_file_list = [ + os.path.join(DATA_PATH, data_info["data_source"], file) + for file in data_info["dev_file"] + ] + dev_data.extend( + self.decode_json_file( + data_file_list=dev_data_file_list, + table_file=os.path.join( + DATA_PATH, + data_info["data_source"], + data_info["dev_tables_file"], + ), + db_folder_path=os.path.join( + DATA_PATH, + data_info["data_source"], + "database", + ), + db_id_name=data_info["db_id_name"], + output_name=data_info["output_name"], + is_multiple_turn=data_info["is_multiple_turn"], + ) + ) + with open(self.train_file, "w", encoding="utf-8") as s: + json.dump(train_data, s, indent=4, ensure_ascii=False) + with open(self.dev_file, "w", encoding="utf-8") as s: + json.dump(dev_data, s, indent=4, ensure_ascii=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--code_representation", help="Enable code representation", default=False + ) + args = parser.parse_args() + + all_in_one_train_file = os.path.join(DATA_PATH, "example_text2sql_train_alpaca_noselect.json") + all_in_one_dev_file = os.path.join(DATA_PATH, "example_text2sql_dev_alpaca_noselect.json") + precess = ProcessSqlData( + train_file=all_in_one_train_file, + dev_file=all_in_one_dev_file, + code_representation=args.code_representation, + ) + precess.create_sft_raw_data() + + # one-shot + one_shot_all_in_one_train_file = os.path.join( + DATA_PATH, "example_text2sql_train_one_shot.json" + ) + one_shot_all_in_one_dev_file = os.path.join( + DATA_PATH, "example_text2sql_dev_one_shot.json" + ) + #one_shot_precess = ProcessSqlData( + # train_file=one_shot_all_in_one_train_file, + # dev_file=one_shot_all_in_one_dev_file, + # num_shot=1, + # code_representation=args.code_representation, + #) + #one_shot_precess.create_sft_raw_data() diff --git a/src/tuner.py b/src/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..634a5f8044bf3dc3d02353651f3fec6c23baadd5 --- /dev/null +++ b/src/tuner.py @@ -0,0 +1,72 @@ + +import os +from unsloth import FastLanguageModel +import torch +from trl import SFTTrainer +from transformers import TrainingArguments +from datasets import load_dataset + +max_seq_length = 2048 +file = "Spider/alpaca_sft_prompts/train.jsonl" +dataset = load_dataset("json", data_files = {"train" : file}, split = "train") +print(f"Number of examples in the dataset: {len(dataset)}") + +# Load model +model, tokenizer = FastLanguageModel.from_pretrained( + model_name = "unsloth/codellama-7b", + max_seq_length = max_seq_length, + dtype = None, + load_in_4bit = True, +) + +# Do model patching and add fast LoRA weights and training +model = FastLanguageModel.get_peft_model( + model, + r = 64, + target_modules = ["q_proj", "v_proj"], + lora_alpha = 32, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized + use_gradient_checkpointing = True, + random_state = 3407, + max_seq_length = max_seq_length, + use_rslora = False, # Rank stabilized LoRA + loftq_config = None, # LoftQ +) + +def formatting_func(example): + text = f"{example['instruction']}\n{example['output']}" + return text + +trainer = SFTTrainer( + model = model, + train_dataset = dataset, + formatting_func = formatting_func, + max_seq_length = max_seq_length, + packing=True, + tokenizer = tokenizer, + args = TrainingArguments( + per_device_train_batch_size = 1, + gradient_accumulation_steps = 16, + warmup_steps = 30, + warmup_ratio = 0.03, + num_train_epochs = 8, + fp16 = not torch.cuda.is_bf16_supported(), + bf16 = torch.cuda.is_bf16_supported(), + logging_steps = 50, + save_steps = 2000, + output_dir = "overfitting/codellama7b_blog", + optim = "adamw_8bit", + weight_decay = 1, + lr_scheduler_type = "cosine_with_restarts", + learning_rate = 2e-04, + seed = 3407, + ), +) +trainer.train() + +# Save the model +model.save_pretrained("lora_model_codellama7b_blog") +model.save_pretrained_merged("overfitting/codellama7b_blog", tokenizer, save_method = "merged_16bit",) +#model.push_to_hub_merged("oleherbst/llama3-8b-oig-unsloth-merged", tokenizer, save_method = "merged_16bit", token = os.environ.get("HF_TOKEN")) +#model.push_to_hub("oleherbst/llama3-8b-oig-unsloth", tokenizer, save_method = "lora", token = os.environ.get("HF_TOKEN"))