diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 503a494b3..2e0561e90 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -30,10 +30,21 @@ public actor AsrManager { return asrModels?.version.decoderLayers ?? 2 } + /// Get the max audio samples for the current model's preprocessor. + internal func getMaxModelSamples() -> Int { + return asrModels?.version.maxAudioSamples ?? ASRConstants.maxModelSamples + } + /// Token duration optimization model /// Cached vocabulary loaded once during initialization internal var vocabulary: [Int: String] = [:] + + /// Optional ARPA language model for beam search rescoring + public var arpaLanguageModel: ARPALanguageModel? + + /// Optional CoreML RNN language model for BPE-level beam search rescoring + public var rnnLanguageModel: RnnLanguageModel? #if DEBUG // Test-only setter internal func setVocabularyForTesting(_ vocab: [Int: String]) { @@ -240,6 +251,21 @@ public actor AsrManager { } internal func initializeDecoderState(for source: AudioSource) async throws { + // Zipformer2 uses a stateless decoder — no LSTM state to initialize + if asrModels?.version.hasStatelessDecoder == true { + var state: TdtDecoderState + switch source { + case .microphone: state = microphoneDecoderState + case .system: state = systemDecoderState + } + state.reset() + switch source { + case .microphone: microphoneDecoderState = state + case .system: systemDecoderState = state + } + return + } + guard let decoderModel = decoderModel else { throw ASRError.notInitialized } @@ -303,6 +329,16 @@ public actor AsrManager { return directory.standardizedFileURL } + /// Set the ARPA language model for beam search rescoring. + public func setLanguageModel(_ lm: ARPALanguageModel) { + self.arpaLanguageModel = lm + } + + /// Set the CoreML RNN language model for BPE-level beam search rescoring. + public func setRnnLanguageModel(_ lm: RnnLanguageModel) { + self.rnnLanguageModel = lm + } + public func resetState() { // Use model's decoder layer count, or 2 if models not loaded (v2/v3 default) let layers = asrModels?.version.decoderLayers ?? 2 @@ -385,6 +421,34 @@ public actor AsrManager { isLastChunk: isLastChunk, globalFrameOffset: globalFrameOffset ) + case .zipformer2: + let zipformerDecoder = ZipformerRnntDecoder(config: adaptedConfig) + switch config.decodingMethod { + case .greedy: + return try zipformerDecoder.decode( + encoderOutput: encoderOutput, + encoderSequenceLength: encoderSequenceLength, + decoderModel: decoder_, + joinerModel: joint, + blankId: models.version.blankId, + contextSize: models.version.contextSize + ) + case .beamSearch(let beamWidth, let lmWeight, let tokenCandidates): + return try zipformerDecoder.beamDecode( + encoderOutput: encoderOutput, + encoderSequenceLength: encoderSequenceLength, + decoderModel: decoder_, + joinerModel: joint, + vocabulary: vocabulary, + lm: arpaLanguageModel, + rnnLm: rnnLanguageModel, + blankId: models.version.blankId, + contextSize: models.version.contextSize, + beamWidth: beamWidth, + lmWeight: lmWeight, + tokenCandidates: tokenCandidates + ) + } } } diff --git a/Sources/FluidAudio/ASR/AsrModels.swift b/Sources/FluidAudio/ASR/AsrModels.swift index 67129c6bd..18b73fca8 100644 --- a/Sources/FluidAudio/ASR/AsrModels.swift +++ b/Sources/FluidAudio/ASR/AsrModels.swift @@ -8,27 +8,46 @@ public enum AsrModelVersion: Sendable { case v3 /// 110M parameter hybrid TDT-CTC model with fused preprocessor+encoder case tdtCtc110m + /// Zipformer2 transducer (icefall/sherpa-onnx) with stateless decoder + case zipformer2 var repo: Repo { switch self { case .v2: return .parakeetV2 case .v3: return .parakeet case .tdtCtc110m: return .parakeetTdtCtc110m + case .zipformer2: return .zipformer2 } } /// Whether this model version uses a fused preprocessor+encoder (no separate Encoder model) public var hasFusedEncoder: Bool { switch self { - case .tdtCtc110m: return true + case .tdtCtc110m, .zipformer2: return true default: return false } } + /// Whether this model uses a stateless decoder (context window) vs stateful LSTM + public var hasStatelessDecoder: Bool { + switch self { + case .zipformer2: return true + default: return false + } + } + + /// Decoder context window size (for stateless decoders) + public var contextSize: Int { + switch self { + case .zipformer2: return 2 + default: return 0 // Not applicable for LSTM decoders + } + } + /// Encoder hidden dimension for this model version public var encoderHiddenSize: Int { switch self { - case .tdtCtc110m: return 512 + case .tdtCtc110m, .zipformer2: return 512 default: return 1024 } } @@ -38,6 +57,7 @@ public enum AsrModelVersion: Sendable { switch self { case .v2, .tdtCtc110m: return 1024 case .v3: return 8192 + case .zipformer2: return 0 } } @@ -45,9 +65,20 @@ public enum AsrModelVersion: Sendable { public var decoderLayers: Int { switch self { case .tdtCtc110m: return 1 + case .zipformer2: return 1 // Dummy, not used for stateless decoder default: return 2 } } + + /// Maximum audio samples for the fused preprocessor input. + /// Zipformer2 uses 239120 (produces 1495 mel frames for encoder compatibility). + /// Parakeet uses 240000 (15s at 16kHz). + public var maxAudioSamples: Int { + switch self { + case .zipformer2: return 239_120 + default: return 240_000 + } + } } public struct AsrModels: Sendable { @@ -84,7 +115,7 @@ public struct AsrModels: Sendable { self.version = version } - /// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m fused) + /// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m/zipformer2) public var usesSplitFrontend: Bool { !version.hasFusedEncoder } @@ -123,7 +154,7 @@ extension AsrModels { private static func inferredVersion(from directory: URL) -> AsrModelVersion? { let directoryPath = directory.path.lowercased() - let knownVersions: [AsrModelVersion] = [.tdtCtc110m, .v2, .v3] + let knownVersions: [AsrModelVersion] = [.zipformer2, .tdtCtc110m, .v2, .v3] for version in knownVersions { if directoryPath.contains(version.repo.folderName.lowercased()) { @@ -315,6 +346,84 @@ extension AsrModels { return try await load(from: targetDir, configuration: nil) } + /// Load Zipformer2 transducer models directly from a local directory. + /// + /// The directory should contain encoder, decoder, and joiner models as either + /// `.mlpackage` (source) or `.mlmodelc` (compiled) format, plus vocab.json. + /// If `.mlpackage` files are found, they are compiled on the fly. + /// + /// - Parameters: + /// - directory: Directory containing the Zipformer2 model files + /// - configuration: Optional MLModel configuration + /// - Returns: Loaded AsrModels with version set to .zipformer2 + public static func loadZipformer2( + from directory: URL, + configuration: MLModelConfiguration? = nil + ) throws -> AsrModels { + let config: MLModelConfiguration + if let configuration { + config = configuration + } else { + // Zipformer2 models exported with iOS18 target work best with .all compute units. + // The default .cpuAndNeuralEngine triggers very slow ANE compilation for these models. + let c = MLModelConfiguration() + c.computeUnits = .all + config = c + } + + func loadModel(name: String, packageExt: String, compiledExt: String) throws -> MLModel { + let compiledPath = directory.appendingPathComponent(name + compiledExt) + if FileManager.default.fileExists(atPath: compiledPath.path) { + return try MLModel(contentsOf: compiledPath, configuration: config) + } + let packagePath = directory.appendingPathComponent(name + packageExt) + guard FileManager.default.fileExists(atPath: packagePath.path) else { + throw AsrModelsError.modelNotFound( + name + packageExt, packagePath) + } + // Compile .mlpackage → temporary .mlmodelc + let compiledURL = try MLModel.compileModel(at: packagePath) + return try MLModel(contentsOf: compiledURL, configuration: config) + } + + // Zipformer2 requires fused Preprocessor (audio → encoder features) + let encoderModel = try loadModel( + name: "Preprocessor", packageExt: ".mlpackage", compiledExt: ".mlmodelc") + + // Load decoder + let decoderModel = try loadModel( + name: ModelNames.Zipformer2.decoder, packageExt: ".mlpackage", compiledExt: ".mlmodelc") + + // Load joiner + let joinerModel = try loadModel( + name: ModelNames.Zipformer2.joiner, packageExt: ".mlpackage", compiledExt: ".mlmodelc") + + // Load vocabulary (JSON array format) + let vocabPath = directory.appendingPathComponent(ModelNames.Zipformer2.vocabulary) + guard FileManager.default.fileExists(atPath: vocabPath.path) else { + throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.vocabulary, vocabPath) + } + let vocabData = try Data(contentsOf: vocabPath) + guard let vocabArray = try JSONSerialization.jsonObject(with: vocabData) as? [String] else { + throw AsrModelsError.loadingFailed("Vocabulary file has unexpected format") + } + var vocabulary: [Int: String] = [:] + for (index, token) in vocabArray.enumerated() { + vocabulary[index] = token + } + + // The encoder/preprocessor serves as "preprocessor" in AsrModels + return AsrModels( + encoder: nil, + preprocessor: encoderModel, + decoder: decoderModel, + joint: joinerModel, + configuration: config, + vocabulary: vocabulary, + version: .zipformer2 + ) + } + public static func defaultConfiguration() -> MLModelConfiguration { let config = MLModelConfiguration() config.allowLowPrecisionAccumulationOnGPU = true diff --git a/Sources/FluidAudio/ASR/AsrTranscription.swift b/Sources/FluidAudio/ASR/AsrTranscription.swift index 5574abfc8..38e1849a4 100644 --- a/Sources/FluidAudio/ASR/AsrTranscription.swift +++ b/Sources/FluidAudio/ASR/AsrTranscription.swift @@ -22,21 +22,22 @@ extension AsrManager { } // Route to appropriate processing method based on audio length - if audioSamples.count <= ASRConstants.maxModelSamples { + let maxSamples = getMaxModelSamples() + if audioSamples.count <= maxSamples { let originalLength = audioSamples.count let frameAlignedCandidate = ((originalLength + ASRConstants.samplesPerEncoderFrame - 1) / ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame let frameAlignedLength: Int let alignedSamples: [Float] - if frameAlignedCandidate > originalLength && frameAlignedCandidate <= ASRConstants.maxModelSamples { + if frameAlignedCandidate > originalLength && frameAlignedCandidate <= maxSamples { frameAlignedLength = frameAlignedCandidate alignedSamples = audioSamples + Array(repeating: 0, count: frameAlignedLength - originalLength) } else { frameAlignedLength = originalLength alignedSamples = audioSamples } - let paddedAudio: [Float] = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples) + let paddedAudio: [Float] = padAudioIfNeeded(alignedSamples, targetLength: maxSamples) let (hypothesis, encoderSequenceLength) = try await executeMLInferenceWithTimings( paddedAudio, originalLength: frameAlignedLength, @@ -142,13 +143,27 @@ extension AsrManager { encoderOutputProvider = preprocessorOutput } + // Parakeet uses "encoder"/"encoder_length", Zipformer2 fused uses "encoder_out"/"encoder_out_lens" + let encoderKey = encoderOutputProvider.featureValue(for: "encoder") != nil ? "encoder" : "encoder_out" + let lengthKey = + encoderOutputProvider.featureValue(for: "encoder_length") != nil + ? "encoder_length" : "encoder_out_lens" let rawEncoderOutput = try extractFeatureValue( - from: encoderOutputProvider, key: "encoder", errorMessage: "Invalid encoder output") + from: encoderOutputProvider, key: encoderKey, errorMessage: "Invalid encoder output") let encoderLength = try extractFeatureValue( - from: encoderOutputProvider, key: "encoder_length", + from: encoderOutputProvider, key: lengthKey, errorMessage: "Invalid encoder output length") - let encoderSequenceLength = encoderLength[0].intValue + var encoderSequenceLength = encoderLength[0].intValue + + // For Zipformer2 fused models, encoder_out_lens is a traced constant (always max frames). + // Compute the actual valid frame count from the audio length instead. + // Formula: mel_frames = (samples - 200) / 160 + 1, encoder_frames = (mel_frames - 7) / 4 + if asrModels?.version == .zipformer2, let actualLength = originalLength { + let melFrames = max(1, (actualLength - 200) / 160 + 1) + let validFrames = max(1, (melFrames - 7) / 4) + encoderSequenceLength = min(encoderSequenceLength, validFrames) + } // Calculate actual audio frames if not provided using shared constants let actualFrames = @@ -227,6 +242,7 @@ extension AsrManager { ) async throws -> (tokens: [Int], timestamps: [Int], confidences: [Float], encoderSequenceLength: Int) { // Select and copy decoder state for the source var state = (source == .microphone) ? microphoneDecoderState : systemDecoderState + let maxSamples = getMaxModelSamples() let originalLength = chunkSamples.count let frameAlignedCandidate = @@ -236,7 +252,7 @@ extension AsrManager { let alignedSamples: [Float] if previousTokens.isEmpty && frameAlignedCandidate > originalLength - && frameAlignedCandidate <= ASRConstants.maxModelSamples + && frameAlignedCandidate <= maxSamples { frameAlignedLength = frameAlignedCandidate alignedSamples = chunkSamples + Array(repeating: 0, count: frameAlignedLength - originalLength) @@ -244,7 +260,7 @@ extension AsrManager { frameAlignedLength = originalLength alignedSamples = chunkSamples } - let padded = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples) + let padded = padAudioIfNeeded(alignedSamples, targetLength: maxSamples) let (hypothesis, encLen) = try await executeMLInferenceWithTimings( padded, originalLength: frameAlignedLength, diff --git a/Sources/FluidAudio/ASR/AsrTypes.swift b/Sources/FluidAudio/ASR/AsrTypes.swift index c4dcf2950..4aef09dda 100644 --- a/Sources/FluidAudio/ASR/AsrTypes.swift +++ b/Sources/FluidAudio/ASR/AsrTypes.swift @@ -2,6 +2,14 @@ import Foundation // MARK: - Configuration +/// Decoding method for transducer models. +public enum DecodingMethod: Sendable { + /// Greedy search: one best token per encoder frame + case greedy + /// Modified beam search with optional LM rescoring + case beamSearch(beamWidth: Int = 4, lmWeight: Float = 0.3, tokenCandidates: Int = 8) +} + public struct ASRConfig: Sendable { public let sampleRate: Int public let tdtConfig: TdtConfig @@ -19,6 +27,9 @@ public struct ASRConfig: Sendable { /// Default: 480,000 samples (~30 seconds at 16kHz) public let streamingThreshold: Int + /// Decoding method for transducer (RNNT/TDT) models + public let decodingMethod: DecodingMethod + public static let `default` = ASRConfig() public init( @@ -26,13 +37,15 @@ public struct ASRConfig: Sendable { tdtConfig: TdtConfig = .default, encoderHiddenSize: Int = ASRConstants.encoderHiddenSize, streamingEnabled: Bool = true, - streamingThreshold: Int = 480_000 + streamingThreshold: Int = 480_000, + decodingMethod: DecodingMethod = .greedy ) { self.sampleRate = sampleRate self.tdtConfig = tdtConfig self.encoderHiddenSize = encoderHiddenSize self.streamingEnabled = streamingEnabled self.streamingThreshold = streamingThreshold + self.decodingMethod = decodingMethod } } diff --git a/Sources/FluidAudio/ASR/CTC/RnnLanguageModel.swift b/Sources/FluidAudio/ASR/CTC/RnnLanguageModel.swift new file mode 100644 index 000000000..09276b1f6 --- /dev/null +++ b/Sources/FluidAudio/ASR/CTC/RnnLanguageModel.swift @@ -0,0 +1,122 @@ +/// CoreML RNN language model for BPE-level beam search rescoring. +/// +/// Wraps a step-by-step LSTM LM that takes one token at a time and +/// maintains hidden state. Used by both RNNT and CTC beam search decoders +/// for token-level scoring (better than word-level ARPA). +/// +/// The CoreML model expects: +/// - Inputs: token_id [1] int32, h_in [layers, 1, hidden] f32, c_in [layers, 1, hidden] f32 +/// - Outputs: log_probs [1, vocab] f32, h_out [layers, 1, hidden] f32, c_out [layers, 1, hidden] f32 + +import CoreML +import Foundation + +public struct RnnLanguageModel { + + private let model: MLModel + public let vocabSize: Int + public let numLayers: Int + public let hiddenDim: Int + private let predictionOptions: MLPredictionOptions + + public init(model: MLModel, vocabSize: Int, numLayers: Int, hiddenDim: Int) { + self.model = model + self.vocabSize = vocabSize + self.numLayers = numLayers + self.hiddenDim = hiddenDim + self.predictionOptions = MLPredictionOptions() + } + + /// Load from a compiled .mlmodelc or .mlpackage directory. + public static func load( + from url: URL, + vocabSize: Int, + numLayers: Int, + hiddenDim: Int, + computeUnits: MLComputeUnits = .all + ) throws -> RnnLanguageModel { + let fm = FileManager.default + let config = MLModelConfiguration() + config.computeUnits = computeUnits + + let compiledURL = url.appendingPathExtension("mlmodelc") + let packageURL = url.pathExtension == "mlmodelc" ? url + : url.pathExtension == "mlpackage" ? url : compiledURL + + let mlModel: MLModel + if fm.fileExists(atPath: url.path), url.pathExtension == "mlmodelc" { + mlModel = try MLModel(contentsOf: url, configuration: config) + } else if fm.fileExists(atPath: url.path), url.pathExtension == "mlpackage" { + let compiled = try MLModel.compileModel(at: url) + mlModel = try MLModel(contentsOf: compiled, configuration: config) + } else { + // Try appending extensions + let mlmodelc = url.appendingPathExtension("mlmodelc") + let mlpackage = url.appendingPathExtension("mlpackage") + if fm.fileExists(atPath: mlmodelc.path) { + mlModel = try MLModel(contentsOf: mlmodelc, configuration: config) + } else if fm.fileExists(atPath: mlpackage.path) { + let compiled = try MLModel.compileModel(at: mlpackage) + mlModel = try MLModel(contentsOf: compiled, configuration: config) + } else { + throw RnnLmError.modelNotFound(url.path) + } + } + + return RnnLanguageModel(model: mlModel, vocabSize: vocabSize, + numLayers: numLayers, hiddenDim: hiddenDim) + } + + /// Create zero-initialized LSTM state. + public func makeInitialState() throws -> (h: MLMultiArray, c: MLMultiArray) { + let shape = [numLayers, 1, hiddenDim] as [NSNumber] + let h = try MLMultiArray(shape: shape, dataType: .float32) + let c = try MLMultiArray(shape: shape, dataType: .float32) + let count = numLayers * hiddenDim + memset(h.dataPointer, 0, count * 4) + memset(c.dataPointer, 0, count * 4) + return (h, c) + } + + /// Score a single token given LSTM state. Returns log_probs pointer and new state. + /// + /// This is the core method used in beam search. Each beam hypothesis carries + /// its own (h, c) state, so this method does not mutate any shared state. + /// + /// - Parameters: + /// - tokenId: BPE token ID to score + /// - h: LSTM hidden state [numLayers, 1, hiddenDim] + /// - c: LSTM cell state [numLayers, 1, hiddenDim] + /// - Returns: (logProbs MLMultiArray [1, vocabSize], h_out, c_out) + public func score( + tokenId: Int, h: MLMultiArray, c: MLMultiArray + ) throws -> (logProbs: MLMultiArray, hOut: MLMultiArray, cOut: MLMultiArray) { + let tokenArray = try MLMultiArray(shape: [1], dataType: .int32) + tokenArray[0] = NSNumber(value: Int32(tokenId)) + + let input = try MLDictionaryFeatureProvider(dictionary: [ + "token_id": MLFeatureValue(multiArray: tokenArray), + "h_in": MLFeatureValue(multiArray: h), + "c_in": MLFeatureValue(multiArray: c), + ]) + + let output = try model.prediction(from: input, options: predictionOptions) + + let logProbs = output.featureValue(for: "log_probs")!.multiArrayValue! + let hOut = output.featureValue(for: "h_out")!.multiArrayValue! + let cOut = output.featureValue(for: "c_out")!.multiArrayValue! + + return (logProbs, hOut, cOut) + } +} + +public enum RnnLmError: Error, LocalizedError { + case modelNotFound(String) + + public var errorDescription: String? { + switch self { + case .modelNotFound(let path): + return "RNN-LM model not found at: \(path) (tried .mlmodelc and .mlpackage)" + } + } +} diff --git a/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift new file mode 100644 index 000000000..d70d31033 --- /dev/null +++ b/Sources/FluidAudio/ASR/TDT/ZipformerRnntDecoder.swift @@ -0,0 +1,337 @@ +/// RNN-T decoder for Zipformer2 transducer models (icefall/sherpa-onnx). +/// +/// Supports both greedy and modified beam search decoding, with optional +/// ARPA language model rescoring at word boundaries. +/// +/// Unlike Parakeet TDT, Zipformer2 uses: +/// - **Stateless decoder**: context window of token IDs (no LSTM hidden/cell states) +/// - **Standard RNNT**: no duration prediction, advance one encoder frame per step +/// - **blank_id = 0**: first token in vocabulary is blank +/// +/// The decoder takes the last `contextSize` token IDs as input and produces +/// a decoder embedding. The joiner combines encoder + decoder embeddings to +/// produce logits over the vocabulary. + +import Accelerate +import CoreML +import Foundation +import OSLog + +// MARK: - Beam hypothesis for RNNT modified beam search + +internal struct RnntBeam { + var tokens: [Int] + var context: [Int] + var logProb: Float + var lmScore: Float + var wordPieces: [String] + var prevWord: String? + var timestamps: [Int] + var confidences: [Float] + // RNN-LM LSTM state per beam (nil when using ARPA or no LM) + var rnnLmH: MLMultiArray? + var rnnLmC: MLMultiArray? + + var total: Float { logProb + lmScore } +} + +// MARK: - Decoder + +internal struct ZipformerRnntDecoder { + + private let logger = AppLogger(category: "ZipformerRNNT") + private let config: ASRConfig + private let predictionOptions = AsrModels.optimizedPredictionOptions() + + init(config: ASRConfig) { + self.config = config + } + + // MARK: - Greedy decode (one token per frame) + + /// Decode encoder output using greedy RNNT search. + func decode( + encoderOutput: MLMultiArray, + encoderSequenceLength: Int, + decoderModel: MLModel, + joinerModel: MLModel, + blankId: Int, + contextSize: Int + ) throws -> TdtHypothesis { + let joinerDim = encoderOutput.shape[2].intValue + + var context = [Int](repeating: blankId, count: contextSize) + let hypothesis = TdtDecoderState.make(decoderLayers: 1) + var result = TdtHypothesis(decState: hypothesis) + + let encStride0 = encoderOutput.strides[0].intValue + let encStride1 = encoderOutput.strides[1].intValue + let encStride2 = encoderOutput.strides[2].intValue + let encPtr = encoderOutput.dataPointer.bindMemory( + to: Float.self, capacity: encoderOutput.count) + + let encoderStep = try MLMultiArray( + shape: [1, NSNumber(value: joinerDim)], dataType: .float32) + let decoderInput = try MLMultiArray( + shape: [1, NSNumber(value: contextSize)], dataType: .int32) + let encStepPtr = encoderStep.dataPointer.bindMemory( + to: Float.self, capacity: joinerDim) + + for t in 0.. TdtHypothesis { + let joinerDim = encoderOutput.shape[2].intValue + + let encStride0 = encoderOutput.strides[0].intValue + let encStride1 = encoderOutput.strides[1].intValue + let encStride2 = encoderOutput.strides[2].intValue + let encPtr = encoderOutput.dataPointer.bindMemory( + to: Float.self, capacity: encoderOutput.count) + + let encoderStep = try MLMultiArray( + shape: [1, NSNumber(value: joinerDim)], dataType: .float32) + let decoderInput = try MLMultiArray( + shape: [1, NSNumber(value: contextSize)], dataType: .int32) + let encStepPtr = encoderStep.dataPointer.bindMemory( + to: Float.self, capacity: joinerDim) + + // Initialize with single blank-context beam + let initialContext = [Int](repeating: blankId, count: contextSize) + let initialLmState = try rnnLm?.makeInitialState() + var beams = [RnntBeam( + tokens: [], context: initialContext, logProb: 0.0, lmScore: 0.0, + wordPieces: [], prevWord: nil, timestamps: [], confidences: [], + rnnLmH: initialLmState?.h, rnnLmC: initialLmState?.c + )] + + // Cache decoder outputs for each unique context to avoid redundant calls + var decoderCache: [[Int]: MLMultiArray] = [:] + + for t in 0.. maxLogit { maxLogit = logitsPtr[v] } + } + var sumExp: Float = 0 + for v in 0.. $1.1 } + + // Candidate 2..K+1: top non-blank tokens + for (tokenId, tokenLogProb) in indexed.prefix(tokenCandidates) { + var newContext = beam.context + newContext.removeFirst() + newContext.append(tokenId) + + var newLmScore = beam.lmScore + var newWordPieces = beam.wordPieces + var newPrevWord = beam.prevWord + var newRnnLmH = beam.rnnLmH + var newRnnLmC = beam.rnnLmC + + if let rnnLm = rnnLm, let h = beam.rnnLmH, let c = beam.rnnLmC { + // RNN-LM: token-level scoring (every token) + let lmResult = try rnnLm.score(tokenId: tokenId, h: h, c: c) + let lmLogProb = lmResult.logProbs.dataPointer.bindMemory( + to: Float.self, capacity: rnnLm.vocabSize)[tokenId] + newLmScore += lmWeight * lmLogProb + newRnnLmH = lmResult.hOut + newRnnLmC = lmResult.cOut + } else if let lm = lm, let tokenStr = vocabulary[tokenId] { + // ARPA fallback: word-level scoring at boundaries + newWordPieces.append(tokenStr) + if tokenStr.hasPrefix("\u{2581}") && !beam.wordPieces.isEmpty { + let word = beam.wordPieces.joined() + .replacingOccurrences(of: "\u{2581}", with: "") + if !word.isEmpty { + newLmScore += lmWeight * lm.score( + word: word.lowercased(), prev: beam.prevWord) + newPrevWord = word.lowercased() + } + newWordPieces = [tokenStr] + } + } + + var newTimestamps = beam.timestamps + newTimestamps.append(t) + var newConfidences = beam.confidences + newConfidences.append(exp(tokenLogProb)) + + candidates.append(RnntBeam( + tokens: beam.tokens + [tokenId], + context: newContext, + logProb: beam.logProb + tokenLogProb, + lmScore: newLmScore, + wordPieces: newWordPieces, + prevWord: newPrevWord, + timestamps: newTimestamps, + confidences: newConfidences, + rnnLmH: newRnnLmH, + rnnLmC: newRnnLmC + )) + } + } + + // Prune to top beamWidth by total score + candidates.sort { $0.total > $1.total } + beams = Array(candidates.prefix(beamWidth)) + + // Clear decoder cache for contexts no longer in active beams + let activeContexts = Set(beams.map { $0.context }) + decoderCache = decoderCache.filter { activeContexts.contains($0.key) } + } + + // Score final incomplete word for ARPA LM (RNN-LM already scored per-token) + if rnnLm == nil, let lm = lm { + for i in 0.. = [ + encoderFile, + decoderFile, + joinerFile, + ] + } + /// CTC model names public enum CTC { public static let melSpectrogram = "MelSpectrogram" @@ -642,6 +670,8 @@ public enum ModelNames { return ModelNames.Qwen3ASR.requiredModelsFull case .multilingualG2p: return ModelNames.MultilingualG2P.requiredModels + case .zipformer2: + return ModelNames.Zipformer2.requiredModels } } } diff --git a/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift b/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift index 1212df825..56995eb59 100644 --- a/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift @@ -842,6 +842,7 @@ extension ASRBenchmark { case .v2: versionLabel = "v2" case .v3: versionLabel = "v3" case .tdtCtc110m: versionLabel = "tdt-ctc-110m" + case .zipformer2: versionLabel = "zipformer2" } logger.info(" Model version: \(versionLabel)") logger.info(" Debug mode: \(debugMode ? "enabled" : "disabled")") diff --git a/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift b/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift index e970c8cb5..22292cc2e 100644 --- a/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift +++ b/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift @@ -241,9 +241,12 @@ enum TranscribeCommand { modelVersion = .v3 case "tdt-ctc-110m", "110m": modelVersion = .tdtCtc110m + case "zipformer2", "zipformer": + modelVersion = .zipformer2 default: logger.error( - "Invalid model version: \(arguments[i + 1]). Use 'v2', 'v3', or 'tdt-ctc-110m'") + "Invalid model version: \(arguments[i + 1]). Use 'v2', 'v3', 'tdt-ctc-110m', or 'zipformer2'" + ) exit(1) } i += 1 @@ -290,7 +293,11 @@ enum TranscribeCommand { let models: AsrModels if let modelDir = modelDir { let dir = URL(fileURLWithPath: modelDir) - models = try await AsrModels.load(from: dir, version: modelVersion) + if modelVersion == .zipformer2 { + models = try AsrModels.loadZipformer2(from: dir) + } else { + models = try await AsrModels.load(from: dir, version: modelVersion) + } } else { models = try await AsrModels.downloadAndLoad(version: modelVersion) } @@ -411,6 +418,7 @@ enum TranscribeCommand { case .v2: modelVersionLabel = "v2" case .v3: modelVersionLabel = "v3" case .tdtCtc110m: modelVersionLabel = "tdt-ctc-110m" + case .zipformer2: modelVersionLabel = "zipformer2" } let output = TranscriptionJSONOutput( audioFile: audioFile, @@ -665,6 +673,7 @@ enum TranscribeCommand { case .v2: modelVersionLabel = "v2" case .v3: modelVersionLabel = "v3" case .tdtCtc110m: modelVersionLabel = "tdt-ctc-110m" + case .zipformer2: modelVersionLabel = "zipformer2" } let output = TranscriptionJSONOutput( audioFile: audioFile,