Skip to content

Commit b94c224

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add feature range validation for ebc
Summary: Enable validation of KeyedJaggedTensor feature values against EmbeddingBagConfig ranges to catch out-of-range embedding lookups early. Modified all validation functions to return boolean values, allowing callers to programmatically distinguish between hard failures (structural errors that raise ValueError) and soft failures (out-of-range values that return False with warnings). This supports two use cases: 1. Production monitoring - detect invalid embedding IDs without crashing 2. Data quality checks - identify features with values outside [0, num_embeddings) All validation functions now return bool for consistency, maintaining full backward compatibility since existing code can continue to ignore return values. Differential Revision: D88013492
1 parent b451635 commit b94c224

File tree

3 files changed

+307
-18
lines changed

3 files changed

+307
-18
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,7 @@ def input_dist(
15261526
"pytorch/torchrec:enable_kjt_validation"
15271527
):
15281528
logger.info("Validating input features...")
1529-
validate_keyed_jagged_tensor(features)
1529+
validate_keyed_jagged_tensor(features, self._embedding_bag_configs)
15301530

15311531
self._create_input_dist(features.keys())
15321532
self._has_uninitialized_input_dist = False

torchrec/sparse/jagged_tensor_validator.py

Lines changed: 120 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,50 +7,83 @@
77

88
# pyre-strict
99

10+
import logging
11+
from typing import Dict, List, Optional
12+
1013
import torch
14+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
1115
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1216

17+
logger: logging.Logger = logging.getLogger(__name__)
18+
1319

1420
def validate_keyed_jagged_tensor(
15-
kjt: KeyedJaggedTensor,
16-
) -> None:
21+
kjt: KeyedJaggedTensor, configs: Optional[List[EmbeddingBagConfig]] = None
22+
) -> bool:
1723
"""
1824
Validates the inputs that construct a KeyedJaggedTensor.
1925
2026
Any invalid input will result in a ValueError being thrown.
27+
28+
Returns:
29+
bool: True if all validations pass (including feature range),
30+
False if feature range validation fails (soft warning).
2131
"""
22-
_validate_lengths_and_offsets(kjt)
23-
_validate_keys(kjt)
24-
_validate_weights(kjt)
32+
valid = True
33+
valid = valid and _validate_lengths_and_offsets(kjt)
34+
valid = valid and _validate_keys(kjt)
35+
valid = valid and _validate_weights(kjt)
36+
if configs is not None:
37+
valid = valid and _validate_feature_range(kjt, configs)
38+
return valid
2539

2640

27-
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
41+
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> bool:
2842
"""
2943
Validates the lengths and offsets of a KJT.
3044
3145
- At least one of lengths or offsets is provided
3246
- If both are provided, they are consistent with each other
3347
- The dimensions of these tensors align with the values tensor
48+
49+
Returns:
50+
bool: True if validation passes.
51+
52+
Raises:
53+
ValueError: If validation fails.
3454
"""
3555
lengths = kjt.lengths_or_none()
3656
offsets = kjt.offsets_or_none()
57+
valid = True
3758
if lengths is None and offsets is None:
3859
raise ValueError(
3960
"lengths and offsets cannot be both empty in KeyedJaggedTensor"
4061
)
4162
elif lengths is not None and offsets is not None:
42-
_validate_lengths_and_offsets_consistency(lengths, offsets, kjt.values())
63+
valid = valid and _validate_lengths_and_offsets_consistency(
64+
lengths, offsets, kjt.values()
65+
)
4366
elif lengths is not None:
44-
_validate_lengths(lengths, kjt.values())
67+
valid = valid and _validate_lengths(lengths, kjt.values())
4568
elif offsets is not None:
46-
_validate_offsets(offsets, kjt.values())
69+
valid = valid and _validate_offsets(offsets, kjt.values())
70+
return valid
4771

4872

4973
def _validate_lengths_and_offsets_consistency(
5074
lengths: torch.Tensor, offsets: torch.Tensor, values: torch.Tensor
51-
) -> None:
52-
_validate_lengths(lengths, values)
53-
_validate_offsets(offsets, values)
75+
) -> bool:
76+
"""
77+
Validates consistency between lengths and offsets.
78+
79+
Returns:
80+
bool: True if validation passes.
81+
82+
Raises:
83+
ValueError: If validation fails.
84+
"""
85+
valid = _validate_lengths(lengths, values)
86+
valid = valid and _validate_offsets(offsets, values)
5487

5588
if lengths.numel() != offsets.numel() - 1:
5689
raise ValueError(
@@ -59,16 +92,36 @@ def _validate_lengths_and_offsets_consistency(
5992

6093
if not lengths.equal(torch.diff(offsets)):
6194
raise ValueError("offsets is not equal to the cumulative sum of lengths")
95+
return valid
6296

6397

64-
def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> None:
98+
def _validate_lengths(lengths: torch.Tensor, values: torch.Tensor) -> bool:
99+
"""
100+
Validates lengths tensor.
101+
102+
Returns:
103+
bool: True if validation passes.
104+
105+
Raises:
106+
ValueError: If validation fails.
107+
"""
65108
if lengths.sum().item() != values.numel():
66109
raise ValueError(
67110
f"Sum of lengths must equal the number of values, but got {lengths.sum().item()} and {values.numel()}"
68111
)
112+
return True
69113

70114

71-
def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
115+
def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> bool:
116+
"""
117+
Validates offsets tensor.
118+
119+
Returns:
120+
bool: True if validation passes.
121+
122+
Raises:
123+
ValueError: If validation fails.
124+
"""
72125
if offsets.numel() == 0:
73126
raise ValueError("offsets cannot be empty")
74127

@@ -79,14 +132,21 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
79132
raise ValueError(
80133
f"The last element of offsets must equal to the number of values, but got {offsets[-1]} and {values.numel()}"
81134
)
135+
return True
82136

83137

84-
def _validate_keys(kjt: KeyedJaggedTensor) -> None:
138+
def _validate_keys(kjt: KeyedJaggedTensor) -> bool:
85139
"""
86140
Validates KJT keys, assuming the lengths/offsets input are valid.
87141
88142
- keys must be unique
89143
- For non-VBE cases, the size of lengths is divisible by the number of keys
144+
145+
Returns:
146+
bool: True if validation passes.
147+
148+
Raises:
149+
ValueError: If validation fails.
90150
"""
91151
keys = kjt.keys()
92152

@@ -110,14 +170,58 @@ def _validate_keys(kjt: KeyedJaggedTensor) -> None:
110170
raise ValueError(
111171
f"lengths size must be divisible by keys size, but got {lengths_size} and {len(keys)}"
112172
)
173+
return True
113174

114175

115-
def _validate_weights(kjt: KeyedJaggedTensor) -> None:
176+
def _validate_weights(kjt: KeyedJaggedTensor) -> bool:
116177
"""
117178
Validates if the KJT weights has the same size as values.
179+
180+
Returns:
181+
bool: True if validation passes.
182+
183+
Raises:
184+
ValueError: If validation fails.
118185
"""
119186
weights = kjt.weights_or_none()
120187
if weights is not None and weights.numel() != kjt.values().numel():
121188
raise ValueError(
122189
f"weights size must equal to values size, but got {weights.numel()} and {kjt.values().numel()}"
123190
)
191+
return True
192+
193+
194+
def _validate_feature_range(
195+
kjt: KeyedJaggedTensor, configs: List[EmbeddingBagConfig]
196+
) -> bool:
197+
"""
198+
Validates if the KJT feature range is valid.
199+
200+
Returns:
201+
bool: True if all features are within valid range, False otherwise.
202+
"""
203+
feature_to_range_map: Dict[str, int] = {}
204+
for config in configs:
205+
for feature in config.feature_names:
206+
feature_to_range_map[feature] = config.num_embeddings
207+
208+
if len(kjt.keys() & feature_to_range_map.keys()) == 0:
209+
logger.info(
210+
f"None of KJT._keys {kjt.keys()} in the config {feature_to_range_map.keys()}"
211+
)
212+
return True
213+
214+
valid = True
215+
jtd = kjt.to_dict()
216+
for feature, jt in jtd.items():
217+
if feature not in feature_to_range_map:
218+
logger.info(f"Feature {feature} is not in the config")
219+
continue
220+
Min, Max = jt.values().min(), jt.values().max()
221+
if Min < 0 or Max >= feature_to_range_map[feature]:
222+
logger.warning(
223+
f"Feature {feature} has range {Min, Max} "
224+
f"which is out of range {0, feature_to_range_map[feature]}"
225+
)
226+
valid = False
227+
return valid

0 commit comments

Comments
 (0)