Skip to content

Commit 30d6061

Browse files
author
sanising
committed
Update tests with new random sampling logic
Signed-off-by: sanising <sanising@qti.qualcomm.com>
1 parent 1a01d57 commit 30d6061

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

tests/transformers/sampler/test_sampler.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_greedy_sampling(
211211
"top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
212212
"top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
213213
"min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
214-
"random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
214+
"random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32),
215215
},
216216
)
217217
model_wo_sampler_exec_info = model_wo_sampler.generate(
@@ -233,7 +233,6 @@ def test_greedy_sampling(
233233

234234

235235
@pytest.mark.on_qaic
236-
@pytest.mark.skip
237236
@pytest.mark.parametrize(
238237
"model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length",
239238
random_sampling_configs,
@@ -291,6 +290,7 @@ def test_random_sampling(
291290

292291
# Generate texts from prompts
293292
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model)
293+
np.random.seed(0)
294294
model_w_sampler_exec_info = model_w_sampler.generate(
295295
tokenizer=tokenizer,
296296
prompts=prompts,
@@ -301,11 +301,13 @@ def test_random_sampling(
301301
"repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
302302
"presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
303303
# "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
304-
"temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
305-
"top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
304+
"temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
305+
"top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
306306
"top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
307307
"min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
308-
"random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
308+
"random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype(
309+
np.float32
310+
),
309311
},
310312
)
311313
model_wo_sampler_exec_info = model_wo_sampler.generate(
@@ -319,32 +321,32 @@ def test_random_sampling(
319321

320322
# Compare generated texts
321323
golden_texts = {
322-
"w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both",
324+
"w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over",
323325
"wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ",
324326
}
325327
golden_ids = {
326328
"w_sampler": [
327329
[
328-
21380,
330+
319,
331+
3615,
329332
322,
330-
590,
331-
25448,
332-
2927,
333-
29892,
334-
19963,
335-
2654,
336-
29879,
337-
470,
338-
3708,
339-
2701,
340-
313,
341-
29902,
342-
508,
343-
30010,
344-
29873,
345-
505,
346-
963,
347-
1716,
333+
306,
334+
626,
335+
263,
336+
3005,
337+
295,
338+
749,
339+
9227,
340+
1058,
341+
12355,
342+
267,
343+
304,
344+
26987,
345+
278,
346+
3186,
347+
29889,
348+
2973,
349+
975,
348350
]
349351
],
350352
"wo_sampler": [

0 commit comments

Comments
 (0)