Skip to content

Commit 55e76e9

Browse files
author
sanising
committed
Fix bug
Signed-off-by: sanising <sanising@qti.qualcomm.com>
1 parent a24a55d commit 55e76e9

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

examples/performance/on_device_sampling.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def main(args, **kwargs):
6868
# Ideally this should come from a logits processor like xgrammar, but for the sake of the
6969
# example, we generate a random bitmask
7070
sampling_params.update(
71-
{"token_bitmasks": np.random.choice([True, False], size=(bs, qeff_model.model.config.vocab_size))}
71+
{
72+
"token_bitmasks": np.tile(
73+
np.random.choice([True, False], size=(qeff_model.model.config.vocab_size,)), (bs, 1)
74+
)
75+
}
7276
)
7377
print("sampling_params:")
7478
pprint(sampling_params)

tests/transformers/sampler/test_sampler.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -615,16 +615,14 @@ def test_guided_decoding(
615615
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model)
616616
np.random.seed(0)
617617
sampling_params = {
618-
"repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
619-
"presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
620-
# "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
621-
"temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
618+
"repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
619+
"presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
620+
# "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
621+
"temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
622622
"top_ks": np.array(1024, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
623-
"top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
624-
"min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
625-
"random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=1024), (full_batch_size, 1)).astype(
626-
np.float32
627-
),
623+
"top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
624+
"min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
625+
"random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32),
628626
}
629627
model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate(
630628
tokenizer=tokenizer,
@@ -636,8 +634,9 @@ def test_guided_decoding(
636634
sampling_params={
637635
**sampling_params,
638636
**{
639-
"token_bitmasks": np.random.choice(
640-
[True, False], size=(full_batch_size, model_w_sampler_w_guided_decoding.model.config.vocab_size)
637+
"token_bitmasks": np.tile(
638+
np.random.choice([True, False], size=(model_w_sampler_w_guided_decoding.model.config.vocab_size,)),
639+
(full_batch_size, 1),
641640
)
642641
},
643642
},
@@ -653,4 +652,4 @@ def test_guided_decoding(
653652
assert (
654653
model_w_sampler_w_guided_decoding_exec_info.generated_ids
655654
!= model_w_sampler_wo_guided_decoding_exec_info.generated_ids
656-
), "Sampler outputs with and without guided decoding should not match"
655+
).any(), "Sampler outputs with and without guided decoding should not match"

0 commit comments

Comments
 (0)