Skip to content

Commit c934b08

Browse files
nipung90facebook-github-bot
authored andcommitted
Enable logging for the plan() function, ShardEstimators and TrainingPipeline class constructors (meta-pytorch#3576)
Summary: Pull Request resolved: meta-pytorch#3576 Differential Revision: D87910772
1 parent 7ecee99 commit c934b08

File tree

8 files changed

+135
-4
lines changed

8 files changed

+135
-4
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@
120120
except OSError:
121121
pass
122122

123+
try:
124+
# This is a safety measure against torch package issues for when
125+
# Torchrec is included in the inference side model code. We should
126+
# remove this once we are sure all model side packages have the required
127+
# dependencies
128+
from torchrec.distributed.logger import _torchrec_method_logger
129+
except Exception:
130+
from torchrec.distributed.empty_logger import _torchrec_method_logger
131+
123132

124133
logger: logging.Logger = logging.getLogger(__name__)
125134

@@ -466,6 +475,7 @@ class ShardedEmbeddingBagCollection(
466475
This is part of the public API to allow for manual data dist pipelining.
467476
"""
468477

478+
@_torchrec_method_logger()
469479
def __init__(
470480
self,
471481
module: EmbeddingBagCollectionInterface,
@@ -2021,6 +2031,7 @@ class ShardedEmbeddingBag(
20212031
This is part of the public API to allow for manual data dist pipelining.
20222032
"""
20232033

2034+
@_torchrec_method_logger()
20242035
def __init__(
20252036
self,
20262037
module: nn.EmbeddingBag,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import functools
9+
from typing import Any, Callable, ParamSpec, TypeVar
10+
11+
12+
_T = TypeVar("_T")
13+
_P = ParamSpec("_P")
14+
15+
16+
def _torchrec_method_logger(
17+
**wrapper_kwargs: Any,
18+
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore
19+
"""This method decorator logs the input, output, and exception of wrapped events."""
20+
21+
def decorator(func: Callable[_P, _T]): # pyre-ignore
22+
@functools.wraps(func)
23+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
24+
return func(*args, **kwargs)
25+
26+
return wrapper
27+
28+
return decorator

torchrec/distributed/planner/planners.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@
7070
)
7171
from torchrec.distributed.utils import get_device_type, none_throws
7272

73+
try:
74+
# This is a safety measure against torch package issues for when
75+
# Torchrec is included in the inference side model code. We should
76+
# remove this once we are sure all model side packages have the required
77+
# dependencies
78+
from torchrec.distributed.logger import _torchrec_method_logger
79+
except Exception:
80+
from torchrec.distributed.empty_logger import _torchrec_method_logger
81+
7382
logger: logging.Logger = logging.getLogger(__name__)
7483

7584

@@ -498,6 +507,7 @@ def collective_plan(
498507
sharders,
499508
)
500509

510+
@_torchrec_method_logger()
501511
def plan(
502512
self,
503513
module: nn.Module,

torchrec/distributed/planner/shard_estimators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@
5353

5454
from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface
5555

56+
try:
57+
# This is a safety measure against torch package issues for when
58+
# Torchrec is included in the inference side model code. We should
59+
# remove this once we are sure all model side packages have the required
60+
# dependencies
61+
from torchrec.distributed.logger import _torchrec_method_logger
62+
except Exception:
63+
from torchrec.distributed.empty_logger import _torchrec_method_logger
64+
5665
logger: logging.Logger = logging.getLogger(__name__)
5766

5867

@@ -955,6 +964,7 @@ class EmbeddingStorageEstimator(ShardEstimator):
955964
is_inference (bool): If the model is inference model. Default to False.
956965
"""
957966

967+
@_torchrec_method_logger()
958968
def __init__(
959969
self,
960970
topology: Topology,

torchrec/distributed/shard.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@
3030
from torchrec.distributed.utils import init_parameters
3131
from torchrec.modules.utils import reset_module_states_post_sharding
3232

33+
try:
34+
# This is a safety measure against torch package issues for when
35+
# Torchrec is included in the inference side model code. We should
36+
# remove this once we are sure all model side packages have the required
37+
# dependencies
38+
from torchrec.distributed.logger import _torchrec_method_logger
39+
except Exception:
40+
from torchrec.distributed.empty_logger import _torchrec_method_logger
41+
3342

3443
def _join_module_path(path: str, name: str) -> str:
3544
return (path + "." + name) if path else name
@@ -146,6 +155,7 @@ def _shard(
146155

147156
# pyre-ignore
148157
@contract()
158+
@_torchrec_method_logger()
149159
def shard_modules(
150160
module: nn.Module,
151161
env: Optional[ShardingEnv] = None,
@@ -194,6 +204,7 @@ def init_weights(m):
194204
return _shard_modules(module, env, device, plan, sharders, init_params)
195205

196206

207+
@_torchrec_method_logger()
197208
def _shard_modules( # noqa: C901
198209
module: nn.Module,
199210
# TODO: Consolidate to using Dict[str, ShardingEnv]

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,26 @@ def test_sharding_ebc_input_validation_enabled(self, mock_jk: Mock) -> None:
248248
values=torch.tensor([1, 2, 3, 4, 5]),
249249
lengths=torch.tensor([1, 2, 0, 2]),
250250
offsets=torch.tensor([0, 1, 3, 3, 5]),
251-
)
251+
).to(self.device)
252252
mock_jk.return_value = True
253253

254254
with self.assertRaisesRegex(ValueError, "keys must be unique"):
255255
model(kjt)
256-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
256+
257+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
258+
# This ignores any other calls to justknobs_check() with other inputs
259+
# and protects the test from breaking when new JK checks are added.
260+
validation_calls = [
261+
call
262+
for call in mock_jk.call_args_list
263+
if len(call[0]) > 0
264+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
265+
]
266+
self.assertEqual(
267+
1,
268+
len(validation_calls),
269+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
270+
)
257271

258272
@patch("torch._utils_internal.justknobs_check")
259273
def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
@@ -271,7 +285,20 @@ def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
271285
model(kjt)
272286
model(kjt)
273287

274-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
288+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
289+
# This ignores any other calls to justknobs_check() with other inputs
290+
# and protects the test from breaking when new JK checks are added.
291+
validation_calls = [
292+
call
293+
for call in mock_jk.call_args_list
294+
if len(call[0]) > 0
295+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
296+
]
297+
self.assertEqual(
298+
1,
299+
len(validation_calls),
300+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
301+
)
275302
matched_logs = list(
276303
filter(lambda s: "Validating input features..." in s, logs.output)
277304
)
@@ -294,7 +321,20 @@ def test_sharding_ebc_input_validation_disabled(self, mock_jk: Mock) -> None:
294321
except ValueError:
295322
self.fail("Input validation should not be enabled.")
296323

297-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
324+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
325+
# This ignores any other calls to justknobs_check() with other inputs
326+
# and protects the test from breaking when new JK checks are added.
327+
validation_calls = [
328+
call
329+
for call in mock_jk.call_args_list
330+
if len(call[0]) > 0
331+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
332+
]
333+
self.assertEqual(
334+
1,
335+
len(validation_calls),
336+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
337+
)
298338

299339
def _create_sharded_model(
300340
self, embedding_dim: int = 128, num_embeddings: int = 256

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@
7373
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
7474
from torchrec.streamable import Pipelineable
7575

76+
try:
77+
# This is a safety measure against torch package issues for when
78+
# Torchrec is included in the inference side model code. We should
79+
# remove this once we are sure all model side packages have the required
80+
# dependencies
81+
from torchrec.distributed.logger import _torchrec_method_logger
82+
except Exception:
83+
from torchrec.distributed.empty_logger import _torchrec_method_logger
7684

7785
logger: logging.Logger = logging.getLogger(__name__)
7886

@@ -106,6 +114,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
106114
def progress(self, dataloader_iter: Iterator[In]) -> Out:
107115
pass
108116

117+
# pyre-ignore [56]
118+
@_torchrec_method_logger()
109119
def __init__(self) -> None:
110120
# pipeline state such as in foward, in backward etc, used in training recover scenarios
111121
self._state: PipelineState = PipelineState.IDLE

torchrec/modules/mc_embedding_modules.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@
2020
from torchrec.modules.mc_modules import ManagedCollisionCollection
2121
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
2222

23+
try:
24+
# This is a safety measure against torch package issues for when
25+
# Torchrec is included in the inference side model code. We should
26+
# remove this once we are sure all model side packages have the required
27+
# dependencies
28+
from torchrec.distributed.logger import _torchrec_method_logger
29+
except Exception:
30+
from torchrec.distributed.empty_logger import _torchrec_method_logger
31+
2332

2433
def evict(
2534
evictions: Dict[str, Optional[torch.Tensor]],
@@ -125,6 +134,7 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio
125134
126135
"""
127136

137+
@_torchrec_method_logger()
128138
def __init__(
129139
self,
130140
embedding_collection: EmbeddingCollection,
@@ -164,6 +174,7 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec
164174
165175
"""
166176

177+
@_torchrec_method_logger()
167178
def __init__(
168179
self,
169180
embedding_bag_collection: EmbeddingBagCollection,

0 commit comments

Comments
 (0)