Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 50 additions & 92 deletions crates/uzu/src/backends/common/kernel/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@ use crate::{
backends::common::{
Backend, Context, Encoder, Kernels,
gpu_types::ArgmaxPair,
kernel::{
ArgmaxFinalKernel, ArgmaxMainKernel, ArgmaxSingleKernel, BitmaskKernel, GumbelKernel, MinPKernel,
TemperatureKernel, TopKKernel, TopPKernel,
},
kernel::{ArgmaxFinalKernel, ArgmaxMainKernel, ArgmaxSingleKernel, BitmaskKernel, StochasticKernel},
},
session::parameter::{SamplingMethod, SamplingProcessingOrder},
session::parameter::SamplingMethod,
};

#[derive(Debug, Clone, Copy, PartialEq)]
Expand All @@ -36,14 +33,14 @@ enum ArgmaxImplementation<B: Backend> {
},
}

// StochasticKernel retains N_CANDIDATES=64 candidates; top_k must not exceed this.
const MAX_TOP_K: u32 = 64;

pub struct SamplingKernel<B: Backend> {
bitmask: <B::Kernels as Kernels>::BitmaskKernel,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add bitmask to argmax instead of having a separate kernel only for argmax

temperature: <B::Kernels as Kernels>::TemperatureKernel,
topk: <B::Kernels as Kernels>::TopKKernel,
topp: <B::Kernels as Kernels>::TopPKernel,
minp: <B::Kernels as Kernels>::MinPKernel,
gumbel: <B::Kernels as Kernels>::GumbelKernel,
argmax_implementation: ArgmaxImplementation<B>,
Comment on lines 40 to 41
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any cases where separate kernels are faster? If no, unified stochastic should replace all of the old kernels, not just be a new option

stochastic: <B::Kernels as Kernels>::StochasticKernel,
stochastic_masked: <B::Kernels as Kernels>::StochasticKernel,
max_batch_size: usize,
max_vocab_size: usize,
}
Expand All @@ -58,6 +55,10 @@ pub enum SamplingError<B: Backend> {
BatchSizeExceeded(usize, usize),
#[error("Vocab size {0} exceeds maximum {1}")]
VocabSizeExceeded(usize, usize),
#[error("Stochastic: top_k={0} exceeds N_CANDIDATES={MAX_TOP_K}")]
TopKTooLarge(u32),
Comment on lines +58 to +59
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably still support top_k > 64 via some fallback.

#[error("Stochastic: top_p={0} is not in (0, 1]")]
TopPOutOfRange(f32),
}

impl<B: Backend> SamplingKernel<B> {
Expand All @@ -79,16 +80,6 @@ impl<B: Backend> SamplingKernel<B> {
) -> Result<Self, SamplingError<B>> {
let bitmask = <B::Kernels as Kernels>::BitmaskKernel::new(context, data_type, true)
.map_err(SamplingError::BackendError)?;
let temperature = <B::Kernels as Kernels>::TemperatureKernel::new(context, data_type, true)
.map_err(SamplingError::BackendError)?;
let topk =
<B::Kernels as Kernels>::TopKKernel::new(context, data_type, true).map_err(SamplingError::BackendError)?;
let topp =
<B::Kernels as Kernels>::TopPKernel::new(context, data_type, true).map_err(SamplingError::BackendError)?;
let minp =
<B::Kernels as Kernels>::MinPKernel::new(context, data_type, true).map_err(SamplingError::BackendError)?;
let gumbel = <B::Kernels as Kernels>::GumbelKernel::new(context, data_type, true)
.map_err(SamplingError::BackendError)?;

let argmax_implementation = match argmax_strategy {
ArgmaxStrategy::SinglePass => {
Expand Down Expand Up @@ -126,14 +117,16 @@ impl<B: Backend> SamplingKernel<B> {
},
};

let stochastic = <B::Kernels as Kernels>::StochasticKernel::new(context, data_type, false)
.map_err(SamplingError::BackendError)?;
let stochastic_masked = <B::Kernels as Kernels>::StochasticKernel::new(context, data_type, true)
.map_err(SamplingError::BackendError)?;

Ok(Self {
bitmask,
temperature,
topk,
topp,
minp,
gumbel,
argmax_implementation,
stochastic,
stochastic_masked,
max_batch_size,
max_vocab_size,
})
Expand All @@ -159,85 +152,50 @@ impl<B: Backend> SamplingKernel<B> {
return Err(SamplingError::VocabSizeExceeded(vocab_size, self.max_vocab_size));
}

if let Some(bitmask_buffer) = bitmask_buffer {
self.bitmask.encode(
None::<&B::Buffer>,
(bitmask_buffer, bitmask_offset),
logits_buffer.deref_mut(),
batch_size as u32,
vocab_size as u32,
encoder,
);
}

if let SamplingMethod::Stochastic {
temperature,
top_k,
top_p,
min_p,
processing_order,
..
Comment on lines -178 to +160
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're silently ignoring processing order

} = sampling_method
{
if let Some(temperature) = temperature
&& processing_order == SamplingProcessingOrder::TemperatureThenFilters
{
self.temperature.encode(
None::<&B::Buffer>,
logits_buffer.deref_mut(),
batch_size as u32,
vocab_size as u32,
temperature,
encoder,
);
let top_k_val = top_k.unwrap_or(0);
if top_k_val > MAX_TOP_K {
return Err(SamplingError::TopKTooLarge(top_k_val));
}

if let Some(top_k) = top_k {
self.topk.encode(
None::<&B::Buffer>,
logits_buffer.deref_mut(),
batch_size as u32,
vocab_size as u32,
top_k,
encoder,
);
}
if let Some(top_p) = top_p {
self.topp.encode(
None::<&B::Buffer>,
logits_buffer.deref_mut(),
batch_size as u32,
vocab_size as u32,
top_p,
encoder,
);
}
if let Some(min_p) = min_p {
self.minp.encode(
None::<&B::Buffer>,
logits_buffer.deref_mut(),
batch_size as u32,
vocab_size as u32,
min_p,
encoder,
);
if let Some(p) = top_p {
if p <= 0.0 || p > 1.0 {
return Err(SamplingError::TopPOutOfRange(p));
}
}

if let Some(temperature) = temperature
&& processing_order == SamplingProcessingOrder::FiltersThenTemperature
{
self.temperature.encode(
None::<&B::Buffer>,
logits_buffer.deref_mut(),
batch_size as u32,
vocab_size as u32,
temperature,
encoder,
);
}
let kernel = if bitmask_buffer.is_some() {
&self.stochastic_masked
} else {
&self.stochastic
};
kernel.encode(
logits_buffer.deref(),
(seeds_buffer, seeds_offset),
sampled_tokens_buffer,
bitmask_buffer.map(|b| (b, bitmask_offset)),
batch_size as u32,
vocab_size as u32,
temperature.unwrap_or(1.0),
top_k_val,
top_p.unwrap_or(1.0),
min_p.unwrap_or(0.0),
encoder,
);
return Ok(());
}

self.gumbel.encode(
// SamplingMethod::Greedy
if let Some(bitmask_buffer) = bitmask_buffer {
self.bitmask.encode(
None::<&B::Buffer>,
(seeds_buffer, seeds_offset),
(bitmask_buffer, bitmask_offset),
logits_buffer.deref_mut(),
batch_size as u32,
vocab_size as u32,
Expand Down
37 changes: 0 additions & 37 deletions crates/uzu/src/backends/cpu/kernel/sampling/gumbel.rs

This file was deleted.

46 changes: 0 additions & 46 deletions crates/uzu/src/backends/cpu/kernel/sampling/min_p.rs

This file was deleted.

6 changes: 1 addition & 5 deletions crates/uzu/src/backends/cpu/kernel/sampling/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
pub mod argmax;
pub mod bitmask;
pub mod gumbel;
pub mod min_p;
pub mod temperature;
pub mod top_k;
pub mod top_p;
pub mod stochastic;
Loading
Loading