Add standalone CTC head for custom vocabulary (#435)#450
Conversation
|
Claude finished @Alex-Wengg's task —— View job Code Review Complete ✅Successfully reviewed PR #450: Add standalone CTC head for custom vocabulary Todo List:
Review SummaryThis PR successfully implements a standalone CTC head for custom vocabulary processing, achieving a significant performance improvement (86% speedup) while maintaining accuracy. The architectural changes are well-designed and follow the project's patterns. Key Strengths:✅ Performance Impact: RTFx improved from 25.98x to 48.35x (86% speedup) with identical Dict Recall (99.4%) Issues Found:🔶 Performance Concerns (Minor)
🔶 Error Handling (Minor)
🔶 API Design (Nit)
🔶 Documentation (Nit)
Specific Findings:AsrManager.swift: Well-integrated caching mechanism with proper actor isolation. Consider adding zero-copy accessor for performance. AsrTranscription.swift: Excellent dual-path implementation (unified vs separate CTC). The fallback logic is robust and well-logged. CtcKeywordSpotter.swift: New AsrModels.swift: Clean optional CTC head loading with proper error handling and logging. CtcEarningsBenchmark.swift: Comprehensive benchmark updates with proper path detection for unified vs separate models. Recommendations:
Overall Assessment:APPROVE - This is a well-executed optimization that delivers significant performance gains without compromising functionality. The issues identified are minor and don't block the PR. The architectural approach is sound and follows project conventions. Review submitted successfully! 🎯 |
There was a problem hiding this comment.
Code Review Summary
This PR implements a significant optimization for custom vocabulary processing by introducing a standalone 1MB CTC head that replaces the need for a separate 97.5MB CTC encoder. The implementation is well-architected and delivers impressive performance gains.
Performance Impact ⚡
- 86% speedup: RTFx improved from 25.98x to 48.35x
- Memory reduction: ~96MB saved (97.5MB → 1MB for custom vocabulary)
- Accuracy preserved: Dict Recall remains at 99.4%
Architecture Quality ✅
- Proper actor-based concurrency without
@unchecked Sendable - Clean separation between unified and legacy dual-model approaches
- Graceful fallback when standalone CTC head is unavailable
- Well-designed public APIs for external benchmarking access
Minor Issues Identified
- Performance: Array copying in
getCachedCtcLogProbs()could be optimized with zero-copy accessor - Error Handling: Some silent failures could benefit from better logging
- Documentation: Cache lifecycle and memory behavior could be clearer
Test Coverage
The benchmark results demonstrate thorough validation across 772 earnings call files with comprehensive metrics tracking. The test plan shows good coverage of the new functionality.
This is a solid optimization that significantly improves the custom vocabulary workflow while maintaining code quality and backwards compatibility. The minor issues identified are non-blocking and can be addressed in future iterations.
Recommendation: APPROVE 🎯
VAD Benchmark ResultsPerformance Comparison
Dataset Details
✅: Average F1-Score above 70% |
Parakeet EOU Benchmark Results ✅Status: Benchmark passed Performance Metrics
Streaming Metrics
Test runtime: 0m44s • 03/28/2026, 04:08 PM EST RTFx = Real-Time Factor (higher is better) • Processing includes: Model inference, audio preprocessing, state management, and file I/O |
PocketTTS Smoke Test ✅
Runtime: 0m23s Note: PocketTTS uses CoreML MLState (macOS 15) KV cache + Mimi streaming state. CI VM lacks physical GPU — audio quality may differ from Apple Silicon. |
Qwen3-ASR int8 Smoke Test ✅
Runtime: 3m57s Note: CI VM lacks physical GPU — CoreML MLState (macOS 15) KV cache produces degraded results on virtualized runners. On Apple Silicon: ~1.3% WER / 2.5x RTFx. |
Speaker Diarization Benchmark ResultsSpeaker Diarization PerformanceEvaluating "who spoke when" detection accuracy
Diarization Pipeline Timing BreakdownTime spent in each stage of speaker diarization
Speaker Diarization Research ComparisonResearch baselines typically achieve 18-30% DER on standard datasets
Note: RTFx shown above is from GitHub Actions runner. On Apple Silicon with ANE:
🎯 Speaker Diarization Test • AMI Corpus ES2004a • 1049.0s meeting audio • 43.4s diarization time • Test runtime: 2m 34s • 03/28/2026, 04:43 PM EST |
Offline VBx Pipeline ResultsSpeaker Diarization Performance (VBx Batch Mode)Optimal clustering with Hungarian algorithm for maximum accuracy
Offline VBx Pipeline Timing BreakdownTime spent in each stage of batch diarization
Speaker Diarization Research ComparisonOffline VBx achieves competitive accuracy with batch processing
Pipeline Details:
🎯 Offline VBx Test • AMI Corpus ES2004a • 1049.0s meeting audio • 220.3s processing • Test runtime: 3m 45s • 03/28/2026, 04:28 PM EST |
ASR Benchmark Results ✅Status: All benchmarks passed Parakeet v3 (multilingual)
Parakeet v2 (English-optimized)
Streaming (v3)
Streaming (v2)
Streaming tests use 5 files with 0.5s chunks to simulate real-time audio streaming 25 files per dataset • Test runtime: 5m28s • 03/28/2026, 04:28 PM EST RTFx = Real-Time Factor (higher is better) • Calculated as: Total audio duration ÷ Total processing time Expected RTFx Performance on Physical M1 Hardware:• M1 Mac: ~28x (clean), ~25x (other) Testing methodology follows HuggingFace Open ASR Leaderboard |
Export the CTC decoder head (512→1025 linear projection) as a separate 1MB CoreML model instead of requiring the full 97.5MB CTC encoder. The CtcHead model runs on the existing TDT encoder output, achieving 99.4% Dict Recall at 70.29x RTFx on the earnings benchmark (772 files). - Load optional CtcHead.mlmodelc from model directory in AsrModels - Run CTC head on raw encoder output in AsrTranscription - Add spotKeywordsFromLogProbs() for DP on pre-computed log-probs - Add applyLogSoftmax() for raw logits→log-probs conversion - Expose cached CTC logits via AsrManager for VocabularyRescorer - Update CtcEarningsBenchmark to use standalone CTC head path
Instead of only loading CtcHead.mlmodelc if manually placed in the model directory, download it on demand from FluidInference/parakeet-ctc-110m-coreml via DownloadUtils.loadModels when the tdtCtc110m model version is used.
Try loading CtcHead.mlmodelc from the local TDT model directory first (v1), then fall back to auto-downloading from the parakeet-ctc-110m HF repo (v2). Mark CTC head loading as beta in log messages.
- Update CustomVocabulary.md with dual architecture diagrams (standalone CTC head vs separate CTC encoder) and approach comparison table - Add CTC head section to TDT-CTC-110M.md covering architecture, loading paths, performance, conversion, and beta status - Update benchmarks100.md with standalone CTC head results (70.29x RTFx, 1MB model, 99.4% Dict Recall)
d6e8254 to
4e787f7
Compare
- Skip CTC head caching for multi-chunk audio (>15s) to prevent stale logits from last chunk being used for full-audio rescoring - Clear cachedCtcLogits in resetState() and cleanup() to prevent leak - Rename getCachedCtcLogProbs() to getCachedCtcRawLogits() to accurately reflect that values are raw logits, not log-probabilities - Remove duplicate CTC inference in benchmark by reusing pre-computed logProbs via spotKeywordsFromLogProbs() for both paths
Sortformer High-Latency Benchmark ResultsES2004a Performance (30.4s latency config)
Sortformer High-Latency • ES2004a • Runtime: 2m 12s • 2026-03-28T20:10:04.043Z |
The CTC head guard requires isLastChunk to be true, but the single-chunk path in transcribeWithState did not pass it, causing the CTC head to never execute for single-chunk audio (the primary use case).
There was a problem hiding this comment.
🟡 Streaming chunk path incorrectly caches CTC logits from partial audio as if single-chunk
transcribeStreamingChunk() calls executeMLInferenceWithTimings without passing globalFrameOffset, so it defaults to 0 (AsrTranscription.swift:280-287). When isLastChunk: true, the caching condition isLastChunk && globalFrameOffset == 0 at AsrTranscription.swift:157,166 is satisfied, causing the CTC head to run and cache logits from ONLY the last streaming chunk. The public APIs hasCachedCtcLogits and getCachedCtcRawLogits() then return this partial-chunk data as if it were valid full-audio logits. An external caller who streams multiple chunks and then checks the cache would get incorrect data.
(Refers to lines 280-287)
Was this helpful? React with 👍 or 👎 to provide feedback.
| // Cached CTC logits from fused Preprocessor (unified custom vocabulary) | ||
| internal var cachedCtcLogits: MLMultiArray? | ||
| internal var cachedCtcFrameDuration: Double? | ||
|
|
||
| /// Whether the Preprocessor outputs CTC logits (unified custom vocabulary model). | ||
| public var hasCachedCtcLogits: Bool { cachedCtcLogits != nil } | ||
|
|
||
| /// Get cached CTC raw logits as [[Float]] for external use (e.g. benchmarks). | ||
| /// These are raw logits — callers must apply `CtcKeywordSpotter.applyLogSoftmax()` | ||
| /// to convert to log-probabilities before use in keyword detection. | ||
| /// Returns nil if the CTC head model is not available or audio was multi-chunk. | ||
| public func getCachedCtcRawLogits() -> (rawLogits: [[Float]], frameDuration: Double)? { | ||
| guard let logits = cachedCtcLogits, let duration = cachedCtcFrameDuration else { return nil } | ||
| let shape = logits.shape | ||
| guard shape.count == 3 else { return nil } | ||
| let numFrames = shape[1].intValue | ||
| let vocabSize = shape[2].intValue | ||
| var result: [[Float]] = [] | ||
| result.reserveCapacity(numFrames) | ||
| for t in 0..<numFrames { | ||
| var frame: [Float] = [] | ||
| frame.reserveCapacity(vocabSize) | ||
| for v in 0..<vocabSize { | ||
| frame.append(logits[[0, t, v] as [NSNumber]].floatValue) | ||
| } | ||
| result.append(frame) | ||
| } | ||
| return (rawLogits: result, frameDuration: duration) | ||
| } |
There was a problem hiding this comment.
🔴 No unit tests added for new CTC head functionality (AGENTS.md violation)
AGENTS.md states: "Add unit tests when writing new code." This PR adds significant new functionality — CTC head model loading (AsrModels.swift:219-252), CTC logit caching (AsrManager.swift:57-84), applyLogSoftmax() static method (CtcKeywordSpotter.swift:268-306), spotKeywordsFromLogProbs() (CtcKeywordSpotter.swift:191-254), convertCtcLogitsToArray() (AsrTranscription.swift:654-686), and the cached-logits integration in applyVocabularyRescoring — but no test files were added or modified in the PR.
Prompt for agents
Add unit tests for the new CTC head functionality. At minimum, create tests in Tests/FluidAudioTests/ for:
1. CtcKeywordSpotter.applyLogSoftmax() - verify it produces valid log-probabilities (sum to ~1 after exp), applies temperature scaling correctly, and applies blank bias to the correct index.
2. CtcKeywordSpotter.spotKeywordsFromLogProbs() - verify it produces the same detections as spotKeywordsWithLogProbs when given the same logProbs.
3. AsrManager cached CTC logit lifecycle - verify cachedCtcLogits is nil after resetState(), nil after cleanup(), and that getCachedCtcRawLogits() returns nil when no CTC head is loaded.
4. convertCtcLogitsToArray() - verify correct conversion from MLMultiArray shape [1, T, V] to [[Float]] with proper log-softmax application.
Was this helpful? React with 👍 or 👎 to provide feedback.
| if version == .tdtCtc110m { | ||
| // v1: Check local TDT model directory first | ||
| let repoDir = repoPath(from: directory, version: version) | ||
| let ctcHeadPath = repoDir.appendingPathComponent(Names.ctcHeadFile) | ||
| if FileManager.default.fileExists(atPath: ctcHeadPath.path) { | ||
| let ctcConfig = MLModelConfiguration() | ||
| ctcConfig.computeUnits = config.computeUnits | ||
| ctcHeadModel = try? MLModel(contentsOf: ctcHeadPath, configuration: ctcConfig) | ||
| if ctcHeadModel != nil { | ||
| logger.info("[Beta] Loaded CTC head model from local directory") | ||
| } else { | ||
| logger.warning("CTC head model found but failed to load: \(ctcHeadPath.path)") | ||
| } | ||
| } | ||
|
|
||
| // v2: Fall back to downloading from parakeet-ctc-110m HF repo | ||
| if ctcHeadModel == nil { | ||
| do { | ||
| let ctcModels = try await DownloadUtils.loadModels( | ||
| .parakeetCtc110m, | ||
| modelNames: [Names.ctcHeadFile], | ||
| directory: parentDirectory, | ||
| computeUnits: config.computeUnits, | ||
| progressHandler: progressHandler | ||
| ) | ||
| ctcHeadModel = ctcModels[Names.ctcHeadFile] | ||
| if ctcHeadModel != nil { | ||
| logger.info("[Beta] Loaded CTC head model from HF repo") | ||
| } | ||
| } catch { | ||
| logger.warning("CTC head model not available: \(error.localizedDescription)") | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
🔴 Nested if statements in CTC head loading code violate AGENTS.md control flow rule
AGENTS.md states: "Nested if statements should be absolutely avoided. Use guard statements and inverted conditions to exit early." The new CTC head loading block has 3 levels of nesting: if version == .tdtCtc110m → if FileManager.default.fileExists → if ctcHeadModel != nil. This could be restructured by extracting a helper method or using guard-based early exits.
Prompt for agents
In Sources/FluidAudio/ASR/Parakeet/AsrModels.swift lines 219-252, extract the CTC head loading logic into a separate private static method like `loadCtcHead(from directory: URL, parentDirectory: URL, config: MLModelConfiguration, progressHandler: ...)` that uses guard statements and early returns instead of nested ifs. The outer call site would become: `let ctcHeadModel = version == .tdtCtc110m ? try? await loadCtcHead(...) : nil`. Inside the helper, use guard for the file existence check and return early on failure, avoiding the 3-level nesting.
Was this helpful? React with 👍 or 👎 to provide feedback.
Summary
CtcHead.mlmodelcfrom model directory and run it on existing TDT encoder outputspotKeywordsFromLogProbs()andapplyLogSoftmax()APIs for pre-computed CTC log-probabilitiesBenchmark (772 earnings call files)
Test plan
swift build -c releasepassesCtcHead.mlmodelctoparakeet-tdt-ctc-110mrepo