diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index b9fbc8e6f..07ad9f494 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1543,7 +1543,7 @@ def input_dist( "pytorch/torchrec:enable_kjt_validation" ): logger.info("Validating input features...") - validate_keyed_jagged_tensor(features) + validate_keyed_jagged_tensor(features, self._embedding_bag_configs) self._create_input_dist(features.keys()) self._has_uninitialized_input_dist = False diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index f4f7008ac..e68af4f2a 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -2825,11 +2825,15 @@ def to_dict(self, compute_offsets: bool = True) -> Dict[str, JaggedTensor]: logger.warning( "Trying to non-strict torch.export KJT to_dict, which is extremely slow and not recommended!" ) + length_per_key = self.length_per_key() + if isinstance(length_per_key, torch.Tensor): + # length_per_key should be a list of ints, but in some (incorrect) cases it is a tensor + length_per_key = length_per_key.tolist() _jt_dict = _maybe_compute_kjt_to_jt_dict( stride=self.stride(), stride_per_key=self.stride_per_key(), keys=self.keys(), - length_per_key=self.length_per_key(), + length_per_key=length_per_key, lengths=self.lengths(), values=self.values(), variable_stride_per_key=self.variable_stride_per_key(), diff --git a/torchrec/sparse/jagged_tensor_validator.py b/torchrec/sparse/jagged_tensor_validator.py index 84132c23e..8f71a34f8 100644 --- a/torchrec/sparse/jagged_tensor_validator.py +++ b/torchrec/sparse/jagged_tensor_validator.py @@ -7,30 +7,50 @@ # pyre-strict +import logging +from typing import Dict, List, Optional + import torch +from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +logger: logging.Logger = logging.getLogger(__name__) + def validate_keyed_jagged_tensor( - kjt: KeyedJaggedTensor, -) -> None: + kjt: KeyedJaggedTensor, configs: Optional[List[EmbeddingBagConfig]] = None +) -> bool: """ Validates the inputs that construct a KeyedJaggedTensor. Any invalid input will result in a ValueError being thrown. + + Returns: + bool: True if all validations pass (including feature range), + False if feature range validation fails (soft warning). """ _validate_lengths_and_offsets(kjt) _validate_keys(kjt) _validate_weights(kjt) + if configs is not None: + return _validate_feature_range(kjt, configs) + else: + return True -def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None: +def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> bool: """ Validates the lengths and offsets of a KJT. - At least one of lengths or offsets is provided - If both are provided, they are consistent with each other - The dimensions of these tensors align with the values tensor + + Returns: + bool: True if validation passes. + + Raises: + ValueError: If validation fails. """ lengths = kjt.lengths_or_none() offsets = kjt.offsets_or_none() @@ -44,6 +64,7 @@ def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None: _validate_lengths(lengths, kjt.values()) elif offsets is not None: _validate_offsets(offsets, kjt.values()) + return True def _validate_lengths_and_offsets_consistency( @@ -59,16 +80,36 @@ def _validate_lengths_and_offsets_consistency( if not lengths.equal(torch.diff(offsets)): raise ValueError("offsets is not equal to the cumulative sum of lengths") + return True -def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> None: +def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> bool: + """ + Validates lengths tensor. + + Returns: + bool: True if validation passes. + + Raises: + ValueError: If validation fails. + """ if lengths.sum().item() != values.numel(): raise ValueError( f"Sum of lengths must equal the number of values, but got {lengths.sum().item()} and {values.numel()}" ) + return True + +def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> bool: + """ + Validates offsets tensor. -def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None: + Returns: + bool: True if validation passes. + + Raises: + ValueError: If validation fails. + """ if offsets.numel() == 0: raise ValueError("offsets cannot be empty") @@ -79,14 +120,21 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None: raise ValueError( f"The last element of offsets must equal to the number of values, but got {offsets[-1]} and {values.numel()}" ) + return True -def _validate_keys(kjt: KeyedJaggedTensor) -> None: +def _validate_keys(kjt: KeyedJaggedTensor) -> bool: """ Validates KJT keys, assuming the lengths/offsets input are valid. - keys must be unique - For non-VBE cases, the size of lengths is divisible by the number of keys + + Returns: + bool: True if validation passes. + + Raises: + ValueError: If validation fails. """ keys = kjt.keys() @@ -110,14 +158,60 @@ def _validate_keys(kjt: KeyedJaggedTensor) -> None: raise ValueError( f"lengths size must be divisible by keys size, but got {lengths_size} and {len(keys)}" ) + return True -def _validate_weights(kjt: KeyedJaggedTensor) -> None: +def _validate_weights(kjt: KeyedJaggedTensor) -> bool: """ Validates if the KJT weights has the same size as values. + + Returns: + bool: True if validation passes. + + Raises: + ValueError: If validation fails. """ weights = kjt.weights_or_none() if weights is not None and weights.numel() != kjt.values().numel(): raise ValueError( f"weights size must equal to values size, but got {weights.numel()} and {kjt.values().numel()}" ) + return True + + +def _validate_feature_range( + kjt: KeyedJaggedTensor, configs: List[EmbeddingBagConfig] +) -> bool: + """ + Validates if the KJT feature range is valid. + + Returns: + bool: True if all features are within valid range, False otherwise. + """ + feature_to_range_map: Dict[str, int] = {} + for config in configs: + for feature in config.feature_names: + feature_to_range_map[feature] = config.num_embeddings + + if len(kjt.keys() & feature_to_range_map.keys()) == 0: + logger.info( + f"None of KJT._keys {kjt.keys()} in the config {feature_to_range_map.keys()}" + ) + return True + + valid = True + jtd = kjt.to_dict() + for feature, jt in jtd.items(): + if feature not in feature_to_range_map: + logger.info(f"Feature {feature} is not in the config") + continue + if jt.values().numel() == 0: + continue + min_value, max_value = jt.values().min(), jt.values().max() + if min_value < 0 or max_value >= feature_to_range_map[feature]: + logger.warning( + f"Feature {feature} has range {min_value, max_value} " + f"which is out of range {0, feature_to_range_map[feature]}" + ) + valid = False + return valid diff --git a/torchrec/sparse/tests/test_jagged_tensor_validator.py b/torchrec/sparse/tests/test_jagged_tensor_validator.py index c4c4531e6..8d681dd71 100644 --- a/torchrec/sparse/tests/test_jagged_tensor_validator.py +++ b/torchrec/sparse/tests/test_jagged_tensor_validator.py @@ -14,6 +14,7 @@ import torch from hypothesis import given, settings, strategies as st, Verbosity from parameterized import param, parameterized +from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.sparse.jagged_tensor_validator import validate_keyed_jagged_tensor @@ -236,4 +237,188 @@ def test_valid_kjt_from_offsets( def test_valid_empty_kjt(self) -> None: kjt = KeyedJaggedTensor.empty() - validate_keyed_jagged_tensor(kjt) + result = validate_keyed_jagged_tensor(kjt) + self.assertTrue(result) + + def test_feature_range_valid_values(self) -> None: + # Setup: Create KJT with values within valid range and corresponding configs + kjt = KeyedJaggedTensor( + keys=["feature1", "feature2"], + values=torch.tensor([0, 5, 10, 1, 2, 3]), + lengths=torch.tensor([3, 3]), + ) + configs = [ + EmbeddingBagConfig( + num_embeddings=20, + embedding_dim=8, + name="table1", + feature_names=["feature1", "feature2"], + ) + ] + + # Execute: Validate KJT with configs + result = validate_keyed_jagged_tensor(kjt, configs) + + # Assert: Validation should return True for valid values + self.assertTrue(result) + + def test_feature_range_boundary_values(self) -> None: + # Setup: Create KJT with boundary values (0 and num_embeddings-1) + kjt = KeyedJaggedTensor( + keys=["feature1"], + values=torch.tensor([0, 9]), + lengths=torch.tensor([2]), + ) + configs = [ + EmbeddingBagConfig( + num_embeddings=10, + embedding_dim=8, + name="table1", + feature_names=["feature1"], + ) + ] + + # Execute: Validate KJT with configs + result = validate_keyed_jagged_tensor(kjt, configs) + + # Assert: Boundary values should be valid + self.assertTrue(result) + + def test_feature_range_negative_value_returns_false(self) -> None: + # Setup: Create KJT with negative value + kjt = KeyedJaggedTensor( + keys=["feature1"], + values=torch.tensor([-1, 5, 10]), + lengths=torch.tensor([3]), + ) + configs = [ + EmbeddingBagConfig( + num_embeddings=20, + embedding_dim=8, + name="table1", + feature_names=["feature1"], + ) + ] + + # Execute: Validate KJT with configs + result = validate_keyed_jagged_tensor(kjt, configs) + + # Assert: Validation should return False for out of range values + self.assertFalse(result) + + def test_feature_range_value_exceeds_num_embeddings_returns_false( + self, + ) -> None: + # Setup: Create KJT with value >= num_embeddings + kjt = KeyedJaggedTensor( + keys=["feature1"], + values=torch.tensor([0, 5, 20]), + lengths=torch.tensor([3]), + ) + configs = [ + EmbeddingBagConfig( + num_embeddings=20, + embedding_dim=8, + name="table1", + feature_names=["feature1"], + ) + ] + + # Execute: Validate KJT with configs + result = validate_keyed_jagged_tensor(kjt, configs) + + # Assert: Validation should return False for out of range values + self.assertFalse(result) + + def test_feature_range_feature_not_in_config_returns_true(self) -> None: + # Setup: Create KJT with feature not in any config + kjt = KeyedJaggedTensor( + keys=["feature1", "feature2"], + values=torch.tensor([0, 5, 1, 2]), + lengths=torch.tensor([2, 2]), + ) + configs = [ + EmbeddingBagConfig( + num_embeddings=20, + embedding_dim=8, + name="table1", + feature_names=["feature1"], + ) + ] + + # Execute: Validate KJT with configs + result = validate_keyed_jagged_tensor(kjt, configs) + + # Assert: Should return True since feature2 is just not in config (not invalid) + self.assertTrue(result) + + def test_feature_range_multiple_tables_with_different_ranges(self) -> None: + # Setup: Create KJT with features from different tables with different num_embeddings + kjt = KeyedJaggedTensor( + keys=["feature1", "feature2"], + values=torch.tensor([0, 5, 10, 15, 25]), + lengths=torch.tensor([3, 2]), + ) + configs = [ + EmbeddingBagConfig( + num_embeddings=15, + embedding_dim=8, + name="table1", + feature_names=["feature1"], + ), + EmbeddingBagConfig( + num_embeddings=30, + embedding_dim=16, + name="table2", + feature_names=["feature2"], + ), + ] + + # Execute: Validate KJT with configs + result = validate_keyed_jagged_tensor(kjt, configs) + + # Assert: feature1 has max 10 < 15, feature2 has max 25 < 30 - all valid + self.assertTrue(result) + + def test_feature_range_multiple_features_one_out_of_range(self) -> None: + # Setup: Create KJT with one feature in range and one out of range + kjt = KeyedJaggedTensor( + keys=["feature1", "feature2"], + values=torch.tensor([0, 5, 10, 15, 25]), + lengths=torch.tensor([3, 2]), + ) + configs = [ + EmbeddingBagConfig( + num_embeddings=15, + embedding_dim=8, + name="table1", + feature_names=["feature1"], + ), + EmbeddingBagConfig( + num_embeddings=20, + embedding_dim=16, + name="table2", + feature_names=["feature2"], + ), + ] + + # Execute: Validate KJT with configs + result = validate_keyed_jagged_tensor(kjt, configs) + + # Assert: Should return False since feature2 has value 25 >= 20 + self.assertFalse(result) + + def test_feature_range_empty_configs(self) -> None: + # Setup: Create KJT with no configs provided + kjt = KeyedJaggedTensor( + keys=["feature1"], + values=torch.tensor([0, 5, 10]), + lengths=torch.tensor([3]), + ) + configs = [] + + # Execute: Validate KJT with empty configs + result = validate_keyed_jagged_tensor(kjt, configs) + + # Assert: Should return True since no configs means no range validation + self.assertTrue(result)