diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index 9dea96b0d..24f2f1780 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -93,6 +93,10 @@ class MetricName(MetricNameBase): NMSE = "nmse" NRMSE = "nrmse" + NUM_POSITIVE_SAMPLES = "num_positive_samples" + SUM_WEIGHTS = "sum_weights" + NUM_MISSING_LABELS = "num_missing_labels" + class MetricNamespaceBase(StrValueMixin, Enum): pass @@ -153,6 +157,10 @@ class MetricNamespace(MetricNamespaceBase): NMSE = "nmse" + NUM_POSITIVE_SAMPLES = "num_positive_samples" + SUM_WEIGHTS = "sum_weights" + NUM_MISSING_LABELS = "num_missing_labels" + class MetricPrefix(StrValueMixin, Enum): DEFAULT = "" diff --git a/torchrec/metrics/num_missing_labels.py b/torchrec/metrics/num_missing_labels.py new file mode 100644 index 000000000..6c921bc6a --- /dev/null +++ b/torchrec/metrics/num_missing_labels.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, +) + + +def compute_missing_label_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + return torch.sum(torch.where(torch.isnan(labels), weights, 0), dim=-1) + + +def get_num_missing_labels_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: Optional[torch.Tensor], +) -> Dict[str, torch.Tensor]: + if weights is None: + weights = torch.ones_like(labels) + return { + "missing_label_sum": compute_missing_label_sum(labels, predictions, weights), + } + + +class NumMissingLabelsMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for weighted number of missing labels. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "missing_label_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + states = get_num_missing_labels_states(labels, predictions, weights) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.NUM_MISSING_LABELS, + metric_prefix=MetricPrefix.LIFETIME, + value=cast(torch.Tensor, self.missing_label_sum), + ), + MetricComputationReport( + name=MetricName.NUM_MISSING_LABELS, + metric_prefix=MetricPrefix.WINDOW, + value=self.get_window_state("missing_label_sum"), + ), + ] + return reports + + +class NumMissingLabelsMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.NUM_MISSING_LABELS + _computation_class: Type[RecMetricComputation] = NumMissingLabelsMetricComputation diff --git a/torchrec/metrics/num_positive_samples.py b/torchrec/metrics/num_positive_samples.py new file mode 100644 index 000000000..cce48d637 --- /dev/null +++ b/torchrec/metrics/num_positive_samples.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, +) + + +def compute_weighted_pos_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + return torch.sum(weights * torch.nan_to_num(labels, 0), dim=-1) + + +def get_num_positive_sample_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: Optional[torch.Tensor], +) -> Dict[str, torch.Tensor]: + if weights is None: + weights = torch.ones_like(labels) + return { + "weighted_pos_sum": compute_weighted_pos_sum(labels, predictions, weights), + } + + +class NumPositiveSamplesMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for weighted number of positives. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "weighted_pos_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + states = get_num_positive_sample_states(labels, predictions, weights) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.NUM_POSITIVE_SAMPLES, + metric_prefix=MetricPrefix.LIFETIME, + value=cast(torch.Tensor, self.weighted_pos_sum), + ), + MetricComputationReport( + name=MetricName.NUM_POSITIVE_SAMPLES, + metric_prefix=MetricPrefix.WINDOW, + value=self.get_window_state("weighted_pos_sum"), + ), + ] + return reports + + +class NumPositiveSamplesMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.NUM_POSITIVE_SAMPLES + _computation_class: Type[RecMetricComputation] = NumPositiveSamplesMetricComputation diff --git a/torchrec/metrics/sum_weights.py b/torchrec/metrics/sum_weights.py new file mode 100644 index 000000000..a8ee89d08 --- /dev/null +++ b/torchrec/metrics/sum_weights.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, +) + + +def compute_weighted_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + return torch.sum(weights, dim=-1) + + +def get_weighted_sum_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: Optional[torch.Tensor], +) -> Dict[str, torch.Tensor]: + if weights is None: + weights = torch.ones_like(labels) + return { + "weighted_sum": compute_weighted_sum(labels, predictions, weights), + } + + +class SumWeightsMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for sum of weights. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "weighted_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + states = get_weighted_sum_states(labels, predictions, weights) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.SUM_WEIGHTS, + metric_prefix=MetricPrefix.LIFETIME, + value=cast(torch.Tensor, self.weighted_sum), + ), + MetricComputationReport( + name=MetricName.SUM_WEIGHTS, + metric_prefix=MetricPrefix.WINDOW, + value=self.get_window_state("weighted_sum"), + ), + ] + return reports + + +class SumWeightsMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.SUM_WEIGHTS + _computation_class: Type[RecMetricComputation] = SumWeightsMetricComputation diff --git a/torchrec/metrics/tests/test_num_missing_labels.py b/torchrec/metrics/tests/test_num_missing_labels.py new file mode 100644 index 000000000..952969409 --- /dev/null +++ b/torchrec/metrics/tests/test_num_missing_labels.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, Optional, Type, Union + +import torch +from torch import no_grad +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.num_missing_labels import ( + compute_missing_label_sum, + NumMissingLabelsMetric, +) +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + RecTaskInfo, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestNumMissingLabelsMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + missing_label_sum = torch.sum( + torch.where(torch.isnan(labels), weights, 0), dim=-1 + ) + return { + "missing_label_sum": missing_label_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return states["missing_label_sum"] + + +class NumMissingLabelsMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = NumMissingLabelsMetric + task_name: str = "num_missing_labels" + + def test_precision_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NumMissingLabelsMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNumMissingLabelsMetric, + metric_name=NumMissingLabelsMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_precision_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NumMissingLabelsMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestNumMissingLabelsMetric, + metric_name=NumMissingLabelsMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_precision_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NumMissingLabelsMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestNumMissingLabelsMetric, + metric_name=NumMissingLabelsMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class NumMissingLabelsGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = NumMissingLabelsMetric + task_name: str = "num_missing_labels" + + def test_sync_num_missing_labels(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=NumMissingLabelsMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNumMissingLabelsMetric, + metric_name=NumMissingLabelsGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Optional[torch.Tensor]]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.6]]), + "expected_num_missing_labels": torch.tensor([0]), + }, + # no weight + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": None, + "expected_num_missing_labels": torch.tensor([0]), + }, + # weights are 0.5 + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": torch.tensor([[0.5] * 6]), + "expected_num_missing_labels": torch.tensor([0]), + }, + # all weights are zero + { + "labels": torch.tensor([[1, 1, 1, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0] * 5]), + "expected_num_missing_labels": torch.tensor([0]), + }, + # Missing labels + { + "labels": torch.tensor([[float("nan"), 1, float("nan"), 1, 0]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[1] * 5]), + "expected_num_missing_labels": torch.tensor([2]), + }, + # Multi tasks + { + "labels": torch.tensor([[[1, 1, 1, 1, 1]], [[1, 0, 0, 1, 1]]]), + "predictions": torch.tensor( + [[0.2, 0.6, 0.8, 0.4, 0.9], [0.2, 0.6, 0.8, 0.4, 0.9]] + ), + "weights": torch.tensor([[1] * 5, [2] * 5]), + "expected_num_missing_labels": torch.tensor([0, 0]), + }, + # Missing labels different weights + { + "labels": torch.tensor([[float("nan"), 1, float("nan"), 0, float("nan")]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.2, 0.2, 0.8, 1, 1]]), + "expected_num_missing_labels": torch.tensor([2]), + }, + ] + + +class NumMissingLabelsTest(unittest.TestCase): + r"""This set of tests verify the computation logic of num positive samples in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @torch.no_grad() + def _test_num_missing_labels_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_num_missing_labels: torch.Tensor, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + if weights is None: + # pyre-ignore + inputs["weights"] = None + else: + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + num_missing_labels = NumMissingLabelsMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + ) + num_missing_labels.update(**inputs) + actual_num_missing_labels = num_missing_labels.compute() + + for task_id, task in enumerate(task_list): + cur_actual_num_missing_labels = actual_num_missing_labels[ + f"num_missing_labels-{task.name}|window_num_missing_labels" + ][0] + cur_expected_num_missing_labels = expected_num_missing_labels[task_id] + if cur_expected_num_missing_labels.isnan().any(): + self.assertTrue(cur_actual_num_missing_labels.isnan().any()) + else: + torch.testing.assert_close( + cur_actual_num_missing_labels, + cur_expected_num_missing_labels, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_num_missing_labels}, Expected: {cur_expected_num_missing_labels}", + ) + + def test_num_missing_labels(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + # pyre-ignore + self._test_num_missing_labels_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise diff --git a/torchrec/metrics/tests/test_num_positive_samples.py b/torchrec/metrics/tests/test_num_positive_samples.py new file mode 100644 index 000000000..774a92846 --- /dev/null +++ b/torchrec/metrics/tests/test_num_positive_samples.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, Optional, Type, Union + +import torch +from torch import no_grad +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.num_positive_samples import ( + compute_weighted_pos_sum, + NumPositiveSamplesMetric, +) +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + RecTaskInfo, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestNumPositiveSamplesMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + weighted_pos_sum = torch.sum(weights * torch.nan_to_num(labels, 0), dim=-1) + return { + "weighted_pos_sum": weighted_pos_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return states["weighted_pos_sum"] + + +class NumPositiveSamplesMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = NumPositiveSamplesMetric + task_name: str = "num_positive_samples" + + def test_precision_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NumPositiveSamplesMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNumPositiveSamplesMetric, + metric_name=NumPositiveSamplesMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_precision_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NumPositiveSamplesMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestNumPositiveSamplesMetric, + metric_name=NumPositiveSamplesMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_precision_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NumPositiveSamplesMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestNumPositiveSamplesMetric, + metric_name=NumPositiveSamplesMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class NumPositiveSamplesGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = NumPositiveSamplesMetric + task_name: str = "num_positive_samples" + + def test_sync_num_positive_samples(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=NumPositiveSamplesMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNumPositiveSamplesMetric, + metric_name=NumPositiveSamplesGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Optional[torch.Tensor]]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.6]]), + "expected_num_positive_samples": torch.tensor([0.8]), + }, + # no weight + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": None, + "expected_num_positive_samples": torch.tensor([3]), + }, + # weights are 0.5 + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": torch.tensor([[0.5] * 6]), + "expected_num_positive_samples": torch.tensor([1.5]), + }, + # all weights are zero + { + "labels": torch.tensor([[1, 1, 1, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0] * 5]), + "expected_num_positive_samples": torch.tensor([0]), + }, + # Missing labels + { + "labels": torch.tensor([[float("nan"), 1, float("nan"), 1, 0]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[1] * 5]), + "expected_num_positive_samples": torch.tensor([2]), + }, + # Multi tasks + { + "labels": torch.tensor([[[1, 1, 1, 1, 1]], [[1, 0, 0, 1, 1]]]), + "predictions": torch.tensor( + [[0.2, 0.6, 0.8, 0.4, 0.9], [0.2, 0.6, 0.8, 0.4, 0.9]] + ), + "weights": torch.tensor([[1] * 5, [2] * 5]), + "expected_num_positive_samples": torch.tensor([5, 6]), + }, + ] + + +class NumPositiveSamplesTest(unittest.TestCase): + r"""This set of tests verify the computation logic of num positive samples in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @torch.no_grad() + def _test_num_positive_samples_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_num_positive_samples: torch.Tensor, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + if weights is None: + # pyre-ignore + inputs["weights"] = None + else: + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + num_positive_samples = NumPositiveSamplesMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + ) + num_positive_samples.update(**inputs) + actual_num_positive_samples = num_positive_samples.compute() + + for task_id, task in enumerate(task_list): + cur_actual_num_positive_samples = actual_num_positive_samples[ + f"num_positive_samples-{task.name}|window_num_positive_samples" + ][0] + cur_expected_num_positive_samples = expected_num_positive_samples[task_id] + if cur_expected_num_positive_samples.isnan().any(): + self.assertTrue(cur_actual_num_positive_samples.isnan().any()) + else: + torch.testing.assert_close( + cur_actual_num_positive_samples, + cur_expected_num_positive_samples, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_num_positive_samples}, Expected: {cur_expected_num_positive_samples}", + ) + + def test_num_positive_samples(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + # pyre-ignore + self._test_num_positive_samples_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise diff --git a/torchrec/metrics/tests/test_sum_weights.py b/torchrec/metrics/tests/test_sum_weights.py new file mode 100644 index 000000000..352bd6bb2 --- /dev/null +++ b/torchrec/metrics/tests/test_sum_weights.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, Optional, Type, Union + +import torch +from torch import no_grad +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.sum_weights import compute_weighted_sum, SumWeightsMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + RecTaskInfo, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestSumWeightsMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + weighted_sum = torch.sum(weights, dim=-1) + return { + "weighted_sum": weighted_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return states["weighted_sum"] + + +class SumWeightsMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = SumWeightsMetric + task_name: str = "sum_weights" + + def test_precision_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=SumWeightsMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestSumWeightsMetric, + metric_name=SumWeightsMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_precision_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=SumWeightsMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestSumWeightsMetric, + metric_name=SumWeightsMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_precision_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=SumWeightsMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestSumWeightsMetric, + metric_name=SumWeightsMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class SumWeightsGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = SumWeightsMetric + task_name: str = "sum_weights" + + def test_sync_sum_weights(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=SumWeightsMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestSumWeightsMetric, + metric_name=SumWeightsGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Optional[torch.Tensor]]]: + return [ + # random_inputs + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.1, 0.1, 0.1, 0.1, 0.6]]), + "expected_sum_weights": torch.tensor([1.0]), + }, + # no weight + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": None, + "expected_sum_weights": torch.tensor([6]), + }, + # weights are 0.5 + { + "labels": torch.tensor([[1, 0, 1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5] * 6]), + "weights": torch.tensor([[0.5] * 6]), + "expected_sum_weights": torch.tensor([3]), + }, + # all weights are zero + { + "labels": torch.tensor([[1, 1, 1, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0] * 5]), + "expected_sum_weights": torch.tensor([0]), + }, + # Multi tasks + { + "labels": torch.tensor([[[1, 1, 1, 1, 1]], [[1, 0, 0, 1, 1]]]), + "predictions": torch.tensor( + [[0.2, 0.6, 0.8, 0.4, 0.9], [0.2, 0.6, 0.8, 0.4, 0.9]] + ), + "weights": torch.tensor([[1] * 5, [2] * 5]), + "expected_sum_weights": torch.tensor([5, 10]), + }, + ] + + +class SumWeightsTest(unittest.TestCase): + r"""This set of tests verify the computation logic of num positive samples in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @torch.no_grad() + def _test_sum_weights_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_sum_weights: torch.Tensor, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + inputs: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]] = { + "predictions": {}, + "labels": {}, + "weights": {}, + } + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + # pyre-ignore + inputs["predictions"][task_info.name] = predictions[i] + # pyre-ignore + inputs["labels"][task_info.name] = labels[i] + if weights is None: + # pyre-ignore + inputs["weights"] = None + else: + # pyre-ignore + inputs["weights"][task_info.name] = weights[i] + + sum_weights = SumWeightsMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + ) + sum_weights.update(**inputs) + actual_sum_weights = sum_weights.compute() + + for task_id, task in enumerate(task_list): + cur_actual_sum_weights = actual_sum_weights[ + f"sum_weights-{task.name}|window_sum_weights" + ][0] + cur_expected_sum_weights = expected_sum_weights[task_id] + if cur_expected_sum_weights.isnan().any(): + self.assertTrue(cur_actual_sum_weights.isnan().any()) + else: + torch.testing.assert_close( + cur_actual_sum_weights, + cur_expected_sum_weights, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_sum_weights}, Expected: {cur_expected_sum_weights}", + ) + + def test_sum_weights(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + # pyre-ignore + self._test_sum_weights_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise