-
Notifications
You must be signed in to change notification settings - Fork 49
Unified stochastic kernel #259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,12 +11,9 @@ use crate::{ | |
| backends::common::{ | ||
| Backend, Context, Encoder, Kernels, | ||
| gpu_types::ArgmaxPair, | ||
| kernel::{ | ||
| ArgmaxFinalKernel, ArgmaxMainKernel, ArgmaxSingleKernel, BitmaskKernel, GumbelKernel, MinPKernel, | ||
| TemperatureKernel, TopKKernel, TopPKernel, | ||
| }, | ||
| kernel::{ArgmaxFinalKernel, ArgmaxMainKernel, ArgmaxSingleKernel, BitmaskKernel, StochasticKernel}, | ||
| }, | ||
| session::parameter::{SamplingMethod, SamplingProcessingOrder}, | ||
| session::parameter::SamplingMethod, | ||
| }; | ||
|
|
||
| #[derive(Debug, Clone, Copy, PartialEq)] | ||
|
|
@@ -36,14 +33,14 @@ enum ArgmaxImplementation<B: Backend> { | |
| }, | ||
| } | ||
|
|
||
| // StochasticKernel retains N_CANDIDATES=64 candidates; top_k must not exceed this. | ||
| const MAX_TOP_K: u32 = 64; | ||
|
|
||
| pub struct SamplingKernel<B: Backend> { | ||
| bitmask: <B::Kernels as Kernels>::BitmaskKernel, | ||
| temperature: <B::Kernels as Kernels>::TemperatureKernel, | ||
| topk: <B::Kernels as Kernels>::TopKKernel, | ||
| topp: <B::Kernels as Kernels>::TopPKernel, | ||
| minp: <B::Kernels as Kernels>::MinPKernel, | ||
| gumbel: <B::Kernels as Kernels>::GumbelKernel, | ||
| argmax_implementation: ArgmaxImplementation<B>, | ||
|
Comment on lines
40
to
41
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there any cases where separate kernels are faster? If no, unified stochastic should replace all of the old kernels, not just be a new option |
||
| stochastic: <B::Kernels as Kernels>::StochasticKernel, | ||
| stochastic_masked: <B::Kernels as Kernels>::StochasticKernel, | ||
| max_batch_size: usize, | ||
| max_vocab_size: usize, | ||
| } | ||
|
|
@@ -58,6 +55,10 @@ pub enum SamplingError<B: Backend> { | |
| BatchSizeExceeded(usize, usize), | ||
| #[error("Vocab size {0} exceeds maximum {1}")] | ||
| VocabSizeExceeded(usize, usize), | ||
| #[error("Stochastic: top_k={0} exceeds N_CANDIDATES={MAX_TOP_K}")] | ||
| TopKTooLarge(u32), | ||
|
Comment on lines
+58
to
+59
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably still support top_k > 64 via some fallback. |
||
| #[error("Stochastic: top_p={0} is not in (0, 1]")] | ||
| TopPOutOfRange(f32), | ||
| } | ||
|
|
||
| impl<B: Backend> SamplingKernel<B> { | ||
|
|
@@ -79,16 +80,6 @@ impl<B: Backend> SamplingKernel<B> { | |
| ) -> Result<Self, SamplingError<B>> { | ||
| let bitmask = <B::Kernels as Kernels>::BitmaskKernel::new(context, data_type, true) | ||
| .map_err(SamplingError::BackendError)?; | ||
| let temperature = <B::Kernels as Kernels>::TemperatureKernel::new(context, data_type, true) | ||
| .map_err(SamplingError::BackendError)?; | ||
| let topk = | ||
| <B::Kernels as Kernels>::TopKKernel::new(context, data_type, true).map_err(SamplingError::BackendError)?; | ||
| let topp = | ||
| <B::Kernels as Kernels>::TopPKernel::new(context, data_type, true).map_err(SamplingError::BackendError)?; | ||
| let minp = | ||
| <B::Kernels as Kernels>::MinPKernel::new(context, data_type, true).map_err(SamplingError::BackendError)?; | ||
| let gumbel = <B::Kernels as Kernels>::GumbelKernel::new(context, data_type, true) | ||
| .map_err(SamplingError::BackendError)?; | ||
|
|
||
| let argmax_implementation = match argmax_strategy { | ||
| ArgmaxStrategy::SinglePass => { | ||
|
|
@@ -126,14 +117,16 @@ impl<B: Backend> SamplingKernel<B> { | |
| }, | ||
| }; | ||
|
|
||
| let stochastic = <B::Kernels as Kernels>::StochasticKernel::new(context, data_type, false) | ||
| .map_err(SamplingError::BackendError)?; | ||
| let stochastic_masked = <B::Kernels as Kernels>::StochasticKernel::new(context, data_type, true) | ||
| .map_err(SamplingError::BackendError)?; | ||
|
|
||
| Ok(Self { | ||
| bitmask, | ||
| temperature, | ||
| topk, | ||
| topp, | ||
| minp, | ||
| gumbel, | ||
| argmax_implementation, | ||
| stochastic, | ||
| stochastic_masked, | ||
| max_batch_size, | ||
| max_vocab_size, | ||
| }) | ||
|
|
@@ -159,85 +152,50 @@ impl<B: Backend> SamplingKernel<B> { | |
| return Err(SamplingError::VocabSizeExceeded(vocab_size, self.max_vocab_size)); | ||
| } | ||
|
|
||
| if let Some(bitmask_buffer) = bitmask_buffer { | ||
| self.bitmask.encode( | ||
| None::<&B::Buffer>, | ||
| (bitmask_buffer, bitmask_offset), | ||
| logits_buffer.deref_mut(), | ||
| batch_size as u32, | ||
| vocab_size as u32, | ||
| encoder, | ||
| ); | ||
| } | ||
|
|
||
| if let SamplingMethod::Stochastic { | ||
| temperature, | ||
| top_k, | ||
| top_p, | ||
| min_p, | ||
| processing_order, | ||
| .. | ||
|
Comment on lines
-178
to
+160
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're silently ignoring processing order |
||
| } = sampling_method | ||
| { | ||
| if let Some(temperature) = temperature | ||
| && processing_order == SamplingProcessingOrder::TemperatureThenFilters | ||
| { | ||
| self.temperature.encode( | ||
| None::<&B::Buffer>, | ||
| logits_buffer.deref_mut(), | ||
| batch_size as u32, | ||
| vocab_size as u32, | ||
| temperature, | ||
| encoder, | ||
| ); | ||
| let top_k_val = top_k.unwrap_or(0); | ||
| if top_k_val > MAX_TOP_K { | ||
| return Err(SamplingError::TopKTooLarge(top_k_val)); | ||
| } | ||
|
|
||
| if let Some(top_k) = top_k { | ||
| self.topk.encode( | ||
| None::<&B::Buffer>, | ||
| logits_buffer.deref_mut(), | ||
| batch_size as u32, | ||
| vocab_size as u32, | ||
| top_k, | ||
| encoder, | ||
| ); | ||
| } | ||
| if let Some(top_p) = top_p { | ||
| self.topp.encode( | ||
| None::<&B::Buffer>, | ||
| logits_buffer.deref_mut(), | ||
| batch_size as u32, | ||
| vocab_size as u32, | ||
| top_p, | ||
| encoder, | ||
| ); | ||
| } | ||
| if let Some(min_p) = min_p { | ||
| self.minp.encode( | ||
| None::<&B::Buffer>, | ||
| logits_buffer.deref_mut(), | ||
| batch_size as u32, | ||
| vocab_size as u32, | ||
| min_p, | ||
| encoder, | ||
| ); | ||
| if let Some(p) = top_p { | ||
| if p <= 0.0 || p > 1.0 { | ||
| return Err(SamplingError::TopPOutOfRange(p)); | ||
| } | ||
| } | ||
|
|
||
| if let Some(temperature) = temperature | ||
| && processing_order == SamplingProcessingOrder::FiltersThenTemperature | ||
| { | ||
| self.temperature.encode( | ||
| None::<&B::Buffer>, | ||
| logits_buffer.deref_mut(), | ||
| batch_size as u32, | ||
| vocab_size as u32, | ||
| temperature, | ||
| encoder, | ||
| ); | ||
| } | ||
| let kernel = if bitmask_buffer.is_some() { | ||
| &self.stochastic_masked | ||
| } else { | ||
| &self.stochastic | ||
| }; | ||
| kernel.encode( | ||
| logits_buffer.deref(), | ||
| (seeds_buffer, seeds_offset), | ||
| sampled_tokens_buffer, | ||
| bitmask_buffer.map(|b| (b, bitmask_offset)), | ||
| batch_size as u32, | ||
| vocab_size as u32, | ||
| temperature.unwrap_or(1.0), | ||
| top_k_val, | ||
| top_p.unwrap_or(1.0), | ||
| min_p.unwrap_or(0.0), | ||
| encoder, | ||
| ); | ||
| return Ok(()); | ||
| } | ||
|
|
||
| self.gumbel.encode( | ||
| // SamplingMethod::Greedy | ||
| if let Some(bitmask_buffer) = bitmask_buffer { | ||
| self.bitmask.encode( | ||
| None::<&B::Buffer>, | ||
| (seeds_buffer, seeds_offset), | ||
| (bitmask_buffer, bitmask_offset), | ||
| logits_buffer.deref_mut(), | ||
| batch_size as u32, | ||
| vocab_size as u32, | ||
|
|
||
This file was deleted.
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,3 @@ | ||
| pub mod argmax; | ||
| pub mod bitmask; | ||
| pub mod gumbel; | ||
| pub mod min_p; | ||
| pub mod temperature; | ||
| pub mod top_k; | ||
| pub mod top_p; | ||
| pub mod stochastic; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add bitmask to argmax instead of having a separate kernel only for argmax