diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift index 4ddc8f2..6f63b8e 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift @@ -213,13 +213,8 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable { } func validateSamplingStrategy(_ config: SamplingConfiguration) throws { - guard config.temperature > 0 else { return } - if config.topP != nil { - throw InferenceRuntimeError.invalidArgument( - "CoreAI pipelined GPU sampler does not support topP. " - + "Only greedy (temperature=0) and temperature+topK are supported." - ) - } + // All sampling configurations are now supported by the GPU sampler: + // greedy, temperature, topK, topP, and minP. } func warmup(queryLength: Int, sampling: SamplingConfiguration?) async throws { @@ -534,7 +529,7 @@ private struct EngineImpl: ~Copyable { let newSampler = try MPSGraphSamplerFactory.makeSampler( device: device, vocabSize: self.config.vocabSize, - temperature: temperature + config: config ) cachedSampler = newSampler cachedSamplerTemperature = temperature @@ -677,7 +672,7 @@ private struct EngineImpl: ~Copyable { let localGPUSampler = gpuSampler let outputBuffer = inputTokensBuffer let logitsOffset = (actualTokenCount - 1) * vocabSize * MemoryLayout.size - let samplerStrategy = gpuSampler is MPSGraphArgmaxSampler ? "GPU-argmax" : "GPU-topK" + let samplerStrategy = gpuSampler is MPSGraphArgmaxSampler ? "GPU-argmax" : "GPU-composite" let samplerTemperature = cachedSamplerTemperature ?? 0.0 let sampleSpan = InstrumentsProfiler.beginSampleEncoding( @@ -861,7 +856,7 @@ private struct EngineImpl: ~Copyable { // queue FIFO ordering via MTLDispatchListApply), guaranteeing every // continuation.yield has returned before the caller calls finish(). // We use a bare command buffer instead of the sampler to avoid the shared - // MPSGraphExecutableExecutionDescriptor issue in MPSGraphTopKSampler. + // MPSGraphExecutableExecutionDescriptor issue in MPSGraphCompositeSampler. await withCheckedContinuation { (sentinelCont: CheckedContinuation) in do { let queue = pipelineQueue diff --git a/swift/Sources/CoreAILanguageModels/Samplers/CompositeSampler.swift b/swift/Sources/CoreAILanguageModels/Samplers/CompositeSampler.swift index 9407bbc..5a8dac6 100644 --- a/swift/Sources/CoreAILanguageModels/Samplers/CompositeSampler.swift +++ b/swift/Sources/CoreAILanguageModels/Samplers/CompositeSampler.swift @@ -7,7 +7,7 @@ import Accelerate import Foundation import os.signpost -/// Unified CPU sampler supporting greedy, temperature, topK, and topP sampling. +/// Unified CPU sampler supporting greedy, temperature, topK, topP, and minP sampling. /// /// This sampler uses Float32 internally for all non-greedy sampling to ensure /// numerical precision. The Float16 → Float32 conversion happens once at entry, @@ -21,11 +21,12 @@ import os.signpost /// - **Temperature only**: Scale logits, softmax, multinomial sample /// - **TopK**: Keep K highest probability tokens, sample from them /// - **TopP (nucleus)**: Keep tokens until cumulative probability >= P -/// - **Combined**: Apply topP first (broader filter), then topK (hard limit) +/// - **MinP**: Keep tokens whose probability >= minP × max probability +/// - **Combined**: Apply minP first, then topP (broader filter), then topK (hard limit) /// /// ## Algorithm Order /// ``` -/// logits → [temperature scaling] → [topP filter] → [topK filter] → [softmax] → [sample] +/// logits → [temperature scaling] → [minP filter] → [topP filter] → [topK filter] → [softmax] → [sample] /// ``` public struct CompositeSampler { // MARK: - Public API (Float16 input - backward compatible) @@ -158,8 +159,8 @@ public struct CompositeSampler { /// Full sampling pipeline in Float32. Two paths: /// - /// **Fast path** — no topK/topP: vectorized softmax + multinomial over the full vocab. - /// **Slow path** — topK and/or topP set: identify the active subset (typically 5-500 tokens), + /// **Fast path** — no topK/topP/minP: vectorized softmax + multinomial over the full vocab. + /// **Slow path** — topK, topP, and/or minP set: identify the active subset (typically 5-500 tokens), /// compact those logits into a small array, softmax + sample over just those K. Skips the /// V-sized softmax and the V-sized multinomial scan entirely. /// @@ -181,8 +182,9 @@ public struct CompositeSampler { let needsTopP = (config.topP.map { $0 < 1.0 } ?? false) let needsTopK = (config.topK != nil) + let needsMinP = (config.minP != nil) - if !needsTopP && !needsTopK { + if !needsTopP && !needsTopK && !needsMinP { // Fast path: vectorized softmax + multinomial over the full vocab. softmaxVectorized(&logits) let token = Int32(multinomialSample(logits, using: &rng)) @@ -194,7 +196,8 @@ public struct CompositeSampler { let activeIndices = selectActiveIndices( logits: logits, topK: config.topK, - topP: config.topP.map { Float($0) }) + topP: config.topP.map { Float($0) }, + minP: config.minP.map { Float($0) }) // Degenerate: empty active set (e.g. all logits -.infinity). Fall back to index 0. guard !activeIndices.isEmpty else { @@ -249,22 +252,26 @@ public struct CompositeSampler { // MARK: - Active-subset selection (slow path) - /// Identify the indices kept after topK / topP filtering, sorted descending by logit. + /// Identify the indices kept after minP / topK / topP filtering, sorted descending by logit. /// /// When `topK` is set: partial sort via min-heap of size K (O(V log K)) — avoids the /// O(V log V) full sort and the V-sized [Int] allocation. /// - /// When only `topP` is set: full sort (variable cutoff means a partial-sort window + /// When only `topP` or `minP` is set: full sort (variable cutoff means a partial-sort window /// would need a fallback path; deferred). Still wins over the previous code by /// avoiding the downstream V-sized softmax + multinomial scan. /// + /// `minP` is applied first in logit space: tokens with logit < max_logit + log(minP) + /// are excluded. This is equivalent to excluding tokens with probability < minP × max_prob. + /// /// `topP` is applied within the partial-sort window. When topK is also set, this is /// numerically equivalent to global topP because top-K already captures essentially /// all probability mass for realistic distributions. private static func selectActiveIndices( logits: [Float], topK: Int?, - topP: Float? + topP: Float?, + minP: Float? ) -> [Int] { let vocabSize = logits.count @@ -275,17 +282,27 @@ public struct CompositeSampler { sortedIndices = (0.. logits[$1] } } - guard let p = topP else { return sortedIndices } guard !sortedIndices.isEmpty else { return [] } let maxLogit = logits[sortedIndices[0]] // All-masked degenerate input. guard maxLogit > -.infinity else { return [] } + // Apply minP in logit space: keep tokens where logit >= maxLogit + log(minP). + // This is equivalent to P(token) >= minP * P(best) in probability space. + var filtered = sortedIndices + if let m = minP { + let threshold = maxLogit + logf(m) + filtered = filtered.filter { logits[$0] >= threshold } + if filtered.isEmpty { return [sortedIndices[0]] } + } + + guard let p = topP else { return filtered } + // Compute exp(logit - max) per candidate (only K entries, not V). - var expValues = [Float](repeating: 0, count: sortedIndices.count) + var expValues = [Float](repeating: 0, count: filtered.count) var sumExp: Float = 0 - for (i, idx) in sortedIndices.enumerated() { + for (i, idx) in filtered.enumerated() { let e = expf(logits[idx] - maxLogit) expValues[i] = e sumExp += e @@ -293,8 +310,8 @@ public struct CompositeSampler { let invSumExp = 1.0 / max(sumExp, .leastNormalMagnitude) var cumProb: Float = 0 - var cutoff = sortedIndices.count - for i in 0..= p { cutoff = i + 1 @@ -302,7 +319,7 @@ public struct CompositeSampler { } } - return Array(sortedIndices.prefix(cutoff)) + return Array(filtered.prefix(cutoff)) } /// Returns the K indices with the largest logit values, sorted descending by logit. @@ -394,6 +411,9 @@ public struct CompositeSampler { if config.temperature > 0 { parts.append("temp") } + if config.minP != nil { + parts.append("minP") + } if config.topP != nil { parts.append("topP") } diff --git a/swift/Sources/CoreAILanguageModels/Samplers/MPSGraphSamplers.swift b/swift/Sources/CoreAILanguageModels/Samplers/MPSGraphSamplers.swift index 4f469d6..041c83c 100644 --- a/swift/Sources/CoreAILanguageModels/Samplers/MPSGraphSamplers.swift +++ b/swift/Sources/CoreAILanguageModels/Samplers/MPSGraphSamplers.swift @@ -16,18 +16,17 @@ import MetalPerformanceShadersGraph // ## Design Decisions // // ### Protocol-Based Architecture -// Both argmax (greedy) and TopK (probabilistic) samplers conform to the +// Both argmax (greedy) and composite (probabilistic) samplers conform to the // `MPSGraphSampler` protocol, enabling runtime selection based on temperature: // - temperature == 0: Argmax sampler (deterministic, fastest) -// - temperature > 0: TopK sampler (probabilistic, more creative) +// - temperature > 0: Composite sampler (probabilistic with topK/topP/minP) // // The factory pattern (`MPSGraphSamplerFactory`) selects the appropriate sampler // once at generation start, with the sampler cached for the entire generation. // // ### Fixed Vocab Size at Compile Time -// Unlike MPSGraphInferenceEngine which uses dynamic shape `[1, -1]` for the -// vocab dimension, these samplers fix the vocab size at compile time. This -// enables better MPSGraph optimization and eliminates runtime shape inference. +// These samplers fix the vocab size at compile time. This enables better +// MPSGraph optimization and eliminates runtime shape inference. // // ### Temperature at Init (Immutable) // Temperature is baked into the TopK sampler at initialization rather than @@ -39,25 +38,6 @@ import MetalPerformanceShadersGraph // extracting the last token's logits using a blit encoder before sampling. // This is critical for efficient prefill where we only need to sample from // the final position. -// -// ### Comparison with MPSGraphInferenceEngine -// | Feature | MPSGraphInferenceEngine | Core AI Samplers | -// |---------------------|-------------------------|----------------------| -// | Sampling types | Argmax only | Argmax + TopK | -// | Vocab shape | Dynamic [1, -1] | Fixed [1, vocabSize] | -// | Temperature | N/A (greedy only) | At init (immutable) | -// | Slice handling | N/A | Blit + encode | -// | Testing hooks | None | testingOnlyRandomOverride | -// | Buffer allocation | Per-call | Pre-allocated | -// -// ### Why Not Use MPSGraphInferenceEngine's Sampler? -// 1. Core AI needs TopK sampling with temperature for creative generation -// 2. Core AI uses ComputeStream's Metal3 queue (withMetal3Queue), not direct -// command queue - we need sampler methods that take a queue parameter -// 3. CoreAI.s pipelined architecture requires completion handlers for yielding -// tokens without blocking the main inference loop -// 4. Fixed vocab size enables better graph optimization for large vocabs -// (150K+ for Qwen models) // MARK: - MPSGraph Sampler Protocol @@ -120,22 +100,47 @@ enum MPSGraphSamplerFactory { /// /// Selection logic: /// - temperature == 0: Returns argmax sampler (greedy, deterministic) - /// - temperature > 0: Returns TopK sampler (probabilistic) + /// - temperature > 0: Returns composite sampler (topK + topP + minP) static func makeSampler( device: MTLDevice, vocabSize: Int, - temperature: Double + config: SamplingConfiguration ) throws -> any MPSGraphSampler { - if temperature == 0 { + if config.temperature == 0 { return try MPSGraphArgmaxSampler(device: device, vocabSize: vocabSize) + } + + // Determine effective K for the topK operation: + // - If topK is explicitly set, use it + // - If only topP or minP is set, use a generous window (1000) + // - Default (just temperature): use 40 + let effectiveK: Int + if let k = config.topK { + effectiveK = k + } else if config.topP != nil || config.minP != nil { + effectiveK = min(1000, vocabSize) } else { - return try MPSGraphTopKSampler( - device: device, - vocabSize: vocabSize, - k: 40, - temperature: Float(temperature) - ) + effectiveK = 40 } + + return try MPSGraphCompositeSampler( + device: device, + vocabSize: vocabSize, + k: effectiveK, + temperature: Float(config.temperature), + topP: config.topP.map { Float($0) } ?? 1.0, + minP: config.minP.map { Float($0) } ?? 0.0 + ) + } + + /// Legacy convenience for temperature-only creation (used by tests). + static func makeSampler( + device: MTLDevice, + vocabSize: Int, + temperature: Double + ) throws -> any MPSGraphSampler { + let config = SamplingConfiguration(temperature: temperature) + return try makeSampler(device: device, vocabSize: vocabSize, config: config) } } @@ -193,7 +198,6 @@ final class MPSGraphArgmaxSampler: @unchecked Sendable { self.graph = graph // Input: logits for a single token position [1, vocabSize] as Float16 - // Match MPSGraphInferenceEngine pattern: [1, vocabSize] with axis 1 reduction let inputPlaceholder = graph.placeholder( shape: [1, vocabSize as NSNumber], dataType: .float16, @@ -424,34 +428,24 @@ extension MPSGraphArgmaxSampler: MPSGraphSampler {} // MARK: - MPSGraph Top-K Sampler -/// MPSGraph-based Top-K sampler with temperature scaling. +/// MPSGraph-based composite sampler with temperature, TopK, TopP, and MinP. /// /// This sampler uses Apple's optimized `topK` operation combined with softmax -/// for probabilistic token sampling. Unlike greedy argmax, this enables: +/// for probabilistic token sampling. Supports: /// - Temperature-controlled randomness /// - Top-K filtering for quality/diversity tradeoff +/// - Top-P (nucleus) filtering for adaptive vocabulary +/// - Min-P filtering for relative probability thresholding /// /// ## Sampling Algorithm -/// 1. Extract Top-K logits and indices +/// 1. Extract Top-K logits and indices from full vocab /// 2. Apply temperature scaling: logits / temperature /// 3. Apply softmax to get probabilities -/// 4. Sample using multinomial (cumsum + random comparison) -/// -/// ## Usage with Core AI's ComputeStream -/// ```swift -/// computeStream.withMetal3Queue { queue in -/// topKSampler.encode( -/// to: queue, -/// logitsBuffer: logitsBuffer, -/// temperature: 0.7, -/// outputBuffer: tokenBuffer, -/// completion: { token in -/// continuation.yield(token) -/// } -/// ) -/// } -/// ``` -final class MPSGraphTopKSampler: @unchecked Sendable { +/// 4. Apply MinP filter: keep probs >= minP × max_prob +/// 5. Apply TopP filter: keep probs where exclusive cumsum < topP +/// 6. Re-normalize masked probabilities +/// 7. Sample using multinomial (cumsum + random comparison) +final class MPSGraphCompositeSampler: @unchecked Sendable { private let device: MTLDevice private let mpsDevice: MPSGraphDevice private let graph: MPSGraph @@ -460,6 +454,8 @@ final class MPSGraphTopKSampler: @unchecked Sendable { private let logitsPlaceholder: MPSGraphTensor private let temperaturePlaceholder: MPSGraphTensor private let randomPlaceholder: MPSGraphTensor + private let topPPlaceholder: MPSGraphTensor + private let minPPlaceholder: MPSGraphTensor private let outputTensor: MPSGraphTensor private let executable: MPSGraphExecutable @@ -473,12 +469,24 @@ final class MPSGraphTopKSampler: @unchecked Sendable { /// The temperature this sampler was configured with let temperature: Float + /// The topP value (1.0 = disabled) + let topP: Float + + /// The minP value (0.0 = disabled) + let minP: Float + /// Pre-allocated buffer for random value private let randomBuffer: MTLBuffer /// Pre-allocated buffer for temperature private let temperatureBuffer: MTLBuffer + /// Pre-allocated buffer for topP value + private let topPBuffer: MTLBuffer + + /// Pre-allocated buffer for minP value + private let minPBuffer: MTLBuffer + // Pre-allocated objects reused every step to avoid CPU object creation overhead. private var cachedLogitsData: MPSGraphTensorData? private var cachedOutputData: MPSGraphTensorData? @@ -486,36 +494,46 @@ final class MPSGraphTopKSampler: @unchecked Sendable { private var cachedOutputBuffer: MTLBuffer? private let temperatureData: MPSGraphTensorData private let randomData: MPSGraphTensorData + private let topPData: MPSGraphTensorData + private let minPData: MPSGraphTensorData private let execDescriptor: MPSGraphExecutableExecutionDescriptor /// Testing only: Override random value for deterministic tests. - /// When set, this value is used instead of generating a random number. - /// Set to nil for production use. var testingOnlyRandomOverride: Float? - /// Initialize the MPSGraph Top-K sampler. + /// Initialize the MPSGraph composite sampler. /// - Parameters: /// - device: Metal device /// - vocabSize: Vocabulary size (fixed for compilation) - /// - k: Number of top tokens to sample from (default: 40) - /// - temperature: Sampling temperature (default: 1.0) - init(device: MTLDevice, vocabSize: Int, k: Int = 40, temperature: Float = 1.0) throws { + /// - k: Number of top tokens to consider + /// - temperature: Sampling temperature + /// - topP: Nucleus sampling threshold (1.0 = disabled) + /// - minP: Minimum probability threshold (0.0 = disabled) + init(device: MTLDevice, vocabSize: Int, k: Int = 40, temperature: Float = 1.0, topP: Float = 1.0, minP: Float = 0.0) + throws + { self.device = device self.mpsDevice = MPSGraphDevice(mtlDevice: device) self.vocabSize = vocabSize self.k = k + self.temperature = temperature + self.topP = topP + self.minP = minP // Pre-allocate buffers guard let randomBuffer = device.makeBuffer(length: MemoryLayout.size, options: .storageModeShared), - let temperatureBuffer = device.makeBuffer(length: MemoryLayout.size, options: .storageModeShared) + let temperatureBuffer = device.makeBuffer(length: MemoryLayout.size, options: .storageModeShared), + let topPBuffer = device.makeBuffer(length: MemoryLayout.size, options: .storageModeShared), + let minPBuffer = device.makeBuffer(length: MemoryLayout.size, options: .storageModeShared) else { throw MPSGraphSamplerError.bufferAllocationFailed } - self.temperature = temperature self.randomBuffer = randomBuffer self.temperatureBuffer = temperatureBuffer + self.topPBuffer = topPBuffer + self.minPBuffer = minPBuffer - // Build the Top-K sampling graph + // Build the composite sampling graph let graph = MPSGraph() self.graph = graph @@ -543,45 +561,74 @@ final class MPSGraphTopKSampler: @unchecked Sendable { ) self.randomPlaceholder = randomPlaceholder + // TopP threshold [1] + let topPPlaceholder = graph.placeholder( + shape: [1 as NSNumber], + dataType: .float32, + name: "topP" + ) + self.topPPlaceholder = topPPlaceholder + + // MinP threshold [1] + let minPPlaceholder = graph.placeholder( + shape: [1 as NSNumber], + dataType: .float32, + name: "minP" + ) + self.minPPlaceholder = minPPlaceholder + // Cast logits to Float32 for numerical stability let logitsFloat32 = graph.cast(logitsPlaceholder, to: .float32, name: "logits_f32") - // Get Top-K values and indices - // topK returns a tuple: (values: [1, k], indices: [1, k]) + // Step 1: Get Top-K values and indices let topKResult = graph.topK(logitsFloat32, k: k, name: "topk") - let topKValues = topKResult[0] // [1, k] + let topKValues = topKResult[0] // [1, k] sorted descending let topKIndices = topKResult[1] // [1, k] as Int32 - // Apply temperature: values / temperature - // Broadcast temperature to match shape + // Step 2: Apply temperature: values / temperature let scaledValues = graph.division(topKValues, temperaturePlaceholder, name: "scaled") - // Softmax over the K dimension (axis 1) + // Step 3: Softmax over the K dimension (axis 1) let probabilities = graph.softMax(with: scaledValues, axis: 1, name: "probs") - // Multinomial sampling via cumulative sum + random comparison - // cumsum: [1, k] where each element is sum of probs up to that point - let cumsum = graph.cumulativeSum(probabilities, axis: 1, exclusive: false, reverse: false, name: "cumsum") - - // Compare: cumsum >= random (broadcast random across k dimension) - // This gives us a boolean mask where True means "this token or later" - let randomBroadcast = graph.broadcast(randomPlaceholder, shape: [1, k as NSNumber], name: "random_broadcast") - let mask = graph.greaterThanOrEqualTo(cumsum, randomBroadcast, name: "mask") - - // Convert mask to float and use argmax to find first True - let maskFloat = graph.cast(mask, to: .float32, name: "mask_float") - let selectedIdx = graph.reductionArgMaximum(with: maskFloat, axis: 1, name: "selected_idx") - - // Gather the token index from topKIndices using selectedIdx - // selectedIdx is [1] with value 0..k-1 - // We need to index into topKIndices[0, selectedIdx] to get the actual token ID + // Step 4: MinP filtering + // max_prob is the first element (topK returns sorted descending) + let maxProb = graph.sliceTensor(probabilities, dimension: 1, start: 0, length: 1, name: "max_prob") + // threshold = minP * max_prob + let minPThreshold = graph.multiplication(minPPlaceholder, maxProb, name: "minp_threshold") + // mask: probs >= threshold (broadcasts [1,1] to [1,k]) + let minPMask = graph.greaterThanOrEqualTo(probabilities, minPThreshold, name: "minp_mask") + + // Step 5: TopP filtering via exclusive cumulative sum + // exclusive_cumsum[i] = sum of probs[0..i-1], so position 0 always has value 0 + let exclusiveCumsum = graph.cumulativeSum( + probabilities, axis: 1, exclusive: true, reverse: false, name: "excl_cumsum") + // mask: exclusive_cumsum < topP (includes all tokens before cumsum reaches topP) + let topPMask = graph.lessThan(exclusiveCumsum, topPPlaceholder, name: "topp_mask") + + // Step 6: Combined mask = minP AND topP + let combinedMask = graph.logicalAND(minPMask, topPMask, name: "combined_mask") + let maskFloat = graph.cast(combinedMask, to: .float32, name: "mask_float") + + // Step 7: Apply mask and re-normalize + let maskedProbs = graph.multiplication(probabilities, maskFloat, name: "masked_probs") + let sumMasked = graph.reductionSum(with: maskedProbs, axis: 1, name: "sum_masked") + // Avoid division by zero: use max(sum, epsilon) + let epsilon = graph.constant(1e-10, dataType: .float32) + let safeDenominator = graph.maximum(sumMasked, epsilon, name: "safe_denom") + let normalizedProbs = graph.division(maskedProbs, safeDenominator, name: "normalized_probs") + + // Step 8: Multinomial sampling via cumulative sum + random comparison + let cumsum = graph.cumulativeSum(normalizedProbs, axis: 1, exclusive: false, reverse: false, name: "cumsum") + let selectionMask = graph.greaterThanOrEqualTo(cumsum, randomPlaceholder, name: "selection_mask") + let selectionMaskFloat = graph.cast(selectionMask, to: .float32, name: "selection_mask_float") + let selectedIdx = graph.reductionArgMaximum(with: selectionMaskFloat, axis: 1, name: "selected_idx") + + // Step 9: Gather the token index from topKIndices let selectedIdxInt32 = graph.cast(selectedIdx, to: .int32, name: "selected_idx_i32") - - // Flatten topKIndices to [k] and use gatherElements let indicesFlat = graph.reshape(topKIndices, shape: [k as NSNumber], name: "indices_flat") let selectedIdxFlat = graph.reshape(selectedIdxInt32, shape: [1 as NSNumber], name: "selected_flat") - // Gather the actual token ID let outputTensor = graph.gatherAlongAxis( 0, updates: indicesFlat, @@ -595,6 +642,8 @@ final class MPSGraphTopKSampler: @unchecked Sendable { logitsPlaceholder: MPSGraphShapedType(shape: [1, vocabSize as NSNumber], dataType: .float16), temperaturePlaceholder: MPSGraphShapedType(shape: [1 as NSNumber], dataType: .float32), randomPlaceholder: MPSGraphShapedType(shape: [1 as NSNumber], dataType: .float32), + topPPlaceholder: MPSGraphShapedType(shape: [1 as NSNumber], dataType: .float32), + minPPlaceholder: MPSGraphShapedType(shape: [1 as NSNumber], dataType: .float32), ] let compilationDescriptor = MPSGraphCompilationDescriptor() @@ -608,7 +657,7 @@ final class MPSGraphTopKSampler: @unchecked Sendable { compilationDescriptor: compilationDescriptor ) - // Pre-allocate tensor data for temperature and random buffers (never change) + // Pre-allocate tensor data for buffers self.temperatureData = MPSGraphTensorData( temperatureBuffer, shape: [1 as NSNumber], @@ -619,12 +668,20 @@ final class MPSGraphTopKSampler: @unchecked Sendable { shape: [1 as NSNumber], dataType: .float32 ) + self.topPData = MPSGraphTensorData( + topPBuffer, + shape: [1 as NSNumber], + dataType: .float32 + ) + self.minPData = MPSGraphTensorData( + minPBuffer, + shape: [1 as NSNumber], + dataType: .float32 + ) self.execDescriptor = MPSGraphExecutableExecutionDescriptor() } - /// Encode Top-K sampling asynchronously (protocol conformance). - /// - /// Uses the temperature configured at init time. + /// Encode composite sampling asynchronously (protocol conformance). func encode( to queue: MTLCommandQueue, logitsBuffer: MTLBuffer, @@ -633,10 +690,11 @@ final class MPSGraphTopKSampler: @unchecked Sendable { outputOffset: Int, completion: @escaping (Int32) -> Void ) { - // Write temperature to buffer (use configured temperature) + // Write runtime values to buffers temperatureBuffer.contents().assumingMemoryBound(to: Float.self).pointee = max(temperature, 0.01) + topPBuffer.contents().assumingMemoryBound(to: Float.self).pointee = topP + minPBuffer.contents().assumingMemoryBound(to: Float.self).pointee = minP - // Use override if set (for testing), otherwise generate random value [0, 1) let randomValue = testingOnlyRandomOverride ?? Float.random(in: 0..<1) randomBuffer.contents().assumingMemoryBound(to: Float.self).pointee = randomValue @@ -667,10 +725,9 @@ final class MPSGraphTopKSampler: @unchecked Sendable { cachedOutputBuffer = outputBuffer } - // Reuse pre-allocated execution descriptor, update completion handler execDescriptor.completionHandler = { [outputBuffer, outputOffset] (_, error) in if let error = error { - print("MPSGraph Top-K error: \(error)") + print("MPSGraph composite sampler error: \(error)") completion(0) return } @@ -682,18 +739,15 @@ final class MPSGraphTopKSampler: @unchecked Sendable { completion(result) } - // Run async — temperatureData and randomData are pre-allocated, buffer contents updated above executable.runAsync( with: queue, - inputs: [logitsData, temperatureData, randomData], + inputs: [logitsData, temperatureData, randomData, topPData, minPData], results: [outputData], executionDescriptor: execDescriptor ) } - /// Encode Top-K sampling with slice support for prefill scenarios (protocol conformance). - /// - /// Uses the temperature configured at init time. + /// Encode composite sampling with slice support for prefill scenarios. func encodeWithSlice( to queue: MTLCommandQueue, logitsBuffer: MTLBuffer, @@ -702,7 +756,6 @@ final class MPSGraphTopKSampler: @unchecked Sendable { outputOffset: Int, completion: @escaping (Int32) -> Void ) { - // For single-token decode, use direct encoding if queryLength == 1 { encode( to: queue, @@ -715,25 +768,19 @@ final class MPSGraphTopKSampler: @unchecked Sendable { return } - // For multi-token (prefill), we need to handle the offset - // Pattern: Commit blit separately, then use runAsync for sampling - // This avoids the issue where encode() to MPSCommandBuffer commits internally - let logitsOffset = (queryLength - 1) * vocabSize * MemoryLayout.size let sliceSize = vocabSize * MemoryLayout.size - // Create a temporary buffer for the single token's logits guard let tempBuffer = device.makeBuffer(length: sliceSize, options: .storageModeShared) else { completion(0) return } - // Step 1: Create and commit blit command buffer separately guard let blitCmdBuffer = queue.makeCommandBuffer() else { completion(0) return } - blitCmdBuffer.label = "MPSGraph Top-K Blit" + blitCmdBuffer.label = "MPSGraph Composite Blit" guard let blitEncoder = blitCmdBuffer.makeBlitCommandEncoder() else { completion(0) @@ -747,23 +794,22 @@ final class MPSGraphTopKSampler: @unchecked Sendable { size: sliceSize ) blitEncoder.endEncoding() - blitCmdBuffer.commit() // Commit blit immediately (GPU will order operations) + blitCmdBuffer.commit() - // Step 2: Use runAsync for sampling (executes after blit due to GPU queue ordering) - // Write temperature and random to buffers (use configured temperature) + // Write runtime values temperatureBuffer.contents().assumingMemoryBound(to: Float.self).pointee = max(self.temperature, 0.01) + topPBuffer.contents().assumingMemoryBound(to: Float.self).pointee = topP + minPBuffer.contents().assumingMemoryBound(to: Float.self).pointee = minP let randomValue = testingOnlyRandomOverride ?? Float.random(in: 0..<1) randomBuffer.contents().assumingMemoryBound(to: Float.self).pointee = randomValue - // Create tensor data (tempBuffer is unique per prefill call, can't cache) let logitsData = MPSGraphTensorData(tempBuffer, shape: [1, vocabSize as NSNumber], dataType: .float16) let outputData = MPSGraphTensorData(outputBuffer, shape: [1 as NSNumber], dataType: .int32) - // Use a separate execution descriptor for prefill let prefillExecDescriptor = MPSGraphExecutableExecutionDescriptor() prefillExecDescriptor.completionHandler = { [outputBuffer, outputOffset] (_, error) in if let error = error { - print("MPSGraph Top-K error: \(error)") + print("MPSGraph composite sampler error: \(error)") completion(0) return } @@ -775,10 +821,9 @@ final class MPSGraphTopKSampler: @unchecked Sendable { completion(result) } - // Run async - GPU naturally orders this after the blit due to queue ordering executable.runAsync( with: queue, - inputs: [logitsData, temperatureData, randomData], + inputs: [logitsData, temperatureData, randomData, topPData, minPData], results: [outputData], executionDescriptor: prefillExecDescriptor ) @@ -786,7 +831,7 @@ final class MPSGraphTopKSampler: @unchecked Sendable { } // Conformance to MPSGraphSampler protocol -extension MPSGraphTopKSampler: MPSGraphSampler {} +extension MPSGraphCompositeSampler: MPSGraphSampler {} // MARK: - Errors diff --git a/swift/Sources/CoreAILanguageModels/Samplers/SamplingConfiguration.swift b/swift/Sources/CoreAILanguageModels/Samplers/SamplingConfiguration.swift index 40c6838..aa7b358 100644 --- a/swift/Sources/CoreAILanguageModels/Samplers/SamplingConfiguration.swift +++ b/swift/Sources/CoreAILanguageModels/Samplers/SamplingConfiguration.swift @@ -11,13 +11,15 @@ import CoreAIShared /// - **Temperature**: Controls overall randomness /// - **TopK**: Limits vocabulary to K most likely tokens /// - **TopP (nucleus)**: Limits vocabulary by cumulative probability threshold +/// - **MinP**: Limits vocabulary by minimum probability relative to the most likely token /// /// ## Sampling Algorithm Order /// When multiple parameters are set, they are applied in this order: /// 1. Temperature scaling (logits / temperature) -/// 2. TopP filtering (cumulative probability cutoff) -/// 3. TopK filtering (hard limit on vocabulary) -/// 4. Softmax and multinomial sampling +/// 2. MinP filtering (relative probability threshold) +/// 3. TopP filtering (cumulative probability cutoff) +/// 4. TopK filtering (hard limit on vocabulary) +/// 5. Softmax and multinomial sampling /// /// ## Usage Example /// ```swift @@ -33,6 +35,9 @@ import CoreAIShared /// // Nucleus (TopP) sampling /// let nucleus = SamplingConfiguration(temperature: 0.9, topP: 0.95) /// +/// // MinP sampling (cheaper alternative to TopP) +/// let minP = SamplingConfiguration(temperature: 0.9, minP: 0.05) +/// /// // Combined TopK + TopP (recommended for best quality) /// let combined = SamplingConfiguration(temperature: 0.8, topK: 50, topP: 0.9) /// ``` @@ -72,6 +77,19 @@ public struct SamplingConfiguration: Sendable, Equatable, Hashable { /// When uncertain (flat distribution), vocabulary expands. public let topP: Double? + /// Min-P sampling: only consider tokens whose probability is at least minP times + /// the most likely token's probability. + /// + /// - **nil**: No min-P filtering + /// - **0.05**: Common default (keep tokens with >= 5% of top token's probability) + /// - **0.1**: More aggressive filtering + /// - **0.01**: Very permissive + /// + /// MinP is a simpler, cheaper alternative to TopP that adapts to the distribution shape. + /// When the model is confident, fewer tokens pass. When uncertain, more tokens pass. + /// Unlike TopP, it does not require sorting — it operates as a simple threshold in logit space. + public let minP: Double? + /// A boolean flag that requests the sampling operation be combined /// with logit inference. /// @@ -88,17 +106,21 @@ public struct SamplingConfiguration: Sendable, Equatable, Hashable { /// - temperature: The randomness factor for token generation. Must be >= 0.0. /// - topK: Optional top-K limit. Must be > 0 if set. /// - topP: Optional top-P threshold. Must be in (0, 1] if set. + /// - minP: Optional min-P threshold. Must be in (0, 1] if set. /// - combined: Whether to combine sampling with logit inference. Defaults to true. /// /// - Note: Call `validate()` to check for potentially suboptimal configurations. - public init(temperature: Double, topK: Int? = nil, topP: Double? = nil, combined: Bool = true) { + public init(temperature: Double, topK: Int? = nil, topP: Double? = nil, minP: Double? = nil, combined: Bool = true) + { precondition(temperature >= 0, "Temperature must be non-negative.") precondition(topK == nil || topK! > 0, "TopK must be positive if set.") precondition(topP == nil || (topP! > 0 && topP! <= 1), "TopP must be in (0, 1] if set.") + precondition(minP == nil || (minP! > 0 && minP! <= 1), "MinP must be in (0, 1] if set.") self.temperature = temperature self.topK = topK self.topP = topP + self.minP = minP self.combined = combined } @@ -125,11 +147,11 @@ public struct SamplingConfiguration: Sendable, Equatable, Hashable { temperature == 0 } - /// Whether this configuration requires composite sampling (topK and/or topP). + /// Whether this configuration requires composite sampling (topK, topP, and/or minP). /// - /// True when temperature > 0 and either topK or topP is set. + /// True when temperature > 0 and any of topK, topP, or minP is set. public var isComposite: Bool { - temperature > 0 && (topK != nil || topP != nil) + temperature > 0 && (topK != nil || topP != nil || minP != nil) } /// Validates the configuration and returns warnings for potentially suboptimal settings. @@ -137,7 +159,8 @@ public struct SamplingConfiguration: Sendable, Equatable, Hashable { /// This method checks for: /// - topK=1 with temperature>0 (should use greedy instead) /// - topP=1.0 (effectively disabled, same as nil) - /// - topK/topP set with temperature=0 (ignored for greedy) + /// - minP=1.0 (effectively greedy, should use temperature=0) + /// - topK/topP/minP set with temperature=0 (ignored for greedy) /// /// - Returns: Array of warning messages, empty if configuration is optimal. private func validate() -> [String] { @@ -159,11 +182,27 @@ public struct SamplingConfiguration: Sendable, Equatable, Hashable { ) } - // Check for topK/topP with temperature=0 (ignored) - if temperature == 0 && (topK != nil || topP != nil) { + // Check for minP=1.0 (effectively greedy) + if let m = minP, m == 1.0 { + warnings.append( + "minP=1.0 keeps only the single most-probable token. " + + "Use temperature=0 for deterministic output, or a smaller minP value." + ) + } + + // Check for minP + topP together (unusual, may indicate confusion) + if minP != nil && topP != nil { + warnings.append( + "Both minP and topP are set. They serve similar purposes (adaptive filtering). " + + "Both will apply (minP first, then topP), but typically only one is needed." + ) + } + + // Check for topK/topP/minP with temperature=0 (ignored) + if temperature == 0 && (topK != nil || topP != nil || minP != nil) { warnings.append( - "topK/topP are ignored when temperature=0 (greedy sampling). " - + "Set temperature>0 to enable filtering, or remove topK/topP." + "topK/topP/minP are ignored when temperature=0 (greedy sampling). " + + "Set temperature>0 to enable filtering, or remove topK/topP/minP." ) } @@ -183,27 +222,31 @@ public struct SamplingConfiguration: Sendable, Equatable, Hashable { /// Returns a normalized configuration with redundant settings removed. /// /// - topP=1.0 is replaced with nil (no effect) - /// - topK/topP are removed if temperature=0 (greedy ignores them) + /// - topK/topP/minP are removed if temperature=0 (greedy ignores them) /// /// - Returns: A new configuration with redundant settings removed. public func normalized() -> SamplingConfiguration { let effectiveTopK: Int? let effectiveTopP: Double? + let effectiveMinP: Double? if temperature == 0 { - // Greedy ignores topK/topP + // Greedy ignores topK/topP/minP effectiveTopK = nil effectiveTopP = nil + effectiveMinP = nil } else { effectiveTopK = topK // topP=1.0 is equivalent to nil effectiveTopP = (topP == 1.0) ? nil : topP + effectiveMinP = minP } return SamplingConfiguration( temperature: temperature, topK: effectiveTopK, topP: effectiveTopP, + minP: effectiveMinP, combined: combined ) } diff --git a/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift b/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift index a2bda2a..a12942b 100644 --- a/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift +++ b/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift @@ -84,6 +84,11 @@ struct LLMRunner: AsyncParsableCommand, Sendable { help: "Top-P (nucleus) sampling: consider tokens in top P probability mass (e.g., 0.9)") var topP: Double? + @Option( + name: .customLong("min-p"), + help: "Min-P sampling: keep tokens with probability >= minP × max probability (e.g., 0.05)") + var minP: Double? + @Option(help: "Sampling strategy. Options: 'temperature' (default), 'greedy'") var samplingStrategy: String = "temperature" @@ -303,6 +308,9 @@ struct LLMRunner: AsyncParsableCommand, Sendable { if let p = topP { CLILogger.log("TopP: \(p)", component: "Main") } + if let m = minP { + CLILogger.log("MinP: \(m)", component: "Main") + } if kvCacheStrategy != .auto { CLILogger.log("KV Cache Strategy: \(kvCacheStrategy.rawValue)", component: "Main") if let capacity = kvCacheInitialCapacity { @@ -669,13 +677,14 @@ struct LLMRunner: AsyncParsableCommand, Sendable { temperature: temperature, topK: topK, topP: topP, + minP: minP, combined: !synchronousSampling ) case "greedy": - // Fatal error if topK/topP set with greedy - if topK != nil || topP != nil { - print("Error: --top-k and --top-p cannot be used with --sampling-strategy greedy") - print("Use --sampling-strategy temperature with --top-k/--top-p, or remove --top-k/--top-p for greedy") + // Fatal error if topK/topP/minP set with greedy + if topK != nil || topP != nil || minP != nil { + print("Error: --top-k, --top-p, and --min-p cannot be used with --sampling-strategy greedy") + print("Use --sampling-strategy temperature with --top-k/--top-p/--min-p, or remove them for greedy") throw ExitCode.failure } config = SamplingConfiguration(temperature: 0, combined: !synchronousSampling) diff --git a/swift/Tests/LanguageModelsTests/CompositeSamplerTests.swift b/swift/Tests/LanguageModelsTests/CompositeSamplerTests.swift index 8099203..c084593 100644 --- a/swift/Tests/LanguageModelsTests/CompositeSamplerTests.swift +++ b/swift/Tests/LanguageModelsTests/CompositeSamplerTests.swift @@ -179,6 +179,104 @@ struct CompositeSamplerTests { } } + // MARK: - MinP sampling + + @Test("minP limits the sampled set to tokens with sufficient relative probability") + func minPLimitsSampledSet() { + let vocab = 100 + var rng = Xoshiro256StarStar(seed: 0xAAAA_BBBB) + + // Token 99 has logit 10.0, tokens 98-95 have logit ~9.5-8.0 + // Everything else is -5.0 (very low relative probability) + var rawLogits = [Float](repeating: -5.0, count: vocab) + rawLogits[99] = 10.0 + rawLogits[98] = 9.5 + rawLogits[97] = 9.0 + rawLogits[96] = 8.5 + rawLogits[95] = 8.0 + + var sampledIndices = Set() + for _ in 0..<2_000 { + var logits = rawLogits + let token = CompositeSampler.sample( + from: &logits, config: .init(temperature: 1.0, minP: 0.1), using: &rng) + sampledIndices.insert(token) + } + + // With minP=0.1, only tokens whose probability is >= 10% of the max token's probability + // should be sampled. The -5.0 tokens should be far below this threshold. + for idx in sampledIndices { + #expect(idx >= 95, "minP=0.1 sampled low-probability token \(idx)") + } + } + + @Test("minP=1.0 keeps only the most probable token (equivalent to greedy)") + func minPOneIsGreedy() { + let vocab = 50 + var rng = Xoshiro256StarStar(seed: 0xCCCC_DDDD) + + var rawLogits = [Float](repeating: 0, count: vocab) + rawLogits[42] = 5.0 + rawLogits[10] = 4.9 + + var allSame = true + for _ in 0..<100 { + var logits = rawLogits + let token = CompositeSampler.sample( + from: &logits, config: .init(temperature: 1.0, minP: 1.0), using: &rng) + if token != 42 { + allSame = false + break + } + } + #expect(allSame, "minP=1.0 should always pick the top token") + } + + @Test("minP + topK combined: both filters apply") + func minPWithTopK() { + let vocab = 100 + var rng = Xoshiro256StarStar(seed: 0x1111_2222) + + // Spread: top 5 tokens have high logits, next 5 have medium, rest are low + var rawLogits = [Float](repeating: -10.0, count: vocab) + for i in 95..<100 { rawLogits[i] = 5.0 } // high + for i in 90..<95 { rawLogits[i] = 2.0 } // medium + + var sampledIndices = Set() + for _ in 0..<2_000 { + var logits = rawLogits + // topK=10 would include indices 90-99, but minP=0.3 should exclude + // the medium ones since exp(2-5)/exp(0) = exp(-3) ≈ 0.05 < 0.3 + let token = CompositeSampler.sample( + from: &logits, config: .init(temperature: 1.0, topK: 10, minP: 0.3), using: &rng) + sampledIndices.insert(token) + } + + // Only the top 5 (indices 95-99) should survive both filters + for idx in sampledIndices { + #expect(idx >= 95, "minP+topK: sampled \(idx) which should have been filtered") + } + } + + @Test("minP with guided generation mask: masked tokens never sampled") + func minPWithGGMask() { + let vocab = 256 + let allowed: Set = [10, 20, 30, 40, 50] + var rng = Xoshiro256StarStar(seed: 0x3333_4444) + + var counts = [Int](repeating: 0, count: vocab) + for _ in 0..<5_000 { + var logits = makeMaskedLogits(vocab: vocab, allowed: allowed, fill: 2.0, sentinel: -.infinity) + let token = CompositeSampler.sample( + from: &logits, config: .init(temperature: 1.0, minP: 0.05), using: &rng) + counts[Int(token)] += 1 + } + + for i in 0..() + for _ in 0..<30 { + await withCheckedContinuation { continuation in + sampler.encode( + to: queue, + logitsBuffer: logitsBuffer, + logitsOffset: 0, + outputBuffer: outputBuffer, + outputOffset: 0, + completion: { _ in continuation.resume() } + ) + } + let result = outputBuffer.contents().assumingMemoryBound(to: Int32.self).pointee + sampledTokens.insert(result) + } + + // All sampled tokens should be from the high-probability set + for token in sampledTokens { + #expect(highProbTokens.contains(Int(token)), "TopP sampled unexpected token \(token)") + } + } + + @Test("TopP=1.0 (disabled) behaves same as plain TopK") + func topPDisabledMatchesTopK() async throws { + let device = try #require(Self.device) + let sampler = try MPSGraphCompositeSampler( + device: device, vocabSize: Self.vocabSize, k: 40, temperature: 1.0, topP: 1.0, minP: 0.0) + + let logitsBuffer = try #require(device.makeBuffer(length: Self.vocabSize * 2, options: .storageModeShared)) + let outputBuffer = try #require(device.makeBuffer(length: 4, options: .storageModeShared)) + + let logitsPtr = logitsBuffer.contents().assumingMemoryBound(to: Float16.self) + for i in 0..= 10% of max prob + let sampler = try MPSGraphCompositeSampler( + device: device, vocabSize: Self.vocabSize, k: 1000, temperature: 1.0, topP: 1.0, minP: 0.1) + + let logitsBuffer = try #require(device.makeBuffer(length: Self.vocabSize * 2, options: .storageModeShared)) + let outputBuffer = try #require(device.makeBuffer(length: 4, options: .storageModeShared)) + + let logitsPtr = logitsBuffer.contents().assumingMemoryBound(to: Float16.self) + for i in 0..() + for _ in 0..<30 { + await withCheckedContinuation { continuation in + sampler.encode( + to: queue, + logitsBuffer: logitsBuffer, + logitsOffset: 0, + outputBuffer: outputBuffer, + outputOffset: 0, + completion: { _ in continuation.resume() } + ) + } + let result = outputBuffer.contents().assumingMemoryBound(to: Int32.self).pointee + sampledTokens.insert(result) + } + + // Token 300 should never appear (too low relative probability) + #expect(!sampledTokens.contains(300), "minP=0.1 should filter out token 300, but it was sampled") + // High probability tokens should appear + #expect(!sampledTokens.isEmpty, "Should have sampled some tokens") + } + + @Test("MinP=0.0 (disabled) allows all top-K tokens") + func minPDisabled() async throws { + let device = try #require(Self.device) + let sampler = try MPSGraphCompositeSampler( + device: device, vocabSize: Self.vocabSize, k: 40, temperature: 1.0, topP: 1.0, minP: 0.0) + + let logitsBuffer = try #require(device.makeBuffer(length: Self.vocabSize * 2, options: .storageModeShared)) + let outputBuffer = try #require(device.makeBuffer(length: 4, options: .storageModeShared)) + + let logitsPtr = logitsBuffer.contents().assumingMemoryBound(to: Float16.self) + for i in 0..() + let randomValues: [Float] = [0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95] + for r in randomValues { + sampler.testingOnlyRandomOverride = r + await withCheckedContinuation { continuation in + sampler.encode( + to: queue, + logitsBuffer: logitsBuffer, + logitsOffset: 0, + outputBuffer: outputBuffer, + outputOffset: 0, + completion: { _ in continuation.resume() } + ) + } + let result = outputBuffer.contents().assumingMemoryBound(to: Int32.self).pointee + sampledTokens.insert(result) + } + + // With minP=0.0, all tokens should be accessible + #expect(sampledTokens.count >= 3, "minP=0.0 should allow diverse sampling, got \(sampledTokens.count) unique") + for token in sampledTokens { + #expect(tokens.contains(Int(token)), "Unexpected token \(token)") + } + sampler.testingOnlyRandomOverride = nil + } + + @Test("Combined TopP + MinP: both filters apply") + func combinedTopPAndMinP() async throws { + let device = try #require(Self.device) + // topP=0.95 + minP=0.1 + let sampler = try MPSGraphCompositeSampler( + device: device, vocabSize: Self.vocabSize, k: 1000, temperature: 1.0, topP: 0.95, minP: 0.1) + + let logitsBuffer = try #require(device.makeBuffer(length: Self.vocabSize * 2, options: .storageModeShared)) + let outputBuffer = try #require(device.makeBuffer(length: 4, options: .storageModeShared)) + + let logitsPtr = logitsBuffer.contents().assumingMemoryBound(to: Float16.self) + for i in 0..() + for _ in 0..<30 { + await withCheckedContinuation { continuation in + sampler.encode( + to: queue, + logitsBuffer: logitsBuffer, + logitsOffset: 0, + outputBuffer: outputBuffer, + outputOffset: 0, + completion: { _ in continuation.resume() } + ) + } + let result = outputBuffer.contents().assumingMemoryBound(to: Int32.self).pointee + sampledTokens.insert(result) + } + + // Token 2000 should be filtered by minP + #expect(!sampledTokens.contains(2000), "Combined topP+minP should filter token 2000") + // Should only sample from the top cluster + for token in sampledTokens { + #expect([1000, 1001, 1002].contains(Int(token)), "Unexpected token \(token)") + } + } +} + +// MARK: - MPSGraph Sampler Factory Tests + +@Suite("MPSGraph Sampler Factory Tests", .enabled(if: !CIEnvironment.isVM)) +struct MPSGraphSamplerFactoryTests { + static let device: MTLDevice? = MTLCreateSystemDefaultDevice() + + @Test("Factory creates argmax sampler for temperature=0") + func factoryCreatesArgmax() throws { + let device = try #require(Self.device) + let config = SamplingConfiguration(temperature: 0) + let sampler = try MPSGraphSamplerFactory.makeSampler(device: device, vocabSize: 32000, config: config) + #expect(sampler is MPSGraphArgmaxSampler) + } + + @Test("Factory creates composite sampler for temperature>0") + func factoryCreatesComposite() throws { + let device = try #require(Self.device) + let config = SamplingConfiguration(temperature: 0.8, topK: 50, topP: 0.9, minP: 0.05) + let sampler = try MPSGraphSamplerFactory.makeSampler(device: device, vocabSize: 32000, config: config) + #expect(sampler is MPSGraphCompositeSampler) + let composite = sampler as! MPSGraphCompositeSampler + #expect(composite.k == 50) + #expect(composite.topP == 0.9) + #expect(composite.minP == 0.05) + } + + @Test("Factory uses K=1000 when only topP is set") + func factoryUsesLargeKForTopP() throws { + let device = try #require(Self.device) + let config = SamplingConfiguration(temperature: 0.8, topP: 0.9) + let sampler = try MPSGraphSamplerFactory.makeSampler(device: device, vocabSize: 32000, config: config) + let composite = sampler as! MPSGraphCompositeSampler + #expect(composite.k == 1000) + } + + @Test("Factory uses K=1000 when only minP is set") + func factoryUsesLargeKForMinP() throws { + let device = try #require(Self.device) + let config = SamplingConfiguration(temperature: 0.8, minP: 0.05) + let sampler = try MPSGraphSamplerFactory.makeSampler(device: device, vocabSize: 32000, config: config) + let composite = sampler as! MPSGraphCompositeSampler + #expect(composite.k == 1000) + } + + @Test("Factory uses K=40 for temperature-only") + func factoryUsesDefaultK() throws { + let device = try #require(Self.device) + let config = SamplingConfiguration(temperature: 0.8) + let sampler = try MPSGraphSamplerFactory.makeSampler(device: device, vocabSize: 32000, config: config) + let composite = sampler as! MPSGraphCompositeSampler + #expect(composite.k == 40) + } +} diff --git a/swift/Tests/LanguageModelsTests/SamplingConfigurationTests.swift b/swift/Tests/LanguageModelsTests/SamplingConfigurationTests.swift index 8301cd7..9affdf9 100644 --- a/swift/Tests/LanguageModelsTests/SamplingConfigurationTests.swift +++ b/swift/Tests/LanguageModelsTests/SamplingConfigurationTests.swift @@ -198,4 +198,54 @@ struct SamplingConfigurationTests { #expect(sampledTokens.contains(1)) #expect(!sampledTokens.contains(2)) } + + // MARK: - MinP Configuration Tests + + @Test("MinP configuration") + func minPConfig() { + let config = SamplingConfiguration(temperature: 0.9, minP: 0.05) + + #expect(config.temperature == 0.9) + #expect(config.minP == 0.05) + #expect(config.isComposite == true) + } + + @Test("MinP=nil means isComposite is false (temp only)") + func minPNilNotComposite() { + let config = SamplingConfiguration(temperature: 0.9) + + #expect(config.minP == nil) + #expect(config.isComposite == false) + } + + @Test("Normalized config removes minP for greedy") + func normalizedRemovesMinPForGreedy() { + let config = SamplingConfiguration(temperature: 0, minP: 0.1) + let normalized = config.normalized() + + #expect(normalized.minP == nil) + } + + @Test("MinP sampling excludes low relative probability tokens") + func minPSamplingExclusion() { + // Logits: [10.0, 9.5, 2.0] + // After temp=1 softmax: token 0 ≈ 0.62, token 1 ≈ 0.38, token 2 ≈ 0.0003 + // minP=0.1 threshold: 0.1 * 0.62 = 0.062 + // Token 2 (0.0003) < 0.062 → filtered out + let config = SamplingConfiguration(temperature: 1.0, minP: 0.1) + let logits: [Float16] = [10.0, 9.5, 2.0] + + var sampledTokens = Set() + for _ in 0..<100 { + var logitsCopy = logits + let token = config.fallbackSampler(from: &logitsCopy) + sampledTokens.insert(token) + } + + // Token 2 should never be sampled + #expect(!sampledTokens.contains(2), "minP should filter out token 2") + // Tokens 0 and 1 should both appear + #expect(sampledTokens.contains(0)) + #expect(sampledTokens.contains(1)) + } }