diff --git a/src/__pycache__/chat_model.cpython-311.pyc b/src/__pycache__/chat_model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..764d7a49cfb63dc62d3cf97cf1bf069e9d4db011
Binary files /dev/null and b/src/__pycache__/chat_model.cpython-311.pyc differ
diff --git a/src/__pycache__/chat_model.cpython-38.pyc b/src/__pycache__/chat_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5aab31cc3c0b2c6ce3cd7ff98c556c2e973b0fc
Binary files /dev/null and b/src/__pycache__/chat_model.cpython-38.pyc differ
diff --git a/src/__pycache__/config.cpython-311.pyc b/src/__pycache__/config.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d125297421cb0f7952328d784e7d0400b3be776d
Binary files /dev/null and b/src/__pycache__/config.cpython-311.pyc differ
diff --git a/src/__pycache__/config.cpython-38.pyc b/src/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4df068d804bbb079a78bb531b4dd463bd61b8f32
Binary files /dev/null and b/src/__pycache__/config.cpython-38.pyc differ
diff --git a/src/__pycache__/config_parser.cpython-311.pyc b/src/__pycache__/config_parser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7455b902a8791a4ab81308af64bbe9c37bff1b29
Binary files /dev/null and b/src/__pycache__/config_parser.cpython-311.pyc differ
diff --git a/src/__pycache__/config_parser.cpython-38.pyc b/src/__pycache__/config_parser.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4f1d617238228fdb418833c389f34e2e519633d
Binary files /dev/null and b/src/__pycache__/config_parser.cpython-38.pyc differ
diff --git a/src/__pycache__/data_args.cpython-311.pyc b/src/__pycache__/data_args.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..994546d53bdbc0e1bf6f8c3e1428decf9144a4bd
Binary files /dev/null and b/src/__pycache__/data_args.cpython-311.pyc differ
diff --git a/src/__pycache__/data_args.cpython-38.pyc b/src/__pycache__/data_args.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ca8d054fe8fff5956093c94394b8b82b39177e9
Binary files /dev/null and b/src/__pycache__/data_args.cpython-38.pyc differ
diff --git a/src/__pycache__/data_utils.cpython-311.pyc b/src/__pycache__/data_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f3d08c3e07cb8bcfe4c97c690d1a1daedfd2c22
Binary files /dev/null and b/src/__pycache__/data_utils.cpython-311.pyc differ
diff --git a/src/__pycache__/export.cpython-311.pyc b/src/__pycache__/export.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e2946a1d7e1d6f4f9be7eb60215e4d46e3313fe
Binary files /dev/null and b/src/__pycache__/export.cpython-311.pyc differ
diff --git a/src/__pycache__/load.cpython-311.pyc b/src/__pycache__/load.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32bfaacf3260b8ed660090214b9010898d43f80c
Binary files /dev/null and b/src/__pycache__/load.cpython-311.pyc differ
diff --git a/src/__pycache__/load.cpython-38.pyc b/src/__pycache__/load.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9b67eccfcb396562c79f2b9476105f02d7a091e
Binary files /dev/null and b/src/__pycache__/load.cpython-38.pyc differ
diff --git a/src/__pycache__/loggings.cpython-311.pyc b/src/__pycache__/loggings.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f61b3de8b7140675c666684f1d11c1f51707c8a0
Binary files /dev/null and b/src/__pycache__/loggings.cpython-311.pyc differ
diff --git a/src/__pycache__/loggings.cpython-38.pyc b/src/__pycache__/loggings.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca6ea0e3c239ce8f2c79b747d3fc7da6223b3db7
Binary files /dev/null and b/src/__pycache__/loggings.cpython-38.pyc differ
diff --git a/src/__pycache__/model_args.cpython-311.pyc b/src/__pycache__/model_args.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a26a1b60ec68cb256dcca764f4d44fa7fba4ddc2
Binary files /dev/null and b/src/__pycache__/model_args.cpython-311.pyc differ
diff --git a/src/__pycache__/model_args.cpython-38.pyc b/src/__pycache__/model_args.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d90def21ee0898713d76861ba13ea6a3499ebc9
Binary files /dev/null and b/src/__pycache__/model_args.cpython-38.pyc differ
diff --git a/src/__pycache__/model_trainer.cpython-311.pyc b/src/__pycache__/model_trainer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4bc316ef4f15d71afb99922ab1346a943bc9928
Binary files /dev/null and b/src/__pycache__/model_trainer.cpython-311.pyc differ
diff --git a/src/__pycache__/model_trainer.cpython-38.pyc b/src/__pycache__/model_trainer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6788ca6d0fd24e16dbd584bfffb2a8279a68a02e
Binary files /dev/null and b/src/__pycache__/model_trainer.cpython-38.pyc differ
diff --git a/src/__pycache__/predict.cpython-311.pyc b/src/__pycache__/predict.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de7d66b72097f53bdcd681e05132b7ffc177edee
Binary files /dev/null and b/src/__pycache__/predict.cpython-311.pyc differ
diff --git a/src/__pycache__/predict.cpython-38.pyc b/src/__pycache__/predict.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bacdadaf6da97c66156ed0b3d953b438b8433bff
Binary files /dev/null and b/src/__pycache__/predict.cpython-38.pyc differ
diff --git a/src/__pycache__/sft_train.cpython-311.pyc b/src/__pycache__/sft_train.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..375ae6837294188886ea1e902fcd01ebaade88e5
Binary files /dev/null and b/src/__pycache__/sft_train.cpython-311.pyc differ
diff --git a/src/__pycache__/sql_data_process.cpython-311.pyc b/src/__pycache__/sql_data_process.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..875b2274cfdbeb5c10dce5325b35261def2f571c
Binary files /dev/null and b/src/__pycache__/sql_data_process.cpython-311.pyc differ
diff --git a/src/chat_model.py b/src/chat_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b41c089995303f3e046dd931a5127e2bed6efd
--- /dev/null
+++ b/src/chat_model.py
@@ -0,0 +1,561 @@
+import torch
+import json
+from typing import Any, Union, Dict, Generator, List, Optional, Tuple
+from threading import Thread
+from transformers import GenerationConfig, TextIteratorStreamer
+
+from .config_parser import get_infer_args
+from .load import dispatch_model, load_model_and_tokenizer
+from .model_trainer import get_logits_processor
+from .data_args import (
+    DEFAULT_PROMPT_DICT,
+    ALPACA_PROMPT_DICT,
+    SQL_PROMPT_DICT,
+    Template,
+    Llama2Template,
+)
+from .loggings import get_logger
+
+
+logger = get_logger(__name__)
+
+templates: Dict[str, Template] = {}
+
+def get_template_and_fix_tokenizer(
+    name: str, tokenizer: "PreTrainedTokenizer"
+) -> Template:
+    template = templates.get(name, None)
+    assert template is not None, "Template {} does not exist.".format(name)
+
+    additional_special_tokens = template.stop_words
+
+    if tokenizer.eos_token_id is None:
+        tokenizer.eos_token = "<|endoftext|>"
+        logger.info("Add eos token: {}".format(tokenizer.eos_token))
+
+    if tokenizer.pad_token_id is None:
+        if tokenizer.unk_token_id is not None:
+            tokenizer.pad_token = tokenizer.unk_token
+        else:
+            tokenizer.pad_token = tokenizer.eos_token
+        logger.info("Add pad token: {}".format(tokenizer.pad_token))
+
+    if name is None:
+        return None
+
+    tokenizer.add_special_tokens(
+        dict(additional_special_tokens=additional_special_tokens),
+        replace_additional_special_tokens=False,
+    )
+    return template
+
+
+class ChatModel:
+    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
+        (
+            model_args,
+            self.data_args,
+            finetuning_args,
+            self.generating_args,
+        ) = get_infer_args(args)
+        self.model, self.tokenizer = load_model_and_tokenizer(
+            model_args, finetuning_args
+        )
+        self.tokenizer.padding_side = "left"
+        self.model = dispatch_model(self.model)
+        self.template = get_template_and_fix_tokenizer(
+            self.data_args.template, self.tokenizer
+        )
+        self.system_prompt = self.data_args.system_prompt
+
+    def process_args(
+        self,
+        query: str,
+        history: Optional[List[Tuple[str, str]]] = None,
+        system: Optional[str] = None,
+        **input_kwargs
+    ) -> Tuple[Dict[str, Any], int]:
+        system = system or self.system_prompt
+
+        prompt, _ = self.template.encode_oneturn(
+            tokenizer=self.tokenizer,
+            query=query,
+            resp="",
+            history=history,
+            system=system,
+        )
+        input_ids = torch.tensor([prompt], device=self.model.device)
+        prompt_length = len(input_ids[0])
+
+        do_sample = input_kwargs.pop("do_sample", None)
+        temperature = input_kwargs.pop("temperature", None)
+        top_p = input_kwargs.pop("top_p", None)
+        top_k = input_kwargs.pop("top_k", None)
+        repetition_penalty = input_kwargs.pop("repetition_penalty", None)
+        max_length = input_kwargs.pop("max_length", None)
+        max_new_tokens = input_kwargs.pop("max_new_tokens", None)
+
+        generating_args = self.generating_args.to_dict()
+        generating_args.update(
+            dict(
+                do_sample=do_sample
+                if do_sample is not None
+                else generating_args["do_sample"],
+                temperature=temperature or generating_args["temperature"],
+                top_p=top_p or generating_args["top_p"],
+                top_k=top_k or generating_args["top_k"],
+                repetition_penalty=repetition_penalty
+                or generating_args["repetition_penalty"],
+                eos_token_id=[self.tokenizer.eos_token_id]
+                + self.tokenizer.additional_special_tokens_ids,
+                pad_token_id=self.tokenizer.pad_token_id,
+            )
+        )
+
+        if max_length:
+            generating_args.pop("max_new_tokens", None)
+            generating_args["max_length"] = max_length
+
+        if max_new_tokens:
+            generating_args.pop("max_length", None)
+            generating_args["max_new_tokens"] = max_new_tokens
+
+        gen_kwargs = dict(
+            inputs=input_ids,
+            generation_config=GenerationConfig(**generating_args),
+            logits_processor=get_logits_processor(),
+        )
+
+        return gen_kwargs, prompt_length
+
+    @torch.inference_mode()
+    def chat(
+        self,
+        query: str,
+        history: Optional[List[Tuple[str, str]]] = None,
+        system: Optional[str] = None,
+        **input_kwargs
+    ) -> Tuple[str, Tuple[int, int]]:
+        gen_kwargs, prompt_length = self.process_args(
+            query, history, system, **input_kwargs
+        )
+        generation_output = self.model.generate(**gen_kwargs)
+        outputs = generation_output.tolist()[0][prompt_length:]
+        response = self.tokenizer.decode(outputs, skip_special_tokens=True)
+        response_length = len(outputs)
+        return response, (prompt_length, response_length)
+
+    @torch.inference_mode()
+    def stream_chat(
+        self,
+        query: str,
+        history: Optional[List[Tuple[str, str]]] = None,
+        system: Optional[str] = None,
+        **input_kwargs
+    ) -> Generator[str, None, None]:
+        gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
+        streamer = TextIteratorStreamer(
+            self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
+        )
+        gen_kwargs["streamer"] = streamer
+
+        thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
+        thread.start()
+
+        yield from streamer
+
+
+def register_template(
+    name: str,
+    prefix: List[Union[str, Dict[str, str]]],
+    prompt: List[Union[str, Dict[str, str]]],
+    system: str,
+    sep: List[Union[str, Dict[str, str]]],
+    stop_words: Optional[List[str]] = [],
+    use_history: Optional[bool] = True,
+) -> None:
+    template_class = Llama2Template if "llama2" in name else Template
+    templates[name] = template_class(
+        prefix=prefix,
+        prompt=prompt,
+        system=system,
+        sep=sep,
+        stop_words=stop_words,
+        use_history=use_history,
+    )
+
+
+r"""
+Supports language model inference without histories.
+"""
+register_template(
+    name="vanilla",
+    prefix=[],
+    prompt=["{{query}}"],
+    system="",
+    sep=[],
+    use_history=False,
+)
+
+r"""
+Supports language model for  mistral sqlcoder-7b
+"""
+register_template(
+    name="mistral",
+    prefix=["{{system}}"],
+    prompt=["[INST] {{query}} [/INST]"],
+    system="",
+    sep=[],
+)
+
+
+r"""
+Default template.
+"""
+register_template(
+    name="default",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}\nAssistant: "],
+    system=(
+        "A chat between a curious user and an artificial intelligence assistant. "
+        "The assistant gives helpful, detailed, and polite answers to the user's questions."
+    ),
+    sep=["\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
+          https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
+          https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
+"""
+register_template(
+    name="llama2",
+    prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"],
+    prompt=["[INST] {{query}} [/INST] "],
+    system=(
+        "You are a helpful, respectful and honest assistant. "
+        "Always answer as helpfully as possible, while being safe.  "
+        "Your answers should not include any harmful, unethical, "
+        "racist, sexist, toxic, dangerous, or illegal content. "
+        "Please ensure that your responses are socially unbiased and positive in nature.\n"
+        "If a question does not make any sense, or is not factually coherent, "
+        "explain why instead of answering something not correct. "
+        "If you don't know the answer to a question, please don't share false information."
+    ),
+    sep=[],
+)
+
+register_template(
+    name="llama3",
+    prefix=["<|start_header_id|>system<|end_header_id|>\n\n{{system}}<|eot_id|>\n"],
+    prompt=["<|start_header_id|>user<|end_header_id|>\n\n{{query}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"],
+    system=(
+        "You are a helpful, respectful and honest assistant. "
+        "Always answer as helpfully as possible, while being safe.  "
+        "Your answers should not include any harmful, unethical, "
+        "racist, sexist, toxic, dangerous, or illegal content. "
+        "Please ensure that your responses are socially unbiased and positive in nature.\n"
+        "If a question does not make any sense, or is not factually coherent, "
+        "explain why instead of answering something not correct. "
+        "If you don't know the answer to a question, please don't share false information."
+    ),
+    sep=[],
+)
+
+r"""
+Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
+          https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
+"""
+register_template(
+    name="llama2_zh",
+    prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"],
+    prompt=["[INST] {{query}} [/INST] "],
+    system="You are a helpful assistant. 你是一个乐于助人的助手。",
+    sep=[],
+)
+
+
+r"""
+Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
+          https://github.com/ymcui/Chinese-LLaMA-Alpaca
+"""
+register_template(
+    name="alpaca",
+    prefix=["{{system}}"],
+    prompt=["### Instruction:\n{{query}}\n\n### Response:\n"],
+    system=(
+        "Below is an instruction that describes a task. "
+        "Write a response that appropriately completes the request."
+    ),
+    sep=["\n\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
+          https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
+"""
+register_template(
+    name="vicuna",
+    prefix=["{{system}}"],
+    prompt=["USER: {{query}} ASSISTANT: "],
+    system=(
+        "A chat between a curious user and an artificial intelligence assistant. "
+        "The assistant gives helpful, detailed, and polite answers to the user's questions."
+    ),
+    sep=[],
+)
+
+
+r"""
+Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
+"""
+register_template(
+    name="belle",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}\n\nBelle: "],
+    system="",
+    sep=["\n\n"],
+)
+
+
+r"""
+Supports: https://github.com/CVI-SZU/Linly
+"""
+register_template(
+    name="linly",
+    prefix=["{{system}}"],
+    prompt=["User: {{query}}\nBot: "],
+    system="",
+    sep=["\n"],
+)
+
+
+r"""
+Supports: https://github.com/Neutralzz/BiLLa
+"""
+register_template(
+    name="billa",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}\nAssistant: "],
+    system="",
+    sep=["\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
+"""
+register_template(
+    name="ziya",
+    prefix=["{{system}}"],
+    prompt=[{"token": "<human>"}, ":{{query}}\n", {"token": "<bot>"}, ":"],
+    system="",
+    sep=["\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/qhduan/aquilachat-7b
+"""
+register_template(
+    name="aquila",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}###Assistant: "],
+    system=(
+        "A chat between a curious human and an artificial intelligence assistant. "
+        "The assistant gives helpful, detailed, and polite answers to the human's questions."
+    ),
+    sep=["###"],
+)
+
+
+r"""
+Supports: https://huggingface.co/internlm/internlm-chat-7b
+"""
+register_template(
+    name="intern",
+    prefix=["{{system}}"],
+    prompt=["<|User|>:{{query}}", {"token": "<eoh>"}, "\n<|Bot|>:"],
+    system="",
+    sep=["\n"],
+    stop_words=["</s>", "<eoa>"],  # internlm cannot replace eos token
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
+Used for training and inference of the fine-tuned models.
+"""
+register_template(
+    name="baichuan",
+    prefix=["{{system}}"],
+    prompt=[
+        {"token": "<reserved_102>"},  # user token
+        "{{query}}",
+        {"token": "<reserved_103>"},  # assistant token
+    ],
+    system="",
+    sep=[],
+    stop_words=[],
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
+Used for inference of the original model.
+"""
+register_template(
+    name="baichuan_eval",
+    prefix=["{{system}}", {"token": "<reserved_102>"}],  # user token
+    prompt=["{{query}}", {"token": "<reserved_103>"}],  # assistant token
+    system="",
+    sep=[],
+    stop_words=["<reserved_102>"],  # user token
+)
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
+          https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
+Used for training and inference of the fine-tuned models.
+"""
+register_template(
+    name="baichuan2",
+    prefix=["{{system}}"],
+    prompt=[
+        {"token": "<reserved_106>"},  # user token
+        "{{query}}",
+        {"token": "<reserved_107>"},  # assistant token
+    ],
+    system="",
+    sep=[],
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
+          https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
+Used for inference of the original model.
+"""
+register_template(
+    name="baichuan2_eval",
+    prefix=["{{system}}", {"token": "<reserved_106>"}],  # user token
+    prompt=["{{query}}", {"token": "<reserved_107>"}],  # assistant token
+    system="",
+    sep=[],
+    stop_words=["<reserved_106>"],  # user token
+)
+
+
+r"""
+Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
+          https://huggingface.co/HuggingFaceH4/starchat-beta
+
+"""
+register_template(
+    name="starchat",
+    prefix=[{"token": "<|system|>"}, "\n{{system}}", {"token": "<|end|>"}],
+    prompt=[
+        {"token": "<|user|>"},
+        "\n{{query}}",
+        {"token": "<|end|>"},
+        "\n",
+        {"token": "<|assistant|>"},
+    ],
+    system="",
+    sep=["\n"],
+    stop_words=["<|end|>"],
+)
+
+
+r"""
+Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
+"""
+register_template(
+    name="chatml",
+    prefix=[{"token": "<|im_start|>"}, "system\n{{system}}", {"token": "<|im_end|>"}],
+    prompt=[
+        {"token": "<|im_start|>"},
+        "user\n{{query}}",
+        {"token": "<|im_end|>"},
+        "\n",
+        {"token": "<|im_start|>"},
+        "assistant\n",
+    ],
+    system="You are a helpful assistant.",
+    sep=["\n"],
+    stop_words=["<|im_end|>"],
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm2-6b
+"""
+register_template(
+    name="chatglm2",
+    prefix=[{"token": "[gMASK]"}, {"token": "sop"}, "{{system}}"],
+    prompt=["[Round {{idx}}]\n\n问:{{query}}\n\n答:"],
+    system="",
+    sep=["\n\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm3-6b
+"""
+register_template(
+    name="chatglm3",
+    prefix=[
+        {"token": "[gMASK]"},
+        {"token": "sop"},
+        {"token": "<|system|>"},
+        "\n",
+        "{{system}}",
+    ],
+    prompt=[
+        {"token": "<|user|>"},
+        "\n",
+        "{{query}}",
+        {"token": "<|assistant|>"},
+        "\n",  # add an extra newline to avoid error in ChatGLM's process_response method
+    ],
+    system=(
+        "You are ChatGLM3, a large language model trained by Zhipu.AI. "
+        "Follow the user's instructions carefully. Respond using markdown."
+    ),
+    sep=[],
+    stop_words=["<|user|>", "<|observation|>"],
+)
+
+register_template(
+    name="chatglm3_raw",  # the raw template for tool tuning
+    prefix=[
+        {"token": "[gMASK]"},
+        {"token": "sop"},
+        {"token": "<|system|>"},
+        "\n",
+        "{{system}}",
+    ],
+    prompt=[{"token": "<|user|>"}, "\n", "{{query}}", {"token": "<|assistant|>"}],
+    system=(
+        "You are ChatGLM3, a large language model trained by Zhipu.AI. "
+        "Follow the user's instructions carefully. Respond using markdown."
+    ),
+    sep=[],
+    stop_words=["<|user|>", "<|observation|>"],
+)
+
+
+r"""
+Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
+"""
+register_template(
+    name="xverse",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}\n\nAssistant: "],
+    system="",
+    sep=[],
+)
+
+
diff --git a/src/config.py b/src/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..64df5416ced8a9c8174a5560b2dd3167e43f5c53
--- /dev/null
+++ b/src/config.py
@@ -0,0 +1,225 @@
+
+import os
+
+### path config
+ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+# ROOT_PATH = "/root/autodl-tmp"
+# MODELS_PARENT_PATH = "/home/model_files/codellama/"
+# DEFAULT_FT_MODEL_NAME = "CodeLlama-7b-Instruct-hf"
+MODELS_PARENT_PATH = "/home/ambcj/BA/text2sql-sft/models"
+DEFAULT_FT_MODEL_NAME = "Baichuan2-13B-Chat"
+MODEL_PATH = os.path.join(MODELS_PARENT_PATH, DEFAULT_FT_MODEL_NAME)
+
+# MODEL_PATH = os.path.join(ROOT_PATH, "model")
+ADAPTER_PATH = os.path.join(ROOT_PATH, "text2sql-sft/adapter")
+MERGED_MODELS = os.path.join(ROOT_PATH, "text2sql-sft/merged_models")
+
+# DATA_PATH = "/root/autodl-tmp/data/spider/pre_processed_data"
+# OUT_DIR= "/root/autodl-tmp/codellama"
+
+DATA_PATH = os.path.join(ROOT_PATH, "text2sql-sft/data")
+PREDICTED_DATA_PATH = os.path.join(ROOT_PATH, "text2sql-sft/data/eval_data/dev_sql.json")
+PREDICTED_OUT_FILENAME = "pred_sql.sql"
+# OUT_DIR = os.path.join(DATA_PATH, "out_pred")
+OUT_DIR = os.path.join(ROOT_PATH, "text2sql-sft/output/")
+
+## model constants
+IGNORE_INDEX = -100
+DEFAULT_PAD_TOKEN = "[PAD]"
+DEFAULT_EOS_TOKEN = "</s>"
+DEFAULT_BOS_TOKEN = "<s>"
+DEFAULT_UNK_TOKEN = "<unk>"
+
+
+LOG_FILE_NAME = "trainer_log.jsonl"
+
+# head_state_dict,model save name
+VALUE_HEAD_FILE_NAME = "value_head.bin"
+
+# output ,finetuning_args save_to_json name
+FINETUNING_ARGS_NAME = "finetuning_args.json"
+
+#  when prepare_model_for_training ,layer_norm_names
+LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
+EXT2TYPE = {"csv": "csv", "json": "json", "jsonl": "json", "txt": "text"}
+
+# text2sql dataset information for processing sql data
+# TODO: BIRD \ WiKiSQL \ ...
+SQL_DATA_INFO = [
+    {
+        "data_source": "spider",
+        "train_file": ["train_spider.json", "train_others.json"],
+        "dev_file": ["dev.json"],
+        "train_tables_file": "tables.json",
+        "dev_tables_file": "tables.json",
+        "db_id_name": "db_id",
+        "output_name": "query",
+        "is_multiple_turn": False,
+    }
+    # {
+    #     "data_source": "bird",
+    #     "train_file": ["train/train.json"],
+    #     "dev_file": ["dev/dev.json"],
+    #     "train_tables_file": "train/train_tables.json",
+    #     "dev_tables_file": "dev/dev_tables.json",
+    #     "db_id_name": "db_id",
+    #     "output_name": "SQL",
+    #     "is_multiple_turn": False,
+    # }
+    # ,
+    # {
+    #     "data_source": "chase",
+    #     "train_file": ["Chase/chase_train.json"],
+    #     "dev_file": ["Chase/chase_dev.json"],
+    #     "tables_file": "Chase/chase_tables.json",
+    #     "db_id_name": "database_id",
+    #     "is_multiple_turn": True,
+    # }
+    # ,
+    # {
+    #     "data_source": "cosql_dataset",
+    #     "train_file": ["sql_state_tracking/cosql_train.json"],
+    #     "dev_file": ["sql_state_tracking/cosql_dev.json"],
+    #     "tables_file": "tables.json",
+    #     "db_id_name": "database_id",
+    #     "is_multiple_turn": True,
+    # }
+    # ,
+    # {
+    # {
+    #     "data_source": "sparc",
+    #     "train_file": ["train.json"],
+    #     "train_tables_file": "tables.json",
+    #     "dev_tables_file": "tables.json",
+    #     "dev_file": ["dev.json"],
+    #     "db_id_name": "database_id",
+    #     "is_multiple_turn": True,
+    #     "output_name": "query",
+    # }
+]
+CODE_REPRESENTATION_PROMPT = """\
+/* Given the following database schema: */\n{}\n\n"""
+CR_INPUT_PROMPT = """\
+/* Answer the following: {}\n*/
+"""
+INSTRUCTION_PROMPT = """\
+I want you to act as a SQL terminal in front of an example database, \
+you need only to return the sql command to me.Below is an instruction that describes a task, \
+Write a response that appropriately completes the request.\n"
+##Instruction:\n{}\n"""
+INPUT_PROMPT = "###Input:\n{}\n\n###Response:"
+
+ALPACA_PROMPT = """\
+Below is an instruction that describes a task , paired with an input that provides further context . Write a response that appropriately completes the request.\n \
+\n###Instruction:\nWrite a sql to answer the question "{}"\n\n"""
+ALPACA_INPUT_PROMPT = "###Input:\n{}\n\n###Response:\n"
+
+
+INSTRUCTION_ONE_SHOT_PROMPT = """\
+I want you to act as a SQL terminal in front of an example database. \
+You need only to return the sql command to me. \
+First, I will show you few examples of an instruction followed by the correct SQL response. \
+Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\
+\n### Example1 Instruction:
+The database contains tables such as employee, salary, and position. \
+Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \
+Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \
+Table position has columns such as position_id, title, and department. position_id is the primary key. \
+The employee_id of salary is the foreign key of employee_id of employee. \
+The position_id of employee is the foreign key of position_id of position.\
+\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
+\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
+\n###New Instruction:\n{}\n"""
+
+# EXAMPLES =[EXAMPLE1, EXAMPLE1]
+
+# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
+# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
+# \n###New Instruction:\n{}\n"
+
+### test--------------------
+
+
+# METHODS = ["full", "freeze", "lora"]
+
+# STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"]
+
+# DATASET_STAGE_MAP = {
+#     "SFT": "sft",
+#     "Pre-Training": "pt",
+#     "Reward Modeling": "rm",
+#     "PPO": "sft",
+#     "DPO": "rm",
+# }
+
+# SUPPORTED_MODELS = {
+#     "LLaMA-7B": "huggyllama/llama-7b",
+#     "LLaMA-13B": "huggyllama/llama-13b",
+#     "LLaMA-30B": "huggyllama/llama-30b",
+#     "LLaMA-65B": "huggyllama/llama-65b",
+#     "LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
+#     "LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
+#     "LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
+#     "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
+#     "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
+#     "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
+#     "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
+#     "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
+#     "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
+#     "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
+#     "BLOOM-560M": "bigscience/bloom-560m",
+#     "BLOOM-3B": "bigscience/bloom-3b",
+#     "BLOOM-7B1": "bigscience/bloom-7b1",
+#     "BLOOMZ-560M": "bigscience/bloomz-560m",
+#     "BLOOMZ-3B": "bigscience/bloomz-3b",
+#     "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
+#     "Falcon-7B": "tiiuae/falcon-7b",
+#     "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
+#     "Falcon-40B": "tiiuae/falcon-40b",
+#     "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
+#     "Baichuan-7B": "baichuan-inc/Baichuan-7B",
+#     "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
+#     "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
+#     "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
+#     "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
+#     "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
+#     "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
+#     "InternLM-7B": "internlm/internlm-7b",
+#     "InternLM-7B-Chat": "internlm/internlm-chat-7b",
+#     "Qwen-7B": "Qwen/Qwen-7B",
+#     "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
+#     "XVERSE-13B": "xverse/XVERSE-13B",
+#     "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
+#     "ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
+#     "ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
+# }
+
+# DEFAULT_MODULE = {
+#     "LLaMA": "q_proj,v_proj",
+#     "LLaMA2": "q_proj,v_proj",
+#     "ChineseLLaMA2": "q_proj,v_proj",
+#     "BLOOM": "query_key_value",
+#     "BLOOMZ": "query_key_value",
+#     "Falcon": "query_key_value",
+#     "Baichuan": "W_pack",
+#     "Baichuan2": "W_pack",
+#     "InternLM": "q_proj,v_proj",
+#     "Qwen": "c_attn",
+#     "XVERSE": "q_proj,v_proj",
+#     "ChatGLM2": "query_key_value",
+#     "ChatGLM3": "query_key_value",
+
+# }
+
+# DEFAULT_TEMPLATE = {
+#     "LLaMA2": "llama2",
+#     "ChineseLLaMA2": "llama2_zh",
+#     "Baichuan": "baichuan",
+#     "Baichuan2": "baichuan2",
+#     "InternLM": "intern",
+#     "Qwen": "chatml",
+#     "ChatGLM2": "chatglm2",
+#     "ChatGLM3": "chatglm3",
+
+# }
diff --git a/src/config_parser.py b/src/config_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ef8d04bdfcb1df8db922613b405a1175d28b94d
--- /dev/null
+++ b/src/config_parser.py
@@ -0,0 +1,258 @@
+import os
+import sys
+import torch
+import transformers
+import datasets
+from transformers.trainer import WEIGHTS_NAME
+from transformers.modeling_utils import load_sharded_checkpoint
+from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
+from transformers import HfArgumentParser, Seq2SeqTrainingArguments
+from transformers.trainer_utils import get_last_checkpoint
+from typing import Any, Dict, Optional, Tuple
+from .loggings import get_logger
+from .model_args import (
+    ModelArguments,
+    FinetuningArguments,
+    GeneratingArguments,
+)
+from .data_args import DataArguments
+
+
+logger = get_logger(__name__)
+
+
+def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
+    state_dict: Dict[str, torch.Tensor] = model.state_dict()
+    filtered_state_dict = {}
+
+    for k, v in model.named_parameters():
+        if v.requires_grad:
+            filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
+
+    return filtered_state_dict
+
+
+def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
+    weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
+    if os.path.exists(weights_file):
+        model_state_dict = torch.load(weights_file, map_location="cpu")
+        model.load_state_dict(model_state_dict, strict=False)  # skip missing keys
+    elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
+        load_sharded_checkpoint(model, checkpoint_dir, strict=False)
+    else:
+        logger.warning(
+            "Provided path ({}) does not contain pre-trained weights.".format(
+                checkpoint_dir
+            )
+        )
+        return False
+    return True
+
+
+def _parse_args(
+    parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None
+) -> Tuple[Any]:
+    if args is not None:
+        return parser.parse_dict(args)
+    elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
+        return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
+    elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+        return parser.parse_json_file(os.path.abspath(sys.argv[1]))
+    else:
+        return parser.parse_args_into_dataclasses()
+
+
+def parse_train_args(
+    args: Optional[Dict[str, Any]] = None
+) -> Tuple[
+    ModelArguments,
+    DataArguments,
+    Seq2SeqTrainingArguments,
+    FinetuningArguments,
+    GeneratingArguments,
+]:
+    parser = HfArgumentParser(
+        (
+            ModelArguments,
+            DataArguments,
+            Seq2SeqTrainingArguments,
+            FinetuningArguments,
+            GeneratingArguments,
+        )
+    )
+    return _parse_args(parser, args)
+
+
+def parse_infer_args(
+    args: Optional[Dict[str, Any]] = None
+) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
+    parser = HfArgumentParser(
+        (ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments)
+    )
+    return _parse_args(parser, args)
+
+
+def get_train_args(
+    args: Optional[Dict[str, Any]] = None, data_args_init: bool = True
+) -> Tuple[
+    ModelArguments,
+    DataArguments,
+    Seq2SeqTrainingArguments,
+    FinetuningArguments,
+    GeneratingArguments,
+]:
+    (
+        model_args,
+        data_args,
+        training_args,
+        finetuning_args,
+        generating_args,
+    ) = parse_train_args(args)
+
+    # Setup logging
+    if training_args.should_log:
+        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
+        transformers.utils.logging.set_verbosity_info()
+
+    log_level = training_args.get_process_log_level()
+    datasets.utils.logging.set_verbosity(log_level)
+    transformers.utils.logging.set_verbosity(log_level)
+    transformers.utils.logging.enable_default_handler()
+    transformers.utils.logging.enable_explicit_format()
+
+    # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
+    if data_args_init:
+        data_args.init_for_training()
+
+    if training_args.max_steps == -1 and data_args.streaming:
+        raise ValueError("Please specify `max_steps` in streaming mode.")
+
+    if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
+        raise ValueError("Streaming mode should have an integer val size.")
+
+    if training_args.do_train and training_args.predict_with_generate:
+        raise ValueError(
+            "`predict_with_generate` cannot be set as True while training."
+        )
+
+    if (
+        training_args.do_train
+        and finetuning_args.finetuning_type == "lora"
+        and finetuning_args.lora_target is None
+    ):
+        raise ValueError("Please specify `lora_target` in LoRA training.")
+
+    if (
+        model_args.quantization_bit is not None
+        and finetuning_args.finetuning_type != "lora"
+    ):
+        raise ValueError("Quantization is only compatible with the LoRA method.")
+
+    if model_args.checkpoint_dir is not None:
+        if finetuning_args.finetuning_type != "lora":
+            if len(model_args.checkpoint_dir) != 1:
+                raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
+        elif (
+            model_args.quantization_bit is not None
+            and len(model_args.checkpoint_dir) != 1
+        ):
+            raise ValueError("Quantized model only accepts a single checkpoint.")
+
+    if model_args.quantization_bit is not None and (not training_args.do_train):
+        logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
+
+    if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
+        logger.warning("We recommend enable mixed precision training.")
+
+    # postprocess data_args
+    if data_args.max_samples is not None and data_args.streaming:
+        logger.warning(
+            "`max_samples` is incompatible with `streaming`. Disabling max_samples."
+        )
+        data_args.max_samples = None
+
+    # postprocess training_args
+    if (
+        training_args.local_rank != -1
+        and training_args.ddp_find_unused_parameters is None
+        and finetuning_args.finetuning_type == "lora"
+    ):
+        logger.warning(
+            "`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training."
+        )
+        training_args_dict = training_args.to_dict()
+        training_args_dict.update(dict(ddp_find_unused_parameters=False))
+        training_args = Seq2SeqTrainingArguments(**training_args_dict)
+
+    if (
+        training_args.resume_from_checkpoint is None
+        and training_args.do_train
+        and os.path.isdir(training_args.output_dir)
+        and not training_args.overwrite_output_dir
+    ):
+        last_checkpoint = get_last_checkpoint(training_args.output_dir)
+        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+            raise ValueError(
+                "Output directory already exists and is not empty. Use `overwrite_output_dir`."
+            )
+
+        if last_checkpoint is not None:
+            training_args_dict = training_args.to_dict()
+            training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
+            training_args = Seq2SeqTrainingArguments(**training_args_dict)
+            logger.info(
+                "Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
+            )
+
+    # postprocess model_args
+    if training_args.bf16:
+        if not torch.cuda.is_bf16_supported():
+            raise ValueError("Current device does not support bf16 training.")
+        model_args.compute_dtype = torch.bfloat16
+    else:
+        model_args.compute_dtype = torch.float16
+
+    model_args.model_max_length = (
+        data_args.max_source_length + data_args.max_target_length
+    )
+
+    # Log on each process the small summary:
+    logger.info(
+        "Process rank: {}, device: {}, n_gpu: {}\n  distributed training: {}, compute dtype: {}".format(
+            training_args.local_rank,
+            training_args.device,
+            training_args.n_gpu,
+            bool(training_args.local_rank != -1),
+            str(model_args.compute_dtype),
+        )
+    )
+    logger.info(f"Training/evaluation parameters {training_args}")
+
+    # Set seed before initializing model.
+    transformers.set_seed(training_args.seed)
+
+    return model_args, data_args, training_args, finetuning_args, generating_args
+
+
+def get_infer_args(
+    args: Optional[Dict[str, Any]] = None
+) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
+    model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
+
+    if (
+        model_args.quantization_bit is not None
+        and finetuning_args.finetuning_type != "lora"
+    ):
+        raise ValueError("Quantization is only compatible with the LoRA method.")
+
+    if model_args.checkpoint_dir is not None:
+        if finetuning_args.finetuning_type != "lora":
+            if len(model_args.checkpoint_dir) != 1:
+                raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
+        elif (
+            model_args.quantization_bit is not None
+            and len(model_args.checkpoint_dir) != 1
+        ):
+            raise ValueError("Quantized model only accepts a single checkpoint.")
+
+    return model_args, data_args, finetuning_args, generating_args
diff --git a/src/data_args.py b/src/data_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..55fc306d25cf9c29e12b73bd52bce3045b7819ba
--- /dev/null
+++ b/src/data_args.py
@@ -0,0 +1,417 @@
+
+import os
+import json
+import tiktoken
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
+
+if TYPE_CHECKING:
+    from transformers import PreTrainedTokenizer
+
+
+DEFAULT_PROMPT_DICT = {
+    "prompt_input": ("{instruction}\n\n{input}\n\n"),
+    "prompt_no_input": ("{instruction}\n\n"),
+}
+
+
+CR_PROMPT_DICT = {
+    "prompt_input": (
+        "/* Given the following database schema : */\n "
+        "{instruction}\n\n/* Answer the following: {input}\n*/\nSELECT"
+    ),
+    "prompt_no_input": (
+        "/* Given the following database schema : */\n "
+        "{instruction}\n\nSELECT"
+    ),
+}
+
+ALPACA_PROMPT_DICT = {
+    "prompt_input": (
+        "Below is an instruction that describes a task, paired with an input that provides further context. "
+        "Write a response that appropriately completes the request.\n\n"
+        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
+    ),
+    "prompt_no_input": (
+        "Below is an instruction that describes a task. "
+        "Write a response that appropriately completes the request.\n\n"
+        "### Instruction:\n{instruction}\n\n### Response: "
+    ),
+}
+
+SQL_PROMPT_DICT = {
+    "prompt_input": (
+        "I want you to act as a SQL terminal in front of an example database, \
+         you need only to return the sql command to me.Below is an instruction that describes a task, \
+         Write a response that appropriately completes the request.\n"
+        "##Instruction:\n{instruction}\n###Input:\n{input}\n\n###Response:"
+    ),
+    "prompt_no_input": (
+        "I want you to act as a SQL terminal in front of an example database, \
+        you need only to return the sql command to me.Below is an instruction that describes a task, \
+        Write a response that appropriately completes the request.\n"
+        "####Instruction:\n{instruction}\n\###Response: "
+    ),
+}
+
+
+@dataclass
+class DatasetAttr:
+    load_from: str
+    dataset_name: Optional[str] = None
+    dataset_sha1: Optional[str] = None
+    system_prompt: Optional[str] = None
+    stage: Optional[str] = None
+
+    def __repr__(self) -> str:
+        return self.dataset_name
+
+    def __post_init__(self):
+        self.prompt = "instruction"
+        self.query = "input"
+        self.response = "output"
+        self.history = None
+
+
+@dataclass
+class DataArguments:
+    r"""
+    Arguments pertaining to what data we are going to input our model for training and evaluation.
+    """
+    template: str = field(
+        metadata={
+            "help": "Which template to use for constructing prompts in training and inference."
+        }
+    )
+    dataset: Optional[str] = field(
+        default="example_text2sql",
+        metadata={
+            "help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."
+        },
+    )
+    dataset_dir: Optional[str] = field(
+        default="data/",
+        metadata={"help": "The name of the folder containing datasets."},
+    )
+    cutoff_len: Optional[int] = field(
+        default=1024,
+        metadata={"help": "The maximum length of the model inputs after tokenization."},
+    )
+    reserved_label_len: Optional[int] = field(
+        default=1,
+        metadata={"help": "The maximum length reserved for label after tokenization."},
+    )
+    split: Optional[str] = field(
+        default="train",
+        metadata={"help": "Which dataset split to use for training and evaluation."},
+    )
+    streaming: Optional[bool] = field(
+        default=False, metadata={"help": "Enable streaming mode."}
+    )
+    buffer_size: Optional[int] = field(
+        default=16384,
+        metadata={
+            "help": "Size of the buffer to randomly sample examples from in streaming mode."
+        },
+    )
+    mix_strategy: Optional[
+        Literal["concat", "interleave_under", "interleave_over"]
+    ] = field(default="concat", metadata={"help": "Strategy to use in dataset mixing."})
+    interleave_probs: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."
+        },
+    )
+    overwrite_cache: Optional[bool] = field(
+        default=False,
+        metadata={"help": "Overwrite the cached training and evaluation sets."},
+    )
+    preprocessing_num_workers: Optional[int] = field(
+        default=None,
+        metadata={"help": "The number of processes to use for the preprocessing."},
+    )
+    max_source_length: Optional[int] = field(
+        default=512,
+        metadata={
+            "help": "The maximum total input sequence length after tokenization."
+        },
+    )
+    max_target_length: Optional[int] = field(
+        default=512,
+        metadata={
+            "help": "The maximum total output sequence length after tokenization."
+        },
+    )
+    max_samples: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": "For debugging purposes, truncate the number of examples for each dataset."
+        },
+    )
+    eval_num_beams: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"
+        },
+    )
+    ignore_pad_token_for_loss: Optional[bool] = field(
+        default=True,
+        metadata={
+            "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
+        },
+    )
+    system_prompt: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."
+        },
+    )
+    val_size: Optional[float] = field(
+        default=0,
+        metadata={
+            "help": "Size of the development set, should be an integer or a float in range `[0,1)`."
+        },
+    )
+    predicted_input_filename: Optional[str] = field(
+        default="dbgpt_hub/data/example_text2sql_dev.json",
+        metadata={"help": "Predict input filename to do pred "},
+    )
+    predicted_out_filename: Optional[str] = field(
+        default="pred_sql.sql",
+        metadata={"help": "Filename to save predicted outcomes"},
+    )
+
+    def init_for_training(self):  # support mixing multiple datasets
+        dataset_names = [ds.strip() for ds in self.dataset.split(",")]
+        with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
+            dataset_info = json.load(f)
+
+        prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
+        prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
+        assert len(prompt_list) == len(
+            dataset_names
+        ), "Number of system prompts should be equal to datasets or 1."
+
+        if self.interleave_probs is not None:
+            self.interleave_probs = [
+                float(prob.strip()) for prob in self.interleave_probs.split(",")
+            ]
+
+        self.dataset_list: List[DatasetAttr] = []
+        for i, name in enumerate(dataset_names):
+            if name not in dataset_info:
+                raise ValueError(
+                    "Undefined dataset {} in dataset_info.json.".format(name)
+                )
+
+            if "hf_hub_url" in dataset_info[name]:
+                dataset_attr = DatasetAttr(
+                    "hf_hub",
+                    dataset_name=dataset_info[name]["hf_hub_url"],
+                    stage=dataset_info[name].get("stage", None),
+                )
+            elif "script_url" in dataset_info[name]:
+                dataset_attr = DatasetAttr(
+                    "script",
+                    dataset_name=dataset_info[name]["script_url"],
+                    stage=dataset_info[name].get("stage", None),
+                )
+            else:
+                dataset_attr = DatasetAttr(
+                    "file",
+                    dataset_name=dataset_info[name]["file_name"],
+                    dataset_sha1=dataset_info[name].get("file_sha1", None),
+                    stage=dataset_info[name].get("stage", None),
+                )
+
+            if "columns" in dataset_info[name]:
+                dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
+                dataset_attr.query = dataset_info[name]["columns"].get("query", None)
+                dataset_attr.response = dataset_info[name]["columns"].get(
+                    "response", None
+                )
+                dataset_attr.history = dataset_info[name]["columns"].get(
+                    "history", None
+                )
+
+            dataset_attr.system_prompt = prompt_list[i]
+            self.dataset_list.append(dataset_attr)
+
+
+@dataclass
+class Template:
+    prefix: List[Union[str, Dict[str, str]]]
+    prompt: List[Union[str, Dict[str, str]]]
+    system: str
+    sep: List[Union[str, Dict[str, str]]]
+    stop_words: List[str]
+    use_history: bool
+
+    def encode_oneturn(
+        self,
+        tokenizer: "PreTrainedTokenizer",
+        query: str,
+        resp: str,
+        history: Optional[List[Tuple[str, str]]] = None,
+        system: Optional[str] = None,
+    ) -> Tuple[List[int], List[int]]:
+        r"""
+        Returns a single pair of token ids representing prompt and response respectively.
+        """
+        system, history = self._format(query, resp, history, system)
+        encoded_pairs = self._encode(tokenizer, system, history)
+        prompt_ids = []
+        for query_ids, resp_ids in encoded_pairs[:-1]:
+            prompt_ids = prompt_ids + query_ids + resp_ids
+        prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
+        return prompt_ids, answer_ids
+
+    def encode_multiturn(
+        self,
+        tokenizer: "PreTrainedTokenizer",
+        query: str,
+        resp: str,
+        history: Optional[List[Tuple[str, str]]] = None,
+        system: Optional[str] = None,
+    ) -> List[Tuple[List[int], List[int]]]:
+        r"""
+        Returns multiple pairs of token ids representing prompts and responses respectively.
+        """
+        system, history = self._format(query, resp, history, system)
+        encoded_pairs = self._encode(tokenizer, system, history)
+        return encoded_pairs
+
+    def _format(
+        self,
+        query: str,
+        resp: str,
+        history: Optional[List[Tuple[str, str]]] = None,
+        system: Optional[str] = None,
+    ) -> Tuple[str, List[Tuple[str, str]]]:
+        r"""
+        Aligns inputs to the standard format.
+        """
+        system = system or self.system  # use system if provided
+        history = history if (history and self.use_history) else []
+        history = history + [(query, resp)]
+        return system, history
+
+    def _get_special_ids(
+        self, tokenizer: "PreTrainedTokenizer"
+    ) -> Tuple[List[int], List[int]]:
+        if tokenizer.bos_token_id is not None and getattr(
+            tokenizer, "add_bos_token", True
+        ):  # baichuan-13b has no bos token
+            bos_ids = [tokenizer.bos_token_id]
+        else:
+            bos_ids = []  # bos token is optional
+
+        if tokenizer.eos_token_id is not None:
+            eos_ids = [tokenizer.eos_token_id]
+        else:
+            raise ValueError("EOS token is required.")
+
+        return bos_ids, eos_ids
+
+    def _encode(
+        self,
+        tokenizer: "PreTrainedTokenizer",
+        system: str,
+        history: List[Tuple[str, str]],
+    ) -> List[Tuple[List[int], List[int]]]:
+        r"""
+        Encodes formatted inputs to pairs of token ids.
+        Turn 0: bos + prefix + sep + query    resp + eos
+        Turn t: sep + bos + query             resp + eos
+        """
+        bos_ids, eos_ids = self._get_special_ids(tokenizer)
+        sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
+        encoded_pairs = []
+        for turn_idx, (query, resp) in enumerate(history):
+            if turn_idx == 0:
+                prefix_ids = self._convert_inputs_to_ids(
+                    tokenizer, context=self.prefix, system=system
+                )
+                if len(prefix_ids) != 0:  # has prefix
+                    prefix_ids = bos_ids + prefix_ids + sep_ids
+                else:
+                    prefix_ids = bos_ids
+            else:
+                prefix_ids = sep_ids + bos_ids
+
+            query_ids = self._convert_inputs_to_ids(
+                tokenizer, context=self.prompt, query=query, idx=str(turn_idx)
+            )
+            resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
+            encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
+        return encoded_pairs
+
+    def _convert_inputs_to_ids(
+        self,
+        tokenizer: "PreTrainedTokenizer",
+        context: List[Union[str, Dict[str, str]]],
+        system: Optional[str] = None,
+        query: Optional[str] = None,
+        idx: Optional[str] = None,
+    ) -> List[int]:
+        r"""
+        Converts context to token ids.
+        """
+        if isinstance(
+            getattr(tokenizer, "tokenizer", None), tiktoken.Encoding
+        ):  # for tiktoken tokenizer (Qwen)
+            kwargs = dict(allowed_special="all")
+        else:
+            kwargs = dict(add_special_tokens=False)
+
+        token_ids = []
+        for elem in context:
+            if isinstance(elem, str):
+                elem = (
+                    elem.replace("{{system}}", system, 1)
+                    if system is not None
+                    else elem
+                )
+                elem = (
+                    elem.replace("{{query}}", query, 1) if query is not None else elem
+                )
+                elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
+                token_ids = token_ids + tokenizer.encode(elem, **kwargs)
+            elif isinstance(elem, dict):
+                token_ids = token_ids + [
+                    tokenizer.convert_tokens_to_ids(elem.get("token"))
+                ]
+            else:
+                raise NotImplementedError
+
+        return token_ids
+
+
+@dataclass
+class Llama2Template(Template):
+    def _encode(
+        self,
+        tokenizer: "PreTrainedTokenizer",
+        system: str,
+        history: List[Tuple[str, str]],
+    ) -> List[Tuple[List[int], List[int]]]:
+        r"""
+        Encodes formatted inputs to pairs of token ids.
+        Turn 0: bos + prefix + query    resp + eos
+        Turn t: bos + query             resp + eos
+        """
+        bos_ids, eos_ids = self._get_special_ids(tokenizer)
+        encoded_pairs = []
+        for turn_idx, (query, resp) in enumerate(history):
+            if turn_idx == 0:  # llama2 template has no sep_ids
+                query = self.prefix[0].replace("{{system}}", system) + query
+            query_ids = self._convert_inputs_to_ids(
+                tokenizer, context=self.prompt, query=query
+            )
+            resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
+            encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
+        return encoded_pairs
+
+
+templates: Dict[str, Template] = {}
diff --git a/src/data_utils.py b/src/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0071434db72b04166daf2d747b810ce9b2223a17
--- /dev/null
+++ b/src/data_utils.py
@@ -0,0 +1,1030 @@
+import hashlib
+import os
+import numpy as np
+import pandas as pd
+import tiktoken
+from itertools import chain
+from typing import (
+    Any,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    TYPE_CHECKING,
+    Generator,
+    Literal,
+)
+from datasets import (
+    Dataset,
+    DatasetDict,
+    concatenate_datasets,
+    load_dataset,
+    interleave_datasets,
+)
+from transformers.tokenization_utils import PreTrainedTokenizer
+
+from .config import EXT2TYPE, IGNORE_INDEX
+from .data_args import (
+    DEFAULT_PROMPT_DICT,
+    ALPACA_PROMPT_DICT,
+    SQL_PROMPT_DICT,
+    Template,
+    Llama2Template,
+)
+
+if TYPE_CHECKING:
+    from .model_args import ModelArguments
+    from .data_args import DataArguments
+    from datasets import IterableDataset
+    from transformers import TrainingArguments, Seq2SeqTrainingArguments
+
+from .loggings import get_logger
+
+
+logger = get_logger(__name__)
+
+
+def extract_default_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+    # Not random, use pre-defined templates
+    if example.get("input", "") != "":
+        prompt_template = DEFAULT_PROMPT_DICT["prompt_input"]
+    else:
+        prompt_template = DEFAULT_PROMPT_DICT["prompt_no_input"]
+
+    # Format prompt with example
+    formated_prompt = prompt_template.format(**example)
+
+    return {"input": formated_prompt}
+
+
+def extract_alpaca_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+    if example.get("input", "") != "":
+        prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
+    else:
+        prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
+    return {"input": prompt_format.format(**example)}
+
+
+def extract_sql_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+    if example.get("input", "") != "":
+        prompt_format = SQL_PROMPT_DICT["prompt_input"]
+    else:
+        prompt_format = SQL_PROMPT_DICT["prompt_no_input"]
+    return {"input": prompt_format.format(**example)}
+
+
+def infer_max_len(
+    source_len: int, target_len: int, data_args: "DataArguments"
+) -> Tuple[int, int]:
+    max_target_len = int(
+        data_args.cutoff_len * (target_len / (source_len + target_len))
+    )
+    max_target_len = max(max_target_len, data_args.reserved_label_len)
+    max_source_len = data_args.cutoff_len - max_target_len
+    return max_source_len, max_target_len
+
+
+def local_dataset(
+    dataset_path: str, eval_dataset_size: float = 0.1
+) -> Tuple[Dataset, Dataset]:
+    """
+    Reads in a dataset from a file and returns it as a split train-test dataset.
+
+    Args:
+        dataset_path (str): The name of the dataset file to read in. \
+            The format is inferred based on the file extension.
+
+    Returns:
+        A tuple containing two datasets - the training subset and the testing subset.
+    Raises:
+        ValueError: If the specified file format is unsupported.
+
+    """
+
+    # Read in the full dataset from file based on the file format
+    if dataset_path.endswith(".json"):
+        full_dataset = load_dataset("json", data_files=dataset_path)
+    elif dataset_path.endswith(".jsonl"):
+        full_dataset = load_dataset("json", data_files=dataset_path)
+    elif dataset_path.endswith(".csv"):
+        full_dataset = Dataset.from_pandas(pd.read_csv(dataset_path))
+    elif dataset_path.endswith(".tsv"):
+        full_dataset = Dataset.from_pandas(pd.read_csv(dataset_path, delimiter="\t"))
+    else:
+        raise ValueError(f"Unsupported dataset format: {dataset_path}")
+    if "train" not in full_dataset:
+        split_dataset = full_dataset.train_test_split(test_size=eval_dataset_size)
+        return split_dataset
+    else:
+        return full_dataset
+
+
+def load_data(
+    dataset_path: str, eval_dataset_size: float = 0.1
+) -> Union[Dict[str, Dataset], None]:
+    """
+    Load a dataset based on its name.
+
+    Args:
+        dataset_path: A string representing the path to the dataset to be loaded.
+
+    Returns:
+        A dictionary containing the loaded dataset if the dataset exists.
+        None if the dataset does not exist.
+
+    Raises:
+        NotImplementedError: If the dataset name provided is not implemented yet or if
+            the dataset is not released.
+
+    Examples:
+        >>> load_data('alpaca')
+        {'train': Dataset(...), 'validation': Dataset(...), 'test': Dataset(...)}
+
+    """
+    if not os.path.exists(dataset_path):
+        # Download dataset from HuggingFace Datasets
+        print(
+            f"Lodding dataset from huggingface, please ref to https://huggingface.co/datasets/{dataset_path}"
+        )
+        dataset = load_dataset(dataset_path, cache_dir="~/.cache/huggingface/datasets")
+        return dataset
+    else:
+        # Load dataset from local file
+        try:
+            print(f"Lodding dataset from local path: {dataset_path}")
+            dataset = local_dataset(dataset_path, eval_dataset_size)
+            return dataset
+        except:
+            raise ValueError(f"Error loading dataset from {dataset_path}")
+
+
+templates: Dict[str, Template] = {}
+
+
+def get_template_and_fix_tokenizer(
+    name: str, tokenizer: "PreTrainedTokenizer"
+) -> Template:
+    template = templates.get(name, None)
+    assert template is not None, "Template {} does not exist.".format(name)
+
+    additional_special_tokens = template.stop_words
+
+    if tokenizer.eos_token_id is None:
+        tokenizer.eos_token = "<|endoftext|>"
+        logger.info("Add eos token: {}".format(tokenizer.eos_token))
+
+    if tokenizer.pad_token_id is None:
+        if tokenizer.unk_token_id is not None:
+            tokenizer.pad_token = tokenizer.unk_token
+        else:
+            tokenizer.pad_token = tokenizer.eos_token
+        logger.info("Add pad token: {}".format(tokenizer.pad_token))
+
+    if name is None:
+        return None
+
+    tokenizer.add_special_tokens(
+        dict(additional_special_tokens=additional_special_tokens),
+        replace_additional_special_tokens=False,
+    )
+    return template
+
+
+def register_template(
+    name: str,
+    prefix: List[Union[str, Dict[str, str]]],
+    prompt: List[Union[str, Dict[str, str]]],
+    system: str,
+    sep: List[Union[str, Dict[str, str]]],
+    stop_words: Optional[List[str]] = [],
+    use_history: Optional[bool] = True,
+) -> None:
+    template_class = Llama2Template if "llama2" in name else Template
+    templates[name] = template_class(
+        prefix=prefix,
+        prompt=prompt,
+        system=system,
+        sep=sep,
+        stop_words=stop_words,
+        use_history=use_history,
+    )
+
+
+r"""
+Supports language model inference without histories.
+"""
+register_template(
+    name="vanilla",
+    prefix=[],
+    prompt=["{{query}}"],
+    system="",
+    sep=[],
+    use_history=False,
+)
+
+r"""
+Supports language model for  mistral sqlcoder-7b
+"""
+register_template(
+    name="mistral",
+    prefix=["{{system}}"],
+    prompt=["[INST] {{query}} [/INST]"],
+    system="",
+    sep=[],
+)
+
+
+r"""
+Default template.
+"""
+register_template(
+    name="default",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}\nAssistant: "],
+    system=(
+        "A chat between a curious user and an artificial intelligence assistant. "
+        "The assistant gives helpful, detailed, and polite answers to the user's questions."
+    ),
+    sep=["\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
+          https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
+          https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
+"""
+register_template(
+    name="llama2",
+    prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"],
+    prompt=["[INST] {{query}} [/INST] "],
+    system=(
+        "You are a helpful, respectful and honest assistant. "
+        "Always answer as helpfully as possible, while being safe.  "
+        "Your answers should not include any harmful, unethical, "
+        "racist, sexist, toxic, dangerous, or illegal content. "
+        "Please ensure that your responses are socially unbiased and positive in nature.\n"
+        "If a question does not make any sense, or is not factually coherent, "
+        "explain why instead of answering something not correct. "
+        "If you don't know the answer to a question, please don't share false information."
+    ),
+    sep=[],
+)
+
+register_template(
+    name="llama3",
+    prefix=["<|start_header_id|>system<|end_header_id|>\n\n{{system}}<|eot_id|>\n"],
+    prompt=["<|start_header_id|>user<|end_header_id|>\n\n{{query}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"],
+    system=(
+        "You are a helpful, respectful and honest assistant. "
+        "Always answer as helpfully as possible, while being safe.  "
+        "Your answers should not include any harmful, unethical, "
+        "racist, sexist, toxic, dangerous, or illegal content. "
+        "Please ensure that your responses are socially unbiased and positive in nature.\n"
+        "If a question does not make any sense, or is not factually coherent, "
+        "explain why instead of answering something not correct. "
+        "If you don't know the answer to a question, please don't share false information."
+    ),
+    sep=[],
+)
+
+r"""
+Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
+          https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
+"""
+register_template(
+    name="llama2_zh",
+    prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"],
+    prompt=["[INST] {{query}} [/INST] "],
+    system="You are a helpful assistant. 你是一个乐于助人的助手。",
+    sep=[],
+)
+
+
+r"""
+Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
+          https://github.com/ymcui/Chinese-LLaMA-Alpaca
+"""
+register_template(
+    name="alpaca",
+    prefix=["{{system}}"],
+    prompt=["### Instruction:\n{{query}}\n\n### Response:\n"],
+    system=(
+        "Below is an instruction that describes a task. "
+        "Write a response that appropriately completes the request."
+    ),
+    sep=["\n\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
+          https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
+"""
+register_template(
+    name="vicuna",
+    prefix=["{{system}}"],
+    prompt=["USER: {{query}} ASSISTANT: "],
+    system=(
+        "A chat between a curious user and an artificial intelligence assistant. "
+        "The assistant gives helpful, detailed, and polite answers to the user's questions."
+    ),
+    sep=[],
+)
+
+
+r"""
+Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
+"""
+register_template(
+    name="belle",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}\n\nBelle: "],
+    system="",
+    sep=["\n\n"],
+)
+
+
+r"""
+Supports: https://github.com/CVI-SZU/Linly
+"""
+register_template(
+    name="linly",
+    prefix=["{{system}}"],
+    prompt=["User: {{query}}\nBot: "],
+    system="",
+    sep=["\n"],
+)
+
+
+r"""
+Supports: https://github.com/Neutralzz/BiLLa
+"""
+register_template(
+    name="billa",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}\nAssistant: "],
+    system="",
+    sep=["\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
+"""
+register_template(
+    name="ziya",
+    prefix=["{{system}}"],
+    prompt=[{"token": "<human>"}, ":{{query}}\n", {"token": "<bot>"}, ":"],
+    system="",
+    sep=["\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/qhduan/aquilachat-7b
+"""
+register_template(
+    name="aquila",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}###Assistant: "],
+    system=(
+        "A chat between a curious human and an artificial intelligence assistant. "
+        "The assistant gives helpful, detailed, and polite answers to the human's questions."
+    ),
+    sep=["###"],
+)
+
+
+r"""
+Supports: https://huggingface.co/internlm/internlm-chat-7b
+"""
+register_template(
+    name="intern",
+    prefix=["{{system}}"],
+    prompt=["<|User|>:{{query}}", {"token": "<eoh>"}, "\n<|Bot|>:"],
+    system="",
+    sep=["\n"],
+    stop_words=["</s>", "<eoa>"],  # internlm cannot replace eos token
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
+Used for training and inference of the fine-tuned models.
+"""
+register_template(
+    name="baichuan",
+    prefix=["{{system}}"],
+    prompt=[
+        {"token": "<reserved_102>"},  # user token
+        "{{query}}",
+        {"token": "<reserved_103>"},  # assistant token
+    ],
+    system="",
+    sep=[],
+    stop_words=[],
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
+Used for inference of the original model.
+"""
+register_template(
+    name="baichuan_eval",
+    prefix=["{{system}}", {"token": "<reserved_102>"}],  # user token
+    prompt=["{{query}}", {"token": "<reserved_103>"}],  # assistant token
+    system="",
+    sep=[],
+    stop_words=["<reserved_102>"],  # user token
+)
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
+          https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
+Used for training and inference of the fine-tuned models.
+"""
+register_template(
+    name="baichuan2",
+    prefix=["{{system}}"],
+    prompt=[
+        {"token": "<reserved_106>"},  # user token
+        "{{query}}",
+        {"token": "<reserved_107>"},  # assistant token
+    ],
+    system="",
+    sep=[],
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
+          https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
+Used for inference of the original model.
+"""
+register_template(
+    name="baichuan2_eval",
+    prefix=["{{system}}", {"token": "<reserved_106>"}],  # user token
+    prompt=["{{query}}", {"token": "<reserved_107>"}],  # assistant token
+    system="",
+    sep=[],
+    stop_words=["<reserved_106>"],  # user token
+)
+
+
+r"""
+Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
+          https://huggingface.co/HuggingFaceH4/starchat-beta
+
+"""
+register_template(
+    name="starchat",
+    prefix=[{"token": "<|system|>"}, "\n{{system}}", {"token": "<|end|>"}],
+    prompt=[
+        {"token": "<|user|>"},
+        "\n{{query}}",
+        {"token": "<|end|>"},
+        "\n",
+        {"token": "<|assistant|>"},
+    ],
+    system="",
+    sep=["\n"],
+    stop_words=["<|end|>"],
+)
+
+
+r"""
+Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
+"""
+register_template(
+    name="chatml",
+    prefix=[{"token": "<|im_start|>"}, "system\n{{system}}", {"token": "<|im_end|>"}],
+    prompt=[
+        {"token": "<|im_start|>"},
+        "user\n{{query}}",
+        {"token": "<|im_end|>"},
+        "\n",
+        {"token": "<|im_start|>"},
+        "assistant\n",
+    ],
+    system="You are a helpful assistant.",
+    sep=["\n"],
+    stop_words=["<|im_end|>"],
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm2-6b
+"""
+register_template(
+    name="chatglm2",
+    prefix=[{"token": "[gMASK]"}, {"token": "sop"}, "{{system}}"],
+    prompt=["[Round {{idx}}]\n\n问:{{query}}\n\n答:"],
+    system="",
+    sep=["\n\n"],
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm3-6b
+"""
+register_template(
+    name="chatglm3",
+    prefix=[
+        {"token": "[gMASK]"},
+        {"token": "sop"},
+        {"token": "<|system|>"},
+        "\n",
+        "{{system}}",
+    ],
+    prompt=[
+        {"token": "<|user|>"},
+        "\n",
+        "{{query}}",
+        {"token": "<|assistant|>"},
+        "\n",  # add an extra newline to avoid error in ChatGLM's process_response method
+    ],
+    system=(
+        "You are ChatGLM3, a large language model trained by Zhipu.AI. "
+        "Follow the user's instructions carefully. Respond using markdown."
+    ),
+    sep=[],
+    stop_words=["<|user|>", "<|observation|>"],
+)
+
+register_template(
+    name="chatglm3_raw",  # the raw template for tool tuning
+    prefix=[
+        {"token": "[gMASK]"},
+        {"token": "sop"},
+        {"token": "<|system|>"},
+        "\n",
+        "{{system}}",
+    ],
+    prompt=[{"token": "<|user|>"}, "\n", "{{query}}", {"token": "<|assistant|>"}],
+    system=(
+        "You are ChatGLM3, a large language model trained by Zhipu.AI. "
+        "Follow the user's instructions carefully. Respond using markdown."
+    ),
+    sep=[],
+    stop_words=["<|user|>", "<|observation|>"],
+)
+
+
+r"""
+Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
+"""
+register_template(
+    name="xverse",
+    prefix=["{{system}}"],
+    prompt=["Human: {{query}}\n\nAssistant: "],
+    system="",
+    sep=[],
+)
+
+
+def split_dataset(
+    dataset: Union["Dataset", "IterableDataset"],
+    data_args: "DataArguments",
+    training_args: "TrainingArguments",
+) -> Dict[str, "Dataset"]:
+    if training_args.do_train:
+        if data_args.val_size > 1e-6:  # Split the dataset
+            if data_args.streaming:
+                val_set = dataset.take(int(data_args.val_size))
+                train_set = dataset.skip(int(data_args.val_size))
+                dataset = dataset.shuffle(
+                    buffer_size=data_args.buffer_size, seed=training_args.seed
+                )
+                return {"train_dataset": train_set, "eval_dataset": val_set}
+            else:
+                val_size = (
+                    int(data_args.val_size)
+                    if data_args.val_size > 1
+                    else data_args.val_size
+                )
+                dataset = dataset.train_test_split(
+                    test_size=val_size, seed=training_args.seed
+                )
+                return {
+                    "train_dataset": dataset["train"],
+                    "eval_dataset": dataset["test"],
+                }
+        else:
+            if data_args.streaming:
+                dataset = dataset.shuffle(
+                    buffer_size=data_args.buffer_size, seed=training_args.seed
+                )
+            return {"train_dataset": dataset}
+    else:  # do_eval or do_predict
+        return {"eval_dataset": dataset}
+
+
+def preprocess_dataset(
+    dataset: Union["Dataset", "IterableDataset"],
+    tokenizer: "PreTrainedTokenizer",
+    data_args: "DataArguments",
+    training_args: "Seq2SeqTrainingArguments",
+    stage: Literal["pt", "sft", "rm", "ppo"],
+) -> Union["Dataset", "IterableDataset"]:
+    column_names = list(next(iter(dataset)).keys())
+    template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
+
+    def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
+        for i in range(len(examples["prompt"])):
+            query, response = examples["prompt"][i], examples["response"][i]
+            query = (
+                query + "\n" + examples["query"][i]
+                if "query" in examples and examples["query"][i]
+                else query
+            )
+            history = examples["history"][i] if "history" in examples else None
+            system = examples["system"][i] if "system" in examples else None
+            yield query, response, history, system
+
+    def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
+        # build grouped texts with format `X1 X2 X3 ...` (without <eos>)
+        if isinstance(
+            getattr(tokenizer, "tokenizer", None), tiktoken.Encoding
+        ):  # for tiktoken tokenizer (Qwen)
+            kwargs = dict(allowed_special="all")
+        else:
+            kwargs = dict(add_special_tokens=False)
+
+        tokenized_examples = tokenizer(examples["prompt"], **kwargs)
+        concatenated_examples = {
+            k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()
+        }
+        total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
+        block_size = data_args.max_source_length
+        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
+        total_length = (total_length // block_size) * block_size
+        # split by chunks of max_source_length
+        result = {
+            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
+            for k, t in concatenated_examples.items()
+        }
+        return result
+
+    def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
+        # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
+        # for multiturn examples, we only mask the prompt part in each prompt-response pair.
+        model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+        max_length = data_args.max_source_length + data_args.max_target_length
+
+        for query, response, history, system in construct_example(examples):
+            input_ids, labels = [], []
+
+            for source_ids, target_ids in template.encode_multiturn(
+                tokenizer, query, response, history, system
+            ):
+                if len(source_ids) > data_args.max_source_length:
+                    source_ids = source_ids[: data_args.max_source_length]
+                if len(target_ids) > data_args.max_target_length:
+                    target_ids = target_ids[: data_args.max_target_length]
+
+                if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
+                    break
+
+                input_ids += source_ids + target_ids
+                labels += [IGNORE_INDEX] * len(source_ids) + target_ids
+
+            model_inputs["input_ids"].append(input_ids)
+            model_inputs["attention_mask"].append([1] * len(input_ids))
+            model_inputs["labels"].append(labels)
+
+        return model_inputs
+
+    def preprocess_unsupervised_dataset(
+        examples: Dict[str, List[Any]]
+    ) -> Dict[str, Any]:
+        # build inputs with format `<bos> X` and labels with format `Y <eos>`
+        model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+
+        for query, response, history, system in construct_example(examples):
+            source_ids, target_ids = template.encode_oneturn(
+                tokenizer, query, response, history, system
+            )
+
+            if len(source_ids) > data_args.max_source_length:
+                source_ids = source_ids[: data_args.max_source_length]
+            if len(target_ids) > data_args.max_target_length:
+                target_ids = target_ids[: data_args.max_target_length]
+
+            model_inputs["input_ids"].append(source_ids)
+            model_inputs["attention_mask"].append([1] * len(source_ids))
+            model_inputs["labels"].append(target_ids)
+
+        return model_inputs
+
+    def preprocess_pairwise_dataset(
+        examples: Dict[str, List[Any]]
+    ) -> Dict[str, List[List[int]]]:
+        # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` for rm stage
+        model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
+        for query, response, history, system in construct_example(examples):
+            if not (
+                isinstance(query, str)
+                and isinstance(response, list)
+                and query != ""
+                and len(response) > 1
+            ):
+                continue
+
+            prompt_ids, chosen_ids = template.encode_oneturn(
+                tokenizer, query, response[0], history, system
+            )
+            _, rejected_ids = template.encode_oneturn(
+                tokenizer, query, response[1], history, system
+            )
+
+            # if template.efficient_eos:
+            chosen_ids += [tokenizer.eos_token_id]
+            rejected_ids += [tokenizer.eos_token_id]
+
+            source_len, target_len = len(prompt_ids), max(
+                len(chosen_ids), len(rejected_ids)
+            )
+            max_source_len, max_target_len = infer_max_len(
+                source_len, target_len, data_args
+            )
+            if source_len > max_source_len:
+                prompt_ids = prompt_ids[:max_source_len]
+            if target_len > max_target_len:
+                chosen_ids = chosen_ids[:max_target_len]
+                rejected_ids = rejected_ids[:max_target_len]
+
+            model_inputs["prompt_ids"].append(prompt_ids)
+            model_inputs["chosen_ids"].append(chosen_ids)
+            model_inputs["rejected_ids"].append(rejected_ids)
+
+        return model_inputs
+
+    def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
+        print("prompt_ids:\n{}".format(example["prompt_ids"]))
+        print(
+            "prompt:\n{}".format(
+                tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)
+            )
+        )
+        print("chosen_ids:\n{}".format(example["chosen_ids"]))
+        print(
+            "chosen:\n{}".format(
+                tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)
+            )
+        )
+        print("rejected_ids:\n{}".format(example["rejected_ids"]))
+        print(
+            "rejected:\n{}".format(
+                tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)
+            )
+        )
+
+    def print_supervised_dataset_example(example):
+        print("input_ids:\n{}".format(example["input_ids"]))
+        print(
+            "inputs:\n{}".format(
+                tokenizer.decode(example["input_ids"], skip_special_tokens=False)
+            )
+        )
+        print("label_ids:\n{}".format(example["labels"]))
+        print(
+            "labels:\n{}".format(
+                tokenizer.decode(
+                    [
+                        token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id
+                        for token_id in example["labels"]
+                    ],
+                    skip_special_tokens=False,
+                )
+            )
+        )
+
+    if stage == "pt":
+        pass
+    elif stage == "sft" and not training_args.predict_with_generate:
+        preprocess_function = preprocess_supervised_dataset
+        print_function = print_supervised_dataset_example
+    elif stage == "rm":
+        print(111111111111111111)
+        preprocess_function = preprocess_pairwise_dataset
+        print_function = print_pairwise_dataset_example
+    else:
+        pass
+
+    with training_args.main_process_first(desc="dataset map pre-processing"):
+        kwargs = {}
+        if not data_args.streaming:
+            kwargs = dict(
+                num_proc=data_args.preprocessing_num_workers,
+                load_from_cache_file=not data_args.overwrite_cache,
+                desc="Running tokenizer on dataset",
+            )
+
+        dataset = dataset.map(
+            preprocess_function, batched=True, remove_columns=column_names, **kwargs
+        )
+
+        print_function(next(iter(dataset)))
+        return dataset
+
+
+## used in get_dataset
+def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
+    if file_sha1 is None:
+        logger.warning(
+            "Checksum failed: missing SHA-1 hash value in dataset_info.json."
+        )
+        return
+
+    if len(data_files) != 1:
+        logger.warning("Checksum failed: too many files.")
+        return
+
+    with open(data_files[0], "rb") as f:
+        sha1 = hashlib.sha1(f.read()).hexdigest()
+        if sha1 != file_sha1:
+            logger.warning(
+                "Checksum failed: mismatched SHA-1 hash value at {}.".format(
+                    data_files[0]
+                )
+            )
+
+
+def get_dataset(
+    model_args: "ModelArguments", data_args: "DataArguments"
+) -> Union["Dataset", "IterableDataset"]:
+    max_samples = data_args.max_samples
+    all_datasets: List[
+        Union["Dataset", "IterableDataset"]
+    ] = []  # support multiple datasets
+
+    for dataset_attr in data_args.dataset_list:
+        logger.info("Loading dataset {}...".format(dataset_attr))
+
+        if dataset_attr.load_from == "hf_hub":
+            data_path = dataset_attr.dataset_name
+            data_files = None
+        elif dataset_attr.load_from == "script":
+            data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+            data_files = None
+        elif dataset_attr.load_from == "file":
+            data_path = None
+            data_files: List[str] = []
+
+            if os.path.isdir(
+                os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+            ):  # directory
+                for file_name in os.listdir(
+                    os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+                ):
+                    data_files.append(
+                        os.path.join(
+                            data_args.dataset_dir, dataset_attr.dataset_name, file_name
+                        )
+                    )
+                    if data_path is None:
+                        data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
+                    else:
+                        assert data_path == EXT2TYPE.get(
+                            file_name.split(".")[-1], None
+                        ), "file type does not match."
+            elif os.path.isfile(
+                os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+            ):  # single file
+                data_files.append(
+                    os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+                )
+                data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
+            else:
+                raise ValueError("File not found.")
+
+            assert data_path, "File extension must be txt, csv, json or jsonl."
+            checksum(data_files, dataset_attr.dataset_sha1)
+        else:
+            raise NotImplementedError
+
+        dataset = load_dataset(
+            data_path,
+            data_files=data_files,
+            split=data_args.split,
+            cache_dir=model_args.cache_dir,
+            streaming=data_args.streaming,
+            use_auth_token=True if model_args.use_auth_token else None,
+        )
+
+        if max_samples is not None:
+            max_samples_temp = min(len(dataset), max_samples)
+            dataset = dataset.select(range(max_samples_temp))
+
+        for column_name in ["prompt", "query", "response", "history"]:  # align datasets
+            if (
+                getattr(dataset_attr, column_name)
+                and getattr(dataset_attr, column_name) != column_name
+            ):
+                dataset = dataset.rename_column(
+                    getattr(dataset_attr, column_name), column_name
+                )
+
+        if dataset_attr.system_prompt:  # add system prompt
+            if data_args.streaming:
+                dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
+            else:
+                dataset = dataset.add_column(
+                    "system", [dataset_attr.system_prompt] * len(dataset)
+                )
+
+        all_datasets.append(dataset)
+
+    if len(data_args.dataset_list) == 1:
+        return all_datasets[0]
+    elif data_args.mix_strategy == "concat":
+        if data_args.streaming:
+            logger.warning(
+                "The samples between different datasets will not be mixed in streaming mode."
+            )
+        return concatenate_datasets(all_datasets)
+    elif data_args.mix_strategy.startswith("interleave"):
+        if not data_args.streaming:
+            logger.warning(
+                "We recommend using `mix_strategy=concat` in non-streaming mode."
+            )
+        stopping_strategy = (
+            "first_exhausted"
+            if data_args.mix_strategy.endswith("under")
+            else "all_exhausted"
+        )
+        return interleave_datasets(
+            all_datasets,
+            data_args.interleave_probs,
+            stopping_strategy=stopping_strategy,
+        )
+    else:
+        raise ValueError("Unknown mixing strategy.")
+
+
+def split_train_eval(
+    dataset: Dataset,
+    do_eval: bool = False,
+    eval_dataset_size: float = 0.1,
+    max_eval_samples: int = None,
+    do_train: bool = True,
+    max_train_samples: int = None,
+) -> Dict[str, Dataset]:
+    """
+    Prepare the training and evaluation datasets for a machine learning model.
+
+    Args:
+        dataset (DatasetDict): The complete dataset containing train, validation, and test splits.
+        do_eval (bool, optional): Whether to use an evaluation dataset or not. Defaults to False.
+        eval_dataset_size (float, optional): The size of the validation set if splitting from the training data.
+            Ignored if `do_eval` is False. Defaults to 0.2.
+        max_eval_samples (int, optional): The maximum number of samples to keep in the evaluation dataset.
+            Ignored if `do_eval` is False or `None`. Defaults to None.
+        do_train (bool, optional): Whether to use a training dataset or not. Defaults to True.
+        max_train_samples (int, optional): The maximum number of samples to keep in the training dataset.
+            Ignored if `do_train` is False or `None`. Defaults to None.
+
+    Returns:
+        Dict[str, Dataset]: A dictionary containing the prepared training and evaluation datasets
+        (if used), where the keys are 'train' and 'eval', respectively.
+    """
+    if not isinstance(dataset, DatasetDict):
+        raise TypeError("The 'dataset' argument must be a DatasetDict object.")
+
+    train_dataset, eval_dataset = None, None
+    # Prepare evaluation dataset
+    if do_eval:
+        if "eval" in dataset:
+            eval_dataset = dataset["eval"]
+        else:
+            # Split train dataset in train and validation according to `eval_dataset_size`
+            print(
+                f"Splitting the dataset into train and validation according to `eval_dataset_size`:  {eval_dataset_size}"
+            )
+            dataset = dataset["train"].train_test_split(
+                test_size=eval_dataset_size, shuffle=True, seed=42
+            )
+            eval_dataset = dataset["test"]
+
+        # Reduce evaluation dataset size (if specified)
+        print(
+            f"You have set the max_eval_samples: {max_eval_samples}, will do sampling ..."
+        )
+        if max_eval_samples is not None and len(eval_dataset) > max_eval_samples:
+            eval_dataset = eval_dataset.select(np.arange(max_eval_samples))
+
+    # Prepare training dataset
+    if do_train:
+        train_dataset = dataset["train"]
+
+        # Reduce training dataset size (if specified)
+        print(
+            f"You have set the max_train_samples: {max_train_samples}, will do sampling ..."
+        )
+        if max_train_samples is not None and len(train_dataset) > max_train_samples:
+            train_dataset = train_dataset.select(np.arange(max_train_samples))
+
+    return train_dataset, eval_dataset
diff --git a/src/ds_config.json b/src/ds_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..e96d4d9b2a886a49eeb7077db31eaa9ef13b0ef0
--- /dev/null
+++ b/src/ds_config.json
@@ -0,0 +1,23 @@
+{
+    "train_micro_batch_size_per_gpu": "auto",
+    "gradient_accumulation_steps": "auto",
+    "gradient_clipping": "auto",
+    "zero_allow_untested_optimizer": true,
+    "fp16": {
+      "enabled": "auto",
+      "loss_scale": 0,
+      "initial_scale_power": 16,
+      "loss_scale_window": 1000,
+      "hysteresis": 2,
+      "min_loss_scale": 1
+    },  
+    "zero_optimization": {
+      "stage": 2,
+      "allgather_partitions": true,
+      "allgather_bucket_size": 5e8,
+      "reduce_scatter": true,
+      "reduce_bucket_size": 5e8,
+      "overlap_comm": false,
+      "contiguous_gradients": true
+    }
+  }
diff --git a/src/export.py b/src/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..0daa4d12915828078153ea444723793311439f0d
--- /dev/null
+++ b/src/export.py
@@ -0,0 +1,14 @@
+import os
+import sys
+
+ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(ROOT_PATH)
+from .model_trainer import export_model
+
+
+def main():
+    export_model()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/load.py b/src/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..8567a84254fd9ba892820d7dd736d9ac1402a0be
--- /dev/null
+++ b/src/load.py
@@ -0,0 +1,418 @@
+import os
+import torch
+import inspect
+import math
+from typing import TYPE_CHECKING, Optional, Tuple, Dict, Literal, List
+from peft import PeftModel, TaskType, LoraConfig, get_peft_model
+from peft.utils import CONFIG_NAME, WEIGHTS_NAME
+from transformers import PreTrainedModel, PreTrainedTokenizer
+from transformers.utils import check_min_version, cached_file
+from transformers.utils.versions import require_version
+from transformers import (
+    AutoConfig,
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    BitsAndBytesConfig,
+    PretrainedConfig,
+    PreTrainedModel,
+    PreTrainedTokenizerBase,
+)
+from transformers.deepspeed import is_deepspeed_zero3_enabled
+from types import MethodType
+from .config import LAYERNORM_NAMES, VALUE_HEAD_FILE_NAME
+from .config_parser import load_trainable_params
+
+if TYPE_CHECKING:
+    from transformers.modeling_utils import PreTrainedModel
+    from .model_args import ModelArguments, FinetuningArguments
+
+
+def prepare_model_for_training(
+    model: "PreTrainedModel",
+    finetuning_type: str,
+    output_layer_name: Optional[str] = "lm_head",
+    use_gradient_checkpointing: Optional[bool] = True,
+    layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES,
+) -> "PreTrainedModel":
+    for name, param in model.named_parameters():
+        if param.ndim == 1 and any(
+            layer_norm_name in name for layer_norm_name in layer_norm_names
+        ):
+            param.data = param.data.to(torch.float32)
+
+    if use_gradient_checkpointing:
+        if hasattr(model, "enable_input_require_grads"):
+            model.enable_input_require_grads()
+        else:
+
+            def make_inputs_require_grad(module, input, output):
+                output.requires_grad_(True)
+
+            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+        model.gradient_checkpointing_enable()
+        model.config.use_cache = (
+            False  # turn off when gradient checkpointing is enabled
+        )
+
+    if finetuning_type != "full" and hasattr(model, output_layer_name):
+        output_layer: torch.nn.Linear = getattr(model, output_layer_name)
+        input_dtype = output_layer.weight.dtype
+
+        class CastOutputToFloat(torch.nn.Sequential):
+            def forward(self, x: torch.Tensor) -> torch.Tensor:
+                return super().forward(x.to(input_dtype)).to(torch.float32)
+
+        setattr(model, output_layer_name, CastOutputToFloat(output_layer))
+
+    return model
+
+def init_adapter(
+    model: "PreTrainedModel",
+    model_args: "ModelArguments",
+    finetuning_args: "FinetuningArguments",
+    is_trainable: bool,
+    is_mergeable: bool,
+) -> "PreTrainedModel":
+    r"""
+    Initializes the adapters.
+
+    Support full-parameter, freeze and LoRA ,QLoRA,training.
+
+    Note that the trainable parameters must be cast to float32.
+    """
+
+    if finetuning_args.finetuning_type == "none" and is_trainable:
+        raise ValueError("You cannot use finetuning_type=none while training.")
+
+    if finetuning_args.finetuning_type == "full" and is_trainable:
+        print("Fine-tuning method: Full")
+        model = model.float()
+
+    if finetuning_args.finetuning_type == "freeze":
+        print("Fine-tuning method: Freeze")
+
+        for name, param in model.named_parameters():
+            if not any(
+                trainable_layer in name
+                for trainable_layer in finetuning_args.trainable_layers
+            ):
+                param.requires_grad_(False)
+            else:
+                param.data = param.data.to(torch.float32)
+
+        if model_args.checkpoint_dir is not None:
+            assert load_trainable_params(
+                model, model_args.checkpoint_dir[0]
+            ), "Model checkpoint is not correctly loaded."
+
+    if finetuning_args.finetuning_type == "lora":
+        print("Fine-tuning method: LoRA")
+        latest_checkpoint = None
+
+        if model_args.checkpoint_dir is not None:
+            assert os.path.exists(
+                os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)
+            ), "Provided path ({}) does not contain a LoRA weight.".format(
+                model_args.checkpoint_dir[0]
+            )
+            assert os.path.exists(
+                os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)
+            ), "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
+
+            if (is_trainable and finetuning_args.resume_lora_training) or (
+                not is_mergeable
+            ):  # continually fine-tuning
+                checkpoints_to_merge, latest_checkpoint = (
+                    model_args.checkpoint_dir[:-1],
+                    model_args.checkpoint_dir[-1],
+                )
+            else:
+                checkpoints_to_merge = model_args.checkpoint_dir
+
+            for checkpoint in checkpoints_to_merge:
+                model = PeftModel.from_pretrained(model, checkpoint)
+                model = model.merge_and_unload()
+
+            if len(checkpoints_to_merge) > 0:
+                print(
+                    "Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))
+                )
+
+            if (
+                latest_checkpoint is not None
+            ):  # resume lora training or quantized inference
+                model = PeftModel.from_pretrained(
+                    model, latest_checkpoint, is_trainable=is_trainable
+                )
+
+        if (
+            is_trainable and latest_checkpoint is None
+        ):  # create new lora weights while training
+            lora_config = LoraConfig(
+                task_type=TaskType.CAUSAL_LM,
+                inference_mode=False,
+                r=finetuning_args.lora_rank,
+                lora_alpha=finetuning_args.lora_alpha,
+                lora_dropout=finetuning_args.lora_dropout,
+                target_modules=finetuning_args.lora_target,
+            )
+            model = get_peft_model(model, lora_config)
+
+    if model_args.checkpoint_dir is not None:
+        print(
+            "Loaded fine-tuned model from checkpoint(s): {}".format(
+                ",".join(model_args.checkpoint_dir)
+            )
+        )
+
+    return model
+
+def load_model_and_tokenizer(
+    model_args: "ModelArguments",
+    finetuning_args: "FinetuningArguments",
+    is_trainable: Optional[bool] = False,
+    add_valuehead: Optional[bool] = False,
+) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
+    r"""
+    Loads pretrained model and tokenizer.
+
+    Support both training and inference.
+    """
+    if (not is_trainable) and model_args.checkpoint_dir is None:
+        print(
+            "Checkpoint is not found at evaluation, load the original model."
+        )
+        finetuning_args = FinetuningArguments(finetuning_type="none")
+
+    config_kwargs = {
+        "trust_remote_code": True,
+        "cache_dir": model_args.cache_dir,
+        "revision": model_args.model_revision,
+        "use_auth_token": True if model_args.use_auth_token else None,
+    }
+
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_args.model_name_or_path,
+        use_fast=model_args.use_fast_tokenizer,
+        split_special_tokens=model_args.split_special_tokens,
+        padding_side="right",  # training with left-padded tensors in fp16 precision may cause overflow
+        **config_kwargs
+    )
+
+    if (
+        finetuning_args.finetuning_type == "full"
+        and model_args.checkpoint_dir is not None
+    ):
+        model_to_load = model_args.checkpoint_dir[0]
+    else:
+        model_to_load = model_args.model_name_or_path
+
+    config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
+
+    if hasattr(config, "fp16") and hasattr(config, "bf16"):  # fix Qwen config
+        if model_args.compute_dtype == torch.bfloat16:
+            setattr(config, "bf16", True)
+        else:
+            setattr(config, "fp16", True)
+
+    # Fix config (for Qwen)
+    #if getattr(config, "model_type", None) == "qwen":
+    #    for dtype_name, dtype in [
+    #        ("fp16", torch.float16),
+    #        ("bf16", torch.bfloat16),
+    #        ("fp32", torch.float32),
+    #    ]:
+    #        setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
+
+    # Set RoPE scaling
+    if model_args.rope_scaling is not None:
+        if hasattr(config, "use_dynamic_ntk"):  # for Qwen models
+            if is_trainable:
+                print("Qwen model does not support RoPE scaling in training.")
+            else:
+                setattr(config, "use_dynamic_ntk", True)
+                setattr(config, "use_logn_attn", True)
+                print("Using dynamic NTK scaling.")
+
+        elif hasattr(config, "rope_scaling"):  # for LLaMA models
+            require_version(
+                "transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0"
+            )
+
+            if is_trainable:
+                if model_args.rope_scaling == "dynamic":
+                    print(
+                        "Dynamic NTK may not work well with fine-tuning. "
+                        "See: https://github.com/huggingface/transformers/pull/24653"
+                    )
+
+                current_max_length = getattr(config, "max_position_embeddings", None)
+                if (
+                    current_max_length
+                    and model_args.model_max_length > current_max_length
+                ):
+                    scaling_factor = float(
+                        math.ceil(model_args.model_max_length / current_max_length)
+                    )
+                else:
+                    print(
+                        "Input length is smaller than max length. Consider increase input length."
+                    )
+                    scaling_factor = 1.0
+            else:
+                scaling_factor = 2.0
+
+            setattr(
+                config,
+                "rope_scaling",
+                {"type": model_args.rope_scaling, "factor": scaling_factor},
+            )
+            print(
+                "Using {} scaling strategy and setting scaling factor to {}".format(
+                    model_args.rope_scaling, scaling_factor
+                )
+            )
+
+        else:
+            print("Current model does not support RoPE scaling.")
+
+    # Quantization configurations (using bitsandbytes library).
+    is_mergeable = True
+    if model_args.quantization_bit is not None:
+        if is_deepspeed_zero3_enabled():
+            raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
+
+        if model_args.quantization_bit == 8:
+            require_version(
+                "bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0"
+            )
+            config_kwargs["load_in_8bit"] = True
+            config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
+
+        elif model_args.quantization_bit == 4:
+            require_version(
+                "bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0"
+            )
+            config_kwargs["load_in_4bit"] = True
+            config_kwargs["quantization_config"] = BitsAndBytesConfig(
+                load_in_4bit=True,
+                bnb_4bit_compute_dtype=model_args.compute_dtype,
+                bnb_4bit_use_double_quant=model_args.double_quantization,
+                bnb_4bit_quant_type=model_args.quantization_type,
+            )
+
+        is_mergeable = False
+        config_kwargs["device_map"] = (
+            {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
+        )
+        print("Quantizing model to {} bit.".format(model_args.quantization_bit))
+
+    # Load and prepare pre-trained models (without valuehead).
+    model = AutoModelForCausalLM.from_pretrained(
+        model_to_load,
+        config=config,
+        torch_dtype=model_args.compute_dtype,
+        low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
+        **config_kwargs
+    )
+
+    # Disable custom generate method (for Qwen)
+    #if "GenerationMixin" not in str(model.generate.__func__):
+    #    model.generate = MethodType(PreTrainedModel.generate, model)
+
+    # Fix LM head (for ChatGLM2,ChatGLM3)
+    #if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
+    #    setattr(model, "lm_head", model.transformer.output_layer)
+
+    # Register auto class to save the custom code files.
+    if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(
+        config, "auto_map", {}
+    ):
+        config.__class__.register_for_auto_class()
+    if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(
+        config, "auto_map", {}
+    ):
+        model.__class__.register_for_auto_class()
+    if isinstance(
+        tokenizer, PreTrainedTokenizerBase
+    ) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
+        tokenizer.__class__.register_for_auto_class()
+
+    # Initialize adapters
+    model = (
+        prepare_model_for_training(model, finetuning_args.finetuning_type)
+        if is_trainable
+        else model
+    )
+    model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
+
+    # Prepare model with valuehead for RLHF
+    #if add_valuehead:
+    #    model: "AutoModelForCausalLMWithValueHead" = (
+    #        AutoModelForCausalLMWithValueHead.from_pretrained(model)
+    #    )
+    #    ignore_modules = [
+    #        name for name, _ in model.named_parameters() if "pretrained_model" in name
+    #    ]
+    #    setattr(model, "_keys_to_ignore_on_save", ignore_modules)
+    #    setattr(
+    #        model, "tie_weights", MethodType(lambda _: None, model)
+    #    )  # use empty method
+    #    vhead_path = (
+    #        model_args.checkpoint_dir[-1]
+    #        if model_args.checkpoint_dir is not None
+    #        else model_args.model_name_or_path
+    #    )
+    #    vhead_params = load_valuehead_params(vhead_path, model_args)
+    #    if vhead_params is not None:
+    #         model.load_state_dict(vhead_params, strict=False)
+    #         logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
+
+    # Prepare model for inference
+    if not is_trainable:
+        model.requires_grad_(False)  # fix all model params
+        infer_dtype = (
+            torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        )  # detect cuda capability
+        model = model.to(infer_dtype) if model_args.quantization_bit is None else model
+
+    #trainable_params, all_param = count_parameters(model)
+    #print(
+    #    "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
+    #        trainable_params, all_param, 100 * trainable_params / all_param
+    #    )
+    #)
+
+    return model, tokenizer
+
+def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
+    r"""
+    Dispatches a pre-trained model to GPUs with balanced memory.
+    Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
+    """
+    if getattr(model, "is_loaded_in_8bit", False) or getattr(
+        model, "is_loaded_in_4bit", False
+    ):  # do nothing
+        return model
+
+    if torch.cuda.device_count() > 1:
+        from accelerate import dispatch_model
+        from accelerate.utils import infer_auto_device_map, get_balanced_memory
+
+        if model._no_split_modules is None:
+            raise ValueError(
+                "The model class needs to implement the `_no_split_modules` attribute."
+            )
+
+        kwargs = {
+            "dtype": model.dtype,
+            "no_split_module_classes": model._no_split_modules,
+        }
+        max_memory = get_balanced_memory(model, **kwargs)
+        # Make sure tied weights are tied before creating the device map.
+        model.tie_weights()
+        device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
+        return dispatch_model(model, device_map)
+    else:
+        return model.cuda()
diff --git a/src/loggings.py b/src/loggings.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a9b4fa70e0018f10cc17c1a1943dc705c9ced90
--- /dev/null
+++ b/src/loggings.py
@@ -0,0 +1,227 @@
+import sys
+import logging
+import os
+import json
+import time
+from typing import TYPE_CHECKING
+from datetime import timedelta
+from transformers import TrainerCallback
+from transformers.trainer_utils import has_length
+from .config import LOG_FILE_NAME
+
+if TYPE_CHECKING:
+    from transformers import TrainingArguments, TrainerState, TrainerControl
+
+
+def reset_logging():
+    r"""
+    Removes basic config of root logger
+    """
+    root = logging.getLogger()
+    list(map(root.removeHandler, root.handlers))
+    list(map(root.removeFilter, root.filters))
+
+
+def get_logger(name: str) -> logging.Logger:
+    formatter = logging.Formatter(
+        fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+    )
+    handler = logging.StreamHandler(sys.stdout)
+    handler.setFormatter(formatter)
+
+    logger = logging.getLogger(name)
+    logger.setLevel(logging.INFO)
+    logger.addHandler(handler)
+
+    return logger
+
+
+logger = get_logger(__name__)
+
+
+class LoggerHandler(logging.Handler):
+    def __init__(self):
+        super().__init__()
+        self.log = ""
+
+    def reset(self):
+        self.log = ""
+
+    def emit(self, record):
+        if record.name == "httpx":
+            return
+        log_entry = self.format(record)
+        self.log += log_entry
+        self.log += "\n\n"
+
+
+class LogCallback(TrainerCallback):
+    def __init__(self, runner=None):
+        self.runner = runner
+        self.in_training = False
+        self.start_time = time.time()
+        self.cur_steps = 0
+        self.max_steps = 0
+        self.elapsed_time = ""
+        self.remaining_time = ""
+
+    def timing(self):
+        cur_time = time.time()
+        elapsed_time = cur_time - self.start_time
+        avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
+        remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
+        self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
+        self.remaining_time = str(timedelta(seconds=int(remaining_time)))
+
+    def on_train_begin(
+        self,
+        args: "TrainingArguments",
+        state: "TrainerState",
+        control: "TrainerControl",
+        **kwargs
+    ):
+        r"""
+        Event called at the beginning of training.
+        """
+        if state.is_local_process_zero:
+            self.in_training = True
+            self.start_time = time.time()
+            self.max_steps = state.max_steps
+            if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
+                logger.warning("Previous log file in this folder will be deleted.")
+                os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
+
+    def on_train_end(
+        self,
+        args: "TrainingArguments",
+        state: "TrainerState",
+        control: "TrainerControl",
+        **kwargs
+    ):
+        r"""
+        Event called at the end of training.
+        """
+        if state.is_local_process_zero:
+            self.in_training = False
+            self.cur_steps = 0
+            self.max_steps = 0
+
+    def on_substep_end(
+        self,
+        args: "TrainingArguments",
+        state: "TrainerState",
+        control: "TrainerControl",
+        **kwargs
+    ):
+        r"""
+        Event called at the end of an substep during gradient accumulation.
+        """
+        if (
+            state.is_local_process_zero
+            and self.runner is not None
+            and self.runner.aborted
+        ):
+            control.should_epoch_stop = True
+            control.should_training_stop = True
+
+    def on_step_end(
+        self,
+        args: "TrainingArguments",
+        state: "TrainerState",
+        control: "TrainerControl",
+        **kwargs
+    ):
+        r"""
+        Event called at the end of a training step.
+        """
+        if state.is_local_process_zero:
+            self.cur_steps = state.global_step
+            self.timing()
+            if self.runner is not None and self.runner.aborted:
+                control.should_epoch_stop = True
+                control.should_training_stop = True
+
+    def on_evaluate(
+        self,
+        args: "TrainingArguments",
+        state: "TrainerState",
+        control: "TrainerControl",
+        **kwargs
+    ):
+        r"""
+        Event called after an evaluation phase.
+        """
+        if state.is_local_process_zero and not self.in_training:
+            self.cur_steps = 0
+            self.max_steps = 0
+
+    def on_predict(
+        self,
+        args: "TrainingArguments",
+        state: "TrainerState",
+        control: "TrainerControl",
+        *other,
+        **kwargs
+    ):
+        r"""
+        Event called after a successful prediction.
+        """
+        if state.is_local_process_zero and not self.in_training:
+            self.cur_steps = 0
+            self.max_steps = 0
+
+    def on_log(
+        self,
+        args: "TrainingArguments",
+        state: "TrainerState",
+        control: "TrainerControl",
+        **kwargs
+    ) -> None:
+        r"""
+        Event called after logging the last logs.
+        """
+        if not state.is_local_process_zero:
+            return
+
+        logs = dict(
+            current_steps=self.cur_steps,
+            total_steps=self.max_steps,
+            loss=state.log_history[-1].get("loss", None),
+            eval_loss=state.log_history[-1].get("eval_loss", None),
+            predict_loss=state.log_history[-1].get("predict_loss", None),
+            reward=state.log_history[-1].get("reward", None),
+            learning_rate=state.log_history[-1].get("learning_rate", None),
+            epoch=state.log_history[-1].get("epoch", None),
+            percentage=round(self.cur_steps / self.max_steps * 100, 2)
+            if self.max_steps != 0
+            else 100,
+            elapsed_time=self.elapsed_time,
+            remaining_time=self.remaining_time,
+        )
+        os.makedirs(args.output_dir, exist_ok=True)
+        with open(
+            os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8"
+        ) as f:
+            f.write(json.dumps(logs) + "\n")
+
+    def on_prediction_step(
+        self,
+        args: "TrainingArguments",
+        state: "TrainerState",
+        control: "TrainerControl",
+        **kwargs
+    ):
+        r"""
+        Event called after a prediction step.
+        """
+        eval_dataloader = kwargs.pop("eval_dataloader", None)
+        if (
+            state.is_local_process_zero
+            and has_length(eval_dataloader)
+            and not self.in_training
+        ):
+            if self.max_steps == 0:
+                self.max_steps = len(eval_dataloader)
+            self.cur_steps += 1
+            self.timing()
diff --git a/src/model_args.py b/src/model_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..61ad2667a096b185f1c4761725a794409857821b
--- /dev/null
+++ b/src/model_args.py
@@ -0,0 +1,413 @@
+import json
+import torch
+from dataclasses import dataclass, field, asdict
+from typing import Optional, Any, Dict, Literal
+from transformers import Seq2SeqTrainingArguments
+from .config import (
+    MODEL_PATH,
+    ADAPTER_PATH,
+)
+
+
+@dataclass
+class ModelArguments:
+    r"""
+    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
+    """
+    model_name_or_path: str = field(
+        metadata={
+            "help": "Path to pretrained model or model identifier from huggingface.co/models."
+        }
+    )
+    cache_dir: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Where to store the pretrained models downloaded from huggingface.co."
+        },
+    )
+    use_fast_tokenizer: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
+        },
+    )
+    use_auth_token: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "Will use the token generated when running `huggingface-cli login`."
+        },
+    )
+    model_revision: Optional[str] = field(
+        default="main",
+        metadata={
+            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
+        },
+    )
+    padding_side: Optional[Literal["left", "right"]] = field(
+        default="left",
+        metadata={"help": "The side on which the model should have padding applied."},
+    )
+    quantization_bit: Optional[int] = field(
+        default=None, metadata={"help": "The number of bits to quantize the model."}
+    )
+    quantization_type: Optional[Literal["fp4", "nf4"]] = field(
+        default="nf4",
+        metadata={"help": "Quantization data type to use in int4 training."},
+    )
+    double_quantization: Optional[bool] = field(
+        default=True,
+        metadata={
+            "help": "Whether to use double quantization in int4 training or not."
+        },
+    )
+    rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
+        default=None, metadata={"help": "Adopt scaled rotary positional embeddings."}
+    )
+    checkpoint_dir: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."
+        },
+    )
+    # reward_model: Optional[str] = field(
+    #     default=None,
+    #     metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
+    # )
+    plot_loss: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "Whether to plot the training loss after fine-tuning or not."
+        },
+    )
+    hf_auth_token: Optional[str] = field(
+        default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}
+    )
+    compute_dtype: Optional[torch.dtype] = field(
+        default=None,
+        metadata={
+            "help": "Used in quantization configs. Do not specify this argument manually."
+        },
+    )
+    model_max_length: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": "Used in rope scaling. Do not specify this argument manually."
+        },
+    )
+    hf_hub_token: Optional[str] = field(
+        default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}
+    )
+    split_special_tokens: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "Whether or not the special tokens should be split during the tokenization process."
+        },
+    )
+
+    def __post_init__(self):
+        if self.compute_dtype is not None or self.model_max_length is not None:
+            raise ValueError("These arguments cannot be specified.")
+
+        if self.checkpoint_dir is not None:  # support merging multiple lora weights
+            self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
+
+        if self.quantization_bit is not None:
+            assert self.quantization_bit in [
+                4,
+                8,
+            ], "We only accept 4-bit or 8-bit quantization."
+
+        if self.use_auth_token == True and self.hf_auth_token is not None:
+            from huggingface_hub.hf_api import HfFolder  # lazy load
+
+            HfFolder.save_token(self.hf_auth_token)
+
+
+@dataclass
+class GeneratingArguments:
+    r"""
+    Arguments pertaining to specify the decoding parameters.
+    """
+    do_sample: Optional[bool] = field(
+        default=True,
+        metadata={
+            "help": "Whether or not to use sampling, use greedy decoding otherwise."
+        },
+    )
+    temperature: Optional[float] = field(
+        default=0.95,
+        metadata={"help": "The value used to modulate the next token probabilities."},
+    )
+    top_p: Optional[float] = field(
+        default=0.7,
+        metadata={
+            "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
+        },
+    )
+    top_k: Optional[int] = field(
+        default=50,
+        metadata={
+            "help": "The number of highest probability vocabulary tokens to keep for top-k filtering."
+        },
+    )
+    num_beams: Optional[int] = field(
+        default=1,
+        metadata={"help": "Number of beams for beam search. 1 means no beam search."},
+    )
+    max_length: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."
+        },
+    )
+    max_new_tokens: Optional[int] = field(
+        default=512,
+        metadata={
+            "help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."
+        },
+    )
+    repetition_penalty: Optional[float] = field(
+        default=1.0,
+        metadata={
+            "help": "The parameter for repetition penalty. 1.0 means no penalty."
+        },
+    )
+    length_penalty: Optional[float] = field(
+        default=1.0,
+        metadata={
+            "help": "Exponential penalty to the length that is used with beam-based generation."
+        },
+    )
+
+    def to_dict(self) -> Dict[str, Any]:
+        args = asdict(self)
+        if args.get("max_new_tokens", None):
+            args.pop("max_length", None)
+        return args
+
+
+@dataclass
+class FinetuningArguments:
+    r"""
+    Arguments pertaining to which techniques we are going to fine-tuning with.
+    """
+    stage: Optional[Literal["sft", "rm"]] = field(
+        default="sft", metadata={"help": "Which stage will be performed in training."}
+    )
+    finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
+        default="lora", metadata={"help": "Which fine-tuning method to use."}
+    )
+    num_hidden_layers: Optional[int] = field(
+        default=32,
+        metadata={
+            "help": 'Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
+                  LLaMA choices: ["32", "40", "60", "80"], \
+                  LLaMA-2 choices: ["32", "40", "80"], \
+                  BLOOM choices: ["24", "30", "70"], \
+                  Falcon choices: ["32", "60"], \
+                  Baichuan choices: ["32", "40"] \
+                  Qwen choices: ["32"], \
+                  XVERSE choices: ["40"], \
+                  ChatGLM2 choices: ["28"],\
+                  ChatGLM3 choices: ["28"]'
+        },
+    )
+    num_layer_trainable: Optional[int] = field(
+        default=3,
+        metadata={
+            "help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."
+        },
+    )
+    name_module_trainable: Optional[
+        Literal["mlp", "self_attn", "self_attention"]
+    ] = field(
+        default="mlp",
+        metadata={
+            "help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \
+                  LLaMA choices: ["mlp", "self_attn"], \
+                  BLOOM & Falcon & ChatGLM2   & ChatGLM3choices: ["mlp", "self_attention"], \
+                  Baichuan choices: ["mlp", "self_attn"], \
+                  Qwen choices: ["mlp", "attn"], \
+                  LLaMA-2, InternLM, XVERSE choices: the same as LLaMA.'
+        },
+    )
+    lora_rank: Optional[int] = field(
+        default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
+    )
+    lora_alpha: Optional[float] = field(
+        default=32.0,
+        metadata={
+            "help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."
+        },
+    )
+    lora_dropout: Optional[float] = field(
+        default=0.1, metadata={"help": "Dropout rate for the LoRA fine-tuning."}
+    )
+    lora_target: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
+                  LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
+                  BLOOM & Falcon & ChatGLM2  & ChatGLM3 choices: ["query_key_value", "self_attention.dense", "mlp.dense"], \
+                  Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
+                  Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
+                  LLaMA-2, InternLM, XVERSE choices: the same as LLaMA.'
+        },
+    )
+    resume_lora_training: Optional[bool] = field(
+        default=True,
+        metadata={
+            "help": "Whether to resume training from the last LoRA weights or create new weights after merging them."
+        },
+    )
+    ppo_score_norm: Optional[bool] = field(
+        default=False, metadata={"help": "Use score normalization in PPO Training."}
+    )
+    dpo_beta: Optional[float] = field(
+        default=0.1, metadata={"help": "The beta parameter for the DPO loss."}
+    )
+
+    def __post_init__(self):
+        if isinstance(
+            self.lora_target, str
+        ):  # support custom target modules/layers of LoRA
+            self.lora_target = [
+                target.strip() for target in self.lora_target.split(",")
+            ]
+
+        if (
+            self.num_layer_trainable > 0
+        ):  # fine-tuning the last n layers if num_layer_trainable > 0
+            trainable_layer_ids = [
+                self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)
+            ]
+        else:  # fine-tuning the first n layers if num_layer_trainable < 0
+            trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
+
+        self.trainable_layers = [
+            "{:d}.{}".format(idx, self.name_module_trainable)
+            for idx in trainable_layer_ids
+        ]
+
+        assert self.finetuning_type in [
+            "lora",
+            "freeze",
+            "full",
+            "none",
+        ], "Invalid fine-tuning method."
+
+    def save_to_json(self, json_path: str):
+        r"""Saves the content of this instance in JSON format inside `json_path`."""
+        json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
+        with open(json_path, "w", encoding="utf-8") as f:
+            f.write(json_string)
+
+    @classmethod
+    def load_from_json(cls, json_path: str):
+        r"""Creates an instance from the content of `json_path`."""
+        with open(json_path, "r", encoding="utf-8") as f:
+            text = f.read()
+        return cls(**json.loads(text))
+
+
+@dataclass
+class TrainingArguments(Seq2SeqTrainingArguments):
+    cache_dir: Optional[str] = field(default=None)
+    train_on_source: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "Whether to train on the input in addition to the target text."
+        },
+    )
+    full_finetune: bool = field(
+        default=False, metadata={"help": "Finetune the entire model without adapters."}
+    )
+    do_train: bool = field(
+        default=True,
+        metadata={"help": "To train or not to train, that is the question?"},
+    )
+    sample_generate: bool = field(
+        default=False, metadata={"help": "If do sample generation on evaluation."}
+    )
+    optim: str = field(
+        default="paged_adamw_32bit", metadata={"help": "The optimizer to be used"}
+    )
+    max_grad_norm: float = field(
+        default=0.3,
+        metadata={
+            "help": "Gradient clipping max norm. This is tuned and works well for all models tested."
+        },
+    )
+    gradient_checkpointing: bool = field(
+        default=True,
+        metadata={"help": "Use gradient checkpointing. You want to use this."},
+    )
+    predict_with_generate: bool = field(
+        default=False,
+        metadata={
+            "help": "Group sequences into batches with same length. Saves memory and speeds up training considerably."
+        },
+    )
+    model_max_length: int = field(
+        default=2048,
+        metadata={
+            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+        },
+    )
+    output_dir: str = field(
+        default=ADAPTER_PATH,
+        metadata={"help": "The output dir for logs and checkpoints"},
+    )
+    per_device_train_batch_size: int = field(
+        default=1,
+        metadata={
+            "help": "The training batch size per GPU. Increase for better speed."
+        },
+    )
+    gradient_accumulation_steps: int = field(
+        default=16,
+        metadata={
+            "help": "How many gradients to accumulate before to perform an optimizer step"
+        },
+    )
+    max_steps: int = field(
+        default=10000, metadata={"help": "How many optimizer update steps to take"}
+    )
+    # use lora dropout instead for regularization if needed
+    weight_decay: float = field(
+        default=0.0, metadata={"help": "The L2 weight decay rate of AdamW"}
+    )
+    learning_rate: float = field(default=0.0002, metadata={"help": "The learnign rate"})
+    remove_unused_columns: bool = field(
+        default=False,
+        metadata={"help": "Removed unused columns. Needed to make this codebase work."},
+    )
+    lr_scheduler_type: str = field(
+        default="constant",
+        metadata={
+            "help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"
+        },
+    )
+    warmup_ratio: float = field(
+        default=0.03, metadata={"help": "Fraction of steps to do a warmup for"}
+    )
+    logging_steps: int = field(
+        default=10,
+        metadata={"help": "The frequency of update steps after which to log the loss"},
+    )
+    group_by_length: bool = field(
+        default=True,
+        metadata={
+            "help": "Group sequences into batches with same length. Saves memory and speeds up training considerably."
+        },
+    )
+    save_strategy: str = field(
+        default="steps", metadata={"help": "When to save checkpoints"}
+    )
+    save_steps: int = field(default=250, metadata={"help": "How often to save a model"})
+    save_total_limit: int = field(
+        default=40,
+        metadata={
+            "help": "How many checkpoints to save before the oldest is overwritten"
+        },
+    )
diff --git a/src/model_trainer.py b/src/model_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae8544d948e3f7672966ec9d99f464c8c89bb9aa
--- /dev/null
+++ b/src/model_trainer.py
@@ -0,0 +1,412 @@
+import os
+import json
+import torch
+import numpy as np
+import torch.nn as nn
+import jieba
+import matplotlib.pyplot as plt
+import math
+from rouge_chinese import Rouge
+from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
+from dataclasses import dataclass
+from .config import IGNORE_INDEX
+from .loggings import get_logger
+from .config_parser import (
+    get_train_args,
+    get_state_dict,
+    load_trainable_params,
+)
+from .load import load_model_and_tokenizer
+from .config import VALUE_HEAD_FILE_NAME, FINETUNING_ARGS_NAME
+from transformers import Seq2SeqTrainer
+from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
+from transformers.modeling_utils import (
+    PreTrainedModel,
+    unwrap_model,
+    load_sharded_checkpoint,
+)
+from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, TRAINER_STATE_NAME
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import LogitsProcessorList
+
+from peft import PeftModel
+from trl import PreTrainedModelWrapper
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Sequence
+
+
+if TYPE_CHECKING:
+    from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
+    from transformers.trainer import PredictionOutput
+    from .model_args import FinetuningArguments
+
+
+logger = get_logger(__name__)
+
+
+class PeftModelMixin:
+    r"""
+    Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
+    """
+
+    def __init__(self) -> None:  # for type checking
+        self.model: PreTrainedModel = None
+        self.tokenizer: "PreTrainedTokenizer" = None
+        self.args: "Seq2SeqTrainingArguments" = None
+        self.finetuning_args: "FinetuningArguments" = None
+        self.state: "TrainerState" = None
+        raise AssertionError("Mixin should not be initialized.")
+
+    def _save(
+        self,
+        output_dir: Optional[str] = None,
+        state_dict: Optional[Dict[str, torch.Tensor]] = None,
+    ) -> None:
+        r"""
+        Saves trainable parameters as model checkpoint.
+
+        This function will only be executed at the process zero.
+
+        Subclass and override to inject custom behavior. It should not be directly used by external scripts.
+        """
+        output_dir = output_dir if output_dir is not None else self.args.output_dir
+        os.makedirs(output_dir, exist_ok=True)
+        logger.info(f"Saving model checkpoint to {output_dir}")
+
+        model = unwrap_model(self.model)
+        if isinstance(model, PreTrainedModelWrapper):
+            # Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
+            model_state_dict = state_dict or model.state_dict()
+            v_head_state_dict = {
+                name.replace("v_head.", ""): model_state_dict[name]
+                .cpu()
+                .clone()
+                .detach()
+                for name in model_state_dict.keys()
+                if name.startswith("v_head.")
+            }
+
+            torch.save(
+                v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME)
+            )
+            model = model.pretrained_model
+
+        state_dict = state_dict or get_state_dict(model)
+        if isinstance(model, (PeftModel, PreTrainedModel)):
+            model.config.use_cache = True
+            model.save_pretrained(
+                output_dir,
+                state_dict=state_dict,
+                safe_serialization=self.args.save_safetensors,
+            )
+            model.config.use_cache = False
+        else:
+            torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+
+        if (
+            self.finetuning_args.finetuning_type == "full"
+            and self.tokenizer is not None
+        ):
+            try:
+                self.tokenizer.save_pretrained(output_dir)
+            except:
+                logger.warning("Cannot save tokenizer, copy the files manually.")
+
+        with open(
+            os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8"
+        ) as f:
+            f.write(self.args.to_json_string() + "\n")
+
+        self.finetuning_args.save_to_json(
+            os.path.join(output_dir, FINETUNING_ARGS_NAME)
+        )
+
+    def _load_best_model(self):
+        r"""
+        Loads trainable parameters from model checkpoint.
+
+        Subclass and override to inject custom behavior. It should not be directly used by external scripts.
+        """
+        logger.info(
+            f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
+        )
+        model = unwrap_model(self.model)
+
+        if isinstance(model, PreTrainedModelWrapper):
+            model.v_head.load_state_dict(
+                torch.load(
+                    os.path.join(
+                        self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME
+                    ),
+                    map_location="cpu",
+                )
+            )
+            model = model.pretrained_model
+
+        if isinstance(model, PeftModel):
+            model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
+        else:  # freeze/full-tuning
+            load_trainable_params(model, self.state.best_model_checkpoint)
+
+
+class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
+    r"""
+    Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
+    """
+
+    def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
+        Seq2SeqTrainer.__init__(self, **kwargs)
+        self.finetuning_args = finetuning_args
+
+
+class Seq2SeqPeftTrainer(PeftTrainer):
+    r"""
+    Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
+    """
+
+    def prediction_step(
+        self,
+        model: nn.Module,
+        inputs: Dict[str, Union[torch.Tensor, Any]],
+        prediction_loss_only: bool,
+        ignore_keys: Optional[List[str]] = None,
+    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+        r"""
+        Removes the prompt part in the generated tokens.
+
+        Subclass and override to inject custom behavior.
+        """
+        prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
+        if prompt_len > label_len:
+            inputs["labels"] = self._pad_tensors_to_target_len(
+                inputs["labels"], inputs["input_ids"]
+            )
+        if label_len > prompt_len:
+            inputs["input_ids"] = self._pad_tensors_to_target_len(
+                inputs["input_ids"], inputs["labels"]
+            )
+            if "attention_mask" in inputs:
+                inputs["attention_mask"] = self._pad_tensors_to_target_len(
+                    inputs["attention_mask"], inputs["labels"], pad_token_id=0
+                )
+            if "position_ids" in inputs:
+                inputs["position_ids"] = self._pad_tensors_to_target_len(
+                    inputs["position_ids"], inputs["labels"], pad_token_id=0
+                )
+
+        loss, generated_tokens, labels = super().prediction_step(
+            model,
+            inputs,
+            prediction_loss_only=prediction_loss_only,
+            ignore_keys=ignore_keys,
+        )
+        if generated_tokens is not None:
+            generated_tokens[
+                :, : max(prompt_len, label_len)
+            ] = self.tokenizer.pad_token_id * torch.ones_like(
+                generated_tokens[:, : max(prompt_len, label_len)]
+            )
+
+        return loss, generated_tokens, labels
+
+    def _pad_tensors_to_target_len(
+        self,
+        src_tensor: torch.Tensor,
+        tgt_tensor: torch.Tensor,
+        pad_token_id: Optional[int] = None,
+    ) -> torch.Tensor:
+        r"""
+        Pads the tensor to the same length as the target tensor.
+
+        Should only be called when predict_with_generate=True.
+        """
+        if pad_token_id is None:
+            if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
+                assert (
+                    self.tokenizer.padding_side == "left"
+                ), "This method only accepts left-padded tensor."
+                pad_token_id = self.tokenizer.pad_token_id
+            else:
+                raise ValueError("PAD token is required.")
+
+        padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
+        padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor  # adopt left-padding
+        return padded_tensor.contiguous()  # in contiguous memory
+
+    def save_predictions(self, predict_results: "PredictionOutput") -> None:
+        r"""
+        Saves model predictions to `output_dir`.
+
+        A custom behavior that not contained in Seq2SeqTrainer.
+        """
+        if not self.is_world_process_zero():
+            return
+
+        output_prediction_file = os.path.join(
+            self.args.output_dir, "generated_predictions.jsonl"
+        )
+        logger.info(f"Saving prediction results to {output_prediction_file}")
+
+        preds = np.where(
+            predict_results.predictions != IGNORE_INDEX,
+            predict_results.predictions,
+            self.tokenizer.pad_token_id,
+        )
+        labels = np.where(
+            predict_results.label_ids != IGNORE_INDEX,
+            predict_results.label_ids,
+            self.tokenizer.pad_token_id,
+        )
+
+        decoded_preds = self.tokenizer.batch_decode(
+            preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
+        )
+        decoded_labels = self.tokenizer.batch_decode(
+            labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
+        )
+
+        with open(output_prediction_file, "w", encoding="utf-8") as writer:
+            res: List[str] = []
+            for pred, label in zip(decoded_preds, decoded_labels):
+                res.append(
+                    json.dumps({"label": label, "predict": pred}, ensure_ascii=False)
+                )
+            writer.write("\n".join(res))
+
+
+@dataclass
+class ComputeMetrics:
+    r"""
+    Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
+    """
+
+    tokenizer: "PreTrainedTokenizer"
+
+    def __call__(
+        self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]
+    ) -> Dict[str, float]:
+        r"""
+        Uses the model predictions to compute metrics.
+        """
+        preds, labels = eval_preds
+        score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
+
+        preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
+        labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
+
+        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
+        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+
+        for pred, label in zip(decoded_preds, decoded_labels):
+            hypothesis = list(jieba.cut(pred))
+            reference = list(jieba.cut(label))
+
+            if (
+                len(" ".join(hypothesis).split()) == 0
+                or len(" ".join(reference).split()) == 0
+            ):
+                result = {
+                    "rouge-1": {"f": 0.0},
+                    "rouge-2": {"f": 0.0},
+                    "rouge-l": {"f": 0.0},
+                }
+            else:
+                rouge = Rouge()
+                scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
+                result = scores[0]
+
+            for k, v in result.items():
+                score_dict[k].append(round(v["f"] * 100, 4))
+
+            bleu_score = sentence_bleu(
+                [list(label)],
+                list(pred),
+                smoothing_function=SmoothingFunction().method3,
+            )
+            score_dict["bleu-4"].append(round(bleu_score * 100, 4))
+
+        return {k: float(np.mean(v)) for k, v in score_dict.items()}
+
+
+# Avoid runtime error in model.generate(do_sample=True).
+class InvalidScoreLogitsProcessor(LogitsProcessor):
+    def __call__(
+        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
+    ) -> torch.FloatTensor:
+        if torch.isnan(scores).any() or torch.isinf(scores).any():
+            scores.zero_()
+            scores[..., 0] = 1.0
+        return scores
+
+
+def get_logits_processor() -> LogitsProcessorList:
+    logits_processor = LogitsProcessorList()
+    logits_processor.append(InvalidScoreLogitsProcessor())
+    return logits_processor
+
+
+# metric used
+def smooth(scalars: List[float]) -> List[float]:
+    r"""
+    EMA implementation according to TensorBoard.
+    """
+    last = scalars[0]
+    smoothed = list()
+    weight = 1.8 * (
+        1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5
+    )  # a sigmoid function
+    for next_val in scalars:
+        smoothed_val = last * weight + (1 - weight) * next_val
+        smoothed.append(smoothed_val)
+        last = smoothed_val
+    return smoothed
+
+
+def plot_loss(
+    save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]
+) -> None:
+    with open(
+        os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8"
+    ) as f:
+        data = json.load(f)
+
+    for key in keys:
+        steps, metrics = [], []
+        for i in range(len(data["log_history"])):
+            if key in data["log_history"][i]:
+                steps.append(data["log_history"][i]["step"])
+                metrics.append(data["log_history"][i][key])
+
+        if len(metrics) == 0:
+            logger.warning(f"No metric {key} to plot.")
+            continue
+
+        plt.figure()
+        plt.plot(steps, metrics, alpha=0.4, label="original")
+        plt.plot(steps, smooth(metrics), label="smoothed")
+        plt.title("training {} of {}".format(key, save_dictionary))
+        plt.xlabel("step")
+        plt.ylabel(key)
+        plt.legend()
+        plt.savefig(
+            os.path.join(save_dictionary, "training_{}.png".format(key)),
+            format="png",
+            dpi=100,
+        )
+        print(
+            "Figure saved:",
+            os.path.join(save_dictionary, "training_{}.png".format(key)),
+        )
+
+
+def export_model(
+    args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"
+):
+    model_args, _, training_args, finetuning_args, _ = get_train_args(
+        args, data_args_init=False
+    )
+    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
+    model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
+    try:
+        tokenizer.save_pretrained(training_args.output_dir)
+    except:
+        logger.warning("Cannot save tokenizer, please copy the files manually.")
diff --git a/src/output/logs/pred_test_20240717_1311.log b/src/output/logs/pred_test_20240717_1311.log
new file mode 100644
index 0000000000000000000000000000000000000000..c25ce3ba0d648dc5885b0dba3e1fb5ad4e340908
--- /dev/null
+++ b/src/output/logs/pred_test_20240717_1311.log
@@ -0,0 +1,8 @@
+ Pred Start time: 2024-07-17 13:11:06
+############pred end###############
+pred End time: Wed Jul 17 01:11:06 PM CEST 2024
+Time elapsed:   hour 0 min 
+ Pred Start time: 2024-07-17 13:11:48
+############pred end###############
+pred End time: Wed Jul 17 01:11:48 PM CEST 2024
+Time elapsed:   hour 0 min 
diff --git a/src/output/logs/pred_test_20240717_1312.log b/src/output/logs/pred_test_20240717_1312.log
new file mode 100644
index 0000000000000000000000000000000000000000..0cd485e8ddf5bb3721005a43ce542747130e8596
--- /dev/null
+++ b/src/output/logs/pred_test_20240717_1312.log
@@ -0,0 +1,4 @@
+ Pred Start time: 2024-07-17 13:12:11
+############pred end###############
+pred End time: Wed Jul 17 01:12:11 PM CEST 2024
+Time elapsed:   hour 0 min 
diff --git a/src/output/logs/pred_test_20240717_1313.log b/src/output/logs/pred_test_20240717_1313.log
new file mode 100644
index 0000000000000000000000000000000000000000..1e24ed1bdc1cbf084c6e0851dd41cb65e5045a80
--- /dev/null
+++ b/src/output/logs/pred_test_20240717_1313.log
@@ -0,0 +1,5 @@
+ Pred Start time: 2024-07-17 13:13:19
+[2024-07-17 13:13:27,533] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
+############pred end###############
+pred End time: Wed Jul 17 01:13:35 PM CEST 2024
+Time elapsed:   hour 0 min 
diff --git a/src/output/logs/pred_test_20240717_1315.log b/src/output/logs/pred_test_20240717_1315.log
new file mode 100644
index 0000000000000000000000000000000000000000..1be9d95b11a3d4985d2ca5a70a15eac237d4d43f
--- /dev/null
+++ b/src/output/logs/pred_test_20240717_1315.log
@@ -0,0 +1,5 @@
+ Pred Start time: 2024-07-17 13:15:49
+[2024-07-17 13:15:56,605] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
+############pred end###############
+pred End time: Wed Jul 17 01:16:03 PM CEST 2024
+Time elapsed:   hour 0 min 
diff --git a/src/output/logs/pred_test_20240717_1316.log b/src/output/logs/pred_test_20240717_1316.log
new file mode 100644
index 0000000000000000000000000000000000000000..359ca98c8d82c9f3571529bd845e218a69a1a957
--- /dev/null
+++ b/src/output/logs/pred_test_20240717_1316.log
@@ -0,0 +1,5 @@
+ Pred Start time: 2024-07-17 13:16:58
+[2024-07-17 13:17:03,895] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
+############pred end###############
+pred End time: Wed Jul 17 01:17:08 PM CEST 2024
+Time elapsed:   hour 0 min 
diff --git a/src/output/logs/pred_test_20240717_1317.log b/src/output/logs/pred_test_20240717_1317.log
new file mode 100644
index 0000000000000000000000000000000000000000..674fe338f693d9b7a34577b7a850f8cb0c549ced
--- /dev/null
+++ b/src/output/logs/pred_test_20240717_1317.log
@@ -0,0 +1,4 @@
+ Pred Start time: 2024-07-17 13:17:44
+############pred end###############
+pred End time: Wed Jul 17 01:17:44 PM CEST 2024
+Time elapsed:   hour 0 min 
diff --git a/src/predict.py b/src/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..2860262cf13a90762dc08d4483d80844d35cab11
--- /dev/null
+++ b/src/predict.py
@@ -0,0 +1,111 @@
+import os
+#from unsloth import FastLanguageModel
+import torch
+from transformers import TrainingArguments, logging
+from datasets import load_dataset
+import json
+import re
+from tqdm import tqdm
+from typing import List, Dict, Optional, Any
+from .chat_model import ChatModel
+from. data_args import SQL_PROMPT_DICT, CR_PROMPT_DICT, DEFAULT_PROMPT_DICT
+
+
+logging.set_verbosity_error()
+
+#file = "Spider/alpaca_sft_prompts/dev.jsonl"
+#dataset = load_dataset("json", data_files = {"dev" : file}, split = "dev")
+
+#model, tokenizer = FastLanguageModel.from_pretrained(
+#        model_name = "Wangzaistone123/CodeLlama-13b-sql-lora", # YOUR MODEL YOU USED FOR TRAINING
+#        max_seq_length = 2048,
+#        dtype = None,
+#        load_in_4bit = False,
+#    )
+#FastLanguageModel.for_inference(model)
+
+#def generate_text(text):
+#    inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
+#    outputs = model.generate(**inputs, max_new_tokens=20)
+#from dbgpt_hub.llm_base.config_parser import load_trainable_params
+
+def prepare_dataset(
+    predict_file_path: Optional[str] = None,
+) -> List[Dict]:
+    with open(predict_file_path, "r") as fp:
+        data = json.load(fp)
+    predict_data = [extract_default_prompt_dataset(item) for item in data]
+    return predict_data
+
+def extract_sql_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+    if example.get("input", "") != "":
+        prompt_format = SQL_PROMPT_DICT["prompt_input"]
+    else:
+        prompt_format = SQL_PROMPT_DICT["prompt_no_input"]
+    return {"input": prompt_format.format(**example)}
+
+def extract_cr_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+    if example.get("input", "") != "":
+        prompt_format = CR_PROMPT_DICT["prompt_input"]
+    else:
+        prompt_format = CR_PROMPT_DICT["prompt_no_input"]
+    print({"input": prompt_format.format(**example)})
+    return {"input": prompt_format.format(**example)}
+
+def extract_default_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+    if example.get("input", "") != "":
+        prompt_format = DEFAULT_PROMPT_DICT["prompt_input"]
+    else:
+        prompt_format = DEFAULT_PROMPT_DICT["prompt_no_input"]
+    print({"input": prompt_format.format(**example)})
+    return {"input": prompt_format.format(**example)}
+
+
+def predict(model: ChatModel, **input_kwargs):
+    args = model.data_args
+    res = []
+    predict_data = prepare_dataset(args.predicted_input_filename)
+
+    # test
+    # for item in predict_data[:20]:
+    for item in tqdm(predict_data, desc="Inference Progress", unit="item"):
+        print(f"item[input] \n{item['input']}")
+        response, _ = model.chat(query=item["input"], history=[], **input_kwargs)
+        res.append(response)
+
+    with open(args.predicted_out_filename, "w") as f:
+        for p in res:
+            try:
+                f.write(p.replace("\n", " ") + "\n")
+            except:
+                f.write("Invalid Output!\n")
+
+if __name__ == "__main__":
+    model = ChatModel()
+    predict(model)
+
+
+# Extract the part after "### Response:" and remove newlines
+#    if "### Response:" in decoded_output:
+#        response = decoded_output.split("### Response:")[1].strip()
+#        response = re.sub(r'\s+', ' ', response)  # Replace multiple spaces/newlines with a single space
+#    else:
+#        response = re.sub(r'\s+', ' ', decoded_output)  # Replace multiple spaces/newlines with a single space
+#    
+#    return response
+#    
+#results = []
+
+#for example in dataset:
+#    generated_output = generate_text(example['instruction'])
+#    expected_output = example['output']
+#    results.append({
+#        "generated_output": generated_output,
+#        "expected_output": expected_output
+#    })
+
+#output_file = 'dbgpt_hub/eval_output/codellama13b_2.json'
+#with open(output_file, 'w') as f:
+#    json.dump(results, f, indent=4)
+
+#print(f"Output written to {output_file}")
diff --git a/src/sft_train.py b/src/sft_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbe0fb28fda49aa0597e9457f2009cb7a41c2a25
--- /dev/null
+++ b/src/sft_train.py
@@ -0,0 +1,165 @@
+import os
+import sys
+
+ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(ROOT_PATH)
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
+
+from .loggings import LogCallback, get_logger
+from .config_parser import get_train_args
+from .load import load_model_and_tokenizer
+from .data_utils import (
+    get_dataset,
+    preprocess_dataset,
+    split_dataset,
+)
+from .config import IGNORE_INDEX
+from .model_trainer import (
+    Seq2SeqPeftTrainer,
+    ComputeMetrics,
+    get_logits_processor,
+    plot_loss,
+)
+
+
+if TYPE_CHECKING:
+    from transformers import TrainerCallback
+    from .model_args import (
+        ModelArguments,
+        FinetuningArguments,
+        GeneratingArguments,
+    )
+    from .data_args import DataArguments
+
+
+logger = get_logger(__name__)
+
+
+def run_sft(
+    model_args: "ModelArguments",
+    data_args: "DataArguments",
+    training_args: "Seq2SeqTrainingArguments",
+    finetuning_args: "FinetuningArguments",
+    generating_args: "GeneratingArguments",
+    callbacks: Optional[List["TrainerCallback"]] = None,
+):
+    dataset = get_dataset(model_args, data_args)
+    model, tokenizer = load_model_and_tokenizer(
+        model_args, finetuning_args, training_args.do_train
+    )
+    dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, "sft")
+    data_collator = DataCollatorForSeq2Seq(
+        tokenizer=tokenizer,
+        label_pad_token_id=IGNORE_INDEX
+        if data_args.ignore_pad_token_for_loss
+        else tokenizer.pad_token_id,
+    )
+
+    # Override the decoding parameters of Seq2SeqTrainer
+    training_args_dict = training_args.to_dict()
+    training_args_dict.update(
+        dict(
+            generation_max_length=training_args.generation_max_length
+            or data_args.max_target_length,
+            generation_num_beams=data_args.eval_num_beams
+            or training_args.generation_num_beams,
+        )
+    )
+    training_args = Seq2SeqTrainingArguments(**training_args_dict)
+
+    # Initialize our Trainer
+    trainer = Seq2SeqPeftTrainer(
+        finetuning_args=finetuning_args,
+        model=model,
+        args=training_args,
+        tokenizer=tokenizer,
+        data_collator=data_collator,
+        callbacks=callbacks,
+        compute_metrics=ComputeMetrics(tokenizer)
+        if training_args.predict_with_generate
+        else None,
+        **split_dataset(dataset, data_args, training_args)
+    )
+
+    # Keyword arguments for `model.generate`
+    gen_kwargs = generating_args.to_dict()
+    gen_kwargs["eos_token_id"] = list(
+        set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids)
+    )
+    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
+    gen_kwargs["logits_processor"] = get_logits_processor()
+
+    # Training
+    if training_args.do_train:
+        train_result = trainer.train(
+            resume_from_checkpoint=training_args.resume_from_checkpoint
+        )
+        trainer.log_metrics("train", train_result.metrics)
+        trainer.save_metrics("train", train_result.metrics)
+        trainer.save_state()
+        trainer.save_model()
+        if trainer.is_world_process_zero() and model_args.plot_loss:
+            plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+    # Evaluation
+    if training_args.do_eval:
+        metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
+        if (
+            training_args.predict_with_generate
+        ):  # eval_loss will be wrong if predict_with_generate is enabled
+            metrics.pop("eval_loss", None)
+        trainer.log_metrics("eval", metrics)
+        trainer.save_metrics("eval", metrics)
+
+    # Predict
+    if training_args.do_predict:
+        predict_results = trainer.predict(
+            dataset, metric_key_prefix="predict", **gen_kwargs
+        )
+        if (
+            training_args.predict_with_generate
+        ):  # predict_loss will be wrong if predict_with_generate is enabled
+            predict_results.metrics.pop("predict_loss", None)
+        trainer.log_metrics("predict", predict_results.metrics)
+        trainer.save_metrics("predict", predict_results.metrics)
+        trainer.save_predictions(predict_results)
+
+
+def train(
+    args: Optional[Dict[str, Any]] = None,
+    callbacks: Optional[List["TrainerCallback"]] = None,
+):
+    (
+        model_args,
+        data_args,
+        training_args,
+        finetuning_args,
+        generating_args,
+    ) = get_train_args(args)
+    callbacks = [LogCallback()] if callbacks is None else callbacks
+
+    run_sft(
+        model_args,
+        data_args,
+        training_args,
+        finetuning_args,
+        generating_args,
+        callbacks,
+    )
+
+
+def export_model(
+    args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"
+):
+    model_args, _, training_args, finetuning_args, _ = get_train_args(args)
+    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
+    model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
+    try:
+        tokenizer.save_pretrained(training_args.output_dir)
+    except:
+        logger.warning("Cannot save tokenizer, please copy the files manually.")
+
+
+if __name__ == "__main__":
+    train()
diff --git a/src/sql_data_process.py b/src/sql_data_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..f45eb19d275ada3e50380fa189a38848ebe310f3
--- /dev/null
+++ b/src/sql_data_process.py
@@ -0,0 +1,281 @@
+import os
+import json
+import jsonlines
+import sys
+import re
+import argparse
+
+ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(ROOT_PATH)
+
+from tqdm import tqdm
+
+from .config import (
+    SQL_DATA_INFO,
+    DATA_PATH,
+    INPUT_PROMPT,
+    INSTRUCTION_PROMPT,
+    INSTRUCTION_ONE_SHOT_PROMPT,
+    CODE_REPRESENTATION_PROMPT,
+    CR_INPUT_PROMPT,
+    ALPACA_PROMPT,
+    ALPACA_INPUT_PROMPT
+)
+
+
+class ProcessSqlData:
+    def __init__(
+        self, train_file=None, dev_file=None, num_shot=0, code_representation=False
+    ) -> None:
+        self.train_file = train_file
+        self.dev_file = dev_file
+        self.num_shot = num_shot
+        self.code_representation = code_representation
+
+    def decode_json_file(
+        self,
+        data_file_list,
+        table_file,
+        db_folder_path,
+        db_id_name,
+        output_name,
+        is_multiple_turn=False,
+    ):
+        """
+        TO DO:
+            1.将相关prompt放入config中
+            2.将不同数据来源的字段信息放入config中
+        """
+
+        if table_file.endswith(".jsonl"):
+            tables = jsonlines.open(table_file)
+            datas = []
+            for data_file in data_file_list:
+                datas.extend(jsonlines.open(data_file))
+
+        elif table_file.endswith(".json"):
+            tables = json.load(open(table_file))
+            datas = []
+            for data_file in data_file_list:
+                datas.extend(json.load(open(data_file)))
+        else:
+            print("Unsupported file types")
+            raise
+
+        # 先将db_id 的table和coloumns处理好
+        db_dict = {}
+        for item in tables:
+            tables = item["table_names_original"]
+            coloumns = item["column_names_original"][1:]
+            primary_key = item["primary_keys"]
+            foreign_keys = item["foreign_keys"]
+            #source = (
+            #    item["db_id"] + " contains tables such as " + ", ".join(tables) + ". "
+            #)
+            source = ""
+            for i, name in enumerate(tables):
+                data = [coloumn[1] for coloumn in coloumns if coloumn[0] == i]
+                source += (
+                    name + "(" + ", ".join(data) + ")\n"
+                )
+
+                # get primary key info
+                for j in range(len(primary_key)):
+                    if type(primary_key[j]) == int:
+                        if coloumns[primary_key[j] - 1][0] == i:
+                            source += (
+                                coloumns[primary_key[j] - 1][1]
+                                + " is the primary key."
+                                + "\n"
+                            )
+                    # combination primary key
+                    elif type(primary_key[j]) == list:
+                        combine_p = "The combination of ("
+                        keys = []
+                        for k in range(len(primary_key[j])):
+                            if coloumns[primary_key[j][k] - 1][0] == i:
+                                keys.append(coloumns[primary_key[j][k] - 1][1])
+                        source += (
+                            combine_p
+                            + ", ".join(keys)
+                            + ") are the primary key."
+                            + "\n"
+                        )
+                    else:
+                        print("not support type", type(primary_key[j]))
+                        continue
+
+            # get foreign key info
+            for key in foreign_keys:
+                source += (
+                    "The "
+                    + coloumns[key[0] - 1][1]
+                    + " of "
+                    + tables[coloumns[key[0] - 1][0]]
+                    + " is the foreign key of "
+                    + coloumns[key[1] - 1][1]
+                    + " of "
+                    + tables[coloumns[key[1] - 1][0]]
+                    + ".\n"
+                )
+
+            db_dict[item["db_id"]] = source
+
+        res = []
+        base_instruction = ALPACA_PROMPT
+        if self.num_shot == 1:
+            base_instruction = INSTRUCTION_ONE_SHOT_PROMPT
+
+        count = 0
+        for data in tqdm(datas):
+            if data[db_id_name] in db_dict.keys():
+                if is_multiple_turn:  # 多轮
+                    history = []
+                    for interaction in data["interaction"]:
+                        input = {
+                            "db_id": data[db_id_name],
+                            "instruction": base_instruction.format(
+                                db_dict[data[db_id_name]]
+                            ),
+                            "input": INPUT_PROMPT.format(interaction["utterance"]),
+                            "output": interaction[output_name],
+                            "history": history,
+                        }
+                        res.append(input)
+                        history.append(
+                            (
+                                INPUT_PROMPT.format(interaction["utterance"]),
+                                interaction[output_name],
+                            )
+                        )
+                else:  # 单轮
+                    if self.code_representation:
+                        db_path = os.path.join(db_folder_path, data[db_id_name])
+                        sql_file_path = next(
+                            (
+                                file
+                                for file in os.listdir(db_path)
+                                if file.endswith(".sql")
+                            ),
+                            None,
+                        )
+                        if sql_file_path is None:
+                            print(f"Skipping {data[db_id_name]} due to missing .sql file")
+                            continue  # 提前结束迭代
+                        schema_file_path = os.path.join(db_path, sql_file_path)
+                        with open(schema_file_path, "r", errors="ignore") as file:
+                            schema_content = file.read()
+                        create_statements = re.findall(
+                            r"CREATE\s.*?;", schema_content, re.DOTALL|re.IGNORECASE
+                        )
+                        input = {
+                            "db_id": data[db_id_name],
+                            "instruction": CODE_REPRESENTATION_PROMPT.format(create_statements),
+                            "input": CR_INPUT_PROMPT.format(data["question"]),
+                            "output": data[output_name],
+                            "history": [],
+                        }
+                        res.append(input)
+                        count += 1
+                        print(f"Generated {count} inputs")
+                    else:
+                        input = {
+                            "db_id": data[db_id_name],
+                            "instruction": base_instruction.format(
+                                data["question"]
+                            ),
+                            "input": ALPACA_INPUT_PROMPT.format(db_dict[data[db_id_name]]),
+                            "output": data[output_name],
+                            "history": [],
+                        }
+                        print(db_dict[data[db_id_name]])
+                        res.append(input)
+                        #count += 1
+                        #print(f"Generated {count} inputs")
+        return res
+
+    def create_sft_raw_data(self):
+        train_data = []
+        dev_data = []
+        for data_info in SQL_DATA_INFO:
+            train_data_file_list = [
+                os.path.join(DATA_PATH, data_info["data_source"], file)
+                for file in data_info["train_file"]
+            ]
+            train_data.extend(
+                self.decode_json_file(
+                    data_file_list=train_data_file_list,
+                    table_file=os.path.join(
+                        DATA_PATH,
+                        data_info["data_source"],
+                        data_info["train_tables_file"],
+                    ),
+                    db_folder_path=os.path.join(
+                        DATA_PATH,
+                        data_info["data_source"],
+                        "database",
+                    ),
+                    db_id_name=data_info["db_id_name"],
+                    output_name=data_info["output_name"],
+                    is_multiple_turn=data_info["is_multiple_turn"],
+                )
+            )
+
+            dev_data_file_list = [
+                os.path.join(DATA_PATH, data_info["data_source"], file)
+                for file in data_info["dev_file"]
+            ]
+            dev_data.extend(
+                self.decode_json_file(
+                    data_file_list=dev_data_file_list,
+                    table_file=os.path.join(
+                        DATA_PATH,
+                        data_info["data_source"],
+                        data_info["dev_tables_file"],
+                    ),
+                    db_folder_path=os.path.join(
+                        DATA_PATH,
+                        data_info["data_source"],
+                        "database",
+                    ),
+                    db_id_name=data_info["db_id_name"],
+                    output_name=data_info["output_name"],
+                    is_multiple_turn=data_info["is_multiple_turn"],
+                )
+            )
+        with open(self.train_file, "w", encoding="utf-8") as s:
+            json.dump(train_data, s, indent=4, ensure_ascii=False)
+        with open(self.dev_file, "w", encoding="utf-8") as s:
+            json.dump(dev_data, s, indent=4, ensure_ascii=False)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--code_representation", help="Enable code representation", default=False
+    )
+    args = parser.parse_args()
+
+    all_in_one_train_file = os.path.join(DATA_PATH, "example_text2sql_train_alpaca_noselect.json")
+    all_in_one_dev_file = os.path.join(DATA_PATH, "example_text2sql_dev_alpaca_noselect.json")
+    precess = ProcessSqlData(
+        train_file=all_in_one_train_file,
+        dev_file=all_in_one_dev_file,
+        code_representation=args.code_representation,
+    )
+    precess.create_sft_raw_data()
+
+    # one-shot
+    one_shot_all_in_one_train_file = os.path.join(
+        DATA_PATH, "example_text2sql_train_one_shot.json"
+    )
+    one_shot_all_in_one_dev_file = os.path.join(
+        DATA_PATH, "example_text2sql_dev_one_shot.json"
+    )
+    #one_shot_precess = ProcessSqlData(
+    #    train_file=one_shot_all_in_one_train_file,
+    #    dev_file=one_shot_all_in_one_dev_file,
+    #    num_shot=1,
+    #    code_representation=args.code_representation,
+    #)
+    #one_shot_precess.create_sft_raw_data()
diff --git a/src/tuner.py b/src/tuner.py
new file mode 100644
index 0000000000000000000000000000000000000000..634a5f8044bf3dc3d02353651f3fec6c23baadd5
--- /dev/null
+++ b/src/tuner.py
@@ -0,0 +1,72 @@
+
+import os
+from unsloth import FastLanguageModel
+import torch
+from trl import SFTTrainer
+from transformers import TrainingArguments
+from datasets import load_dataset
+
+max_seq_length = 2048
+file = "Spider/alpaca_sft_prompts/train.jsonl"
+dataset = load_dataset("json", data_files = {"train" : file}, split = "train")
+print(f"Number of examples in the dataset: {len(dataset)}")
+
+# Load model
+model, tokenizer = FastLanguageModel.from_pretrained(
+    model_name = "unsloth/codellama-7b",
+    max_seq_length = max_seq_length,
+    dtype = None,
+    load_in_4bit = True,
+)
+
+# Do model patching and add fast LoRA weights and training
+model = FastLanguageModel.get_peft_model(
+    model,
+    r = 64,
+    target_modules = ["q_proj", "v_proj"],
+    lora_alpha = 32,
+    lora_dropout = 0, # Supports any, but = 0 is optimized
+    bias = "none",    # Supports any, but = "none" is optimized
+    use_gradient_checkpointing = True,
+    random_state = 3407,
+    max_seq_length = max_seq_length,
+    use_rslora = False,  # Rank stabilized LoRA
+    loftq_config = None, # LoftQ
+)
+
+def formatting_func(example):
+    text = f"{example['instruction']}\n{example['output']}"
+    return text
+
+trainer = SFTTrainer(
+    model = model,
+    train_dataset = dataset,
+    formatting_func = formatting_func,
+    max_seq_length = max_seq_length,
+    packing=True,
+    tokenizer = tokenizer,
+    args = TrainingArguments(
+        per_device_train_batch_size = 1,
+        gradient_accumulation_steps = 16,
+        warmup_steps = 30,
+	warmup_ratio = 0.03,
+        num_train_epochs = 8,
+        fp16 = not torch.cuda.is_bf16_supported(),
+        bf16 = torch.cuda.is_bf16_supported(),
+        logging_steps = 50,
+        save_steps = 2000,
+        output_dir = "overfitting/codellama7b_blog",
+        optim = "adamw_8bit",
+        weight_decay = 1,
+        lr_scheduler_type = "cosine_with_restarts",
+        learning_rate = 2e-04,
+        seed = 3407,
+    ),
+)
+trainer.train()
+
+# Save the model
+model.save_pretrained("lora_model_codellama7b_blog")
+model.save_pretrained_merged("overfitting/codellama7b_blog", tokenizer, save_method = "merged_16bit",)
+#model.push_to_hub_merged("oleherbst/llama3-8b-oig-unsloth-merged", tokenizer, save_method = "merged_16bit", token = os.environ.get("HF_TOKEN"))
+#model.push_to_hub("oleherbst/llama3-8b-oig-unsloth", tokenizer, save_method = "lora", token = os.environ.get("HF_TOKEN"))