diff --git a/sagemaker-train/src/sagemaker/train/__init__.py b/sagemaker-train/src/sagemaker/train/__init__.py index 38a6fda76d..6a6d3b31e8 100644 --- a/sagemaker-train/src/sagemaker/train/__init__.py +++ b/sagemaker-train/src/sagemaker/train/__init__.py @@ -28,6 +28,24 @@ def __getattr__(name): elif name == "ModelTrainer": from sagemaker.train.model_trainer import ModelTrainer return ModelTrainer + elif name == "SFTTrainer": + from sagemaker.train.sft_trainer import SFTTrainer + return SFTTrainer + elif name == "DPOTrainer": + from sagemaker.train.dpo_trainer import DPOTrainer + return DPOTrainer + elif name == "RLVRTrainer": + from sagemaker.train.rlvr_trainer import RLVRTrainer + return RLVRTrainer + elif name == "RLAIFTrainer": + from sagemaker.train.rlaif_trainer import RLAIFTrainer + return RLAIFTrainer + elif name == "TrainingType": + from sagemaker.train.common import TrainingType + return TrainingType + elif name == "CustomizationTechnique": + from sagemaker.train.common import CustomizationTechnique + return CustomizationTechnique elif name == "logger": from sagemaker.core.utils.utils import logger return logger