From c81959e2d7d2928ef850b21354dc4c9ba548b888 Mon Sep 17 00:00:00 2001 From: Jeff Kim Date: Tue, 10 Feb 2026 13:11:18 -0800 Subject: [PATCH] RecMetricModule: torch.cat all tensor lists before gloo all gathers (#3593) Summary: metric_module's get_pre_compute_states() provides an API to perform gloo all gathers instead of the default torchmetric.Metric's sync_dist (nccl). However, the mechanism calls gloo all gathers for each element in a list of tensors. This can be problematic because: - AUC's 3 state tensors hold a list of tensors, not a single tensor. - The size of the tensor list is theoretically unbounded. (In practice, it can grow to orders of 100K) - gloo all gathers are inherently much slower. Instead, this patch aims to: - apply the reduction function prior to the all gather if we're processing a tensor list - enforce that the reduction_fn does not rely on ordering Reviewed By: iamzainhuda Differential Revision: D88297404 --- torchrec/metrics/metric_module.py | 80 +-- torchrec/metrics/tests/test_metric_module.py | 496 ++++++++++++++++++- 2 files changed, 544 insertions(+), 32 deletions(-) diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 2193c48c61..ecb941301c 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -406,27 +406,62 @@ def _get_metric_states( world_size: int, process_group: Union[dist.ProcessGroup, DeviceMesh], ) -> Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]: + """ + Gather metric states from all ranks and apply reduction. + + For list states (e.g., AUC predictions/labels): + - Concatenate local tensors into single tensor for efficient transport + - Single all_gather instead of N all_gathers + - Apply reduction once on gathered tensors + + For tensor states (e.g., NE cross_entropy_sum): + - Stack gathered tensors + - Apply reduction + + This approach works with ANY reduction function (no associativity requirement). + """ result = defaultdict(dict) for task, computation in zip(metric._tasks, metric._metrics_computations): # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `items`. for state_name, reduction_fn in computation._reductions.items(): - tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr( - computation, state_name - ) - - if isinstance(tensor_or_list, list): - gathered = _all_gather_tensor_list( - tensor_or_list, world_size, process_group - ) - else: - gathered = torch.stack( - _all_gather_tensor(tensor_or_list, world_size, process_group) + with record_function(f"## RecMetricModule: {state_name} all gather ##"): + tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr( + computation, state_name ) - reduced = ( - reduction_fn(gathered) if reduction_fn is not None else gathered - ) - result[task.name][state_name] = reduced + + if isinstance(tensor_or_list, list): + if len(tensor_or_list) == 0: + result[task.name][state_name] = [] + continue + + # Concatenate local tensors into single tensor for transport + # This reduces N all_gathers to 1 all_gather + local_concat = torch.cat(tensor_or_list, dim=-1) + + # Single all_gather for the concatenated tensor + gathered_list = _all_gather_tensor( + local_concat, world_size, process_group + ) + + # Apply reduction once on the gathered tensors (one per rank) + if reduction_fn is not None: + reduced = reduction_fn(gathered_list) + else: + reduced = gathered_list + else: + # Single tensor case + gathered = torch.stack( + _all_gather_tensor( + tensor_or_list, world_size, process_group + ) + ) + if reduction_fn is not None: + reduced = reduction_fn(gathered) + else: + reduced = gathered + + result[task.name][state_name] = reduced return result @@ -475,7 +510,8 @@ def get_pre_compute_states( # throughput metric requires special handling, since it's not a RecMetric throughput_metric = self.throughput_metric if throughput_metric is not None: - aggregated_states[throughput_metric._namespace.value] = ( + # Merge in case there are rec metric namespaces that overlap with throughput metric namespace + aggregated_states.setdefault(throughput_metric._namespace.value, {}).update( self._get_throughput_metric_states(throughput_metric) ) @@ -662,15 +698,3 @@ def _all_gather_tensor( out = [torch.empty_like(tensor) for _ in range(world_size)] # pragma: no cover dist.all_gather(out, tensor, group=pg) return out - - -def _all_gather_tensor_list( - tensors: List[torch.Tensor], - world_size: int, - pg: Union[dist.ProcessGroup, DeviceMesh], -) -> List[torch.Tensor]: - """All-gather every tensor in a list and flatten the result.""" - gathered: List[torch.Tensor] = [] # pragma: no cover - for t in tensors: - gathered.extend(_all_gather_tensor(t, world_size, pg)) - return gathered diff --git a/torchrec/metrics/tests/test_metric_module.py b/torchrec/metrics/tests/test_metric_module.py index ea9b12c6ef..de3fa12b89 100644 --- a/torchrec/metrics/tests/test_metric_module.py +++ b/torchrec/metrics/tests/test_metric_module.py @@ -23,7 +23,7 @@ MultiProcessContext, MultiProcessTestBase, ) -from torchrec.metrics.auc import AUCMetric +from torchrec.metrics.auc import _state_reduction, AUCMetric from torchrec.metrics.metric_module import ( generate_metric_module, MetricsResult, @@ -43,7 +43,8 @@ ) from torchrec.metrics.model_utils import parse_task_model_outputs from torchrec.metrics.rec_metric import RecMetricException, RecMetricList, RecTaskInfo -from torchrec.metrics.test_utils import gen_test_batch +from torchrec.metrics.test_utils import gen_test_batch, gen_test_tasks +from torchrec.metrics.test_utils.mock_metrics import MockRecMetric from torchrec.metrics.throughput import ThroughputMetric from torchrec.test_utils import get_free_port, seed_and_log, skip_if_asan_class @@ -850,6 +851,68 @@ def test_async_compute_raises_exception(self) -> None: ): metric_module.async_compute() + def test_shutdown(self) -> None: + metric_module = generate_metric_module( + TestMetricModule, + metrics_config=DefaultMetricsConfig, + batch_size=128, + world_size=1, + my_rank=0, + state_metrics_mapping={}, + device=torch.device("cpu"), + ) + # shutdown() should not raise any exception + metric_module.shutdown() + + def test_local_compute(self) -> None: + metric_module = generate_metric_module( + TestMetricModule, + metrics_config=DefaultMetricsConfig, + batch_size=128, + world_size=1, + my_rank=0, + state_metrics_mapping={}, + device=torch.device("cpu"), + ) + metric_module.update(gen_test_batch(128)) + result = metric_module.local_compute() + self.assertIsInstance(result, dict) + + def test_get_required_inputs(self) -> None: + metric_module = generate_metric_module( + TestMetricModule, + metrics_config=DefaultMetricsConfig, + batch_size=128, + world_size=1, + my_rank=0, + state_metrics_mapping={}, + device=torch.device("cpu"), + ) + # get_required_inputs delegates to rec_metrics + result = metric_module.get_required_inputs() + # Result can be None or a list depending on metric configuration + self.assertTrue(result is None or isinstance(result, list)) + + def test_invalid_max_compute_interval(self) -> None: + with self.assertRaises(ValueError) as context: + RecMetricModule( + batch_size=128, + world_size=1, + min_compute_interval=5.0, + max_compute_interval=0.0, # Invalid: <= 0 when min is set + ) + self.assertIn("Max compute interval", str(context.exception)) + + def test_invalid_min_compute_interval(self) -> None: + with self.assertRaises(ValueError) as context: + RecMetricModule( + batch_size=128, + world_size=1, + min_compute_interval=-1.0, # Invalid: < 0 + max_compute_interval=30.0, + ) + self.assertIn("Min compute interval", str(context.exception)) + def test_load_state_dict_with_trained_batches_key(self) -> None: metric_module = generate_metric_module( TestMetricModule, @@ -1001,7 +1064,7 @@ def test_post_init_raises_when_rec_tasks_is_none(self) -> None: # Execute & Assert: should raise ValueError about rec_tasks being None with self.assertRaises(ValueError) as context: - config = MetricsConfig( + _ = MetricsConfig( rec_tasks=None, # pyre-ignore[6]: Intentionally passing None for testing rec_metrics={ RecMetricEnum.AUC: RecMetricDef(rec_task_indices=[0]), @@ -1021,7 +1084,7 @@ def test_post_init_raises_when_rec_task_index_out_of_range(self) -> None: # Execute & Assert: should raise ValueError about index out of range with self.assertRaises(ValueError) as context: - config = MetricsConfig( + _ = MetricsConfig( rec_tasks=rec_tasks, rec_metrics={ RecMetricEnum.NE: RecMetricDef( @@ -1074,3 +1137,428 @@ def test_metric_module_gather_state(self) -> None: batch_size=batch_size, config=metrics_config, ) + + +@skip_if_asan_class +class MetricModuleGlooDistributedTest(MultiProcessTestBase): + """ + Distributed tests using GLOO backend (works on CPU). + Tests _get_metric_states functionality with torch.cat optimization. + """ + + def setUp(self) -> None: + super().setUp() + self.device = torch.device("cpu") + + def test_get_metric_states_list_reduction(self) -> None: + """ + Test _get_metric_states with list states and concatenation reduction. + Validates the torch.cat optimization for AUC-like metrics. + """ + world_size = 2 + backend = "gloo" + + self._run_multi_process_test( + callable=_test_get_metric_states_with_list_reduction, + world_size=world_size, + backend=backend, + ) + + def test_get_metric_states_tensor_reduction(self) -> None: + """ + Test _get_metric_states with tensor states and sum reduction. + Validates standard reduction for NE-like metrics. + """ + world_size = 2 + backend = "gloo" + + self._run_multi_process_test( + callable=_test_get_metric_states_with_tensor_reduction, + world_size=world_size, + backend=backend, + ) + + def test_get_metric_states_single_tensor(self) -> None: + """ + Test _get_metric_states with a single tensor in the list. + Edge case validation. + """ + world_size = 2 + backend = "gloo" + + self._run_multi_process_test( + callable=_test_get_metric_states_with_single_tensor, + world_size=world_size, + backend=backend, + ) + + def test_get_metric_states_reduction_fn_none(self) -> None: + """ + Test _get_metric_states with reduction_fn=None. + Validates that no TypeError is raised when reduction_fn is None. + """ + world_size = 2 + backend = "gloo" + + self._run_multi_process_test( + callable=_test_get_metric_states_with_reduction_fn_none, + world_size=world_size, + backend=backend, + ) + + def test_get_metric_states_asymmetric_batches(self) -> None: + """ + Test _get_metric_states with different batch values across ranks. + + This validates that the torch.cat approach correctly aggregates + data when ranks have different tensor values (same batch count). + - Rank 0: 3 batch updates with values 1-6 + - Rank 1: 3 batch updates with values 7-12 + """ + world_size = 2 + backend = "gloo" + + self._run_multi_process_test( + callable=_test_get_metric_states_with_asymmetric_batches, + world_size=world_size, + backend=backend, + ) + + +def _test_get_metric_states_with_list_reduction( + rank: int, + world_size: int, + backend: str, +) -> None: + """Test _get_metric_states with list states and concatenation reduction (AUC-like).""" + with MultiProcessContext(rank, world_size, backend) as ctx: + # Create mock metric with list state using concatenation reduction + tasks = gen_test_tasks(["task1"]) + mock_metric = MockRecMetric( + world_size=world_size, + my_rank=rank, + batch_size=10, + tasks=tasks, + is_tensor_list=True, + reduction_fn=_state_reduction, + initial_states={"predictions": []}, + ) + + # Each rank appends different local tensors to simulate batch updates + # Rank 0: [[1, 2], [3, 4]] -> after local concat: [1, 2, 3, 4] + # Rank 1: [[5, 6], [7, 8]] -> after local concat: [5, 6, 7, 8] + # After global gather: [[1, 2, 3, 4], [5, 6, 7, 8]] -> reduction -> [[1, 2, 3, 4, 5, 6, 7, 8]] + if rank == 0: + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[1.0, 2.0]], device=ctx.device)} + ) + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[3.0, 4.0]], device=ctx.device)} + ) + else: # rank == 1 + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[5.0, 6.0]], device=ctx.device)} + ) + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[7.0, 8.0]], device=ctx.device)} + ) + + # Execute: Call _get_metric_states + metric_module = RecMetricModule( + batch_size=10, + world_size=world_size, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + result = metric_module._get_metric_states( + metric=mock_metric, + world_size=world_size, + process_group=ctx.pg or dist.group.WORLD, + ) + + # Assert: Verify result matches expected + # With torch.cat approach: + # - Local concat: rank0 [1,2,3,4], rank1 [5,6,7,8] + # - All-gather produces [[1,2,3,4], [5,6,7,8]] + # - _state_reduction concatenates: [1,2,3,4,5,6,7,8] + expected = [ + torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]], device=ctx.device) + ] + + actual = result["task1"]["predictions"] + + assert len(actual) == len( + expected + ), f"Expected {len(expected)} tensors, got {len(actual)}" + torch.testing.assert_close( + actual[0], + expected[0], + msg="Mismatch in gathered predictions", + ) + + +def _test_get_metric_states_with_tensor_reduction( + rank: int, + world_size: int, + backend: str, +) -> None: + """Test _get_metric_states with tensor states and sum reduction (NE-like).""" + with MultiProcessContext(rank, world_size, backend) as ctx: + # Create mock metric with tensor state using sum reduction + tasks = gen_test_tasks(["task1"]) + initial_value = torch.tensor( + [float(rank + 1)], device=ctx.device + ) # Rank 0: [1.0], Rank 1: [2.0] + mock_metric = MockRecMetric( + world_size=world_size, + my_rank=rank, + batch_size=10, + tasks=tasks, + is_tensor_list=False, + reduction_fn="sum", + initial_states={"state1": initial_value}, + ) + + # Execute: Call _get_metric_states + metric_module = RecMetricModule( + batch_size=10, + world_size=world_size, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + result = metric_module._get_metric_states( + metric=mock_metric, + world_size=world_size, + process_group=ctx.pg or dist.group.WORLD, + ) + + # Assert: Verify result matches expected + # Expected: sum([rank0_value, rank1_value]) = sum([1.0, 2.0]) = 3.0 + expected = torch.tensor([3.0], device=ctx.device) + + actual = result["task1"]["state1"] + + torch.testing.assert_close( + actual, + expected, + msg="Mismatch in summed state", + ) + + +def _test_get_metric_states_with_single_tensor( + rank: int, + world_size: int, + backend: str, +) -> None: + """Test _get_metric_states with a single tensor in the list (edge case).""" + with MultiProcessContext(rank, world_size, backend) as ctx: + + # Create mock metric with list state containing a single tensor + tasks = gen_test_tasks(["task1"]) + mock_metric = MockRecMetric( + world_size=world_size, + my_rank=rank, + batch_size=10, + tasks=tasks, + is_tensor_list=True, + reduction_fn=_state_reduction, + initial_states={"predictions": []}, + ) + + # Each rank has a single tensor + # Rank 0: [[1, 2]] + # Rank 1: [[3, 4]] + # After local concat (no-op since single tensor): [1, 2] and [3, 4] + # After global gather+concat: [[1, 2, 3, 4]] + if rank == 0: + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[1.0, 2.0]], device=ctx.device)} + ) + else: # rank == 1 + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[3.0, 4.0]], device=ctx.device)} + ) + + # Execute: Call _get_metric_states + metric_module = RecMetricModule( + batch_size=10, + world_size=world_size, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + result = metric_module._get_metric_states( + metric=mock_metric, + world_size=world_size, + process_group=ctx.pg or dist.group.WORLD, + ) + + # Assert: Verify result matches expected + # Expected: [[1, 2, 3, 4]] + expected = [torch.tensor([[1.0, 2.0, 3.0, 4.0]], device=ctx.device)] + + actual = result["task1"]["predictions"] + + assert len(actual) == len( + expected + ), f"Expected {len(expected)} tensors, got {len(actual)}" + torch.testing.assert_close( + actual[0], + expected[0], + msg="Mismatch in gathered predictions for single tensor case", + ) + + +def _test_get_metric_states_with_reduction_fn_none( + rank: int, + world_size: int, + backend: str, +) -> None: + """Test _get_metric_states with reduction_fn=None (no reduction applied).""" + with MultiProcessContext(rank, world_size, backend) as ctx: + # Create mock metric with list state and reduction_fn=None + tasks = gen_test_tasks(["task1"]) + mock_metric = MockRecMetric( + world_size=world_size, + my_rank=rank, + batch_size=10, + tasks=tasks, + is_tensor_list=True, + reduction_fn=None, # No reduction + initial_states={"predictions": []}, + ) + + # Each rank has tensors + if rank == 0: + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[1.0, 2.0]], device=ctx.device)} + ) + else: # rank == 1 + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[3.0, 4.0]], device=ctx.device)} + ) + + # Execute: Call _get_metric_states - should NOT raise TypeError + metric_module = RecMetricModule( + batch_size=10, + world_size=world_size, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + result = metric_module._get_metric_states( + metric=mock_metric, + world_size=world_size, + process_group=ctx.pg or dist.group.WORLD, + ) + + # Assert: With reduction_fn=None, gathered_list is returned as-is + # After torch.cat locally: rank0 [1,2], rank1 [3,4] + # After all-gather: [[1,2], [3,4]] (list of 2 tensors) + actual = result["task1"]["predictions"] + + # Should be a list of 2 tensors (one per rank) + assert isinstance(actual, list), f"Expected list, got {type(actual)}" + assert ( + len(actual) == world_size + ), f"Expected {world_size} tensors, got {len(actual)}" + + # Verify the gathered tensors + expected_rank0 = torch.tensor([[1.0, 2.0]], device=ctx.device) + expected_rank1 = torch.tensor([[3.0, 4.0]], device=ctx.device) + + torch.testing.assert_close( + actual[0], + expected_rank0, + msg="Mismatch in rank 0 gathered tensor", + ) + torch.testing.assert_close( + actual[1], + expected_rank1, + msg="Mismatch in rank 1 gathered tensor", + ) + + +def _test_get_metric_states_with_asymmetric_batches( + rank: int, + world_size: int, + backend: str, +) -> None: + """ + Test _get_metric_states with different batch values across ranks. + + This validates that the torch.cat approach correctly aggregates data + when ranks have different tensor values (same batch count). + """ + with MultiProcessContext(rank, world_size, backend) as ctx: + tasks = gen_test_tasks(["task1"]) + mock_metric = MockRecMetric( + world_size=world_size, + my_rank=rank, + batch_size=10, + tasks=tasks, + is_tensor_list=True, + reduction_fn=_state_reduction, + initial_states={"predictions": []}, + ) + + # Same batch count per rank (3 batches each), but different values: + # Rank 0: [[1,2], [3,4], [5,6]] -> local concat -> [1,2,3,4,5,6] + # Rank 1: [[7,8], [9,10], [11,12]] -> local concat -> [7,8,9,10,11,12] + # After all_gather: [[1,2,3,4,5,6], [7,8,9,10,11,12]] + # After reduction (_state_reduction = concat): [1,2,3,4,5,6,7,8,9,10,11,12] + if rank == 0: + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[1.0, 2.0]], device=ctx.device)} + ) + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[3.0, 4.0]], device=ctx.device)} + ) + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[5.0, 6.0]], device=ctx.device)} + ) + else: # rank == 1 + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[7.0, 8.0]], device=ctx.device)} + ) + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[9.0, 10.0]], device=ctx.device)} + ) + mock_metric.append_to_computation_states( + {"predictions": torch.tensor([[11.0, 12.0]], device=ctx.device)} + ) + + metric_module = RecMetricModule( + batch_size=10, + world_size=world_size, + rec_tasks=tasks, + rec_metrics=RecMetricList([mock_metric]), + ) + + result = metric_module._get_metric_states( + metric=mock_metric, + world_size=world_size, + process_group=ctx.pg or dist.group.WORLD, + ) + + # Expected: All values concatenated in rank order + # [1,2,3,4,5,6,7,8,9,10,11,12] + expected = [ + torch.tensor( + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]], + device=ctx.device, + ) + ] + + actual = result["task1"]["predictions"] + + assert len(actual) == len( + expected + ), f"Expected {len(expected)} tensors, got {len(actual)}" + torch.testing.assert_close( + actual[0], + expected[0], + msg="Mismatch in gathered predictions with multiple batches per rank", + )