Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,7 @@ def input_dist(
"pytorch/torchrec:enable_kjt_validation"
):
logger.info("Validating input features...")
validate_keyed_jagged_tensor(features)
validate_keyed_jagged_tensor(features, self._embedding_bag_configs)

self._create_input_dist(features.keys())
self._has_uninitialized_input_dist = False
Expand Down
6 changes: 5 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,11 +2825,15 @@ def to_dict(self, compute_offsets: bool = True) -> Dict[str, JaggedTensor]:
logger.warning(
"Trying to non-strict torch.export KJT to_dict, which is extremely slow and not recommended!"
)
length_per_key = self.length_per_key()
if isinstance(length_per_key, torch.Tensor):
# length_per_key should be a list of ints, but in some (incorrect) cases it is a tensor
length_per_key = length_per_key.tolist()
_jt_dict = _maybe_compute_kjt_to_jt_dict(
stride=self.stride(),
stride_per_key=self.stride_per_key(),
keys=self.keys(),
length_per_key=self.length_per_key(),
length_per_key=length_per_key,
lengths=self.lengths(),
values=self.values(),
variable_stride_per_key=self.variable_stride_per_key(),
Expand Down
138 changes: 122 additions & 16 deletions torchrec/sparse/jagged_tensor_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,83 @@

# pyre-strict

import logging
from typing import Dict, List, Optional

import torch
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

logger: logging.Logger = logging.getLogger(__name__)


def validate_keyed_jagged_tensor(
kjt: KeyedJaggedTensor,
) -> None:
kjt: KeyedJaggedTensor, configs: Optional[List[EmbeddingBagConfig]] = None
) -> bool:
"""
Validates the inputs that construct a KeyedJaggedTensor.

Any invalid input will result in a ValueError being thrown.

Returns:
bool: True if all validations pass (including feature range),
False if feature range validation fails (soft warning).
"""
_validate_lengths_and_offsets(kjt)
_validate_keys(kjt)
_validate_weights(kjt)
valid = True
valid = valid and _validate_lengths_and_offsets(kjt)
valid = valid and _validate_keys(kjt)
valid = valid and _validate_weights(kjt)
if configs is not None:
valid = valid and _validate_feature_range(kjt, configs)
return valid


def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> bool:
"""
Validates the lengths and offsets of a KJT.

- At least one of lengths or offsets is provided
- If both are provided, they are consistent with each other
- The dimensions of these tensors align with the values tensor

Returns:
bool: True if validation passes.

Raises:
ValueError: If validation fails.
"""
lengths = kjt.lengths_or_none()
offsets = kjt.offsets_or_none()
valid = True
if lengths is None and offsets is None:
raise ValueError(
"lengths and offsets cannot be both empty in KeyedJaggedTensor"
)
elif lengths is not None and offsets is not None:
_validate_lengths_and_offsets_consistency(lengths, offsets, kjt.values())
valid = valid and _validate_lengths_and_offsets_consistency(
lengths, offsets, kjt.values()
)
elif lengths is not None:
_validate_lengths(lengths, kjt.values())
valid = valid and _validate_lengths(lengths, kjt.values())
elif offsets is not None:
_validate_offsets(offsets, kjt.values())
valid = valid and _validate_offsets(offsets, kjt.values())
return valid


def _validate_lengths_and_offsets_consistency(
lengths: torch.Tensor, offsets: torch.Tensor, values: torch.Tensor
) -> None:
_validate_lengths(lengths, values)
_validate_offsets(offsets, values)
) -> bool:
"""
Validates consistency between lengths and offsets.

Returns:
bool: True if validation passes.

Raises:
ValueError: If validation fails.
"""
valid = _validate_lengths(lengths, values)
valid = valid and _validate_offsets(offsets, values)

if lengths.numel() != offsets.numel() - 1:
raise ValueError(
Expand All @@ -59,16 +92,36 @@ def _validate_lengths_and_offsets_consistency(

if not lengths.equal(torch.diff(offsets)):
raise ValueError("offsets is not equal to the cumulative sum of lengths")
return valid


def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> None:
def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> bool:
"""
Validates lengths tensor.

Returns:
bool: True if validation passes.

Raises:
ValueError: If validation fails.
"""
if lengths.sum().item() != values.numel():
raise ValueError(
f"Sum of lengths must equal the number of values, but got {lengths.sum().item()} and {values.numel()}"
)
return True


def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> bool:
"""
Validates offsets tensor.

Returns:
bool: True if validation passes.

Raises:
ValueError: If validation fails.
"""
if offsets.numel() == 0:
raise ValueError("offsets cannot be empty")

Expand All @@ -79,14 +132,21 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
raise ValueError(
f"The last element of offsets must equal to the number of values, but got {offsets[-1]} and {values.numel()}"
)
return True


def _validate_keys(kjt: KeyedJaggedTensor) -> None:
def _validate_keys(kjt: KeyedJaggedTensor) -> bool:
"""
Validates KJT keys, assuming the lengths/offsets input are valid.

- keys must be unique
- For non-VBE cases, the size of lengths is divisible by the number of keys

Returns:
bool: True if validation passes.

Raises:
ValueError: If validation fails.
"""
keys = kjt.keys()

Expand All @@ -110,14 +170,60 @@ def _validate_keys(kjt: KeyedJaggedTensor) -> None:
raise ValueError(
f"lengths size must be divisible by keys size, but got {lengths_size} and {len(keys)}"
)
return True


def _validate_weights(kjt: KeyedJaggedTensor) -> None:
def _validate_weights(kjt: KeyedJaggedTensor) -> bool:
"""
Validates if the KJT weights has the same size as values.

Returns:
bool: True if validation passes.

Raises:
ValueError: If validation fails.
"""
weights = kjt.weights_or_none()
if weights is not None and weights.numel() != kjt.values().numel():
raise ValueError(
f"weights size must equal to values size, but got {weights.numel()} and {kjt.values().numel()}"
)
return True


def _validate_feature_range(
kjt: KeyedJaggedTensor, configs: List[EmbeddingBagConfig]
) -> bool:
"""
Validates if the KJT feature range is valid.

Returns:
bool: True if all features are within valid range, False otherwise.
"""
feature_to_range_map: Dict[str, int] = {}
for config in configs:
for feature in config.feature_names:
feature_to_range_map[feature] = config.num_embeddings

if len(kjt.keys() & feature_to_range_map.keys()) == 0:
logger.info(
f"None of KJT._keys {kjt.keys()} in the config {feature_to_range_map.keys()}"
)
return True

valid = True
jtd = kjt.to_dict()
for feature, jt in jtd.items():
if feature not in feature_to_range_map:
logger.info(f"Feature {feature} is not in the config")
continue
if jt.values().numel() == 0:
continue
Min, Max = jt.values().min(), jt.values().max()
if Min < 0 or Max >= feature_to_range_map[feature]:
logger.warning(
f"Feature {feature} has range {Min, Max} "
f"which is out of range {0, feature_to_range_map[feature]}"
)
valid = False
return valid
Loading
Loading