-
Notifications
You must be signed in to change notification settings - Fork 244
Add standalone CTC head for custom vocabulary (#435) #450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2f8a6d5
d7fa888
6a2e4d2
55941bf
4e787f7
d83a893
f0e3dab
e8c0a71
d9fbbb0
cb4f293
e093ed2
1adfd8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,37 @@ public actor AsrManager { | |
| internal var vocabSizeConfig: ContextBiasingConstants.VocabSizeConfig? | ||
| internal var vocabBoostingEnabled: Bool { customVocabulary != nil && vocabularyRescorer != nil } | ||
|
|
||
| // Cached CTC logits from fused Preprocessor (unified custom vocabulary) | ||
| internal var cachedCtcLogits: MLMultiArray? | ||
| internal var cachedCtcFrameDuration: Double? | ||
| internal var cachedCtcValidFrames: Int? | ||
|
|
||
| /// Whether the Preprocessor outputs CTC logits (unified custom vocabulary model). | ||
| public var hasCachedCtcLogits: Bool { cachedCtcLogits != nil } | ||
|
|
||
| /// Get cached CTC raw logits as [[Float]] for external use (e.g. benchmarks). | ||
| /// These are raw logits — callers must apply `CtcKeywordSpotter.applyLogSoftmax()` | ||
| /// to convert to log-probabilities before use in keyword detection. | ||
| /// Returns nil if the CTC head model is not available or audio was multi-chunk. | ||
| public func getCachedCtcRawLogits() -> (rawLogits: [[Float]], frameDuration: Double)? { | ||
| guard let logits = cachedCtcLogits, let duration = cachedCtcFrameDuration else { return nil } | ||
| let shape = logits.shape | ||
| guard shape.count == 3 else { return nil } | ||
| let numFrames = min(shape[1].intValue, cachedCtcValidFrames ?? shape[1].intValue) | ||
| let vocabSize = shape[2].intValue | ||
| var result: [[Float]] = [] | ||
| result.reserveCapacity(numFrames) | ||
| for t in 0..<numFrames { | ||
| var frame: [Float] = [] | ||
| frame.reserveCapacity(vocabSize) | ||
| for v in 0..<vocabSize { | ||
| frame.append(logits[[0, t, v] as [NSNumber]].floatValue) | ||
| } | ||
| result.append(frame) | ||
| } | ||
| return (rawLogits: result, frameDuration: duration) | ||
| } | ||
|
Comment on lines
+56
to
+85
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 No unit tests added for new CTC head functionality (AGENTS.md violation) AGENTS.md states: "Add unit tests when writing new code." This PR adds significant new functionality — CTC head model loading ( Prompt for agentsWas this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| // Cached prediction options for reuse | ||
| internal lazy var predictionOptions: MLPredictionOptions = { | ||
| AsrModels.optimizedPredictionOptions() | ||
|
|
@@ -308,6 +339,9 @@ public actor AsrManager { | |
| let layers = asrModels?.version.decoderLayers ?? 2 | ||
| microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) | ||
| systemDecoderState = TdtDecoderState.make(decoderLayers: layers) | ||
| cachedCtcLogits = nil | ||
| cachedCtcFrameDuration = nil | ||
| cachedCtcValidFrames = nil | ||
| Task { await sharedMLArrayCache.clear() } | ||
| } | ||
|
|
||
|
|
@@ -322,7 +356,10 @@ public actor AsrManager { | |
| // Reset decoder states using fresh allocations for deterministic behavior | ||
| microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) | ||
| systemDecoderState = TdtDecoderState.make(decoderLayers: layers) | ||
| // Release vocabulary boosting resources | ||
| // Release vocabulary boosting resources and cached CTC data | ||
| cachedCtcLogits = nil | ||
| cachedCtcFrameDuration = nil | ||
| cachedCtcValidFrames = nil | ||
| disableVocabularyBoosting() | ||
| Task { await sharedMLArrayCache.clear() } | ||
| logger.info("AsrManager resources cleaned up") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,8 @@ public struct AsrModels: Sendable { | |
| public let preprocessor: MLModel | ||
| public let decoder: MLModel | ||
| public let joint: MLModel | ||
| /// Optional CTC decoder head for custom vocabulary (encoder features → CTC logits) | ||
| public let ctcHead: MLModel? | ||
| public let configuration: MLModelConfiguration | ||
| public let vocabulary: [Int: String] | ||
| public let version: AsrModelVersion | ||
|
|
@@ -71,6 +73,7 @@ public struct AsrModels: Sendable { | |
| preprocessor: MLModel, | ||
| decoder: MLModel, | ||
| joint: MLModel, | ||
| ctcHead: MLModel? = nil, | ||
| configuration: MLModelConfiguration, | ||
| vocabulary: [Int: String], | ||
| version: AsrModelVersion | ||
|
|
@@ -79,6 +82,7 @@ public struct AsrModels: Sendable { | |
| self.preprocessor = preprocessor | ||
| self.decoder = decoder | ||
| self.joint = joint | ||
| self.ctcHead = ctcHead | ||
| self.configuration = configuration | ||
| self.vocabulary = vocabulary | ||
| self.version = version | ||
|
|
@@ -207,11 +211,52 @@ extension AsrModels { | |
| throw AsrModelsError.loadingFailed("Failed to load decoder or joint model") | ||
| } | ||
|
|
||
| // [Beta] Optionally load CTC head model for custom vocabulary. | ||
| // Supports two paths: | ||
| // v1: CtcHead.mlmodelc placed manually in the TDT model directory | ||
| // v2: Auto-download from FluidInference/parakeet-ctc-110m-coreml HF repo | ||
| var ctcHeadModel: MLModel? | ||
| if version == .tdtCtc110m { | ||
| // v1: Check local TDT model directory first | ||
| let repoDir = repoPath(from: directory, version: version) | ||
| let ctcHeadPath = repoDir.appendingPathComponent(Names.ctcHeadFile) | ||
| if FileManager.default.fileExists(atPath: ctcHeadPath.path) { | ||
| let ctcConfig = MLModelConfiguration() | ||
| ctcConfig.computeUnits = config.computeUnits | ||
| ctcHeadModel = try? MLModel(contentsOf: ctcHeadPath, configuration: ctcConfig) | ||
| if ctcHeadModel != nil { | ||
| logger.info("[Beta] Loaded CTC head model from local directory") | ||
| } else { | ||
| logger.warning("CTC head model found but failed to load: \(ctcHeadPath.path)") | ||
| } | ||
| } | ||
|
|
||
| // v2: Fall back to downloading from parakeet-ctc-110m HF repo | ||
| if ctcHeadModel == nil { | ||
| do { | ||
| let ctcModels = try await DownloadUtils.loadModels( | ||
| .parakeetCtc110m, | ||
| modelNames: [Names.ctcHeadFile], | ||
| directory: parentDirectory, | ||
| computeUnits: config.computeUnits, | ||
| progressHandler: progressHandler | ||
| ) | ||
| ctcHeadModel = ctcModels[Names.ctcHeadFile] | ||
| if ctcHeadModel != nil { | ||
| logger.info("[Beta] Loaded CTC head model from HF repo") | ||
| } | ||
| } catch { | ||
| logger.warning("CTC head model not available: \(error.localizedDescription)") | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
+219
to
+252
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Nested if statements in CTC head loading code violate AGENTS.md control flow rule AGENTS.md states: "Nested if statements should be absolutely avoided. Use guard statements and inverted conditions to exit early." The new CTC head loading block has 3 levels of nesting: Prompt for agentsWas this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| let asrModels = AsrModels( | ||
| encoder: encoderModel, | ||
| preprocessor: preprocessorModel, | ||
| decoder: decoderModel, | ||
| joint: jointModel, | ||
| ctcHead: ctcHeadModel, | ||
| configuration: config, | ||
| vocabulary: try loadVocabulary(from: directory, version: version), | ||
| version: version | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.