Skip to content

Commit b1edc6f

Browse files
nipung90facebook-github-bot
authored andcommitted
Enable logging for the plan() function, ShardEstimators and TrainingPipeline class constructors (meta-pytorch#3576)
Summary: Pull Request resolved: meta-pytorch#3576 Reviewed By: DocherPap Differential Revision: D87910772
1 parent 0a2cebd commit b1edc6f

File tree

5 files changed

+115
-3
lines changed

5 files changed

+115
-3
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,23 @@
7070
)
7171
from torchrec.distributed.utils import get_device_type, none_throws
7272

73+
try:
74+
# This is a safety measure against torch package issues for when
75+
# Torchrec is included in the inference side model code. We should
76+
# remove this once we are sure all model side packages have the required
77+
# dependencies
78+
from torchrec.distributed.logger import _torchrec_method_logger
79+
except Exception:
80+
81+
def _torchrec_method_logger(*args, **kwargs):
82+
"""A no-op decorator that accepts any arguments."""
83+
84+
def decorator(func):
85+
return func
86+
87+
return decorator
88+
89+
7390
logger: logging.Logger = logging.getLogger(__name__)
7491

7592

@@ -498,6 +515,7 @@ def collective_plan(
498515
sharders,
499516
)
500517

518+
@_torchrec_method_logger()
501519
def plan(
502520
self,
503521
module: nn.Module,

torchrec/distributed/planner/shard_estimators.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,23 @@
5353

5454
from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface
5555

56+
try:
57+
# This is a safety measure against torch package issues for when
58+
# Torchrec is included in the inference side model code. We should
59+
# remove this once we are sure all model side packages have the required
60+
# dependencies
61+
from torchrec.distributed.logger import _torchrec_method_logger
62+
except Exception:
63+
64+
def _torchrec_method_logger(*args, **kwargs):
65+
"""A no-op decorator that accepts any arguments."""
66+
67+
def decorator(func):
68+
return func
69+
70+
return decorator
71+
72+
5673
logger: logging.Logger = logging.getLogger(__name__)
5774

5875

@@ -955,6 +972,7 @@ class EmbeddingStorageEstimator(ShardEstimator):
955972
is_inference (bool): If the model is inference model. Default to False.
956973
"""
957974

975+
@_torchrec_method_logger()
958976
def __init__(
959977
self,
960978
topology: Topology,

torchrec/distributed/shard.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,22 @@
3030
from torchrec.distributed.utils import init_parameters
3131
from torchrec.modules.utils import reset_module_states_post_sharding
3232

33+
try:
34+
# This is a safety measure against torch package issues for when
35+
# Torchrec is included in the inference side model code. We should
36+
# remove this once we are sure all model side packages have the required
37+
# dependencies
38+
from torchrec.distributed.logger import _torchrec_method_logger
39+
except Exception:
40+
41+
def _torchrec_method_logger(*args, **kwargs):
42+
"""A no-op decorator that accepts any arguments."""
43+
44+
def decorator(func):
45+
return func
46+
47+
return decorator
48+
3349

3450
def _join_module_path(path: str, name: str) -> str:
3551
return (path + "." + name) if path else name
@@ -146,6 +162,7 @@ def _shard(
146162

147163
# pyre-ignore
148164
@contract()
165+
@_torchrec_method_logger()
149166
def shard_modules(
150167
module: nn.Module,
151168
env: Optional[ShardingEnv] = None,
@@ -194,6 +211,7 @@ def init_weights(m):
194211
return _shard_modules(module, env, device, plan, sharders, init_params)
195212

196213

214+
@_torchrec_method_logger()
197215
def _shard_modules( # noqa: C901
198216
module: nn.Module,
199217
# TODO: Consolidate to using Dict[str, ShardingEnv]

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,21 @@ def test_sharding_ebc_input_validation_enabled(self, mock_jk: Mock) -> None:
253253

254254
with self.assertRaisesRegex(ValueError, "keys must be unique"):
255255
model(kjt)
256-
mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")
256+
257+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
258+
# This ignores any other calls to justknobs_check() with other inputs
259+
# and protects the test from breaking when new JK checks are added.
260+
validation_calls = [
261+
call
262+
for call in mock_jk.call_args_list
263+
if len(call[0]) > 0
264+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
265+
]
266+
self.assertEqual(
267+
1,
268+
len(validation_calls),
269+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
270+
)
257271

258272
@patch("torch._utils_internal.justknobs_check")
259273
def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
@@ -271,7 +285,20 @@ def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
271285
model(kjt)
272286
model(kjt)
273287

274-
mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")
288+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
289+
# This ignores any other calls to justknobs_check() with other inputs
290+
# and protects the test from breaking when new JK checks are added.
291+
validation_calls = [
292+
call
293+
for call in mock_jk.call_args_list
294+
if len(call[0]) > 0
295+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
296+
]
297+
self.assertEqual(
298+
1,
299+
len(validation_calls),
300+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
301+
)
275302
matched_logs = list(
276303
filter(lambda s: "Validating input features..." in s, logs.output)
277304
)
@@ -294,7 +321,20 @@ def test_sharding_ebc_input_validation_disabled(self, mock_jk: Mock) -> None:
294321
except ValueError:
295322
self.fail("Input validation should not be enabled.")
296323

297-
mock_jk.assert_any_call("pytorch/torchrec:enable_kjt_validation")
324+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
325+
# This ignores any other calls to justknobs_check() with other inputs
326+
# and protects the test from breaking when new JK checks are added.
327+
validation_calls = [
328+
call
329+
for call in mock_jk.call_args_list
330+
if len(call[0]) > 0
331+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
332+
]
333+
self.assertEqual(
334+
1,
335+
len(validation_calls),
336+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
337+
)
298338

299339
def _create_sharded_model(
300340
self, embedding_dim: int = 128, num_embeddings: int = 256

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@
7373
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
7474
from torchrec.streamable import Pipelineable
7575

76+
try:
77+
# This is a safety measure against torch package issues for when
78+
# Torchrec is included in the inference side model code. We should
79+
# remove this once we are sure all model side packages have the required
80+
# dependencies
81+
from torchrec.distributed.logger import _torchrec_method_logger
82+
except Exception:
83+
84+
def _torchrec_method_logger(*args, **kwargs):
85+
"""A no-op decorator that accepts any arguments."""
86+
87+
def decorator(func):
88+
return func
89+
90+
return decorator
91+
7692

7793
logger: logging.Logger = logging.getLogger(__name__)
7894

@@ -106,6 +122,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
106122
def progress(self, dataloader_iter: Iterator[In]) -> Out:
107123
pass
108124

125+
# pyre-ignore [56]
126+
@_torchrec_method_logger()
109127
def __init__(self) -> None:
110128
# pipeline state such as in foward, in backward etc, used in training recover scenarios
111129
self._state: PipelineState = PipelineState.IDLE

0 commit comments

Comments
 (0)