@@ -211,7 +211,7 @@ def test_greedy_sampling(
211211 "top_ks" : np .array (512 , dtype = np .int32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
212212 "top_ps" : np .array (1.0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
213213 "min_ps" : np .array (0.0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
214- "random_numbers" : np .array ( 0.0 , dtype = np .float32 ). repeat ( full_batch_size ). reshape ( - 1 , 1 ),
214+ "random_numbers" : np .zeros (( full_batch_size , 512 ), dtype = np .float32 ),
215215 },
216216 )
217217 model_wo_sampler_exec_info = model_wo_sampler .generate (
@@ -233,7 +233,6 @@ def test_greedy_sampling(
233233
234234
235235@pytest .mark .on_qaic
236- @pytest .mark .skip
237236@pytest .mark .parametrize (
238237 "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length" ,
239238 random_sampling_configs ,
@@ -291,6 +290,7 @@ def test_random_sampling(
291290
292291 # Generate texts from prompts
293292 tokenizer = load_hf_tokenizer (pretrained_model_name_or_path = model )
293+ np .random .seed (0 )
294294 model_w_sampler_exec_info = model_w_sampler .generate (
295295 tokenizer = tokenizer ,
296296 prompts = prompts ,
@@ -301,11 +301,13 @@ def test_random_sampling(
301301 "repetition_penalties" : np .array (20.2 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
302302 "presence_penalties" : np .array (10.5 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
303303 # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
304- "temperatures" : np .array (100.1 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
305- "top_ks" : np .array (54720 , dtype = np .int32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
304+ "temperatures" : np .array (4.0 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
305+ "top_ks" : np .array (512 , dtype = np .int32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
306306 "top_ps" : np .array (0.89 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
307307 "min_ps" : np .array (0.6 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
308- "random_numbers" : np .array (0.26 , dtype = np .float32 ).repeat (full_batch_size ).reshape (- 1 , 1 ),
308+ "random_numbers" : np .tile (np .random .uniform (low = 0.0 , high = 1.0 , size = 512 ), (full_batch_size , 1 )).astype (
309+ np .float32
310+ ),
309311 },
310312 )
311313 model_wo_sampler_exec_info = model_wo_sampler .generate (
@@ -319,32 +321,32 @@ def test_random_sampling(
319321
320322 # Compare generated texts
321323 golden_texts = {
322- "w_sampler" : "Raymond and my favorite color, alongside reds or purples (I can’t have them both " ,
324+ "w_sampler" : "Aiden and I am a freelance writer who loves to explore the world. With over " ,
323325 "wo_sampler" : "John Smith and I am a software engineer. I have been working in the industry for the past " ,
324326 }
325327 golden_ids = {
326328 "w_sampler" : [
327329 [
328- 21380 ,
330+ 319 ,
331+ 3615 ,
329332 322 ,
330- 590 ,
331- 25448 ,
332- 2927 ,
333- 29892 ,
334- 19963 ,
335- 2654 ,
336- 29879 ,
337- 470 ,
338- 3708 ,
339- 2701 ,
340- 313 ,
341- 29902 ,
342- 508 ,
343- 30010 ,
344- 29873 ,
345- 505 ,
346- 963 ,
347- 1716 ,
333+ 306 ,
334+ 626 ,
335+ 263 ,
336+ 3005 ,
337+ 295 ,
338+ 749 ,
339+ 9227 ,
340+ 1058 ,
341+ 12355 ,
342+ 267 ,
343+ 304 ,
344+ 26987 ,
345+ 278 ,
346+ 3186 ,
347+ 29889 ,
348+ 2973 ,
349+ 975 ,
348350 ]
349351 ],
350352 "wo_sampler" : [
0 commit comments