Skip to content
Snippets Groups Projects
Unverified Commit 6700a1b9 authored by hoshi-hiyouga's avatar hoshi-hiyouga Committed by GitHub
Browse files

Update trainer.py

parent 38a56706
Branches
No related tags found
No related merge requests found
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -9,8 +10,7 @@ from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer, create_custom_scheduler
from types import MethodType
from packaging import version
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
......@@ -31,6 +31,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.finetuning_args = finetuning_args
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment