diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 84a21bd12..607bbadac 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -70,6 +70,23 @@ ) from torchrec.distributed.utils import get_device_type, none_throws +try: + # This is a safety measure against torch package issues for when + # Torchrec is included in the inference side model code. We should + # remove this once we are sure all model side packages have the required + # dependencies + from torchrec.distributed.logger import _torchrec_method_logger +except Exception: + + def _torchrec_method_logger(*args, **kwargs): + """A no-op decorator that accepts any arguments.""" + + def decorator(func): + return func + + return decorator + + logger: logging.Logger = logging.getLogger(__name__) @@ -498,6 +515,7 @@ def collective_plan( sharders, ) + @_torchrec_method_logger() def plan( self, module: nn.Module, diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index e9fe723e8..acf672047 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -53,6 +53,23 @@ from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface +try: + # This is a safety measure against torch package issues for when + # Torchrec is included in the inference side model code. We should + # remove this once we are sure all model side packages have the required + # dependencies + from torchrec.distributed.logger import _torchrec_method_logger +except Exception: + + def _torchrec_method_logger(*args, **kwargs): + """A no-op decorator that accepts any arguments.""" + + def decorator(func): + return func + + return decorator + + logger: logging.Logger = logging.getLogger(__name__) @@ -955,6 +972,7 @@ class EmbeddingStorageEstimator(ShardEstimator): is_inference (bool): If the model is inference model. Default to False. """ + @_torchrec_method_logger() def __init__( self, topology: Topology, diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 0a27711a7..3c0880420 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -30,6 +30,22 @@ from torchrec.distributed.utils import init_parameters from torchrec.modules.utils import reset_module_states_post_sharding +try: + # This is a safety measure against torch package issues for when + # Torchrec is included in the inference side model code. We should + # remove this once we are sure all model side packages have the required + # dependencies + from torchrec.distributed.logger import _torchrec_method_logger +except Exception: + + def _torchrec_method_logger(*args, **kwargs): + """A no-op decorator that accepts any arguments.""" + + def decorator(func): + return func + + return decorator + def _join_module_path(path: str, name: str) -> str: return (path + "." + name) if path else name @@ -146,6 +162,7 @@ def _shard( # pyre-ignore @contract() +@_torchrec_method_logger() def shard_modules( module: nn.Module, env: Optional[ShardingEnv] = None, @@ -194,6 +211,7 @@ def init_weights(m): return _shard_modules(module, env, device, plan, sharders, init_params) +@_torchrec_method_logger() def _shard_modules( # noqa: C901 module: nn.Module, # TODO: Consolidate to using Dict[str, ShardingEnv] diff --git a/torchrec/distributed/test_utils/test_model_parallel_base.py b/torchrec/distributed/test_utils/test_model_parallel_base.py index b81088389..e402cfa10 100644 --- a/torchrec/distributed/test_utils/test_model_parallel_base.py +++ b/torchrec/distributed/test_utils/test_model_parallel_base.py @@ -253,7 +253,21 @@ def test_sharding_ebc_input_validation_enabled(self, mock_jk: Mock) -> None: with self.assertRaisesRegex(ValueError, "keys must be unique"): model(kjt) - mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation") + + # Count only calls with the input "pytorch/torchrec:enable_kjt_validation" + # This ignores any other calls to justknobs_check() with other inputs + # and protects the test from breaking when new JK checks are added. + validation_calls = [ + call + for call in mock_jk.call_args_list + if len(call[0]) > 0 + and call[0][0] == "pytorch/torchrec:enable_kjt_validation" + ] + self.assertEqual( + 1, + len(validation_calls), + "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation", + ) @patch("torch._utils_internal.justknobs_check") def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None: @@ -271,7 +285,20 @@ def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None: model(kjt) model(kjt) - mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation") + # Count only calls with the input "pytorch/torchrec:enable_kjt_validation" + # This ignores any other calls to justknobs_check() with other inputs + # and protects the test from breaking when new JK checks are added. + validation_calls = [ + call + for call in mock_jk.call_args_list + if len(call[0]) > 0 + and call[0][0] == "pytorch/torchrec:enable_kjt_validation" + ] + self.assertEqual( + 1, + len(validation_calls), + "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation", + ) matched_logs = list( filter(lambda s: "Validating input features..." in s, logs.output) ) @@ -294,7 +321,20 @@ def test_sharding_ebc_input_validation_disabled(self, mock_jk: Mock) -> None: except ValueError: self.fail("Input validation should not be enabled.") - mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation") + # Count only calls with the input "pytorch/torchrec:enable_kjt_validation" + # This ignores any other calls to justknobs_check() with other inputs + # and protects the test from breaking when new JK checks are added. + validation_calls = [ + call + for call in mock_jk.call_args_list + if len(call[0]) > 0 + and call[0][0] == "pytorch/torchrec:enable_kjt_validation" + ] + self.assertEqual( + 1, + len(validation_calls), + "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation", + ) def _create_sharded_model( self, embedding_dim: int = 128, num_embeddings: int = 256 diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 61c430bd8..dd6d1fb59 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -73,6 +73,22 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Pipelineable +try: + # This is a safety measure against torch package issues for when + # Torchrec is included in the inference side model code. We should + # remove this once we are sure all model side packages have the required + # dependencies + from torchrec.distributed.logger import _torchrec_method_logger +except Exception: + + def _torchrec_method_logger(*args, **kwargs): + """A no-op decorator that accepts any arguments.""" + + def decorator(func): + return func + + return decorator + logger: logging.Logger = logging.getLogger(__name__) @@ -106,6 +122,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]): def progress(self, dataloader_iter: Iterator[In]) -> Out: pass + # pyre-ignore [56] + @_torchrec_method_logger() def __init__(self) -> None: # pipeline state such as in foward, in backward etc, used in training recover scenarios self._state: PipelineState = PipelineState.IDLE