@@ -248,12 +248,26 @@ def test_sharding_ebc_input_validation_enabled(self, mock_jk: Mock) -> None:
248248 values = torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
249249 lengths = torch .tensor ([1 , 2 , 0 , 2 ]),
250250 offsets = torch .tensor ([0 , 1 , 3 , 3 , 5 ]),
251- )
251+ ). to ( self . device )
252252 mock_jk .return_value = True
253253
254254 with self .assertRaisesRegex (ValueError , "keys must be unique" ):
255255 model (kjt )
256- mock_jk .assert_called_once_with ("pytorch/torchrec:enable_kjt_validation" )
256+
257+ # Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
258+ # This ignores any other calls to justknobs_check() with other inputs
259+ # and protects the test from breaking when new JK checks are added.
260+ validation_calls = [
261+ call
262+ for call in mock_jk .call_args_list
263+ if len (call [0 ]) > 0
264+ and call [0 ][0 ] == "pytorch/torchrec:enable_kjt_validation"
265+ ]
266+ self .assertEqual (
267+ 1 ,
268+ len (validation_calls ),
269+ "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation" ,
270+ )
257271
258272 @patch ("torch._utils_internal.justknobs_check" )
259273 def test_sharding_ebc_validate_input_only_once (self , mock_jk : Mock ) -> None :
@@ -271,7 +285,20 @@ def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
271285 model (kjt )
272286 model (kjt )
273287
274- mock_jk .assert_called_once_with ("pytorch/torchrec:enable_kjt_validation" )
288+ # Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
289+ # This ignores any other calls to justknobs_check() with other inputs
290+ # and protects the test from breaking when new JK checks are added.
291+ validation_calls = [
292+ call
293+ for call in mock_jk .call_args_list
294+ if len (call [0 ]) > 0
295+ and call [0 ][0 ] == "pytorch/torchrec:enable_kjt_validation"
296+ ]
297+ self .assertEqual (
298+ 1 ,
299+ len (validation_calls ),
300+ "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation" ,
301+ )
275302 matched_logs = list (
276303 filter (lambda s : "Validating input features..." in s , logs .output )
277304 )
@@ -294,7 +321,20 @@ def test_sharding_ebc_input_validation_disabled(self, mock_jk: Mock) -> None:
294321 except ValueError :
295322 self .fail ("Input validation should not be enabled." )
296323
297- mock_jk .assert_called_once_with ("pytorch/torchrec:enable_kjt_validation" )
324+ # Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
325+ # This ignores any other calls to justknobs_check() with other inputs
326+ # and protects the test from breaking when new JK checks are added.
327+ validation_calls = [
328+ call
329+ for call in mock_jk .call_args_list
330+ if len (call [0 ]) > 0
331+ and call [0 ][0 ] == "pytorch/torchrec:enable_kjt_validation"
332+ ]
333+ self .assertEqual (
334+ 1 ,
335+ len (validation_calls ),
336+ "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation" ,
337+ )
298338
299339 def _create_sharded_model (
300340 self , embedding_dim : int = 128 , num_embeddings : int = 256
0 commit comments