Skip to content
Snippets Groups Projects
Commit c9a47732 authored by hiyouga's avatar hiyouga
Browse files

fix #3316

parent 6d641af7
Branches
No related tags found
No related merge requests found
import inspect
from enum import Enum, unique
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional
......@@ -129,7 +130,11 @@ def gradient_checkpointing_enable(
return gradient_checkpointing_func(func, *args, **kwargs)
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment