Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@
)
from torchrec.distributed.utils import get_device_type, none_throws

try:
# This is a safety measure against torch package issues for when
# Torchrec is included in the inference side model code. We should
# remove this once we are sure all model side packages have the required
# dependencies
from torchrec.distributed.logger import _torchrec_method_logger
except Exception:

def _torchrec_method_logger(*args, **kwargs):
"""A no-op decorator that accepts any arguments."""

def decorator(func):
return func

return decorator


logger: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -498,6 +515,7 @@ def collective_plan(
sharders,
)

@_torchrec_method_logger()
def plan(
self,
module: nn.Module,
Expand Down
18 changes: 18 additions & 0 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@

from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface

try:
# This is a safety measure against torch package issues for when
# Torchrec is included in the inference side model code. We should
# remove this once we are sure all model side packages have the required
# dependencies
from torchrec.distributed.logger import _torchrec_method_logger
except Exception:

def _torchrec_method_logger(*args, **kwargs):
"""A no-op decorator that accepts any arguments."""

def decorator(func):
return func

return decorator


logger: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -955,6 +972,7 @@ class EmbeddingStorageEstimator(ShardEstimator):
is_inference (bool): If the model is inference model. Default to False.
"""

@_torchrec_method_logger()
def __init__(
self,
topology: Topology,
Expand Down
18 changes: 18 additions & 0 deletions torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@
from torchrec.distributed.utils import init_parameters
from torchrec.modules.utils import reset_module_states_post_sharding

try:
# This is a safety measure against torch package issues for when
# Torchrec is included in the inference side model code. We should
# remove this once we are sure all model side packages have the required
# dependencies
from torchrec.distributed.logger import _torchrec_method_logger
except Exception:

def _torchrec_method_logger(*args, **kwargs):
"""A no-op decorator that accepts any arguments."""

def decorator(func):
return func

return decorator


def _join_module_path(path: str, name: str) -> str:
return (path + "." + name) if path else name
Expand Down Expand Up @@ -146,6 +162,7 @@ def _shard(

# pyre-ignore
@contract()
@_torchrec_method_logger()
def shard_modules(
module: nn.Module,
env: Optional[ShardingEnv] = None,
Expand Down Expand Up @@ -194,6 +211,7 @@ def init_weights(m):
return _shard_modules(module, env, device, plan, sharders, init_params)


@_torchrec_method_logger()
def _shard_modules( # noqa: C901
module: nn.Module,
# TODO: Consolidate to using Dict[str, ShardingEnv]
Expand Down
46 changes: 43 additions & 3 deletions torchrec/distributed/test_utils/test_model_parallel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,21 @@ def test_sharding_ebc_input_validation_enabled(self, mock_jk: Mock) -> None:

with self.assertRaisesRegex(ValueError, "keys must be unique"):
model(kjt)
mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")

# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
# This ignores any other calls to justknobs_check() with other inputs
# and protects the test from breaking when new JK checks are added.
validation_calls = [
call
for call in mock_jk.call_args_list
if len(call[0]) > 0
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
]
self.assertEqual(
1,
len(validation_calls),
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
)

@patch("torch._utils_internal.justknobs_check")
def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
Expand All @@ -271,7 +285,20 @@ def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
model(kjt)
model(kjt)

mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
# This ignores any other calls to justknobs_check() with other inputs
# and protects the test from breaking when new JK checks are added.
validation_calls = [
call
for call in mock_jk.call_args_list
if len(call[0]) > 0
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
]
self.assertEqual(
1,
len(validation_calls),
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
)
matched_logs = list(
filter(lambda s: "Validating input features..." in s, logs.output)
)
Expand All @@ -294,7 +321,20 @@ def test_sharding_ebc_input_validation_disabled(self, mock_jk: Mock) -> None:
except ValueError:
self.fail("Input validation should not be enabled.")

mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
# This ignores any other calls to justknobs_check() with other inputs
# and protects the test from breaking when new JK checks are added.
validation_calls = [
call
for call in mock_jk.call_args_list
if len(call[0]) > 0
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
]
self.assertEqual(
1,
len(validation_calls),
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
)

def _create_sharded_model(
self, embedding_dim: int = 128, num_embeddings: int = 256
Expand Down
18 changes: 18 additions & 0 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Pipelineable

try:
# This is a safety measure against torch package issues for when
# Torchrec is included in the inference side model code. We should
# remove this once we are sure all model side packages have the required
# dependencies
from torchrec.distributed.logger import _torchrec_method_logger
except Exception:

def _torchrec_method_logger(*args, **kwargs):
"""A no-op decorator that accepts any arguments."""

def decorator(func):
return func

return decorator


logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,6 +122,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
def progress(self, dataloader_iter: Iterator[In]) -> Out:
pass

# pyre-ignore [56]
@_torchrec_method_logger()
def __init__(self) -> None:
# pipeline state such as in foward, in backward etc, used in training recover scenarios
self._state: PipelineState = PipelineState.IDLE
Expand Down
Loading