Skip to content

Unified stochastic kernel#259

Draft
habibutsu wants to merge 1 commit intotrymirai:mainfrom
habibutsu:unified_kernel
Draft

Unified stochastic kernel#259
habibutsu wants to merge 1 commit intotrymirai:mainfrom
habibutsu:unified_kernel

Conversation

@habibutsu
Copy link
Copy Markdown
Contributor

It is a draft of kernel.
For checking performance the following command can be used:

cargo test -p uzu --test kernel perf_batch -- --nocapture

on my laptop I have following results:

test sampling::sampling_perf_test::perf_batch1_128k_vocab ... [uzu::backends::metal::backend::Metal] batch=1 vocab=128000
  sequential: mean=0.312ms  min=0.298ms
  unified:    mean=0.137ms  min=0.134ms
  speedup:      2.23x
ok
test sampling::sampling_perf_test::perf_batch64_128k_vocab ... [uzu::backends::metal::backend::Metal] batch=64 vocab=128000
  sequential: mean=8.402ms  min=8.271ms
  unified:    mean=1.921ms  min=1.736ms
  speedup:      4.77x
ok

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5cb70631a4

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

}

// Precompute top_p threshold in unnorm space to avoid per-round division.
const float top_p_mass = top_p * sum_exp;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Compute top-p mass from top-k-filtered logits

This computes top_p_mass from sum_exp before any top-k filtering, but the existing stochastic path applies TopK before TopP (crates/uzu/src/backends/common/kernel/sampling.rs calls self.topk.encode before self.topp.encode). When both parameters are enabled (including common defaults like top_k=20 with top_p=0.95), unified sampling checks top-p against full-vocab mass instead of the top-k subset, which makes top-p substantially weaker and changes sampling behavior relative to the current production path.

Useful? React with 👍 / 👎.

Comment on lines +13 to +20
/// Unified single-pass path: all filtering and Gumbel-max in one kernel dispatch,
/// operating on logits loaded into private registers.
UnifiedStochastic {
temperature: Option<f32>,
top_k: Option<u32>,
top_p: Option<f32>,
min_p: Option<f32>,
},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be a separate sampling policy. Unified stochastic is an implementation detail, not a different sampling policy.

Comment on lines +179 to +188
if let Some(bitmask_buffer) = bitmask_buffer {
self.bitmask.encode(
None::<&B::Buffer>,
(bitmask_buffer, bitmask_offset),
logits_buffer.deref_mut(),
batch_size as u32,
vocab_size as u32,
command_buffer,
);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be fused

Comment on lines 41 to 48
bitmask: <B::Kernels as Kernels>::BitmaskKernel,
temperature: <B::Kernels as Kernels>::TemperatureKernel,
topk: <B::Kernels as Kernels>::TopKKernel,
topp: <B::Kernels as Kernels>::TopPKernel,
minp: <B::Kernels as Kernels>::MinPKernel,
gumbel: <B::Kernels as Kernels>::GumbelKernel,
unified: <B::Kernels as Kernels>::UnifiedStochasticKernel,
argmax_implementation: ArgmaxImplementation<B>,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any cases where separate kernels are faster? If no, unified stochastic should replace all of the old kernels, not just be a new option

Comment on lines +29 to +31
/// Sampling method (default: stochastic with model's generation config)
#[arg(long)]
sampler: Option<SamplerArg>,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation details (unified/separate sampling kernels) shouldn't be exposed in the cli

// with logit-space pivots. Mirrors the Metal kernel logic.
#[kernel(UnifiedStochastic)]
#[variants(T, f32, f16, bf16)]
pub fn unified_stochastic<T: ArrayElement + Float>(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cpu backend is meant as a reference, it shouldn't be doing anything fancy like unified stochastic. I think there should be a single SamplingKernel trait that is implemented in the most straightforward textbook way possible on cpu and with either unified metal kernel directly or with multiple private metal kernels on metal.

Comment on lines +63 to +69
// ── Unified stochastic sampling: temperature + top_k/p/min_p + sampling in one dispatch ──
//
// NOTE: No Gumbel noise, no argmax.
// Gumbel-max (add Gumbel noise to logits → argmax) is mathematically equivalent to
// inverse-transform sampling from the softmax distribution (draw u ~ U(0,1), find
// token at CDF position u). This kernel uses the latter: one uniform draw per round,
// located via a cooperative prefix-sum walk — no per-token noise, no full-vocab argmax.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use gumbel with shared seed between speculator and llm sampling for increased acceptance rate

Copy link
Copy Markdown
Contributor

@uuuvn uuuvn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidentally selected the previous review as approve where it should've been request changes, don't see a way to undo it

@uuuvn uuuvn marked this pull request as draft March 25, 2026 22:06
@@ -0,0 +1,151 @@
mod common;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you adding this file? crates/uzu/tests/integration/session/chat_session/context_mode_test.rs

@@ -0,0 +1,478 @@
mod common;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you adding this file? crates/uzu/tests/unit/encodable_block/sampling_test.rs

@@ -0,0 +1 @@
mod sampling_perf_test;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you adding a whole new directory for a single file?

min_p: f32,
#[specialize] has_bitmask: bool,
) {
let _ = min_p;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

} else {
top_p
};
let bitmask_stride = (vocab_size + 31) / 32;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

div_ceil

const MAX_TOP_K: u32 = 64;

pub struct SamplingKernel<B: Backend> {
bitmask: <B::Kernels as Kernels>::BitmaskKernel,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add bitmask to argmax instead of having a separate kernel only for argmax

Comment on lines +58 to +59
#[error("Stochastic: top_k={0} exceeds N_CANDIDATES={MAX_TOP_K}")]
TopKTooLarge(u32),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably still support top_k > 64 via some fallback.

Comment on lines -178 to +160
processing_order,
..
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're silently ignoring processing order

@uuuvn
Copy link
Copy Markdown
Contributor

uuuvn commented Apr 3, 2026

Please carefully read your diff before re-requesting review, there is a lot of things that are either obviously wrong or very dirty

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants