diff --git a/src/cnlpt/_cli/train.py b/src/cnlpt/_cli/train.py index 44dc1544..c0786208 100644 --- a/src/cnlpt/_cli/train.py +++ b/src/cnlpt/_cli/train.py @@ -6,7 +6,7 @@ from rich.markup import escape as escape_rich_markup from transformers.hf_argparser import HfArgumentParser from transformers.models.auto.modeling_auto import AutoModel -from transformers.trainer_utils import EvaluationStrategy, IntervalStrategy +from transformers.trainer_utils import IntervalStrategy from transformers.training_args import TrainingArguments from ..data.cnlp_dataset import CnlpDataset, HierarchicalDataConfig, TruncationSide @@ -362,9 +362,7 @@ def transformers_arg_option(field_name: str, *args, **kwargs): DoTrainArg = Annotated[bool, transformers_arg_option("do_train", "--do_train")] DoEvalArg = Annotated[bool, transformers_arg_option("do_eval", "--do_eval")] DoPredictArg = Annotated[bool, transformers_arg_option("do_predict", "--do_predict")] -EvalStrategyArg = Annotated[ - EvaluationStrategy, transformers_arg_option("eval_strategy") -] +EvalStrategyArg = Annotated[IntervalStrategy, transformers_arg_option("eval_strategy")] def train(