Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,8 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
}

func validateSamplingStrategy(_ config: SamplingConfiguration) throws {
guard config.temperature > 0 else { return }
if config.topP != nil {
throw InferenceRuntimeError.invalidArgument(
"CoreAI pipelined GPU sampler does not support topP. "
+ "Only greedy (temperature=0) and temperature+topK are supported."
)
}
// All sampling configurations are now supported by the GPU sampler:
// greedy, temperature, topK, topP, and minP.
}

func warmup(queryLength: Int, sampling: SamplingConfiguration?) async throws {
Expand Down Expand Up @@ -534,7 +529,7 @@ private struct EngineImpl: ~Copyable {
let newSampler = try MPSGraphSamplerFactory.makeSampler(
device: device,
vocabSize: self.config.vocabSize,
temperature: temperature
config: config
)
cachedSampler = newSampler
cachedSamplerTemperature = temperature
Expand Down Expand Up @@ -677,7 +672,7 @@ private struct EngineImpl: ~Copyable {
let localGPUSampler = gpuSampler
let outputBuffer = inputTokensBuffer
let logitsOffset = (actualTokenCount - 1) * vocabSize * MemoryLayout<UInt16>.size
let samplerStrategy = gpuSampler is MPSGraphArgmaxSampler ? "GPU-argmax" : "GPU-topK"
let samplerStrategy = gpuSampler is MPSGraphArgmaxSampler ? "GPU-argmax" : "GPU-composite"
let samplerTemperature = cachedSamplerTemperature ?? 0.0

let sampleSpan = InstrumentsProfiler.beginSampleEncoding(
Expand Down Expand Up @@ -861,7 +856,7 @@ private struct EngineImpl: ~Copyable {
// queue FIFO ordering via MTLDispatchListApply), guaranteeing every
// continuation.yield has returned before the caller calls finish().
// We use a bare command buffer instead of the sampler to avoid the shared
// MPSGraphExecutableExecutionDescriptor issue in MPSGraphTopKSampler.
// MPSGraphExecutableExecutionDescriptor issue in MPSGraphCompositeSampler.
await withCheckedContinuation { (sentinelCont: CheckedContinuation<Void, Never>) in
do {
let queue = pipelineQueue
Expand Down
52 changes: 36 additions & 16 deletions swift/Sources/CoreAILanguageModels/Samplers/CompositeSampler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Accelerate
import Foundation
import os.signpost

/// Unified CPU sampler supporting greedy, temperature, topK, and topP sampling.
/// Unified CPU sampler supporting greedy, temperature, topK, topP, and minP sampling.
///
/// This sampler uses Float32 internally for all non-greedy sampling to ensure
/// numerical precision. The Float16 → Float32 conversion happens once at entry,
Expand All @@ -21,11 +21,12 @@ import os.signpost
/// - **Temperature only**: Scale logits, softmax, multinomial sample
/// - **TopK**: Keep K highest probability tokens, sample from them
/// - **TopP (nucleus)**: Keep tokens until cumulative probability >= P
/// - **Combined**: Apply topP first (broader filter), then topK (hard limit)
/// - **MinP**: Keep tokens whose probability >= minP × max probability
/// - **Combined**: Apply minP first, then topP (broader filter), then topK (hard limit)
///
/// ## Algorithm Order
/// ```
/// logits → [temperature scaling] → [topP filter] → [topK filter] → [softmax] → [sample]
/// logits → [temperature scaling] → [minP filter] → [topP filter] → [topK filter] → [softmax] → [sample]
/// ```
public struct CompositeSampler {
// MARK: - Public API (Float16 input - backward compatible)
Expand Down Expand Up @@ -158,8 +159,8 @@ public struct CompositeSampler {

/// Full sampling pipeline in Float32. Two paths:
///
/// **Fast path** — no topK/topP: vectorized softmax + multinomial over the full vocab.
/// **Slow path** — topK and/or topP set: identify the active subset (typically 5-500 tokens),
/// **Fast path** — no topK/topP/minP: vectorized softmax + multinomial over the full vocab.
/// **Slow path** — topK, topP, and/or minP set: identify the active subset (typically 5-500 tokens),
/// compact those logits into a small array, softmax + sample over just those K. Skips the
/// V-sized softmax and the V-sized multinomial scan entirely.
///
Expand All @@ -181,8 +182,9 @@ public struct CompositeSampler {

let needsTopP = (config.topP.map { $0 < 1.0 } ?? false)
let needsTopK = (config.topK != nil)
let needsMinP = (config.minP != nil)

if !needsTopP && !needsTopK {
if !needsTopP && !needsTopK && !needsMinP {
// Fast path: vectorized softmax + multinomial over the full vocab.
softmaxVectorized(&logits)
let token = Int32(multinomialSample(logits, using: &rng))
Expand All @@ -194,7 +196,8 @@ public struct CompositeSampler {
let activeIndices = selectActiveIndices(
logits: logits,
topK: config.topK,
topP: config.topP.map { Float($0) })
topP: config.topP.map { Float($0) },
minP: config.minP.map { Float($0) })

// Degenerate: empty active set (e.g. all logits -.infinity). Fall back to index 0.
guard !activeIndices.isEmpty else {
Expand Down Expand Up @@ -249,22 +252,26 @@ public struct CompositeSampler {

// MARK: - Active-subset selection (slow path)

/// Identify the indices kept after topK / topP filtering, sorted descending by logit.
/// Identify the indices kept after minP / topK / topP filtering, sorted descending by logit.
///
/// When `topK` is set: partial sort via min-heap of size K (O(V log K)) — avoids the
/// O(V log V) full sort and the V-sized [Int] allocation.
///
/// When only `topP` is set: full sort (variable cutoff means a partial-sort window
/// When only `topP` or `minP` is set: full sort (variable cutoff means a partial-sort window
/// would need a fallback path; deferred). Still wins over the previous code by
/// avoiding the downstream V-sized softmax + multinomial scan.
///
/// `minP` is applied first in logit space: tokens with logit < max_logit + log(minP)
/// are excluded. This is equivalent to excluding tokens with probability < minP × max_prob.
///
/// `topP` is applied within the partial-sort window. When topK is also set, this is
/// numerically equivalent to global topP because top-K already captures essentially
/// all probability mass for realistic distributions.
private static func selectActiveIndices(
logits: [Float],
topK: Int?,
topP: Float?
topP: Float?,
minP: Float?
) -> [Int] {
let vocabSize = logits.count

Expand All @@ -275,34 +282,44 @@ public struct CompositeSampler {
sortedIndices = (0..<vocabSize).sorted { logits[$0] > logits[$1] }
}

guard let p = topP else { return sortedIndices }
guard !sortedIndices.isEmpty else { return [] }

let maxLogit = logits[sortedIndices[0]]
// All-masked degenerate input.
guard maxLogit > -.infinity else { return [] }

// Apply minP in logit space: keep tokens where logit >= maxLogit + log(minP).
// This is equivalent to P(token) >= minP * P(best) in probability space.
var filtered = sortedIndices
if let m = minP {
let threshold = maxLogit + logf(m)
filtered = filtered.filter { logits[$0] >= threshold }
if filtered.isEmpty { return [sortedIndices[0]] }
}

guard let p = topP else { return filtered }

// Compute exp(logit - max) per candidate (only K entries, not V).
var expValues = [Float](repeating: 0, count: sortedIndices.count)
var expValues = [Float](repeating: 0, count: filtered.count)
var sumExp: Float = 0
for (i, idx) in sortedIndices.enumerated() {
for (i, idx) in filtered.enumerated() {
let e = expf(logits[idx] - maxLogit)
expValues[i] = e
sumExp += e
}
let invSumExp = 1.0 / max(sumExp, .leastNormalMagnitude)

var cumProb: Float = 0
var cutoff = sortedIndices.count
for i in 0..<sortedIndices.count {
var cutoff = filtered.count
for i in 0..<filtered.count {
cumProb += expValues[i] * invSumExp
if cumProb >= p {
cutoff = i + 1
break
}
}

return Array(sortedIndices.prefix(cutoff))
return Array(filtered.prefix(cutoff))
}

/// Returns the K indices with the largest logit values, sorted descending by logit.
Expand Down Expand Up @@ -394,6 +411,9 @@ public struct CompositeSampler {
if config.temperature > 0 {
parts.append("temp")
}
if config.minP != nil {
parts.append("minP")
}
if config.topP != nil {
parts.append("topP")
}
Expand Down
Loading