@@ -615,16 +615,14 @@ def test_guided_decoding(
615615 tokenizer = load_hf_tokenizer (pretrained_model_name_or_path = model )
616616 np .random .seed (0 )
617617 sampling_params = {
618- "repetition_penalties" : np .array (20.2 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
619- "presence_penalties" : np .array (10.5 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
620- # "frequency_penalties": np.array(0.5 , dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
621- "temperatures" : np .array (4 .0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
618+ "repetition_penalties" : np .array (1.0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
619+ "presence_penalties" : np .array (0.0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
620+ # "frequency_penalties": np.array(0.0 , dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
621+ "temperatures" : np .array (0 .0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
622622 "top_ks" : np .array (1024 , dtype = np .int32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
623- "top_ps" : np .array (0.89 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
624- "min_ps" : np .array (0.6 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
625- "random_numbers" : np .tile (np .random .uniform (low = 0.0 , high = 1.0 , size = 1024 ), (full_batch_size , 1 )).astype (
626- np .float32
627- ),
623+ "top_ps" : np .array (1.0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
624+ "min_ps" : np .array (0.0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
625+ "random_numbers" : np .zeros ((full_batch_size , 1024 ), dtype = np .float32 ),
628626 }
629627 model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding .generate (
630628 tokenizer = tokenizer ,
@@ -636,8 +634,9 @@ def test_guided_decoding(
636634 sampling_params = {
637635 ** sampling_params ,
638636 ** {
639- "token_bitmasks" : np .random .choice (
640- [True , False ], size = (full_batch_size , model_w_sampler_w_guided_decoding .model .config .vocab_size )
637+ "token_bitmasks" : np .tile (
638+ np .random .choice ([True , False ], size = (model_w_sampler_w_guided_decoding .model .config .vocab_size ,)),
639+ (full_batch_size , 1 ),
641640 )
642641 },
643642 },
@@ -653,4 +652,4 @@ def test_guided_decoding(
653652 assert (
654653 model_w_sampler_w_guided_decoding_exec_info .generated_ids
655654 != model_w_sampler_wo_guided_decoding_exec_info .generated_ids
656- ), "Sampler outputs with and without guided decoding should not match"
655+ ). any () , "Sampler outputs with and without guided decoding should not match"
0 commit comments