Skip to content

Commit bd5a187

Browse files
nipung90facebook-github-bot
authored andcommitted
Enable logging for the plan() function, ShardEstimators and TrainingPipeline class constructors (#3576)
Summary: Pull Request resolved: #3576 Differential Revision: D87910772
1 parent ca2f687 commit bd5a187

File tree

7 files changed

+60
-4
lines changed

7 files changed

+60
-4
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
FUSED_PARAM_IS_SSD_TABLE,
5656
FUSED_PARAM_SSD_TABLE_LIST,
5757
)
58+
from torchrec.distributed.logger import _torchrec_method_logger
5859
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5960
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
6061
from torchrec.distributed.sharding.dynamic_sharding import (
@@ -466,6 +467,7 @@ class ShardedEmbeddingBagCollection(
466467
This is part of the public API to allow for manual data dist pipelining.
467468
"""
468469

470+
@_torchrec_method_logger()
469471
def __init__(
470472
self,
471473
module: EmbeddingBagCollectionInterface,
@@ -2021,6 +2023,7 @@ class ShardedEmbeddingBag(
20212023
This is part of the public API to allow for manual data dist pipelining.
20222024
"""
20232025

2026+
@_torchrec_method_logger()
20242027
def __init__(
20252028
self,
20262029
module: nn.EmbeddingBag,

torchrec/distributed/planner/planners.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import nn
1919
from torchrec.distributed.collective_utils import invoke_on_rank_and_broadcast_result
2020
from torchrec.distributed.comm import get_local_size
21+
from torchrec.distributed.logger import _torchrec_method_logger
2122
from torchrec.distributed.planner.constants import BATCH_SIZE, MAX_SIZE
2223
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
2324
from torchrec.distributed.planner.partitioners import (
@@ -498,6 +499,7 @@ def collective_plan(
498499
sharders,
499500
)
500501

502+
@_torchrec_method_logger()
501503
def plan(
502504
self,
503505
module: nn.Module,

torchrec/distributed/planner/shard_estimators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torchrec.optim as trec_optim
1717
from torch import nn
1818
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
19+
from torchrec.distributed.logger import _torchrec_method_logger
1920
from torchrec.distributed.planner.constants import (
2021
BATCHED_COPY_PERF_FACTOR,
2122
BIGINT_DTYPE,
@@ -955,6 +956,7 @@ class EmbeddingStorageEstimator(ShardEstimator):
955956
is_inference (bool): If the model is inference model. Default to False.
956957
"""
957958

959+
@_torchrec_method_logger()
958960
def __init__(
959961
self,
960962
topology: Topology,

torchrec/distributed/shard.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.distributed._composable.contract import contract
1616
from torchrec.distributed.comm import get_local_size
1717
from torchrec.distributed.global_settings import get_propogate_device
18+
from torchrec.distributed.logger import _torchrec_method_logger
1819
from torchrec.distributed.model_parallel import get_default_sharders
1920
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
2021
from torchrec.distributed.sharding_plan import (
@@ -146,6 +147,7 @@ def _shard(
146147

147148
# pyre-ignore
148149
@contract()
150+
@_torchrec_method_logger()
149151
def shard_modules(
150152
module: nn.Module,
151153
env: Optional[ShardingEnv] = None,
@@ -194,6 +196,7 @@ def init_weights(m):
194196
return _shard_modules(module, env, device, plan, sharders, init_params)
195197

196198

199+
@_torchrec_method_logger()
197200
def _shard_modules( # noqa: C901
198201
module: nn.Module,
199202
# TODO: Consolidate to using Dict[str, ShardingEnv]

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,26 @@ def test_sharding_ebc_input_validation_enabled(self, mock_jk: Mock) -> None:
248248
values=torch.tensor([1, 2, 3, 4, 5]),
249249
lengths=torch.tensor([1, 2, 0, 2]),
250250
offsets=torch.tensor([0, 1, 3, 3, 5]),
251-
)
251+
).to(self.device)
252252
mock_jk.return_value = True
253253

254254
with self.assertRaisesRegex(ValueError, "keys must be unique"):
255255
model(kjt)
256-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
256+
257+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
258+
# This ignores any other calls to justknobs_check() with other inputs
259+
# and protects the test from breaking when new JK checks are added.
260+
validation_calls = [
261+
call
262+
for call in mock_jk.call_args_list
263+
if len(call[0]) > 0
264+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
265+
]
266+
self.assertEqual(
267+
1,
268+
len(validation_calls),
269+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
270+
)
257271

258272
@patch("torch._utils_internal.justknobs_check")
259273
def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
@@ -271,7 +285,20 @@ def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
271285
model(kjt)
272286
model(kjt)
273287

274-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
288+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
289+
# This ignores any other calls to justknobs_check() with other inputs
290+
# and protects the test from breaking when new JK checks are added.
291+
validation_calls = [
292+
call
293+
for call in mock_jk.call_args_list
294+
if len(call[0]) > 0
295+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
296+
]
297+
self.assertEqual(
298+
1,
299+
len(validation_calls),
300+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
301+
)
275302
matched_logs = list(
276303
filter(lambda s: "Validating input features..." in s, logs.output)
277304
)
@@ -294,7 +321,20 @@ def test_sharding_ebc_input_validation_disabled(self, mock_jk: Mock) -> None:
294321
except ValueError:
295322
self.fail("Input validation should not be enabled.")
296323

297-
mock_jk.assert_called_once_with("pytorch/torchrec:enable_kjt_validation")
324+
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
325+
# This ignores any other calls to justknobs_check() with other inputs
326+
# and protects the test from breaking when new JK checks are added.
327+
validation_calls = [
328+
call
329+
for call in mock_jk.call_args_list
330+
if len(call[0]) > 0
331+
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
332+
]
333+
self.assertEqual(
334+
1,
335+
len(validation_calls),
336+
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
337+
)
298338

299339
def _create_sharded_model(
300340
self, embedding_dim: int = 128, num_embeddings: int = 256

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import torch
3333
from torch.autograd.profiler import record_function
3434
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
35+
from torchrec.distributed.logger import _torchrec_method_logger
3536
from torchrec.distributed.model_parallel import ShardedModule
3637
from torchrec.distributed.train_pipeline.pipeline_context import (
3738
EmbeddingTrainPipelineContext,
@@ -106,6 +107,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
106107
def progress(self, dataloader_iter: Iterator[In]) -> Out:
107108
pass
108109

110+
# pyre-ignore [56]
111+
@_torchrec_method_logger()
109112
def __init__(self) -> None:
110113
# pipeline state such as in foward, in backward etc, used in training recover scenarios
111114
self._state: PipelineState = PipelineState.IDLE

torchrec/modules/mc_embedding_modules.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn as nn
15+
from torchrec.distributed.logger import _torchrec_method_logger
1516

1617
from torchrec.modules.embedding_modules import (
1718
EmbeddingBagCollection,
@@ -125,6 +126,7 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio
125126
126127
"""
127128

129+
@_torchrec_method_logger()
128130
def __init__(
129131
self,
130132
embedding_collection: EmbeddingCollection,
@@ -164,6 +166,7 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec
164166
165167
"""
166168

169+
@_torchrec_method_logger()
167170
def __init__(
168171
self,
169172
embedding_bag_collection: EmbeddingBagCollection,

0 commit comments

Comments
 (0)