Skip to content
64 changes: 64 additions & 0 deletions Sources/FluidAudio/ASR/AsrManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ public actor AsrManager {
return asrModels?.version.decoderLayers ?? 2
}

/// Get the max audio samples for the current model's preprocessor.
internal func getMaxModelSamples() -> Int {
return asrModels?.version.maxAudioSamples ?? ASRConstants.maxModelSamples
}

/// Token duration optimization model

/// Cached vocabulary loaded once during initialization
internal var vocabulary: [Int: String] = [:]

/// Optional ARPA language model for beam search rescoring
public var arpaLanguageModel: ARPALanguageModel?

/// Optional CoreML RNN language model for BPE-level beam search rescoring
public var rnnLanguageModel: RnnLanguageModel?
#if DEBUG
// Test-only setter
internal func setVocabularyForTesting(_ vocab: [Int: String]) {
Expand Down Expand Up @@ -240,6 +251,21 @@ public actor AsrManager {
}

internal func initializeDecoderState(for source: AudioSource) async throws {
// Zipformer2 uses a stateless decoder — no LSTM state to initialize
if asrModels?.version.hasStatelessDecoder == true {
var state: TdtDecoderState
switch source {
case .microphone: state = microphoneDecoderState
case .system: state = systemDecoderState
}
state.reset()
switch source {
case .microphone: microphoneDecoderState = state
case .system: systemDecoderState = state
}
return
}

guard let decoderModel = decoderModel else {
throw ASRError.notInitialized
}
Expand Down Expand Up @@ -303,6 +329,16 @@ public actor AsrManager {
return directory.standardizedFileURL
}

/// Set the ARPA language model for beam search rescoring.
public func setLanguageModel(_ lm: ARPALanguageModel) {
self.arpaLanguageModel = lm
}

/// Set the CoreML RNN language model for BPE-level beam search rescoring.
public func setRnnLanguageModel(_ lm: RnnLanguageModel) {
self.rnnLanguageModel = lm
}

public func resetState() {
// Use model's decoder layer count, or 2 if models not loaded (v2/v3 default)
let layers = asrModels?.version.decoderLayers ?? 2
Expand Down Expand Up @@ -385,6 +421,34 @@ public actor AsrManager {
isLastChunk: isLastChunk,
globalFrameOffset: globalFrameOffset
)
case .zipformer2:
let zipformerDecoder = ZipformerRnntDecoder(config: adaptedConfig)
switch config.decodingMethod {
case .greedy:
return try zipformerDecoder.decode(
encoderOutput: encoderOutput,
encoderSequenceLength: encoderSequenceLength,
decoderModel: decoder_,
joinerModel: joint,
blankId: models.version.blankId,
contextSize: models.version.contextSize
)
case .beamSearch(let beamWidth, let lmWeight, let tokenCandidates):
return try zipformerDecoder.beamDecode(
encoderOutput: encoderOutput,
encoderSequenceLength: encoderSequenceLength,
decoderModel: decoder_,
joinerModel: joint,
vocabulary: vocabulary,
lm: arpaLanguageModel,
rnnLm: rnnLanguageModel,
blankId: models.version.blankId,
contextSize: models.version.contextSize,
beamWidth: beamWidth,
lmWeight: lmWeight,
tokenCandidates: tokenCandidates
)
}
}
}

Expand Down
117 changes: 113 additions & 4 deletions Sources/FluidAudio/ASR/AsrModels.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,46 @@ public enum AsrModelVersion: Sendable {
case v3
/// 110M parameter hybrid TDT-CTC model with fused preprocessor+encoder
case tdtCtc110m
/// Zipformer2 transducer (icefall/sherpa-onnx) with stateless decoder
case zipformer2

var repo: Repo {
switch self {
case .v2: return .parakeetV2
case .v3: return .parakeet
case .tdtCtc110m: return .parakeetTdtCtc110m
case .zipformer2: return .zipformer2
}
}

/// Whether this model version uses a fused preprocessor+encoder (no separate Encoder model)
public var hasFusedEncoder: Bool {
switch self {
case .tdtCtc110m: return true
case .tdtCtc110m, .zipformer2: return true
default: return false
}
}

/// Whether this model uses a stateless decoder (context window) vs stateful LSTM
public var hasStatelessDecoder: Bool {
switch self {
case .zipformer2: return true
default: return false
}
}

/// Decoder context window size (for stateless decoders)
public var contextSize: Int {
switch self {
case .zipformer2: return 2
default: return 0 // Not applicable for LSTM decoders
}
}

/// Encoder hidden dimension for this model version
public var encoderHiddenSize: Int {
switch self {
case .tdtCtc110m: return 512
case .tdtCtc110m, .zipformer2: return 512
default: return 1024
}
}
Expand All @@ -38,16 +57,28 @@ public enum AsrModelVersion: Sendable {
switch self {
case .v2, .tdtCtc110m: return 1024
case .v3: return 8192
case .zipformer2: return 0
}
}

/// Number of LSTM layers in the decoder prediction network
public var decoderLayers: Int {
switch self {
case .tdtCtc110m: return 1
case .zipformer2: return 1 // Dummy, not used for stateless decoder
default: return 2
}
}

/// Maximum audio samples for the fused preprocessor input.
/// Zipformer2 uses 239120 (produces 1495 mel frames for encoder compatibility).
/// Parakeet uses 240000 (15s at 16kHz).
public var maxAudioSamples: Int {
switch self {
case .zipformer2: return 239_120
default: return 240_000
}
}
}

public struct AsrModels: Sendable {
Expand Down Expand Up @@ -84,7 +115,7 @@ public struct AsrModels: Sendable {
self.version = version
}

/// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m fused)
/// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m/zipformer2)
public var usesSplitFrontend: Bool {
!version.hasFusedEncoder
}
Expand Down Expand Up @@ -123,7 +154,7 @@ extension AsrModels {

private static func inferredVersion(from directory: URL) -> AsrModelVersion? {
let directoryPath = directory.path.lowercased()
let knownVersions: [AsrModelVersion] = [.tdtCtc110m, .v2, .v3]
let knownVersions: [AsrModelVersion] = [.zipformer2, .tdtCtc110m, .v2, .v3]

for version in knownVersions {
if directoryPath.contains(version.repo.folderName.lowercased()) {
Expand Down Expand Up @@ -315,6 +346,84 @@ extension AsrModels {
return try await load(from: targetDir, configuration: nil)
}

/// Load Zipformer2 transducer models directly from a local directory.
///
/// The directory should contain encoder, decoder, and joiner models as either
/// `.mlpackage` (source) or `.mlmodelc` (compiled) format, plus vocab.json.
/// If `.mlpackage` files are found, they are compiled on the fly.
///
/// - Parameters:
/// - directory: Directory containing the Zipformer2 model files
/// - configuration: Optional MLModel configuration
/// - Returns: Loaded AsrModels with version set to .zipformer2
public static func loadZipformer2(
from directory: URL,
configuration: MLModelConfiguration? = nil
) throws -> AsrModels {
let config: MLModelConfiguration
if let configuration {
config = configuration
} else {
// Zipformer2 models exported with iOS18 target work best with .all compute units.
// The default .cpuAndNeuralEngine triggers very slow ANE compilation for these models.
let c = MLModelConfiguration()
c.computeUnits = .all
config = c
}

func loadModel(name: String, packageExt: String, compiledExt: String) throws -> MLModel {
let compiledPath = directory.appendingPathComponent(name + compiledExt)
if FileManager.default.fileExists(atPath: compiledPath.path) {
return try MLModel(contentsOf: compiledPath, configuration: config)
}
let packagePath = directory.appendingPathComponent(name + packageExt)
guard FileManager.default.fileExists(atPath: packagePath.path) else {
throw AsrModelsError.modelNotFound(
name + packageExt, packagePath)
}
// Compile .mlpackage → temporary .mlmodelc
let compiledURL = try MLModel.compileModel(at: packagePath)
return try MLModel(contentsOf: compiledURL, configuration: config)
}

// Zipformer2 requires fused Preprocessor (audio → encoder features)
let encoderModel = try loadModel(
name: "Preprocessor", packageExt: ".mlpackage", compiledExt: ".mlmodelc")

// Load decoder
let decoderModel = try loadModel(
name: ModelNames.Zipformer2.decoder, packageExt: ".mlpackage", compiledExt: ".mlmodelc")

// Load joiner
let joinerModel = try loadModel(
name: ModelNames.Zipformer2.joiner, packageExt: ".mlpackage", compiledExt: ".mlmodelc")

// Load vocabulary (JSON array format)
let vocabPath = directory.appendingPathComponent(ModelNames.Zipformer2.vocabulary)
guard FileManager.default.fileExists(atPath: vocabPath.path) else {
throw AsrModelsError.modelNotFound(ModelNames.Zipformer2.vocabulary, vocabPath)
}
let vocabData = try Data(contentsOf: vocabPath)
guard let vocabArray = try JSONSerialization.jsonObject(with: vocabData) as? [String] else {
throw AsrModelsError.loadingFailed("Vocabulary file has unexpected format")
}
var vocabulary: [Int: String] = [:]
for (index, token) in vocabArray.enumerated() {
vocabulary[index] = token
}

// The encoder/preprocessor serves as "preprocessor" in AsrModels
return AsrModels(
encoder: nil,
preprocessor: encoderModel,
decoder: decoderModel,
joint: joinerModel,
configuration: config,
vocabulary: vocabulary,
version: .zipformer2
)
}

public static func defaultConfiguration() -> MLModelConfiguration {
let config = MLModelConfiguration()
config.allowLowPrecisionAccumulationOnGPU = true
Expand Down
32 changes: 24 additions & 8 deletions Sources/FluidAudio/ASR/AsrTranscription.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@ extension AsrManager {
}

// Route to appropriate processing method based on audio length
if audioSamples.count <= ASRConstants.maxModelSamples {
let maxSamples = getMaxModelSamples()
if audioSamples.count <= maxSamples {
let originalLength = audioSamples.count
let frameAlignedCandidate =
((originalLength + ASRConstants.samplesPerEncoderFrame - 1)
/ ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame
let frameAlignedLength: Int
let alignedSamples: [Float]
if frameAlignedCandidate > originalLength && frameAlignedCandidate <= ASRConstants.maxModelSamples {
if frameAlignedCandidate > originalLength && frameAlignedCandidate <= maxSamples {
frameAlignedLength = frameAlignedCandidate
alignedSamples = audioSamples + Array(repeating: 0, count: frameAlignedLength - originalLength)
} else {
frameAlignedLength = originalLength
alignedSamples = audioSamples
}
let paddedAudio: [Float] = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples)
let paddedAudio: [Float] = padAudioIfNeeded(alignedSamples, targetLength: maxSamples)
let (hypothesis, encoderSequenceLength) = try await executeMLInferenceWithTimings(
paddedAudio,
originalLength: frameAlignedLength,
Expand Down Expand Up @@ -142,13 +143,27 @@ extension AsrManager {
encoderOutputProvider = preprocessorOutput
}

// Parakeet uses "encoder"/"encoder_length", Zipformer2 fused uses "encoder_out"/"encoder_out_lens"
let encoderKey = encoderOutputProvider.featureValue(for: "encoder") != nil ? "encoder" : "encoder_out"
let lengthKey =
encoderOutputProvider.featureValue(for: "encoder_length") != nil
? "encoder_length" : "encoder_out_lens"
let rawEncoderOutput = try extractFeatureValue(
from: encoderOutputProvider, key: "encoder", errorMessage: "Invalid encoder output")
from: encoderOutputProvider, key: encoderKey, errorMessage: "Invalid encoder output")
let encoderLength = try extractFeatureValue(
from: encoderOutputProvider, key: "encoder_length",
from: encoderOutputProvider, key: lengthKey,
errorMessage: "Invalid encoder output length")

let encoderSequenceLength = encoderLength[0].intValue
var encoderSequenceLength = encoderLength[0].intValue

// For Zipformer2 fused models, encoder_out_lens is a traced constant (always max frames).
// Compute the actual valid frame count from the audio length instead.
// Formula: mel_frames = (samples - 200) / 160 + 1, encoder_frames = (mel_frames - 7) / 4
if asrModels?.version == .zipformer2, let actualLength = originalLength {
let melFrames = max(1, (actualLength - 200) / 160 + 1)
let validFrames = max(1, (melFrames - 7) / 4)
encoderSequenceLength = min(encoderSequenceLength, validFrames)
}

// Calculate actual audio frames if not provided using shared constants
let actualFrames =
Expand Down Expand Up @@ -227,6 +242,7 @@ extension AsrManager {
) async throws -> (tokens: [Int], timestamps: [Int], confidences: [Float], encoderSequenceLength: Int) {
// Select and copy decoder state for the source
var state = (source == .microphone) ? microphoneDecoderState : systemDecoderState
let maxSamples = getMaxModelSamples()

let originalLength = chunkSamples.count
let frameAlignedCandidate =
Expand All @@ -236,15 +252,15 @@ extension AsrManager {
let alignedSamples: [Float]
if previousTokens.isEmpty
&& frameAlignedCandidate > originalLength
&& frameAlignedCandidate <= ASRConstants.maxModelSamples
&& frameAlignedCandidate <= maxSamples
{
frameAlignedLength = frameAlignedCandidate
alignedSamples = chunkSamples + Array(repeating: 0, count: frameAlignedLength - originalLength)
} else {
frameAlignedLength = originalLength
alignedSamples = chunkSamples
}
let padded = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples)
let padded = padAudioIfNeeded(alignedSamples, targetLength: maxSamples)
let (hypothesis, encLen) = try await executeMLInferenceWithTimings(
padded,
originalLength: frameAlignedLength,
Expand Down
Loading