77
88# pyre-strict
99
10+ import logging
11+ from typing import Dict , List , Optional
12+
1013import torch
14+ from torchrec .modules .embedding_configs import EmbeddingBagConfig
1115from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
1216
17+ logger : logging .Logger = logging .getLogger (__name__ )
18+
1319
1420def validate_keyed_jagged_tensor (
15- kjt : KeyedJaggedTensor ,
16- ) -> None :
21+ kjt : KeyedJaggedTensor , configs : Optional [ List [ EmbeddingBagConfig ]] = None
22+ ) -> bool :
1723 """
1824 Validates the inputs that construct a KeyedJaggedTensor.
1925
2026 Any invalid input will result in a ValueError being thrown.
27+
28+ Returns:
29+ bool: True if all validations pass (including feature range),
30+ False if feature range validation fails (soft warning).
2131 """
22- _validate_lengths_and_offsets (kjt )
23- _validate_keys (kjt )
24- _validate_weights (kjt )
32+ valid = True
33+ valid = valid and _validate_lengths_and_offsets (kjt )
34+ valid = valid and _validate_keys (kjt )
35+ valid = valid and _validate_weights (kjt )
36+ if configs is not None :
37+ valid = valid and _validate_feature_range (kjt , configs )
38+ return valid
2539
2640
27- def _validate_lengths_and_offsets (kjt : KeyedJaggedTensor ) -> None :
41+ def _validate_lengths_and_offsets (kjt : KeyedJaggedTensor ) -> bool :
2842 """
2943 Validates the lengths and offsets of a KJT.
3044
3145 - At least one of lengths or offsets is provided
3246 - If both are provided, they are consistent with each other
3347 - The dimensions of these tensors align with the values tensor
48+
49+ Returns:
50+ bool: True if validation passes.
51+
52+ Raises:
53+ ValueError: If validation fails.
3454 """
3555 lengths = kjt .lengths_or_none ()
3656 offsets = kjt .offsets_or_none ()
57+ valid = True
3758 if lengths is None and offsets is None :
3859 raise ValueError (
3960 "lengths and offsets cannot be both empty in KeyedJaggedTensor"
4061 )
4162 elif lengths is not None and offsets is not None :
42- _validate_lengths_and_offsets_consistency (lengths , offsets , kjt .values ())
63+ valid = valid and _validate_lengths_and_offsets_consistency (
64+ lengths , offsets , kjt .values ()
65+ )
4366 elif lengths is not None :
44- _validate_lengths (lengths , kjt .values ())
67+ valid = valid and _validate_lengths (lengths , kjt .values ())
4568 elif offsets is not None :
46- _validate_offsets (offsets , kjt .values ())
69+ valid = valid and _validate_offsets (offsets , kjt .values ())
70+ return valid
4771
4872
4973def _validate_lengths_and_offsets_consistency (
5074 lengths : torch .Tensor , offsets : torch .Tensor , values : torch .Tensor
51- ) -> None :
52- _validate_lengths (lengths , values )
53- _validate_offsets (offsets , values )
75+ ) -> bool :
76+ """
77+ Validates consistency between lengths and offsets.
78+
79+ Returns:
80+ bool: True if validation passes.
81+
82+ Raises:
83+ ValueError: If validation fails.
84+ """
85+ valid = _validate_lengths (lengths , values )
86+ valid = valid and _validate_offsets (offsets , values )
5487
5588 if lengths .numel () != offsets .numel () - 1 :
5689 raise ValueError (
@@ -59,16 +92,36 @@ def _validate_lengths_and_offsets_consistency(
5992
6093 if not lengths .equal (torch .diff (offsets )):
6194 raise ValueError ("offsets is not equal to the cumulative sum of lengths" )
95+ return valid
6296
6397
64- def _validate_lengths (lengths : torch .Tensor , values : torch .Tensor ) -> None :
98+ def _validate_lengths (lengths : torch .Tensor , values : torch .Tensor ) -> bool :
99+ """
100+ Validates lengths tensor.
101+
102+ Returns:
103+ bool: True if validation passes.
104+
105+ Raises:
106+ ValueError: If validation fails.
107+ """
65108 if lengths .sum ().item () != values .numel ():
66109 raise ValueError (
67110 f"Sum of lengths must equal the number of values, but got { lengths .sum ().item ()} and { values .numel ()} "
68111 )
112+ return True
69113
70114
71- def _validate_offsets (offsets : torch .Tensor , values : torch .Tensor ) -> None :
115+ def _validate_offsets (offsets : torch .Tensor , values : torch .Tensor ) -> bool :
116+ """
117+ Validates offsets tensor.
118+
119+ Returns:
120+ bool: True if validation passes.
121+
122+ Raises:
123+ ValueError: If validation fails.
124+ """
72125 if offsets .numel () == 0 :
73126 raise ValueError ("offsets cannot be empty" )
74127
@@ -79,14 +132,21 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
79132 raise ValueError (
80133 f"The last element of offsets must equal to the number of values, but got { offsets [- 1 ]} and { values .numel ()} "
81134 )
135+ return True
82136
83137
84- def _validate_keys (kjt : KeyedJaggedTensor ) -> None :
138+ def _validate_keys (kjt : KeyedJaggedTensor ) -> bool :
85139 """
86140 Validates KJT keys, assuming the lengths/offsets input are valid.
87141
88142 - keys must be unique
89143 - For non-VBE cases, the size of lengths is divisible by the number of keys
144+
145+ Returns:
146+ bool: True if validation passes.
147+
148+ Raises:
149+ ValueError: If validation fails.
90150 """
91151 keys = kjt .keys ()
92152
@@ -110,14 +170,58 @@ def _validate_keys(kjt: KeyedJaggedTensor) -> None:
110170 raise ValueError (
111171 f"lengths size must be divisible by keys size, but got { lengths_size } and { len (keys )} "
112172 )
173+ return True
113174
114175
115- def _validate_weights (kjt : KeyedJaggedTensor ) -> None :
176+ def _validate_weights (kjt : KeyedJaggedTensor ) -> bool :
116177 """
117178 Validates if the KJT weights has the same size as values.
179+
180+ Returns:
181+ bool: True if validation passes.
182+
183+ Raises:
184+ ValueError: If validation fails.
118185 """
119186 weights = kjt .weights_or_none ()
120187 if weights is not None and weights .numel () != kjt .values ().numel ():
121188 raise ValueError (
122189 f"weights size must equal to values size, but got { weights .numel ()} and { kjt .values ().numel ()} "
123190 )
191+ return True
192+
193+
194+ def _validate_feature_range (
195+ kjt : KeyedJaggedTensor , configs : List [EmbeddingBagConfig ]
196+ ) -> bool :
197+ """
198+ Validates if the KJT feature range is valid.
199+
200+ Returns:
201+ bool: True if all features are within valid range, False otherwise.
202+ """
203+ feature_to_range_map : Dict [str , int ] = {}
204+ for config in configs :
205+ for feature in config .feature_names :
206+ feature_to_range_map [feature ] = config .num_embeddings
207+
208+ if len (kjt .keys () & feature_to_range_map .keys ()) == 0 :
209+ logger .info (
210+ f"None of KJT._keys { kjt .keys ()} in the config { feature_to_range_map .keys ()} "
211+ )
212+ return True
213+
214+ valid = True
215+ jtd = kjt .to_dict ()
216+ for feature , jt in jtd .items ():
217+ if feature not in feature_to_range_map :
218+ logger .info (f"Feature { feature } is not in the config" )
219+ continue
220+ Min , Max = jt .values ().min (), jt .values ().max ()
221+ if Min < 0 or Max >= feature_to_range_map [feature ]:
222+ logger .warning (
223+ f"Feature { feature } has range { Min , Max } "
224+ f"which is out of range { 0 , feature_to_range_map [feature ]} "
225+ )
226+ valid = False
227+ return valid
0 commit comments