From de7efbcdc4f8ef6e851d0126cab252a13aafe552 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 6 May 2026 10:00:53 -0700 Subject: [PATCH] Fix imports for Model Customization interfaces --- .../src/sagemaker/train/__init__.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sagemaker-train/src/sagemaker/train/__init__.py b/sagemaker-train/src/sagemaker/train/__init__.py index 38a6fda76d..6a6d3b31e8 100644 --- a/sagemaker-train/src/sagemaker/train/__init__.py +++ b/sagemaker-train/src/sagemaker/train/__init__.py @@ -28,6 +28,24 @@ def __getattr__(name): elif name == "ModelTrainer": from sagemaker.train.model_trainer import ModelTrainer return ModelTrainer + elif name == "SFTTrainer": + from sagemaker.train.sft_trainer import SFTTrainer + return SFTTrainer + elif name == "DPOTrainer": + from sagemaker.train.dpo_trainer import DPOTrainer + return DPOTrainer + elif name == "RLVRTrainer": + from sagemaker.train.rlvr_trainer import RLVRTrainer + return RLVRTrainer + elif name == "RLAIFTrainer": + from sagemaker.train.rlaif_trainer import RLAIFTrainer + return RLAIFTrainer + elif name == "TrainingType": + from sagemaker.train.common import TrainingType + return TrainingType + elif name == "CustomizationTechnique": + from sagemaker.train.common import CustomizationTechnique + return CustomizationTechnique elif name == "logger": from sagemaker.core.utils.utils import logger return logger