From e24135442b5286dcea486ec4aac43ff53dd27745 Mon Sep 17 00:00:00 2001 From: miro Date: Fri, 27 Mar 2026 11:04:07 +0000 Subject: [PATCH 1/8] feat: add Zipformer2 transducer support for Vosk/sherpa-onnx models Add AsrModelVersion.zipformer2 for icefall Zipformer2 transducer CoreML models. Key differences from Parakeet TDT: - Stateless decoder (context window of token IDs, no LSTM states) - Standard RNNT greedy decode (no duration prediction) - blank_id=0, 80 mel bins, encoder output shape [1,T,D] not [1,D,T] New files: - ZipformerRnntDecoder.swift: greedy RNNT decode with vDSP argmax Modified: - AsrModelVersion: add .zipformer2 with properties (blankId, contextSize, requiresMelInput, hasStatelessDecoder, melBins) - ModelNames: add Zipformer2 enum and Repo.zipformer2 - AsrManager: route .zipformer2 to ZipformerRnntDecoder - CLI switches: handle new enum case --- Sources/FluidAudio/ASR/AsrManager.swift | 10 ++ Sources/FluidAudio/ASR/AsrModels.swift | 41 +++++- .../ASR/TDT/ZipformerRnntDecoder.swift | 133 ++++++++++++++++++ Sources/FluidAudio/ModelNames.swift | 30 ++++ .../Commands/ASR/AsrBenchmark.swift | 1 + .../Commands/ASR/TranscribeCommand.swift | 2 + 6 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 503a494b3..2ca0221b0 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -385,6 +385,16 @@ public actor AsrManager { isLastChunk: isLastChunk, globalFrameOffset: globalFrameOffset ) + case .zipformer2: + let zipformerDecoder = ZipformerRnntDecoder(config: adaptedConfig) + return try zipformerDecoder.decode( + encoderOutput: encoderOutput, + encoderSequenceLength: encoderSequenceLength, + decoderModel: decoder_, + joinerModel: joint, + blankId: models.version.blankId, + contextSize: models.version.contextSize + ) } } diff --git a/Sources/FluidAudio/ASR/AsrModels.swift b/Sources/FluidAudio/ASR/AsrModels.swift index 67129c6bd..ddf6e6192 100644 --- a/Sources/FluidAudio/ASR/AsrModels.swift +++ b/Sources/FluidAudio/ASR/AsrModels.swift @@ -8,12 +8,15 @@ public enum AsrModelVersion: Sendable { case v3 /// 110M parameter hybrid TDT-CTC model with fused preprocessor+encoder case tdtCtc110m + /// Zipformer2 transducer (icefall/sherpa-onnx) with stateless decoder + case zipformer2 var repo: Repo { switch self { case .v2: return .parakeetV2 case .v3: return .parakeet case .tdtCtc110m: return .parakeetTdtCtc110m + case .zipformer2: return .zipformer2 } } @@ -25,10 +28,34 @@ public enum AsrModelVersion: Sendable { } } + /// Whether this model takes mel frames as input (true) or raw audio (false) + public var requiresMelInput: Bool { + switch self { + case .zipformer2: return true + default: return false + } + } + + /// Whether this model uses a stateless decoder (context window) vs stateful LSTM + public var hasStatelessDecoder: Bool { + switch self { + case .zipformer2: return true + default: return false + } + } + + /// Decoder context window size (for stateless decoders) + public var contextSize: Int { + switch self { + case .zipformer2: return 2 + default: return 0 // Not applicable for LSTM decoders + } + } + /// Encoder hidden dimension for this model version public var encoderHiddenSize: Int { switch self { - case .tdtCtc110m: return 512 + case .tdtCtc110m, .zipformer2: return 512 default: return 1024 } } @@ -38,6 +65,7 @@ public enum AsrModelVersion: Sendable { switch self { case .v2, .tdtCtc110m: return 1024 case .v3: return 8192 + case .zipformer2: return 0 } } @@ -45,9 +73,18 @@ public enum AsrModelVersion: Sendable { public var decoderLayers: Int { switch self { case .tdtCtc110m: return 1 + case .zipformer2: return 1 // Dummy, not used for stateless decoder default: return 2 } } + + /// Number of mel bins for the encoder input + public var melBins: Int { + switch self { + case .zipformer2: return 80 + default: return 128 + } + } } public struct AsrModels: Sendable { @@ -123,7 +160,7 @@ extension AsrModels { private static func inferredVersion(from directory: URL) -> AsrModelVersion? { let directoryPath = directory.path.lowercased() - let knownVersions: [AsrModelVersion] = [.tdtCtc110m, .v2, .v3] + let knownVersions: [AsrModelVersion] = [.zipformer2, .tdtCtc110m, .v2, .v3] for version in knownVersions { if directoryPath.contains(version.repo.folderName.lowercased()) { diff --git a/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift new file mode 100644 index 000000000..d799e1698 --- /dev/null +++ b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift @@ -0,0 +1,133 @@ +/// Greedy RNN-T decoder for Zipformer2 transducer models (icefall/sherpa-onnx). +/// +/// Unlike Parakeet TDT, Zipformer2 uses: +/// - **Stateless decoder**: context window of token IDs (no LSTM hidden/cell states) +/// - **Standard RNNT**: no duration prediction, advance one encoder frame per step +/// - **blank_id = 0**: first token in vocabulary is blank +/// +/// The decoder takes the last `contextSize` token IDs as input and produces +/// a decoder embedding. The joiner combines encoder + decoder embeddings to +/// produce logits over the vocabulary. + +import Accelerate +import CoreML +import Foundation +import OSLog + +internal struct ZipformerRnntDecoder { + + private let logger = AppLogger(category: "ZipformerRNNT") + private let config: ASRConfig + private let predictionOptions = AsrModels.optimizedPredictionOptions() + + init(config: ASRConfig) { + self.config = config + } + + /// Decode encoder output using greedy RNNT search. + /// + /// The encoder output shape is `[1, T, joinerDim]` (time-major, unlike Parakeet's `[1, D, T]`). + /// + /// - Parameters: + /// - encoderOutput: Encoder output, shape `[1, T, joinerDim]` + /// - encoderSequenceLength: Number of valid encoder frames + /// - decoderModel: Stateless decoder CoreML model + /// - joinerModel: Joiner CoreML model + /// - blankId: Blank token ID (typically 0 for Zipformer2) + /// - contextSize: Decoder context window size (typically 2) + /// - Returns: Decoded hypothesis with tokens, timestamps, and confidences + func decode( + encoderOutput: MLMultiArray, + encoderSequenceLength: Int, + decoderModel: MLModel, + joinerModel: MLModel, + blankId: Int, + contextSize: Int + ) throws -> TdtHypothesis { + let joinerDim = encoderOutput.shape[2].intValue + let maxSymbolsPerStep = 10 + + // Context buffer: last `contextSize` tokens, initialized with blank + var context = [Int](repeating: blankId, count: contextSize) + + // Use TdtHypothesis for compatibility with existing pipeline + // We create a dummy decoder state since Zipformer2 is stateless + var hypothesis = TdtDecoderState.make(decoderLayers: 1) + var result = TdtHypothesis(decState: hypothesis) + + // Precompute encoder strides for efficient frame extraction + // Shape: [1, T, joinerDim] + let encStride0 = encoderOutput.strides[0].intValue + let encStride1 = encoderOutput.strides[1].intValue + let encStride2 = encoderOutput.strides[2].intValue + let encPtr = encoderOutput.dataPointer.bindMemory( + to: Float.self, capacity: encoderOutput.count) + + // Preallocate reusable arrays + let encoderStep = try MLMultiArray( + shape: [1, NSNumber(value: joinerDim)], dataType: .float32) + let decoderInput = try MLMultiArray( + shape: [1, NSNumber(value: contextSize)], dataType: .int32) + + let encStepPtr = encoderStep.dataPointer.bindMemory( + to: Float.self, capacity: joinerDim) + + for t in 0.. [1, joinerDim] + for d in 0.. logits + let joinInput = try MLDictionaryFeatureProvider(dictionary: [ + "encoder_out": MLFeatureValue(multiArray: encoderStep), + "decoder_out": MLFeatureValue(multiArray: decoderOut), + ]) + let joinOutput = try joinerModel.prediction( + from: joinInput, options: predictionOptions) + let logits = joinOutput.featureValue(for: "logit")!.multiArrayValue! + + // Argmax over vocabulary using vDSP + let vocabSize = logits.shape.last!.intValue + let logitsPtr = logits.dataPointer.bindMemory( + to: Float.self, capacity: vocabSize) + var maxVal: Float = 0 + var maxIdx: vDSP_Length = 0 + vDSP_maxvi(logitsPtr, 1, &maxVal, &maxIdx, vDSP_Length(vocabSize)) + let tokenId = Int(maxIdx) + + if tokenId == blankId { + break // Move to next encoder frame + } + + // Emit token + result.ySequence.append(tokenId) + result.timestamps.append(t) + result.tokenDurations.append(1) + result.tokenConfidences.append(maxVal) + result.lastToken = tokenId + + // Update context: shift left, add new token + context.removeFirst() + context.append(tokenId) + symbolsEmitted += 1 + } + } + + return result + } +} diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index e243e62cf..a1e1e359d 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -23,6 +23,7 @@ public enum Repo: String, CaseIterable { case qwen3AsrInt8 = "FluidInference/qwen3-asr-0.6b-coreml/int8" case multilingualG2p = "FluidInference/charsiu-g2p-byt5-coreml" case parakeetTdtCtc110m = "FluidInference/parakeet-tdt-ctc-110m-coreml" + case zipformer2 = "FluidInference/sherpa-onnx-zipformer2-coreml" /// Repository slug (without owner) public var name: String { @@ -69,6 +70,8 @@ public enum Repo: String, CaseIterable { return "charsiu-g2p-byt5-coreml" case .parakeetTdtCtc110m: return "parakeet-tdt-ctc-110m-coreml" + case .zipformer2: + return "sherpa-onnx-zipformer2-coreml" } } @@ -91,6 +94,8 @@ public enum Repo: String, CaseIterable { return "FluidInference/qwen3-asr-0.6b-coreml" case .parakeetTdtCtc110m: return "FluidInference/parakeet-tdt-ctc-110m-coreml" + case .zipformer2: + return "FluidInference/sherpa-onnx-zipformer2-coreml" default: return "FluidInference/\(name)" } @@ -151,6 +156,8 @@ public enum Repo: String, CaseIterable { return "charsiu-g2p-byt5" case .parakeetTdtCtc110m: return "parakeet-tdt-ctc-110m" + case .zipformer2: + return "sherpa-onnx-zipformer2" default: return name } @@ -236,12 +243,33 @@ public enum ModelNames { switch repo { case .parakeetTdtCtc110m: return vocabularyFileArray + case .zipformer2: + return "vocab.json" default: return vocabularyFile } } } + /// Zipformer2 transducer model names (icefall/sherpa-onnx) + public enum Zipformer2 { + public static let encoder = "encoder" + public static let decoder = "decoder" + public static let joiner = "joiner" + public static let vocabulary = "vocab.json" + public static let metadata = "metadata.json" + + public static let encoderFile = encoder + ".mlpackage" + public static let decoderFile = decoder + ".mlpackage" + public static let joinerFile = joiner + ".mlpackage" + + public static let requiredModels: Set = [ + encoderFile, + decoderFile, + joinerFile, + ] + } + /// CTC model names public enum CTC { public static let melSpectrogram = "MelSpectrogram" @@ -642,6 +670,8 @@ public enum ModelNames { return ModelNames.Qwen3ASR.requiredModelsFull case .multilingualG2p: return ModelNames.MultilingualG2P.requiredModels + case .zipformer2: + return ModelNames.Zipformer2.requiredModels } } } diff --git a/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift b/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift index 1212df825..56995eb59 100644 --- a/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift @@ -842,6 +842,7 @@ extension ASRBenchmark { case .v2: versionLabel = "v2" case .v3: versionLabel = "v3" case .tdtCtc110m: versionLabel = "tdt-ctc-110m" + case .zipformer2: versionLabel = "zipformer2" } logger.info(" Model version: \(versionLabel)") logger.info(" Debug mode: \(debugMode ? "enabled" : "disabled")") diff --git a/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift b/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift index e970c8cb5..ed5bc09b6 100644 --- a/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift +++ b/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift @@ -411,6 +411,7 @@ enum TranscribeCommand { case .v2: modelVersionLabel = "v2" case .v3: modelVersionLabel = "v3" case .tdtCtc110m: modelVersionLabel = "tdt-ctc-110m" + case .zipformer2: modelVersionLabel = "zipformer2" } let output = TranscriptionJSONOutput( audioFile: audioFile, @@ -665,6 +666,7 @@ enum TranscribeCommand { case .v2: modelVersionLabel = "v2" case .v3: modelVersionLabel = "v3" case .tdtCtc110m: modelVersionLabel = "tdt-ctc-110m" + case .zipformer2: modelVersionLabel = "zipformer2" } let output = TranscriptionJSONOutput( audioFile: audioFile, From 39e4196a3e403eb7fcb9a8689b4580e3193ef519 Mon Sep 17 00:00:00 2001 From: miro Date: Fri, 27 Mar 2026 11:12:20 +0000 Subject: [PATCH 2/8] feat: wire up Zipformer2 mel extraction and end-to-end inference Complete the Zipformer2 integration so it works end-to-end: - AsrManager: initialize AudioMelSpectrogram (80-bin kaldi fbank, no preemphasis, periodic window) when model version requires mel input - AsrTranscription: add executeZipformerInference path that computes mel spectrogram from raw audio, feeds it to the encoder, then runs greedy RNNT decode - AsrModels: add loadZipformer2(from:) for loading models directly from a local directory (encoder/decoder/joiner.mlpackage + vocab.json) - Cleanup melSpectrogram on AsrManager.cleanup() --- Sources/FluidAudio/ASR/AsrManager.swift | 27 +++++- Sources/FluidAudio/ASR/AsrModels.swift | 63 +++++++++++++ Sources/FluidAudio/ASR/AsrTranscription.swift | 89 +++++++++++++++++++ 3 files changed, 178 insertions(+), 1 deletion(-) diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 2ca0221b0..77e492af2 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -19,6 +19,9 @@ public actor AsrManager { internal var decoderModel: MLModel? internal var jointModel: MLModel? + /// Mel spectrogram extractor for models that take mel frames as input (e.g. Zipformer2) + internal var melSpectrogram: AudioMelSpectrogram? + /// The AsrModels instance if initialized with models internal var asrModels: AsrModels? @@ -97,7 +100,10 @@ public actor AsrManager { let decoderReady = decoderModel != nil && jointModel != nil guard decoderReady else { return false } - if asrModels?.usesSplitFrontend == true { + if asrModels?.version.requiresMelInput == true { + // Zipformer2: encoder takes mel frames, no preprocessor needed + return preprocessorModel != nil && melSpectrogram != nil + } else if asrModels?.usesSplitFrontend == true { // Split frontend: need both preprocessor and encoder return preprocessorModel != nil && encoderModel != nil } else { @@ -123,6 +129,24 @@ public actor AsrManager { self.microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) self.systemDecoderState = TdtDecoderState.make(decoderLayers: layers) + // Initialize mel spectrogram for models that need external mel computation + if models.version.requiresMelInput { + // Kaldi-compatible fbank: 80 bins, no preemphasis, periodic Hann window + self.melSpectrogram = AudioMelSpectrogram( + sampleRate: 16000, + nMels: models.version.melBins, + nFFT: 512, + hopLength: 160, + winLength: 400, + preemph: 0.0, + logFloor: 1.0, + logFloorMode: .clamped, + windowPeriodic: true + ) + } else { + self.melSpectrogram = nil + } + logger.info("AsrManager initialized successfully with provided models") } @@ -319,6 +343,7 @@ public actor AsrManager { encoderModel = nil decoderModel = nil jointModel = nil + melSpectrogram = nil // Reset decoder states using fresh allocations for deterministic behavior microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) systemDecoderState = TdtDecoderState.make(decoderLayers: layers) diff --git a/Sources/FluidAudio/ASR/AsrModels.swift b/Sources/FluidAudio/ASR/AsrModels.swift index ddf6e6192..2f409c562 100644 --- a/Sources/FluidAudio/ASR/AsrModels.swift +++ b/Sources/FluidAudio/ASR/AsrModels.swift @@ -352,6 +352,69 @@ extension AsrModels { return try await load(from: targetDir, configuration: nil) } + /// Load Zipformer2 transducer models directly from a local directory. + /// + /// The directory should contain encoder.mlpackage, decoder.mlpackage, joiner.mlpackage, + /// vocab.json, and optionally metadata.json. + /// + /// - Parameters: + /// - directory: Directory containing the Zipformer2 model files + /// - configuration: Optional MLModel configuration + /// - Returns: Loaded AsrModels with version set to .zipformer2 + public static func loadZipformer2( + from directory: URL, + configuration: MLModelConfiguration? = nil + ) throws -> AsrModels { + let config = configuration ?? defaultConfiguration() + let fm = FileManager.default + + // Load encoder (serves as the "preprocessor" in the AsrModels struct) + let encoderPath = directory.appendingPathComponent(ModelNames.Zipformer2.encoderFile) + guard fm.fileExists(atPath: encoderPath.path) else { + throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.encoderFile, encoderPath) + } + let encoderModel = try MLModel(contentsOf: encoderPath, configuration: config) + + // Load decoder + let decoderPath = directory.appendingPathComponent(ModelNames.Zipformer2.decoderFile) + guard fm.fileExists(atPath: decoderPath.path) else { + throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.decoderFile, decoderPath) + } + let decoderModel = try MLModel(contentsOf: decoderPath, configuration: config) + + // Load joiner + let joinerPath = directory.appendingPathComponent(ModelNames.Zipformer2.joinerFile) + guard fm.fileExists(atPath: joinerPath.path) else { + throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.joinerFile, joinerPath) + } + let joinerModel = try MLModel(contentsOf: joinerPath, configuration: config) + + // Load vocabulary (JSON array format) + let vocabPath = directory.appendingPathComponent(ModelNames.Zipformer2.vocabulary) + guard fm.fileExists(atPath: vocabPath.path) else { + throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.vocabulary, vocabPath) + } + let vocabData = try Data(contentsOf: vocabPath) + guard let vocabArray = try JSONSerialization.jsonObject(with: vocabData) as? [String] else { + throw AsrModelsError.loadingFailed("Vocabulary file has unexpected format") + } + var vocabulary: [Int: String] = [:] + for (index, token) in vocabArray.enumerated() { + vocabulary[index] = token + } + + // The encoder serves as "preprocessor" in AsrModels (mel → encoder features) + return AsrModels( + encoder: nil, + preprocessor: encoderModel, + decoder: decoderModel, + joint: joinerModel, + configuration: config, + vocabulary: vocabulary, + version: .zipformer2 + ) + } + public static func defaultConfiguration() -> MLModelConfiguration { let config = MLModelConfiguration() config.allowLowPrecisionAccumulationOnGPU = true diff --git a/Sources/FluidAudio/ASR/AsrTranscription.swift b/Sources/FluidAudio/ASR/AsrTranscription.swift index 5574abfc8..cd3b80938 100644 --- a/Sources/FluidAudio/ASR/AsrTranscription.swift +++ b/Sources/FluidAudio/ASR/AsrTranscription.swift @@ -107,6 +107,19 @@ extension AsrManager { globalFrameOffset: Int = 0 ) async throws -> (hypothesis: TdtHypothesis, encoderSequenceLength: Int) { + // Zipformer2 path: compute mel → run encoder → decode + if let models = asrModels, models.version.requiresMelInput { + return try await executeZipformerInference( + paddedAudio, + originalLength: originalLength, + actualAudioFrames: actualAudioFrames, + decoderState: &decoderState, + contextFrameAdjustment: contextFrameAdjustment, + isLastChunk: isLastChunk, + globalFrameOffset: globalFrameOffset + ) + } + let preprocessorInput = try await preparePreprocessorInput( paddedAudio, actualLength: originalLength) @@ -600,4 +613,80 @@ extension AsrManager { } } + // MARK: - Zipformer2 Inference + + /// Execute inference for Zipformer2 models: mel extraction → encoder → RNNT decode. + /// + /// The Zipformer2 encoder takes mel spectrogram frames `[1, T, 80]` as input + /// (unlike Parakeet which takes raw audio). This method computes the mel features + /// in Swift, passes them through the CoreML encoder, then runs greedy RNNT decoding. + internal func executeZipformerInference( + _ audioSamples: [Float], + originalLength: Int?, + actualAudioFrames: Int?, + decoderState: inout TdtDecoderState, + contextFrameAdjustment: Int, + isLastChunk: Bool, + globalFrameOffset: Int + ) async throws -> (hypothesis: TdtHypothesis, encoderSequenceLength: Int) { + guard let melSpec = melSpectrogram, + let encoderModel = preprocessorModel, // For Zipformer2, preprocessor IS the encoder + let models = asrModels + else { + throw ASRError.notInitialized + } + + // Step 1: Compute mel spectrogram from audio samples + // Zipformer2 uses kaldi-style fbank: 80 bins, no preemphasis, periodic window + let melResult = melSpec.computeFlatTransposed(audio: audioSamples) + let melFrames = melResult.numFrames + let melBins = models.version.melBins + + // Step 2: Build encoder input as MLMultiArray [1, melFrames, melBins] + let melArray = try MLMultiArray( + shape: [1, NSNumber(value: melFrames), NSNumber(value: melBins)], + dataType: .float32 + ) + let melPtr = melArray.dataPointer.bindMemory(to: Float.self, capacity: melFrames * melBins) + melResult.mel.withUnsafeBufferPointer { srcPtr in + memcpy( + melPtr, srcPtr.baseAddress!, min(melResult.mel.count, melFrames * melBins) * MemoryLayout.size) + } + + let melLensArray = try MLMultiArray(shape: [1], dataType: .int32) + melLensArray[0] = NSNumber(value: Int32(melResult.melLength)) + + let encoderInput = try MLDictionaryFeatureProvider(dictionary: [ + "x": MLFeatureValue(multiArray: melArray), + "x_lens": MLFeatureValue(multiArray: melLensArray), + ]) + + // Step 3: Run encoder + try Task.checkCancellation() + let encoderOutput = try await encoderModel.compatPrediction( + from: encoderInput, options: predictionOptions) + + let rawEncoderOutput = try extractFeatureValue( + from: encoderOutput, key: "encoder_out", + errorMessage: "Invalid Zipformer2 encoder output") + let encoderLens = try extractFeatureValue( + from: encoderOutput, key: "encoder_out_lens", + errorMessage: "Invalid Zipformer2 encoder output length") + + let encoderSequenceLength = encoderLens[0].intValue + + // Step 4: Run RNNT decode + let hypothesis = try await tdtDecodeWithTimings( + encoderOutput: rawEncoderOutput, + encoderSequenceLength: encoderSequenceLength, + actualAudioFrames: actualAudioFrames ?? encoderSequenceLength, + originalAudioSamples: audioSamples, + decoderState: &decoderState, + contextFrameAdjustment: contextFrameAdjustment, + isLastChunk: isLastChunk, + globalFrameOffset: globalFrameOffset + ) + + return (hypothesis, encoderSequenceLength) + } } From 5c0511196a3c3fa5daf81b8b260220072452d65b Mon Sep 17 00:00:00 2001 From: miro Date: Fri, 27 Mar 2026 18:15:52 +0000 Subject: [PATCH 3/8] fix: complete Zipformer2 end-to-end pipeline - Skip LSTM state initialization for stateless Zipformer2 decoder (initializeDecoderState was sending Parakeet-style h_in/c_in inputs to the Zipformer2 model which expects only 'y') - Auto-compile .mlpackage to .mlmodelc on first load - Use .all compute units for Zipformer2 (avoids slow ANE compilation) - Pad/truncate mel frames to encoder's fixed input size (1495 frames) - Add --model-version zipformer2 CLI option with --model-dir support Tested: swift run fluidaudiocli transcribe audio.wav --model-version zipformer2 --model-dir /path/to/vosk-0.62-atc-int8 --- Sources/FluidAudio/ASR/AsrManager.swift | 15 +++++ Sources/FluidAudio/ASR/AsrModels.swift | 55 ++++++++++++------- Sources/FluidAudio/ASR/AsrTranscription.swift | 30 +++++++--- .../Commands/ASR/TranscribeCommand.swift | 11 +++- 4 files changed, 82 insertions(+), 29 deletions(-) diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 77e492af2..6fff5412a 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -264,6 +264,21 @@ public actor AsrManager { } internal func initializeDecoderState(for source: AudioSource) async throws { + // Zipformer2 uses a stateless decoder — no LSTM state to initialize + if asrModels?.version.hasStatelessDecoder == true { + var state: TdtDecoderState + switch source { + case .microphone: state = microphoneDecoderState + case .system: state = systemDecoderState + } + state.reset() + switch source { + case .microphone: microphoneDecoderState = state + case .system: systemDecoderState = state + } + return + } + guard let decoderModel = decoderModel else { throw ASRError.notInitialized } diff --git a/Sources/FluidAudio/ASR/AsrModels.swift b/Sources/FluidAudio/ASR/AsrModels.swift index 2f409c562..2e4fd4b8f 100644 --- a/Sources/FluidAudio/ASR/AsrModels.swift +++ b/Sources/FluidAudio/ASR/AsrModels.swift @@ -354,8 +354,9 @@ extension AsrModels { /// Load Zipformer2 transducer models directly from a local directory. /// - /// The directory should contain encoder.mlpackage, decoder.mlpackage, joiner.mlpackage, - /// vocab.json, and optionally metadata.json. + /// The directory should contain encoder, decoder, and joiner models as either + /// `.mlpackage` (source) or `.mlmodelc` (compiled) format, plus vocab.json. + /// If `.mlpackage` files are found, they are compiled on the fly. /// /// - Parameters: /// - directory: Directory containing the Zipformer2 model files @@ -365,33 +366,47 @@ extension AsrModels { from directory: URL, configuration: MLModelConfiguration? = nil ) throws -> AsrModels { - let config = configuration ?? defaultConfiguration() - let fm = FileManager.default + let config: MLModelConfiguration + if let configuration { + config = configuration + } else { + // Zipformer2 models exported with iOS18 target work best with .all compute units. + // The default .cpuAndNeuralEngine triggers very slow ANE compilation for these models. + let c = MLModelConfiguration() + c.computeUnits = .all + config = c + } - // Load encoder (serves as the "preprocessor" in the AsrModels struct) - let encoderPath = directory.appendingPathComponent(ModelNames.Zipformer2.encoderFile) - guard fm.fileExists(atPath: encoderPath.path) else { - throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.encoderFile, encoderPath) + func loadModel(name: String, packageExt: String, compiledExt: String) throws -> MLModel { + let compiledPath = directory.appendingPathComponent(name + compiledExt) + if FileManager.default.fileExists(atPath: compiledPath.path) { + return try MLModel(contentsOf: compiledPath, configuration: config) + } + let packagePath = directory.appendingPathComponent(name + packageExt) + guard FileManager.default.fileExists(atPath: packagePath.path) else { + throw AsrModelsError.modelNotFound( + name + packageExt, packagePath) + } + // Compile .mlpackage → temporary .mlmodelc + let compiledURL = try MLModel.compileModel(at: packagePath) + return try MLModel(contentsOf: compiledURL, configuration: config) } - let encoderModel = try MLModel(contentsOf: encoderPath, configuration: config) + + // Load encoder (serves as the "preprocessor" in the AsrModels struct) + let encoderModel = try loadModel( + name: ModelNames.Zipformer2.encoder, packageExt: ".mlpackage", compiledExt: ".mlmodelc") // Load decoder - let decoderPath = directory.appendingPathComponent(ModelNames.Zipformer2.decoderFile) - guard fm.fileExists(atPath: decoderPath.path) else { - throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.decoderFile, decoderPath) - } - let decoderModel = try MLModel(contentsOf: decoderPath, configuration: config) + let decoderModel = try loadModel( + name: ModelNames.Zipformer2.decoder, packageExt: ".mlpackage", compiledExt: ".mlmodelc") // Load joiner - let joinerPath = directory.appendingPathComponent(ModelNames.Zipformer2.joinerFile) - guard fm.fileExists(atPath: joinerPath.path) else { - throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.joinerFile, joinerPath) - } - let joinerModel = try MLModel(contentsOf: joinerPath, configuration: config) + let joinerModel = try loadModel( + name: ModelNames.Zipformer2.joiner, packageExt: ".mlpackage", compiledExt: ".mlmodelc") // Load vocabulary (JSON array format) let vocabPath = directory.appendingPathComponent(ModelNames.Zipformer2.vocabulary) - guard fm.fileExists(atPath: vocabPath.path) else { + guard FileManager.default.fileExists(atPath: vocabPath.path) else { throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.vocabulary, vocabPath) } let vocabData = try Data(contentsOf: vocabPath) diff --git a/Sources/FluidAudio/ASR/AsrTranscription.swift b/Sources/FluidAudio/ASR/AsrTranscription.swift index cd3b80938..0b730f517 100644 --- a/Sources/FluidAudio/ASR/AsrTranscription.swift +++ b/Sources/FluidAudio/ASR/AsrTranscription.swift @@ -639,22 +639,38 @@ extension AsrManager { // Step 1: Compute mel spectrogram from audio samples // Zipformer2 uses kaldi-style fbank: 80 bins, no preemphasis, periodic window let melResult = melSpec.computeFlatTransposed(audio: audioSamples) - let melFrames = melResult.numFrames let melBins = models.version.melBins - // Step 2: Build encoder input as MLMultiArray [1, melFrames, melBins] + // The encoder has a fixed input size (mel_frames from conversion, default 1495). + // Read the expected size from the encoder model's input description. + let encoderInputDesc = encoderModel.modelDescription.inputDescriptionsByName["x"] + let expectedMelFrames: Int + if let constraint = encoderInputDesc?.multiArrayConstraint { + expectedMelFrames = constraint.shape[1].intValue // [1, T, 80] + } else { + expectedMelFrames = 1495 // Default from conversion script + } + + let actualMelLength = min(melResult.melLength, expectedMelFrames) + + // Step 2: Build encoder input as MLMultiArray [1, expectedMelFrames, melBins] + // Pad with zeros if audio is shorter, truncate if longer let melArray = try MLMultiArray( - shape: [1, NSNumber(value: melFrames), NSNumber(value: melBins)], + shape: [1, NSNumber(value: expectedMelFrames), NSNumber(value: melBins)], dataType: .float32 ) - let melPtr = melArray.dataPointer.bindMemory(to: Float.self, capacity: melFrames * melBins) + // Zero-initialize (handles padding automatically) + let totalMelElements = expectedMelFrames * melBins + let melPtr = melArray.dataPointer.bindMemory(to: Float.self, capacity: totalMelElements) + memset(melPtr, 0, totalMelElements * MemoryLayout.size) + // Copy actual mel data (may be shorter than expectedMelFrames) + let copyElements = min(melResult.mel.count, actualMelLength * melBins) melResult.mel.withUnsafeBufferPointer { srcPtr in - memcpy( - melPtr, srcPtr.baseAddress!, min(melResult.mel.count, melFrames * melBins) * MemoryLayout.size) + memcpy(melPtr, srcPtr.baseAddress!, copyElements * MemoryLayout.size) } let melLensArray = try MLMultiArray(shape: [1], dataType: .int32) - melLensArray[0] = NSNumber(value: Int32(melResult.melLength)) + melLensArray[0] = NSNumber(value: Int32(expectedMelFrames)) let encoderInput = try MLDictionaryFeatureProvider(dictionary: [ "x": MLFeatureValue(multiArray: melArray), diff --git a/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift b/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift index ed5bc09b6..22292cc2e 100644 --- a/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift +++ b/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift @@ -241,9 +241,12 @@ enum TranscribeCommand { modelVersion = .v3 case "tdt-ctc-110m", "110m": modelVersion = .tdtCtc110m + case "zipformer2", "zipformer": + modelVersion = .zipformer2 default: logger.error( - "Invalid model version: \(arguments[i + 1]). Use 'v2', 'v3', or 'tdt-ctc-110m'") + "Invalid model version: \(arguments[i + 1]). Use 'v2', 'v3', 'tdt-ctc-110m', or 'zipformer2'" + ) exit(1) } i += 1 @@ -290,7 +293,11 @@ enum TranscribeCommand { let models: AsrModels if let modelDir = modelDir { let dir = URL(fileURLWithPath: modelDir) - models = try await AsrModels.load(from: dir, version: modelVersion) + if modelVersion == .zipformer2 { + models = try AsrModels.loadZipformer2(from: dir) + } else { + models = try await AsrModels.load(from: dir, version: modelVersion) + } } else { models = try await AsrModels.downloadAndLoad(version: modelVersion) } From 2a751b645d9ec8794d4d586e3afc3c74066b5e75 Mon Sep 17 00:00:00 2001 From: miro Date: Fri, 27 Mar 2026 18:39:03 +0000 Subject: [PATCH 4/8] =?UTF-8?q?feat:=20support=20fused=20Zipformer2=20prep?= =?UTF-8?q?rocessor=20(audio=20=E2=86=92=20encoder=20features)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When Zipformer2 models are exported with --fuse-mel, the Preprocessor takes raw audio (1, 239120) like Parakeet — no external mel needed. - loadZipformer2: auto-detect fused (Preprocessor.mlpackage) vs separate encoder, set hasFusedMel flag - AsrModels.hasFusedMel: controls whether mel extraction is external - usesSplitFrontend: returns false for fused Zipformer2 - Dynamic encoder output key resolution (encoder/encoder_out) - Version-specific maxAudioSamples (239120 for Zipformer2) Tested with real audio via CLI: swift run fluidaudiocli transcribe audio.wav \ --model-version zipformer2 \ --model-dir /path/to/vosk-0.62-atc-fused --- Sources/FluidAudio/ASR/AsrManager.swift | 12 +++-- Sources/FluidAudio/ASR/AsrModels.swift | 46 +++++++++++++++---- Sources/FluidAudio/ASR/AsrTranscription.swift | 26 +++++++---- 3 files changed, 64 insertions(+), 20 deletions(-) diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 6fff5412a..99648ee37 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -33,6 +33,11 @@ public actor AsrManager { return asrModels?.version.decoderLayers ?? 2 } + /// Get the max audio samples for the current model's preprocessor. + internal func getMaxModelSamples() -> Int { + return asrModels?.version.maxAudioSamples ?? ASRConstants.maxModelSamples + } + /// Token duration optimization model /// Cached vocabulary loaded once during initialization @@ -100,8 +105,8 @@ public actor AsrManager { let decoderReady = decoderModel != nil && jointModel != nil guard decoderReady else { return false } - if asrModels?.version.requiresMelInput == true { - // Zipformer2: encoder takes mel frames, no preprocessor needed + if asrModels?.version.requiresMelInput == true && asrModels?.hasFusedMel != true { + // Non-fused Zipformer2: needs mel spectrogram extractor return preprocessorModel != nil && melSpectrogram != nil } else if asrModels?.usesSplitFrontend == true { // Split frontend: need both preprocessor and encoder @@ -130,7 +135,8 @@ public actor AsrManager { self.systemDecoderState = TdtDecoderState.make(decoderLayers: layers) // Initialize mel spectrogram for models that need external mel computation - if models.version.requiresMelInput { + // (non-fused Zipformer2 only; fused models handle mel internally) + if models.version.requiresMelInput && !models.hasFusedMel { // Kaldi-compatible fbank: 80 bins, no preemphasis, periodic Hann window self.melSpectrogram = AudioMelSpectrogram( sampleRate: 16000, diff --git a/Sources/FluidAudio/ASR/AsrModels.swift b/Sources/FluidAudio/ASR/AsrModels.swift index 2e4fd4b8f..ee002934f 100644 --- a/Sources/FluidAudio/ASR/AsrModels.swift +++ b/Sources/FluidAudio/ASR/AsrModels.swift @@ -85,6 +85,16 @@ public enum AsrModelVersion: Sendable { default: return 128 } } + + /// Maximum audio samples for the fused preprocessor input. + /// Zipformer2 uses 239120 (produces 1495 mel frames for encoder compatibility). + /// Parakeet uses 240000 (15s at 16kHz). + public var maxAudioSamples: Int { + switch self { + case .zipformer2: return 239_120 + default: return 240_000 + } + } } public struct AsrModels: Sendable { @@ -101,6 +111,11 @@ public struct AsrModels: Sendable { public let vocabulary: [Int: String] public let version: AsrModelVersion + /// Whether the preprocessor has fused mel extraction (takes raw audio, not mel frames). + /// When true, the Zipformer2 model works like Parakeet: audio_signal → encoder features. + /// When false, external mel computation is required before feeding the encoder. + public let hasFusedMel: Bool + private static let logger = AppLogger(category: "AsrModels") public init( @@ -110,7 +125,8 @@ public struct AsrModels: Sendable { joint: MLModel, configuration: MLModelConfiguration, vocabulary: [Int: String], - version: AsrModelVersion + version: AsrModelVersion, + hasFusedMel: Bool = false ) { self.encoder = encoder self.preprocessor = preprocessor @@ -119,11 +135,12 @@ public struct AsrModels: Sendable { self.configuration = configuration self.vocabulary = vocabulary self.version = version + self.hasFusedMel = hasFusedMel } - /// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m fused) + /// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m/zipformer2 fused) public var usesSplitFrontend: Bool { - !version.hasFusedEncoder + !version.hasFusedEncoder && !hasFusedMel } } @@ -392,9 +409,21 @@ extension AsrModels { return try MLModel(contentsOf: compiledURL, configuration: config) } - // Load encoder (serves as the "preprocessor" in the AsrModels struct) - let encoderModel = try loadModel( - name: ModelNames.Zipformer2.encoder, packageExt: ".mlpackage", compiledExt: ".mlmodelc") + // Check for fused Preprocessor first (--fuse-mel export), then fall back to separate encoder + let isFused = + FileManager.default.fileExists( + atPath: directory.appendingPathComponent("Preprocessor.mlpackage").path) + || FileManager.default.fileExists( + atPath: directory.appendingPathComponent("Preprocessor.mlmodelc").path) + + let encoderModel: MLModel + if isFused { + encoderModel = try loadModel( + name: "Preprocessor", packageExt: ".mlpackage", compiledExt: ".mlmodelc") + } else { + encoderModel = try loadModel( + name: ModelNames.Zipformer2.encoder, packageExt: ".mlpackage", compiledExt: ".mlmodelc") + } // Load decoder let decoderModel = try loadModel( @@ -418,7 +447,7 @@ extension AsrModels { vocabulary[index] = token } - // The encoder serves as "preprocessor" in AsrModels (mel → encoder features) + // The encoder/preprocessor serves as "preprocessor" in AsrModels return AsrModels( encoder: nil, preprocessor: encoderModel, @@ -426,7 +455,8 @@ extension AsrModels { joint: joinerModel, configuration: config, vocabulary: vocabulary, - version: .zipformer2 + version: .zipformer2, + hasFusedMel: isFused ) } diff --git a/Sources/FluidAudio/ASR/AsrTranscription.swift b/Sources/FluidAudio/ASR/AsrTranscription.swift index 0b730f517..2f63f14ce 100644 --- a/Sources/FluidAudio/ASR/AsrTranscription.swift +++ b/Sources/FluidAudio/ASR/AsrTranscription.swift @@ -22,21 +22,22 @@ extension AsrManager { } // Route to appropriate processing method based on audio length - if audioSamples.count <= ASRConstants.maxModelSamples { + let maxSamples = getMaxModelSamples() + if audioSamples.count <= maxSamples { let originalLength = audioSamples.count let frameAlignedCandidate = ((originalLength + ASRConstants.samplesPerEncoderFrame - 1) / ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame let frameAlignedLength: Int let alignedSamples: [Float] - if frameAlignedCandidate > originalLength && frameAlignedCandidate <= ASRConstants.maxModelSamples { + if frameAlignedCandidate > originalLength && frameAlignedCandidate <= maxSamples { frameAlignedLength = frameAlignedCandidate alignedSamples = audioSamples + Array(repeating: 0, count: frameAlignedLength - originalLength) } else { frameAlignedLength = originalLength alignedSamples = audioSamples } - let paddedAudio: [Float] = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples) + let paddedAudio: [Float] = padAudioIfNeeded(alignedSamples, targetLength: maxSamples) let (hypothesis, encoderSequenceLength) = try await executeMLInferenceWithTimings( paddedAudio, originalLength: frameAlignedLength, @@ -107,8 +108,9 @@ extension AsrManager { globalFrameOffset: Int = 0 ) async throws -> (hypothesis: TdtHypothesis, encoderSequenceLength: Int) { - // Zipformer2 path: compute mel → run encoder → decode - if let models = asrModels, models.version.requiresMelInput { + // Zipformer2 with separate mel path (non-fused): compute mel → run encoder → decode + // For fused Zipformer2, fall through to the normal Parakeet path below (same audio_signal interface) + if let models = asrModels, models.version.requiresMelInput, !models.hasFusedMel { return try await executeZipformerInference( paddedAudio, originalLength: originalLength, @@ -155,10 +157,15 @@ extension AsrManager { encoderOutputProvider = preprocessorOutput } + // Parakeet uses "encoder"/"encoder_length", Zipformer2 fused uses "encoder_out"/"encoder_out_lens" + let encoderKey = encoderOutputProvider.featureValue(for: "encoder") != nil ? "encoder" : "encoder_out" + let lengthKey = + encoderOutputProvider.featureValue(for: "encoder_length") != nil + ? "encoder_length" : "encoder_out_lens" let rawEncoderOutput = try extractFeatureValue( - from: encoderOutputProvider, key: "encoder", errorMessage: "Invalid encoder output") + from: encoderOutputProvider, key: encoderKey, errorMessage: "Invalid encoder output") let encoderLength = try extractFeatureValue( - from: encoderOutputProvider, key: "encoder_length", + from: encoderOutputProvider, key: lengthKey, errorMessage: "Invalid encoder output length") let encoderSequenceLength = encoderLength[0].intValue @@ -240,6 +247,7 @@ extension AsrManager { ) async throws -> (tokens: [Int], timestamps: [Int], confidences: [Float], encoderSequenceLength: Int) { // Select and copy decoder state for the source var state = (source == .microphone) ? microphoneDecoderState : systemDecoderState + let maxSamples = getMaxModelSamples() let originalLength = chunkSamples.count let frameAlignedCandidate = @@ -249,7 +257,7 @@ extension AsrManager { let alignedSamples: [Float] if previousTokens.isEmpty && frameAlignedCandidate > originalLength - && frameAlignedCandidate <= ASRConstants.maxModelSamples + && frameAlignedCandidate <= maxSamples { frameAlignedLength = frameAlignedCandidate alignedSamples = chunkSamples + Array(repeating: 0, count: frameAlignedLength - originalLength) @@ -257,7 +265,7 @@ extension AsrManager { frameAlignedLength = originalLength alignedSamples = chunkSamples } - let padded = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples) + let padded = padAudioIfNeeded(alignedSamples, targetLength: maxSamples) let (hypothesis, encLen) = try await executeMLInferenceWithTimings( padded, originalLength: frameAlignedLength, From 382bb95bcd5211ceabd9c06dff9405434152577d Mon Sep 17 00:00:00 2001 From: miro Date: Fri, 27 Mar 2026 18:53:36 +0000 Subject: [PATCH 5/8] refactor: simplify Zipformer2 to require fused preprocessor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove non-fused mel path — Zipformer2 now requires fused Preprocessor.mlpackage (audio_signal → encoder features), same interface as Parakeet tdtCtc110m. Removed: - requiresMelInput, hasFusedMel, melBins properties - melSpectrogram instance and init/cleanup on AsrManager - executeZipformerInference method (80+ lines) - Complex isAvailable branching for mel mode Zipformer2 is now just another fused encoder variant: hasFusedEncoder=true, same preprocessor flow as Parakeet. --- Sources/FluidAudio/ASR/AsrManager.swift | 28 +---- Sources/FluidAudio/ASR/AsrModels.swift | 52 ++------- Sources/FluidAudio/ASR/AsrTranscription.swift | 106 ------------------ 3 files changed, 9 insertions(+), 177 deletions(-) diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 99648ee37..962b307e6 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -19,9 +19,6 @@ public actor AsrManager { internal var decoderModel: MLModel? internal var jointModel: MLModel? - /// Mel spectrogram extractor for models that take mel frames as input (e.g. Zipformer2) - internal var melSpectrogram: AudioMelSpectrogram? - /// The AsrModels instance if initialized with models internal var asrModels: AsrModels? @@ -105,10 +102,7 @@ public actor AsrManager { let decoderReady = decoderModel != nil && jointModel != nil guard decoderReady else { return false } - if asrModels?.version.requiresMelInput == true && asrModels?.hasFusedMel != true { - // Non-fused Zipformer2: needs mel spectrogram extractor - return preprocessorModel != nil && melSpectrogram != nil - } else if asrModels?.usesSplitFrontend == true { + if asrModels?.usesSplitFrontend == true { // Split frontend: need both preprocessor and encoder return preprocessorModel != nil && encoderModel != nil } else { @@ -134,25 +128,6 @@ public actor AsrManager { self.microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) self.systemDecoderState = TdtDecoderState.make(decoderLayers: layers) - // Initialize mel spectrogram for models that need external mel computation - // (non-fused Zipformer2 only; fused models handle mel internally) - if models.version.requiresMelInput && !models.hasFusedMel { - // Kaldi-compatible fbank: 80 bins, no preemphasis, periodic Hann window - self.melSpectrogram = AudioMelSpectrogram( - sampleRate: 16000, - nMels: models.version.melBins, - nFFT: 512, - hopLength: 160, - winLength: 400, - preemph: 0.0, - logFloor: 1.0, - logFloorMode: .clamped, - windowPeriodic: true - ) - } else { - self.melSpectrogram = nil - } - logger.info("AsrManager initialized successfully with provided models") } @@ -364,7 +339,6 @@ public actor AsrManager { encoderModel = nil decoderModel = nil jointModel = nil - melSpectrogram = nil // Reset decoder states using fresh allocations for deterministic behavior microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) systemDecoderState = TdtDecoderState.make(decoderLayers: layers) diff --git a/Sources/FluidAudio/ASR/AsrModels.swift b/Sources/FluidAudio/ASR/AsrModels.swift index ee002934f..18b73fca8 100644 --- a/Sources/FluidAudio/ASR/AsrModels.swift +++ b/Sources/FluidAudio/ASR/AsrModels.swift @@ -23,15 +23,7 @@ public enum AsrModelVersion: Sendable { /// Whether this model version uses a fused preprocessor+encoder (no separate Encoder model) public var hasFusedEncoder: Bool { switch self { - case .tdtCtc110m: return true - default: return false - } - } - - /// Whether this model takes mel frames as input (true) or raw audio (false) - public var requiresMelInput: Bool { - switch self { - case .zipformer2: return true + case .tdtCtc110m, .zipformer2: return true default: return false } } @@ -78,14 +70,6 @@ public enum AsrModelVersion: Sendable { } } - /// Number of mel bins for the encoder input - public var melBins: Int { - switch self { - case .zipformer2: return 80 - default: return 128 - } - } - /// Maximum audio samples for the fused preprocessor input. /// Zipformer2 uses 239120 (produces 1495 mel frames for encoder compatibility). /// Parakeet uses 240000 (15s at 16kHz). @@ -111,11 +95,6 @@ public struct AsrModels: Sendable { public let vocabulary: [Int: String] public let version: AsrModelVersion - /// Whether the preprocessor has fused mel extraction (takes raw audio, not mel frames). - /// When true, the Zipformer2 model works like Parakeet: audio_signal → encoder features. - /// When false, external mel computation is required before feeding the encoder. - public let hasFusedMel: Bool - private static let logger = AppLogger(category: "AsrModels") public init( @@ -125,8 +104,7 @@ public struct AsrModels: Sendable { joint: MLModel, configuration: MLModelConfiguration, vocabulary: [Int: String], - version: AsrModelVersion, - hasFusedMel: Bool = false + version: AsrModelVersion ) { self.encoder = encoder self.preprocessor = preprocessor @@ -135,12 +113,11 @@ public struct AsrModels: Sendable { self.configuration = configuration self.vocabulary = vocabulary self.version = version - self.hasFusedMel = hasFusedMel } - /// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m/zipformer2 fused) + /// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m/zipformer2) public var usesSplitFrontend: Bool { - !version.hasFusedEncoder && !hasFusedMel + !version.hasFusedEncoder } } @@ -409,21 +386,9 @@ extension AsrModels { return try MLModel(contentsOf: compiledURL, configuration: config) } - // Check for fused Preprocessor first (--fuse-mel export), then fall back to separate encoder - let isFused = - FileManager.default.fileExists( - atPath: directory.appendingPathComponent("Preprocessor.mlpackage").path) - || FileManager.default.fileExists( - atPath: directory.appendingPathComponent("Preprocessor.mlmodelc").path) - - let encoderModel: MLModel - if isFused { - encoderModel = try loadModel( - name: "Preprocessor", packageExt: ".mlpackage", compiledExt: ".mlmodelc") - } else { - encoderModel = try loadModel( - name: ModelNames.Zipformer2.encoder, packageExt: ".mlpackage", compiledExt: ".mlmodelc") - } + // Zipformer2 requires fused Preprocessor (audio → encoder features) + let encoderModel = try loadModel( + name: "Preprocessor", packageExt: ".mlpackage", compiledExt: ".mlmodelc") // Load decoder let decoderModel = try loadModel( @@ -455,8 +420,7 @@ extension AsrModels { joint: joinerModel, configuration: config, vocabulary: vocabulary, - version: .zipformer2, - hasFusedMel: isFused + version: .zipformer2 ) } diff --git a/Sources/FluidAudio/ASR/AsrTranscription.swift b/Sources/FluidAudio/ASR/AsrTranscription.swift index 2f63f14ce..0a1568e89 100644 --- a/Sources/FluidAudio/ASR/AsrTranscription.swift +++ b/Sources/FluidAudio/ASR/AsrTranscription.swift @@ -108,20 +108,6 @@ extension AsrManager { globalFrameOffset: Int = 0 ) async throws -> (hypothesis: TdtHypothesis, encoderSequenceLength: Int) { - // Zipformer2 with separate mel path (non-fused): compute mel → run encoder → decode - // For fused Zipformer2, fall through to the normal Parakeet path below (same audio_signal interface) - if let models = asrModels, models.version.requiresMelInput, !models.hasFusedMel { - return try await executeZipformerInference( - paddedAudio, - originalLength: originalLength, - actualAudioFrames: actualAudioFrames, - decoderState: &decoderState, - contextFrameAdjustment: contextFrameAdjustment, - isLastChunk: isLastChunk, - globalFrameOffset: globalFrameOffset - ) - } - let preprocessorInput = try await preparePreprocessorInput( paddedAudio, actualLength: originalLength) @@ -621,96 +607,4 @@ extension AsrManager { } } - // MARK: - Zipformer2 Inference - - /// Execute inference for Zipformer2 models: mel extraction → encoder → RNNT decode. - /// - /// The Zipformer2 encoder takes mel spectrogram frames `[1, T, 80]` as input - /// (unlike Parakeet which takes raw audio). This method computes the mel features - /// in Swift, passes them through the CoreML encoder, then runs greedy RNNT decoding. - internal func executeZipformerInference( - _ audioSamples: [Float], - originalLength: Int?, - actualAudioFrames: Int?, - decoderState: inout TdtDecoderState, - contextFrameAdjustment: Int, - isLastChunk: Bool, - globalFrameOffset: Int - ) async throws -> (hypothesis: TdtHypothesis, encoderSequenceLength: Int) { - guard let melSpec = melSpectrogram, - let encoderModel = preprocessorModel, // For Zipformer2, preprocessor IS the encoder - let models = asrModels - else { - throw ASRError.notInitialized - } - - // Step 1: Compute mel spectrogram from audio samples - // Zipformer2 uses kaldi-style fbank: 80 bins, no preemphasis, periodic window - let melResult = melSpec.computeFlatTransposed(audio: audioSamples) - let melBins = models.version.melBins - - // The encoder has a fixed input size (mel_frames from conversion, default 1495). - // Read the expected size from the encoder model's input description. - let encoderInputDesc = encoderModel.modelDescription.inputDescriptionsByName["x"] - let expectedMelFrames: Int - if let constraint = encoderInputDesc?.multiArrayConstraint { - expectedMelFrames = constraint.shape[1].intValue // [1, T, 80] - } else { - expectedMelFrames = 1495 // Default from conversion script - } - - let actualMelLength = min(melResult.melLength, expectedMelFrames) - - // Step 2: Build encoder input as MLMultiArray [1, expectedMelFrames, melBins] - // Pad with zeros if audio is shorter, truncate if longer - let melArray = try MLMultiArray( - shape: [1, NSNumber(value: expectedMelFrames), NSNumber(value: melBins)], - dataType: .float32 - ) - // Zero-initialize (handles padding automatically) - let totalMelElements = expectedMelFrames * melBins - let melPtr = melArray.dataPointer.bindMemory(to: Float.self, capacity: totalMelElements) - memset(melPtr, 0, totalMelElements * MemoryLayout.size) - // Copy actual mel data (may be shorter than expectedMelFrames) - let copyElements = min(melResult.mel.count, actualMelLength * melBins) - melResult.mel.withUnsafeBufferPointer { srcPtr in - memcpy(melPtr, srcPtr.baseAddress!, copyElements * MemoryLayout.size) - } - - let melLensArray = try MLMultiArray(shape: [1], dataType: .int32) - melLensArray[0] = NSNumber(value: Int32(expectedMelFrames)) - - let encoderInput = try MLDictionaryFeatureProvider(dictionary: [ - "x": MLFeatureValue(multiArray: melArray), - "x_lens": MLFeatureValue(multiArray: melLensArray), - ]) - - // Step 3: Run encoder - try Task.checkCancellation() - let encoderOutput = try await encoderModel.compatPrediction( - from: encoderInput, options: predictionOptions) - - let rawEncoderOutput = try extractFeatureValue( - from: encoderOutput, key: "encoder_out", - errorMessage: "Invalid Zipformer2 encoder output") - let encoderLens = try extractFeatureValue( - from: encoderOutput, key: "encoder_out_lens", - errorMessage: "Invalid Zipformer2 encoder output length") - - let encoderSequenceLength = encoderLens[0].intValue - - // Step 4: Run RNNT decode - let hypothesis = try await tdtDecodeWithTimings( - encoderOutput: rawEncoderOutput, - encoderSequenceLength: encoderSequenceLength, - actualAudioFrames: actualAudioFrames ?? encoderSequenceLength, - originalAudioSamples: audioSamples, - decoderState: &decoderState, - contextFrameAdjustment: contextFrameAdjustment, - isLastChunk: isLastChunk, - globalFrameOffset: globalFrameOffset - ) - - return (hypothesis, encoderSequenceLength) - } } From f1b39bf857e7a2131c38d4c00d714ce544a3fd7d Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 28 Mar 2026 02:48:04 +0000 Subject: [PATCH 6/8] fix: correct Zipformer2 frame count and greedy decode - Compute valid encoder frames from audio length for traced models (encoder_out_lens is a traced constant, not actual frame count) - Switch to single-token-per-frame greedy decode matching Python reference (prevents token repetition loops) Co-Authored-By: Claude Opus 4.6 (1M context) --- Sources/FluidAudio/ASR/AsrTranscription.swift | 11 ++- .../ASR/TDT/ZipformerRnntDecoder.swift | 68 +++++++++---------- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/Sources/FluidAudio/ASR/AsrTranscription.swift b/Sources/FluidAudio/ASR/AsrTranscription.swift index 0a1568e89..38e1849a4 100644 --- a/Sources/FluidAudio/ASR/AsrTranscription.swift +++ b/Sources/FluidAudio/ASR/AsrTranscription.swift @@ -154,7 +154,16 @@ extension AsrManager { from: encoderOutputProvider, key: lengthKey, errorMessage: "Invalid encoder output length") - let encoderSequenceLength = encoderLength[0].intValue + var encoderSequenceLength = encoderLength[0].intValue + + // For Zipformer2 fused models, encoder_out_lens is a traced constant (always max frames). + // Compute the actual valid frame count from the audio length instead. + // Formula: mel_frames = (samples - 200) / 160 + 1, encoder_frames = (mel_frames - 7) / 4 + if asrModels?.version == .zipformer2, let actualLength = originalLength { + let melFrames = max(1, (actualLength - 200) / 160 + 1) + let validFrames = max(1, (melFrames - 7) / 4) + encoderSequenceLength = min(encoderSequenceLength, validFrames) + } // Calculate actual audio frames if not provided using shared constants let actualFrames = diff --git a/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift index d799e1698..22656bbb1 100644 --- a/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift +++ b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift @@ -45,7 +45,6 @@ internal struct ZipformerRnntDecoder { contextSize: Int ) throws -> TdtHypothesis { let joinerDim = encoderOutput.shape[2].intValue - let maxSymbolsPerStep = 10 // Context buffer: last `contextSize` tokens, initialized with blank var context = [Int](repeating: blankId, count: contextSize) @@ -78,42 +77,38 @@ internal struct ZipformerRnntDecoder { encStepPtr[d] = encPtr[0 * encStride0 + t * encStride1 + d * encStride2] } - var symbolsEmitted = 0 - while symbolsEmitted < maxSymbolsPerStep { - // Run stateless decoder with context tokens - for i in 0.. logits - let joinInput = try MLDictionaryFeatureProvider(dictionary: [ - "encoder_out": MLFeatureValue(multiArray: encoderStep), - "decoder_out": MLFeatureValue(multiArray: decoderOut), - ]) - let joinOutput = try joinerModel.prediction( - from: joinInput, options: predictionOptions) - let logits = joinOutput.featureValue(for: "logit")!.multiArrayValue! - - // Argmax over vocabulary using vDSP - let vocabSize = logits.shape.last!.intValue - let logitsPtr = logits.dataPointer.bindMemory( - to: Float.self, capacity: vocabSize) - var maxVal: Float = 0 - var maxIdx: vDSP_Length = 0 - vDSP_maxvi(logitsPtr, 1, &maxVal, &maxIdx, vDSP_Length(vocabSize)) - let tokenId = Int(maxIdx) - - if tokenId == blankId { - break // Move to next encoder frame - } + // One prediction per encoder frame (matches Python reference decoder). + // Run stateless decoder with context tokens + for i in 0.. logits + let joinInput = try MLDictionaryFeatureProvider(dictionary: [ + "encoder_out": MLFeatureValue(multiArray: encoderStep), + "decoder_out": MLFeatureValue(multiArray: decoderOut), + ]) + let joinOutput = try joinerModel.prediction( + from: joinInput, options: predictionOptions) + let logits = joinOutput.featureValue(for: "logit")!.multiArrayValue! + + // Argmax over vocabulary using vDSP + let vocabSize = logits.shape.last!.intValue + let logitsPtr = logits.dataPointer.bindMemory( + to: Float.self, capacity: vocabSize) + var maxVal: Float = 0 + var maxIdx: vDSP_Length = 0 + vDSP_maxvi(logitsPtr, 1, &maxVal, &maxIdx, vDSP_Length(vocabSize)) + let tokenId = Int(maxIdx) + + if tokenId != blankId { // Emit token result.ySequence.append(tokenId) result.timestamps.append(t) @@ -124,7 +119,6 @@ internal struct ZipformerRnntDecoder { // Update context: shift left, add new token context.removeFirst() context.append(tokenId) - symbolsEmitted += 1 } } From 39dce8c40d46445502c19a599fe24c479f1cbc35 Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 28 Mar 2026 14:06:39 +0000 Subject: [PATCH 7/8] feat: add RNNT modified beam search with ARPA LM support - ZipformerRnntDecoder: add beamDecode() with configurable beam width, LM weight, and top-K token candidates per frame - Word-level LM scoring at SentencePiece boundaries - DecodingMethod enum (.greedy, .beamSearch) in ASRConfig - AsrManager: route zipformer2 to beam/greedy based on config, add arpaLanguageModel property and setLanguageModel() Co-Authored-By: Claude Opus 4.6 (1M context) --- Sources/FluidAudio/ASR/AsrManager.swift | 41 ++- Sources/FluidAudio/ASR/AsrTypes.swift | 15 +- .../ASR/TDT/ZipformerRnntDecoder.swift | 264 ++++++++++++++++-- 3 files changed, 283 insertions(+), 37 deletions(-) diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 962b307e6..864347f19 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -39,6 +39,9 @@ public actor AsrManager { /// Cached vocabulary loaded once during initialization internal var vocabulary: [Int: String] = [:] + + /// Optional ARPA language model for beam search rescoring + public var arpaLanguageModel: ARPALanguageModel? #if DEBUG // Test-only setter internal func setVocabularyForTesting(_ vocab: [Int: String]) { @@ -323,6 +326,11 @@ public actor AsrManager { return directory.standardizedFileURL } + /// Set the ARPA language model for beam search rescoring. + public func setLanguageModel(_ lm: ARPALanguageModel) { + self.arpaLanguageModel = lm + } + public func resetState() { // Use model's decoder layer count, or 2 if models not loaded (v2/v3 default) let layers = asrModels?.version.decoderLayers ?? 2 @@ -407,14 +415,31 @@ public actor AsrManager { ) case .zipformer2: let zipformerDecoder = ZipformerRnntDecoder(config: adaptedConfig) - return try zipformerDecoder.decode( - encoderOutput: encoderOutput, - encoderSequenceLength: encoderSequenceLength, - decoderModel: decoder_, - joinerModel: joint, - blankId: models.version.blankId, - contextSize: models.version.contextSize - ) + switch config.decodingMethod { + case .greedy: + return try zipformerDecoder.decode( + encoderOutput: encoderOutput, + encoderSequenceLength: encoderSequenceLength, + decoderModel: decoder_, + joinerModel: joint, + blankId: models.version.blankId, + contextSize: models.version.contextSize + ) + case .beamSearch(let beamWidth, let lmWeight, let tokenCandidates): + return try zipformerDecoder.beamDecode( + encoderOutput: encoderOutput, + encoderSequenceLength: encoderSequenceLength, + decoderModel: decoder_, + joinerModel: joint, + vocabulary: vocabulary, + lm: arpaLanguageModel, + blankId: models.version.blankId, + contextSize: models.version.contextSize, + beamWidth: beamWidth, + lmWeight: lmWeight, + tokenCandidates: tokenCandidates + ) + } } } diff --git a/Sources/FluidAudio/ASR/AsrTypes.swift b/Sources/FluidAudio/ASR/AsrTypes.swift index c4dcf2950..4aef09dda 100644 --- a/Sources/FluidAudio/ASR/AsrTypes.swift +++ b/Sources/FluidAudio/ASR/AsrTypes.swift @@ -2,6 +2,14 @@ import Foundation // MARK: - Configuration +/// Decoding method for transducer models. +public enum DecodingMethod: Sendable { + /// Greedy search: one best token per encoder frame + case greedy + /// Modified beam search with optional LM rescoring + case beamSearch(beamWidth: Int = 4, lmWeight: Float = 0.3, tokenCandidates: Int = 8) +} + public struct ASRConfig: Sendable { public let sampleRate: Int public let tdtConfig: TdtConfig @@ -19,6 +27,9 @@ public struct ASRConfig: Sendable { /// Default: 480,000 samples (~30 seconds at 16kHz) public let streamingThreshold: Int + /// Decoding method for transducer (RNNT/TDT) models + public let decodingMethod: DecodingMethod + public static let `default` = ASRConfig() public init( @@ -26,13 +37,15 @@ public struct ASRConfig: Sendable { tdtConfig: TdtConfig = .default, encoderHiddenSize: Int = ASRConstants.encoderHiddenSize, streamingEnabled: Bool = true, - streamingThreshold: Int = 480_000 + streamingThreshold: Int = 480_000, + decodingMethod: DecodingMethod = .greedy ) { self.sampleRate = sampleRate self.tdtConfig = tdtConfig self.encoderHiddenSize = encoderHiddenSize self.streamingEnabled = streamingEnabled self.streamingThreshold = streamingThreshold + self.decodingMethod = decodingMethod } } diff --git a/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift index 22656bbb1..f3a6989c3 100644 --- a/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift +++ b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift @@ -1,4 +1,7 @@ -/// Greedy RNN-T decoder for Zipformer2 transducer models (icefall/sherpa-onnx). +/// RNN-T decoder for Zipformer2 transducer models (icefall/sherpa-onnx). +/// +/// Supports both greedy and modified beam search decoding, with optional +/// ARPA language model rescoring at word boundaries. /// /// Unlike Parakeet TDT, Zipformer2 uses: /// - **Stateless decoder**: context window of token IDs (no LSTM hidden/cell states) @@ -14,6 +17,23 @@ import CoreML import Foundation import OSLog +// MARK: - Beam hypothesis for RNNT modified beam search + +internal struct RnntBeam { + var tokens: [Int] + var context: [Int] + var logProb: Float + var lmScore: Float + var wordPieces: [String] + var prevWord: String? + var timestamps: [Int] + var confidences: [Float] + + var total: Float { logProb + lmScore } +} + +// MARK: - Decoder + internal struct ZipformerRnntDecoder { private let logger = AppLogger(category: "ZipformerRNNT") @@ -24,18 +44,9 @@ internal struct ZipformerRnntDecoder { self.config = config } + // MARK: - Greedy decode (one token per frame) + /// Decode encoder output using greedy RNNT search. - /// - /// The encoder output shape is `[1, T, joinerDim]` (time-major, unlike Parakeet's `[1, D, T]`). - /// - /// - Parameters: - /// - encoderOutput: Encoder output, shape `[1, T, joinerDim]` - /// - encoderSequenceLength: Number of valid encoder frames - /// - decoderModel: Stateless decoder CoreML model - /// - joinerModel: Joiner CoreML model - /// - blankId: Blank token ID (typically 0 for Zipformer2) - /// - contextSize: Decoder context window size (typically 2) - /// - Returns: Decoded hypothesis with tokens, timestamps, and confidences func decode( encoderOutput: MLMultiArray, encoderSequenceLength: Int, @@ -46,39 +57,28 @@ internal struct ZipformerRnntDecoder { ) throws -> TdtHypothesis { let joinerDim = encoderOutput.shape[2].intValue - // Context buffer: last `contextSize` tokens, initialized with blank var context = [Int](repeating: blankId, count: contextSize) - - // Use TdtHypothesis for compatibility with existing pipeline - // We create a dummy decoder state since Zipformer2 is stateless - var hypothesis = TdtDecoderState.make(decoderLayers: 1) + let hypothesis = TdtDecoderState.make(decoderLayers: 1) var result = TdtHypothesis(decState: hypothesis) - // Precompute encoder strides for efficient frame extraction - // Shape: [1, T, joinerDim] let encStride0 = encoderOutput.strides[0].intValue let encStride1 = encoderOutput.strides[1].intValue let encStride2 = encoderOutput.strides[2].intValue let encPtr = encoderOutput.dataPointer.bindMemory( to: Float.self, capacity: encoderOutput.count) - // Preallocate reusable arrays let encoderStep = try MLMultiArray( shape: [1, NSNumber(value: joinerDim)], dataType: .float32) let decoderInput = try MLMultiArray( shape: [1, NSNumber(value: contextSize)], dataType: .int32) - let encStepPtr = encoderStep.dataPointer.bindMemory( to: Float.self, capacity: joinerDim) for t in 0.. [1, joinerDim] for d in 0.. logits let joinInput = try MLDictionaryFeatureProvider(dictionary: [ "encoder_out": MLFeatureValue(multiArray: encoderStep), "decoder_out": MLFeatureValue(multiArray: decoderOut), @@ -99,7 +98,6 @@ internal struct ZipformerRnntDecoder { from: joinInput, options: predictionOptions) let logits = joinOutput.featureValue(for: "logit")!.multiArrayValue! - // Argmax over vocabulary using vDSP let vocabSize = logits.shape.last!.intValue let logitsPtr = logits.dataPointer.bindMemory( to: Float.self, capacity: vocabSize) @@ -109,14 +107,12 @@ internal struct ZipformerRnntDecoder { let tokenId = Int(maxIdx) if tokenId != blankId { - // Emit token result.ySequence.append(tokenId) result.timestamps.append(t) result.tokenDurations.append(1) result.tokenConfidences.append(maxVal) result.lastToken = tokenId - // Update context: shift left, add new token context.removeFirst() context.append(tokenId) } @@ -124,4 +120,216 @@ internal struct ZipformerRnntDecoder { return result } + + // MARK: - Modified beam search with optional LM + + /// Decode encoder output using modified beam search with optional ARPA LM. + /// + /// Maintains `beamWidth` hypotheses. At each encoder frame, expands each + /// hypothesis by trying blank + top-K non-blank tokens, then prunes to + /// the best `beamWidth` hypotheses by score. + /// + /// Word-level LM scores are applied at SentencePiece word boundaries (▁ prefix). + /// + /// - Parameters: + /// - encoderOutput: Encoder output, shape `[1, T, joinerDim]` + /// - encoderSequenceLength: Number of valid encoder frames + /// - decoderModel: Stateless decoder CoreML model + /// - joinerModel: Joiner CoreML model + /// - vocabulary: Token ID → string mapping for LM word boundary detection + /// - lm: Optional ARPA language model + /// - blankId: Blank token ID (typically 0) + /// - contextSize: Decoder context window size (typically 2) + /// - beamWidth: Number of hypotheses to maintain (default 4) + /// - lmWeight: LM score scaling factor (default 0.3) + /// - tokenCandidates: Top-K non-blank tokens to consider per frame (default 8) + func beamDecode( + encoderOutput: MLMultiArray, + encoderSequenceLength: Int, + decoderModel: MLModel, + joinerModel: MLModel, + vocabulary: [Int: String], + lm: ARPALanguageModel?, + blankId: Int, + contextSize: Int, + beamWidth: Int = 4, + lmWeight: Float = 0.3, + tokenCandidates: Int = 8 + ) throws -> TdtHypothesis { + let joinerDim = encoderOutput.shape[2].intValue + + let encStride0 = encoderOutput.strides[0].intValue + let encStride1 = encoderOutput.strides[1].intValue + let encStride2 = encoderOutput.strides[2].intValue + let encPtr = encoderOutput.dataPointer.bindMemory( + to: Float.self, capacity: encoderOutput.count) + + let encoderStep = try MLMultiArray( + shape: [1, NSNumber(value: joinerDim)], dataType: .float32) + let decoderInput = try MLMultiArray( + shape: [1, NSNumber(value: contextSize)], dataType: .int32) + let encStepPtr = encoderStep.dataPointer.bindMemory( + to: Float.self, capacity: joinerDim) + + // Initialize with single blank-context beam + let initialContext = [Int](repeating: blankId, count: contextSize) + var beams = [RnntBeam( + tokens: [], context: initialContext, logProb: 0.0, lmScore: 0.0, + wordPieces: [], prevWord: nil, timestamps: [], confidences: [] + )] + + // Cache decoder outputs for each unique context to avoid redundant calls + var decoderCache: [[Int]: MLMultiArray] = [:] + + for t in 0.. maxLogit { maxLogit = logitsPtr[v] } + } + var sumExp: Float = 0 + for v in 0.. $1.1 } + + // Candidate 2..K+1: top non-blank tokens + for (tokenId, tokenLogProb) in indexed.prefix(tokenCandidates) { + var newContext = beam.context + newContext.removeFirst() + newContext.append(tokenId) + + // LM scoring at word boundaries + var newLmScore = beam.lmScore + var newWordPieces = beam.wordPieces + var newPrevWord = beam.prevWord + + if let lm = lm, let tokenStr = vocabulary[tokenId] { + newWordPieces.append(tokenStr) + // Check for word boundary: SentencePiece ▁ prefix on NEXT token + if tokenStr.hasPrefix("\u{2581}") && !beam.wordPieces.isEmpty { + // Previous word pieces form a complete word + let word = beam.wordPieces.joined() + .replacingOccurrences(of: "\u{2581}", with: "") + if !word.isEmpty { + newLmScore += lmWeight * lm.score( + word: word.lowercased(), prev: beam.prevWord) + newPrevWord = word.lowercased() + } + newWordPieces = [tokenStr] + } + } + + var newTimestamps = beam.timestamps + newTimestamps.append(t) + var newConfidences = beam.confidences + newConfidences.append(exp(tokenLogProb)) + + candidates.append(RnntBeam( + tokens: beam.tokens + [tokenId], + context: newContext, + logProb: beam.logProb + tokenLogProb, + lmScore: newLmScore, + wordPieces: newWordPieces, + prevWord: newPrevWord, + timestamps: newTimestamps, + confidences: newConfidences + )) + } + } + + // Prune to top beamWidth by total score + candidates.sort { $0.total > $1.total } + beams = Array(candidates.prefix(beamWidth)) + + // Clear decoder cache for contexts no longer in active beams + let activeContexts = Set(beams.map { $0.context }) + decoderCache = decoderCache.filter { activeContexts.contains($0.key) } + } + + // Score final incomplete word for LM + if let lm = lm { + for i in 0.. Date: Sat, 28 Mar 2026 20:03:38 +0000 Subject: [PATCH 8/8] feat: CoreML RNN-LM for BPE-level beam search rescoring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - RnnLanguageModel.swift: CoreML wrapper for step-by-step LSTM LM scoring (token_id, h, c) → (log_probs, h_out, c_out) - ZipformerRnntDecoder: beam search uses RNN-LM for per-token scoring when available, with per-beam LSTM state tracking. Falls back to ARPA word-level scoring when only ARPA LM is provided. - AsrManager: add rnnLanguageModel property + setRnnLanguageModel() Works with all models sharing the same BPE vocabulary (Parakeet, Zipformer). Co-Authored-By: Claude Opus 4.6 (1M context) --- Sources/FluidAudio/ASR/AsrManager.swift | 9 ++ .../FluidAudio/ASR/CTC/RnnLanguageModel.swift | 122 ++++++++++++++++++ .../ASR/TDT/ZipformerRnntDecoder.swift | 60 ++++----- 3 files changed, 162 insertions(+), 29 deletions(-) create mode 100644 Sources/FluidAudio/ASR/CTC/RnnLanguageModel.swift diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 864347f19..2e0561e90 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -42,6 +42,9 @@ public actor AsrManager { /// Optional ARPA language model for beam search rescoring public var arpaLanguageModel: ARPALanguageModel? + + /// Optional CoreML RNN language model for BPE-level beam search rescoring + public var rnnLanguageModel: RnnLanguageModel? #if DEBUG // Test-only setter internal func setVocabularyForTesting(_ vocab: [Int: String]) { @@ -331,6 +334,11 @@ public actor AsrManager { self.arpaLanguageModel = lm } + /// Set the CoreML RNN language model for BPE-level beam search rescoring. + public func setRnnLanguageModel(_ lm: RnnLanguageModel) { + self.rnnLanguageModel = lm + } + public func resetState() { // Use model's decoder layer count, or 2 if models not loaded (v2/v3 default) let layers = asrModels?.version.decoderLayers ?? 2 @@ -433,6 +441,7 @@ public actor AsrManager { joinerModel: joint, vocabulary: vocabulary, lm: arpaLanguageModel, + rnnLm: rnnLanguageModel, blankId: models.version.blankId, contextSize: models.version.contextSize, beamWidth: beamWidth, diff --git a/Sources/FluidAudio/ASR/CTC/RnnLanguageModel.swift b/Sources/FluidAudio/ASR/CTC/RnnLanguageModel.swift new file mode 100644 index 000000000..09276b1f6 --- /dev/null +++ b/Sources/FluidAudio/ASR/CTC/RnnLanguageModel.swift @@ -0,0 +1,122 @@ +/// CoreML RNN language model for BPE-level beam search rescoring. +/// +/// Wraps a step-by-step LSTM LM that takes one token at a time and +/// maintains hidden state. Used by both RNNT and CTC beam search decoders +/// for token-level scoring (better than word-level ARPA). +/// +/// The CoreML model expects: +/// - Inputs: token_id [1] int32, h_in [layers, 1, hidden] f32, c_in [layers, 1, hidden] f32 +/// - Outputs: log_probs [1, vocab] f32, h_out [layers, 1, hidden] f32, c_out [layers, 1, hidden] f32 + +import CoreML +import Foundation + +public struct RnnLanguageModel { + + private let model: MLModel + public let vocabSize: Int + public let numLayers: Int + public let hiddenDim: Int + private let predictionOptions: MLPredictionOptions + + public init(model: MLModel, vocabSize: Int, numLayers: Int, hiddenDim: Int) { + self.model = model + self.vocabSize = vocabSize + self.numLayers = numLayers + self.hiddenDim = hiddenDim + self.predictionOptions = MLPredictionOptions() + } + + /// Load from a compiled .mlmodelc or .mlpackage directory. + public static func load( + from url: URL, + vocabSize: Int, + numLayers: Int, + hiddenDim: Int, + computeUnits: MLComputeUnits = .all + ) throws -> RnnLanguageModel { + let fm = FileManager.default + let config = MLModelConfiguration() + config.computeUnits = computeUnits + + let compiledURL = url.appendingPathExtension("mlmodelc") + let packageURL = url.pathExtension == "mlmodelc" ? url + : url.pathExtension == "mlpackage" ? url : compiledURL + + let mlModel: MLModel + if fm.fileExists(atPath: url.path), url.pathExtension == "mlmodelc" { + mlModel = try MLModel(contentsOf: url, configuration: config) + } else if fm.fileExists(atPath: url.path), url.pathExtension == "mlpackage" { + let compiled = try MLModel.compileModel(at: url) + mlModel = try MLModel(contentsOf: compiled, configuration: config) + } else { + // Try appending extensions + let mlmodelc = url.appendingPathExtension("mlmodelc") + let mlpackage = url.appendingPathExtension("mlpackage") + if fm.fileExists(atPath: mlmodelc.path) { + mlModel = try MLModel(contentsOf: mlmodelc, configuration: config) + } else if fm.fileExists(atPath: mlpackage.path) { + let compiled = try MLModel.compileModel(at: mlpackage) + mlModel = try MLModel(contentsOf: compiled, configuration: config) + } else { + throw RnnLmError.modelNotFound(url.path) + } + } + + return RnnLanguageModel(model: mlModel, vocabSize: vocabSize, + numLayers: numLayers, hiddenDim: hiddenDim) + } + + /// Create zero-initialized LSTM state. + public func makeInitialState() throws -> (h: MLMultiArray, c: MLMultiArray) { + let shape = [numLayers, 1, hiddenDim] as [NSNumber] + let h = try MLMultiArray(shape: shape, dataType: .float32) + let c = try MLMultiArray(shape: shape, dataType: .float32) + let count = numLayers * hiddenDim + memset(h.dataPointer, 0, count * 4) + memset(c.dataPointer, 0, count * 4) + return (h, c) + } + + /// Score a single token given LSTM state. Returns log_probs pointer and new state. + /// + /// This is the core method used in beam search. Each beam hypothesis carries + /// its own (h, c) state, so this method does not mutate any shared state. + /// + /// - Parameters: + /// - tokenId: BPE token ID to score + /// - h: LSTM hidden state [numLayers, 1, hiddenDim] + /// - c: LSTM cell state [numLayers, 1, hiddenDim] + /// - Returns: (logProbs MLMultiArray [1, vocabSize], h_out, c_out) + public func score( + tokenId: Int, h: MLMultiArray, c: MLMultiArray + ) throws -> (logProbs: MLMultiArray, hOut: MLMultiArray, cOut: MLMultiArray) { + let tokenArray = try MLMultiArray(shape: [1], dataType: .int32) + tokenArray[0] = NSNumber(value: Int32(tokenId)) + + let input = try MLDictionaryFeatureProvider(dictionary: [ + "token_id": MLFeatureValue(multiArray: tokenArray), + "h_in": MLFeatureValue(multiArray: h), + "c_in": MLFeatureValue(multiArray: c), + ]) + + let output = try model.prediction(from: input, options: predictionOptions) + + let logProbs = output.featureValue(for: "log_probs")!.multiArrayValue! + let hOut = output.featureValue(for: "h_out")!.multiArrayValue! + let cOut = output.featureValue(for: "c_out")!.multiArrayValue! + + return (logProbs, hOut, cOut) + } +} + +public enum RnnLmError: Error, LocalizedError { + case modelNotFound(String) + + public var errorDescription: String? { + switch self { + case .modelNotFound(let path): + return "RNN-LM model not found at: \(path) (tried .mlmodelc and .mlpackage)" + } + } +} diff --git a/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift index f3a6989c3..d70d31033 100644 --- a/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift +++ b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift @@ -28,6 +28,9 @@ internal struct RnntBeam { var prevWord: String? var timestamps: [Int] var confidences: [Float] + // RNN-LM LSTM state per beam (nil when using ARPA or no LM) + var rnnLmH: MLMultiArray? + var rnnLmC: MLMultiArray? var total: Float { logProb + lmScore } } @@ -123,26 +126,11 @@ internal struct ZipformerRnntDecoder { // MARK: - Modified beam search with optional LM - /// Decode encoder output using modified beam search with optional ARPA LM. + /// Decode encoder output using modified beam search with optional LM. /// - /// Maintains `beamWidth` hypotheses. At each encoder frame, expands each - /// hypothesis by trying blank + top-K non-blank tokens, then prunes to - /// the best `beamWidth` hypotheses by score. - /// - /// Word-level LM scores are applied at SentencePiece word boundaries (▁ prefix). - /// - /// - Parameters: - /// - encoderOutput: Encoder output, shape `[1, T, joinerDim]` - /// - encoderSequenceLength: Number of valid encoder frames - /// - decoderModel: Stateless decoder CoreML model - /// - joinerModel: Joiner CoreML model - /// - vocabulary: Token ID → string mapping for LM word boundary detection - /// - lm: Optional ARPA language model - /// - blankId: Blank token ID (typically 0) - /// - contextSize: Decoder context window size (typically 2) - /// - beamWidth: Number of hypotheses to maintain (default 4) - /// - lmWeight: LM score scaling factor (default 0.3) - /// - tokenCandidates: Top-K non-blank tokens to consider per frame (default 8) + /// Supports two LM types (RNN-LM takes precedence if both provided): + /// - **RNN-LM**: BPE token-level scoring via CoreML LSTM (best quality) + /// - **ARPA**: Word-level n-gram scoring at SentencePiece boundaries func beamDecode( encoderOutput: MLMultiArray, encoderSequenceLength: Int, @@ -150,6 +138,7 @@ internal struct ZipformerRnntDecoder { joinerModel: MLModel, vocabulary: [Int: String], lm: ARPALanguageModel?, + rnnLm: RnnLanguageModel?, blankId: Int, contextSize: Int, beamWidth: Int = 4, @@ -173,9 +162,11 @@ internal struct ZipformerRnntDecoder { // Initialize with single blank-context beam let initialContext = [Int](repeating: blankId, count: contextSize) + let initialLmState = try rnnLm?.makeInitialState() var beams = [RnntBeam( tokens: [], context: initialContext, logProb: 0.0, lmScore: 0.0, - wordPieces: [], prevWord: nil, timestamps: [], confidences: [] + wordPieces: [], prevWord: nil, timestamps: [], confidences: [], + rnnLmH: initialLmState?.h, rnnLmC: initialLmState?.c )] // Cache decoder outputs for each unique context to avoid redundant calls @@ -241,7 +232,8 @@ internal struct ZipformerRnntDecoder { logProb: beam.logProb + logProbs[blankId], lmScore: beam.lmScore, wordPieces: beam.wordPieces, prevWord: beam.prevWord, - timestamps: beam.timestamps, confidences: beam.confidences + timestamps: beam.timestamps, confidences: beam.confidences, + rnnLmH: beam.rnnLmH, rnnLmC: beam.rnnLmC )) // Find top-K non-blank tokens @@ -258,16 +250,24 @@ internal struct ZipformerRnntDecoder { newContext.removeFirst() newContext.append(tokenId) - // LM scoring at word boundaries var newLmScore = beam.lmScore var newWordPieces = beam.wordPieces var newPrevWord = beam.prevWord - - if let lm = lm, let tokenStr = vocabulary[tokenId] { + var newRnnLmH = beam.rnnLmH + var newRnnLmC = beam.rnnLmC + + if let rnnLm = rnnLm, let h = beam.rnnLmH, let c = beam.rnnLmC { + // RNN-LM: token-level scoring (every token) + let lmResult = try rnnLm.score(tokenId: tokenId, h: h, c: c) + let lmLogProb = lmResult.logProbs.dataPointer.bindMemory( + to: Float.self, capacity: rnnLm.vocabSize)[tokenId] + newLmScore += lmWeight * lmLogProb + newRnnLmH = lmResult.hOut + newRnnLmC = lmResult.cOut + } else if let lm = lm, let tokenStr = vocabulary[tokenId] { + // ARPA fallback: word-level scoring at boundaries newWordPieces.append(tokenStr) - // Check for word boundary: SentencePiece ▁ prefix on NEXT token if tokenStr.hasPrefix("\u{2581}") && !beam.wordPieces.isEmpty { - // Previous word pieces form a complete word let word = beam.wordPieces.joined() .replacingOccurrences(of: "\u{2581}", with: "") if !word.isEmpty { @@ -292,7 +292,9 @@ internal struct ZipformerRnntDecoder { wordPieces: newWordPieces, prevWord: newPrevWord, timestamps: newTimestamps, - confidences: newConfidences + confidences: newConfidences, + rnnLmH: newRnnLmH, + rnnLmC: newRnnLmC )) } } @@ -306,8 +308,8 @@ internal struct ZipformerRnntDecoder { decoderCache = decoderCache.filter { activeContexts.contains($0.key) } } - // Score final incomplete word for LM - if let lm = lm { + // Score final incomplete word for ARPA LM (RNN-LM already scored per-token) + if rnnLm == nil, let lm = lm { for i in 0..