diff --git a/Documentation/ASR/CustomVocabulary.md b/Documentation/ASR/CustomVocabulary.md index b68b55ef2..715eb9cfc 100644 --- a/Documentation/ASR/CustomVocabulary.md +++ b/Documentation/ASR/CustomVocabulary.md @@ -17,6 +17,58 @@ The paper introduces a dynamic programming algorithm for CTC-based keyword spott ## Architecture Overview +FluidAudio supports two approaches for CTC-based custom vocabulary boosting: + +### Approach 1: Standalone CTC Head (Beta, Recommended for TDT-CTC-110M) + +``` + ┌─────────────────────────────────────────┐ + │ Audio Input │ + │ (16kHz, mono) │ + └─────────────────┬───────────────────────┘ + │ + ▼ + ┌─────────────────┐ + │ TDT-CTC-110M │ + │ Preprocessor │ + │ (fused encoder) │ + └────────┬────────┘ + │ + encoder output [1, 512, T] + │ + ┌──────────────┴──────────────┐ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ TDT Decoder │ │ CTC Head │ + │ + Joint Network│ │ (1MB, beta) │ + └────────┬────────┘ └────────┬────────┘ + │ │ + ▼ ctc_logits [1, T, 1025] + ┌─────────────────┐ │ + │ Raw Transcript│ ▼ + │ "in video corp"│ ┌─────────────────┐ + └────────┬────────┘ Custom │ Keyword Spotter │ + │ Vocabulary►│ (DP Algorithm) │ + │ └────────┬────────┘ + └──────────────┬──────────────┘ + ▼ + ┌─────────────────┐ + │ Vocabulary │ + │ Rescorer │ + └────────┬────────┘ + │ + ▼ + ┌─────────────────┐ + │ Final Transcript│ + │ "NVIDIA Corp" │ + └─────────────────┘ +``` + +The standalone CTC head is a single linear projection (512 → 1025) extracted from the hybrid TDT-CTC-110M model. It reuses the TDT encoder output, requiring only ~1MB of additional model weight and no second encoder pass. + +### Approach 2: Separate CTC Encoder (Original) + ``` ┌─────────────────────────────────────────┐ │ Audio Input │ @@ -58,24 +110,37 @@ The paper introduces a dynamic programming algorithm for CTC-based keyword spott └─────────────────┘ ``` -## Dual Encoder Alignment +### Approach Comparison + +| | Standalone CTC Head (beta) | Separate CTC Encoder | +|---|---|---| +| **Additional model size** | 1 MB | 97.5 MB | +| **Second encoder pass** | No | Yes | +| **RTFx (earnings benchmark)** | 70.29x | 25.98x | +| **Dict Recall** | 99.4% | 99.4% | +| **TDT model requirement** | TDT-CTC-110M only | Any TDT model | +| **Status** | Beta | Stable | + +The standalone CTC head is available only with the TDT-CTC-110M model because both the TDT and CTC heads share the same encoder in the hybrid architecture. For Parakeet TDT v2/v3 (0.6B), the separate CTC encoder approach is required. + +## Encoder Alignment + +### Separate CTC Encoder (Approach 2) The system uses two separate neural network encoders that process the same audio: -### 1. TDT Encoder (Primary Transcription) +#### TDT Encoder (Primary Transcription) - **Model**: Parakeet TDT 0.6B (600M parameters) - **Architecture**: Token Duration Transducer with FastConformer - **Output**: High-quality transcription with word timestamps - **Frame Rate**: ~40ms per frame -### 2. CTC Encoder (Keyword Spotting) +#### CTC Encoder (Keyword Spotting) - **Model**: Parakeet CTC 110M (110M parameters) - **Architecture**: FastConformer with CTC head - **Output**: Per-frame log-probabilities over 1024 tokens - **Frame Rate**: ~40ms per frame (aligned with TDT) -### Frame Alignment - Both encoders use the same audio preprocessing (mel spectrogram with identical parameters), producing frames at the same rate. This enables direct timestamp comparison between: - TDT decoder word timestamps - CTC keyword detection timestamps @@ -88,18 +153,20 @@ CTC Frames: [0] [1] [2] ... [374] (375 frames @ 40ms) Aligned timestamps ``` -### Memory Usage +#### Memory Usage Running two encoders in parallel increases peak memory consumption: | Configuration | Peak RAM | Notes | |---------------|----------|-------| | TDT encoder only | ~66 MB | Standard transcription | -| TDT + CTC encoders | ~130 MB | With vocabulary boosting | +| TDT + CTC encoders | ~130 MB | With vocabulary boosting (separate encoder) | +| TDT + CTC head | ~67 MB | With vocabulary boosting (standalone head, beta) | *Measured on iPhone 17 Pro. Memory settles after initial model loading.* -The additional ~64 MB overhead comes from the CTC encoder (Parakeet 110M) being loaded alongside the primary TDT encoder. For memory-constrained scenarios, consider: +The standalone CTC head adds negligible memory (~1MB) since it reuses the existing encoder output. The separate CTC encoder adds ~64MB overhead. For memory-constrained scenarios, consider: +- Using the standalone CTC head with TDT-CTC-110M (beta) - Loading the CTC encoder on-demand rather than at startup - Unloading the CTC encoder after transcription completes - Using vocabulary boosting only for files where domain terms are expected diff --git a/Documentation/ASR/TDT-CTC-110M.md b/Documentation/ASR/TDT-CTC-110M.md index 894efebae..c628a06f4 100644 --- a/Documentation/ASR/TDT-CTC-110M.md +++ b/Documentation/ASR/TDT-CTC-110M.md @@ -465,9 +465,78 @@ Tested on iPhone (iOS 17+): - Highest accuracy required - Extra model size acceptable +## Standalone CTC Head for Custom Vocabulary (Beta) + +The TDT-CTC-110M hybrid model shares one FastConformer encoder between its TDT and CTC decoder heads. FluidAudio exploits this by exporting the CTC decoder head as a standalone 1MB CoreML model (`CtcHead.mlmodelc`) that runs on the existing TDT encoder output, enabling custom vocabulary keyword spotting without a second encoder pass. + +### How It Works + +``` +TDT Preprocessor (fused encoder) + │ + ▼ +encoder output [1, 512, T] + │ + ┌────┴────┐ + │ │ + ▼ ▼ +TDT Decoder CtcHead (1MB, beta) + │ │ + ▼ ▼ +transcript ctc_logits [1, T, 1025] + │ + ▼ + Keyword Spotter / VocabularyRescorer +``` + +The CTC head is a single linear projection (512 → 1025) that maps the 512-dimensional encoder features to log-probabilities over 1024 BPE tokens + 1 blank token. + +### Performance + +Benchmarked on 772 earnings call files (Earnings22-KWS): + +| Approach | Model Size | Dict Recall | RTFx | +|----------|-----------|-------------|------| +| Separate CTC encoder | 97.5 MB | 99.4% | 25.98x | +| **Standalone CTC head** | **1 MB** | **99.4%** | **70.29x** | + +The standalone CTC head achieves identical keyword detection quality at 2.7x the speed, using 97x less model weight. + +### Loading + +The CTC head model auto-downloads from [FluidInference/parakeet-ctc-110m-coreml](https://huggingface.co/FluidInference/parakeet-ctc-110m-coreml) when loading the TDT-CTC-110M model. It also supports manual placement in the TDT model directory. + +Two loading paths are supported: +1. **Local (v1):** Place `CtcHead.mlmodelc` in the TDT model directory (`parakeet-tdt-ctc-110m/`) +2. **Auto-download (v2):** Automatically downloaded from the `parakeet-ctc-110m-coreml` HuggingFace repo + +```swift +// CTC head loads automatically with TDT-CTC-110M models +let models = try await AsrModels.downloadAndLoad(version: .tdtCtc110m) +// models.ctcHead is non-nil when CtcHead.mlmodelc is available +``` + +### Conversion + +The CTC head is exported using the conversion script in the mobius repo: + +```bash +cd mobius/models/stt/parakeet-tdt-ctc-110m/coreml/ +uv run python export-ctc-head.py --output-dir ./ctc-head-build +xcrun coremlcompiler compile ctc-head-build/CtcHead.mlpackage ctc-head-build/ +``` + +See [mobius PR #36](https://github.com/FluidInference/mobius/pull/36) for the conversion script. + +### Status + +This feature is **beta**. The CTC head produces identical keyword detection results to the separate CTC encoder, but the auto-download pathway and integration are new. See [#435](https://github.com/FluidInference/FluidAudio/issues/435) and [PR #450](https://github.com/FluidInference/FluidAudio/pull/450) for details. + ## Resources - **Model:** [FluidInference/parakeet-tdt-ctc-110m-coreml](https://huggingface.co/FluidInference/parakeet-tdt-ctc-110m-coreml) +- **CTC Head model:** [FluidInference/parakeet-ctc-110m-coreml](https://huggingface.co/FluidInference/parakeet-ctc-110m-coreml) (includes CtcHead.mlmodelc) - **Benchmark results:** See `benchmarks.md` - **PR:** [#433 - Add TDT-CTC-110M support](https://github.com/FluidInference/FluidAudio/pull/433) +- **CTC Head PR:** [#450 - Add standalone CTC head for custom vocabulary](https://github.com/FluidInference/FluidAudio/pull/450) - **Original NVIDIA model:** [nvidia/parakeet-tdt-1.1b](https://huggingface.co/nvidia/parakeet-tdt-1.1b) diff --git a/Documentation/ASR/benchmarks100.md b/Documentation/ASR/benchmarks100.md index be3436ac2..6220c153e 100644 --- a/Documentation/ASR/benchmarks100.md +++ b/Documentation/ASR/benchmarks100.md @@ -41,3 +41,15 @@ Benchmark comparison between `main` and PR #440 (`standardize-asr-directory-stru ## Verdict **No regressions.** WER is identical across all 6 benchmarks. RTFx differences are within normal system noise (M2 thermals, background processes). The directory restructuring is a pure file move with no behavioral changes. + +## Issue #435: Standalone CTC Head for Custom Vocabulary (Beta) + +Benchmark comparing separate CTC encoder vs standalone CTC head extracted from the TDT-CTC-110M hybrid model. +See [#435](https://github.com/FluidInference/FluidAudio/issues/435) and [PR #450](https://github.com/FluidInference/FluidAudio/pull/450). + +| Metric | Separate CTC (v2 TDT) | Separate CTC (110m TDT) | Standalone CTC Head (110m TDT) | +|---|---|---|---| +| Dict Recall | 99.3% | 99.4% | 99.4% | +| RTFx | 43.94x | 25.98x | 70.29x | +| Additional model size | 97.5 MB | 97.5 MB | 1 MB | + diff --git a/Sources/FluidAudio/ASR/Parakeet/AsrManager.swift b/Sources/FluidAudio/ASR/Parakeet/AsrManager.swift index 503a494b3..5d82a8678 100644 --- a/Sources/FluidAudio/ASR/Parakeet/AsrManager.swift +++ b/Sources/FluidAudio/ASR/Parakeet/AsrManager.swift @@ -53,6 +53,37 @@ public actor AsrManager { internal var vocabSizeConfig: ContextBiasingConstants.VocabSizeConfig? internal var vocabBoostingEnabled: Bool { customVocabulary != nil && vocabularyRescorer != nil } + // Cached CTC logits from fused Preprocessor (unified custom vocabulary) + internal var cachedCtcLogits: MLMultiArray? + internal var cachedCtcFrameDuration: Double? + internal var cachedCtcValidFrames: Int? + + /// Whether the Preprocessor outputs CTC logits (unified custom vocabulary model). + public var hasCachedCtcLogits: Bool { cachedCtcLogits != nil } + + /// Get cached CTC raw logits as [[Float]] for external use (e.g. benchmarks). + /// These are raw logits — callers must apply `CtcKeywordSpotter.applyLogSoftmax()` + /// to convert to log-probabilities before use in keyword detection. + /// Returns nil if the CTC head model is not available or audio was multi-chunk. + public func getCachedCtcRawLogits() -> (rawLogits: [[Float]], frameDuration: Double)? { + guard let logits = cachedCtcLogits, let duration = cachedCtcFrameDuration else { return nil } + let shape = logits.shape + guard shape.count == 3 else { return nil } + let numFrames = min(shape[1].intValue, cachedCtcValidFrames ?? shape[1].intValue) + let vocabSize = shape[2].intValue + var result: [[Float]] = [] + result.reserveCapacity(numFrames) + for t in 0.. ASRResult { - guard let spotter = ctcSpotter, - let rescorer = vocabularyRescorer, + guard let rescorer = vocabularyRescorer, let vocab = customVocabulary, let tokenTimings = result.tokenTimings, !tokenTimings.isEmpty else { @@ -549,13 +581,30 @@ extension AsrManager { } do { - let spotResult = try await spotter.spotKeywordsWithLogProbs( - audioSamples: audioSamples, - customVocabulary: vocab, - minScore: nil - ) + // Try to use cached CTC logits from unified Preprocessor first + let logProbs: [[Float]] + let frameDuration: Double + + if let cached = cachedCtcLogits, let duration = cachedCtcFrameDuration { + // Convert MLMultiArray to [[Float]] + logProbs = convertCtcLogitsToArray(cached) + frameDuration = duration + logger.debug("Using cached CTC logits from Preprocessor (unified model)") + } else if let spotter = ctcSpotter { + // Fallback: run separate CTC encoder + let spotResult = try await spotter.spotKeywordsWithLogProbs( + audioSamples: audioSamples, + customVocabulary: vocab, + minScore: nil + ) + logProbs = spotResult.logProbs + frameDuration = spotResult.frameDuration + logger.debug("Using separate CTC encoder (legacy dual-model approach)") + } else { + logger.warning("Vocabulary rescoring skipped: no CTC logits available") + return result + } - let logProbs = spotResult.logProbs guard !logProbs.isEmpty else { logger.debug("Vocabulary rescoring skipped: no log probs from CTC") return result @@ -570,7 +619,7 @@ extension AsrManager { transcript: result.text, tokenTimings: tokenTimings, logProbs: logProbs, - frameDuration: spotResult.frameDuration, + frameDuration: frameDuration, cbw: vocabConfig.cbw, marginSeconds: 0.5, minSimilarity: effectiveMinSimilarity @@ -600,4 +649,41 @@ extension AsrManager { } } + /// Convert CTC logits MLMultiArray to log-probabilities [[Float]] for rescoring. + /// Applies log-softmax with temperature scaling and blank bias to match + /// the processing done in `CtcKeywordSpotter.computeLogProbs`. + private func convertCtcLogitsToArray(_ ctcLogits: MLMultiArray) -> [[Float]] { + // Expected shape: [1, T, V] where T = frames, V = vocab size + let shape = ctcLogits.shape + guard shape.count == 3 else { + logger.warning("Unexpected CTC logits shape: \(shape)") + return [] + } + + let numFrames = min(shape[1].intValue, cachedCtcValidFrames ?? shape[1].intValue) + let vocabSize = shape[2].intValue + + // Extract raw logits + var rawLogits: [[Float]] = [] + rawLogits.reserveCapacity(numFrames) + + for t in 0.. SpotKeywordsResult { + let totalFrames = logProbs.count + guard totalFrames > 0 else { + return SpotKeywordsResult(detections: [], logProbs: [], frameDuration: 0, totalFrames: 0) + } + + var results: [KeywordDetection] = [] + + for term in customVocabulary.terms { + guard term.text.count >= customVocabulary.minTermLength else { + if debugMode { + logger.debug( + " Skipping '\(term.text)': too short (\(term.text.count) < \(customVocabulary.minTermLength) chars)" + ) + } + continue + } + + let ids = term.ctcTokenIds ?? term.tokenIds + guard let ids, !ids.isEmpty else { continue } + + let tokenCount = ids.count + let adjustedThreshold: Float = + minScore.map { base in + let extraTokens = max(0, tokenCount - ContextBiasingConstants.baselineTokenCountForThreshold) + return base - Float(extraTokens) * ContextBiasingConstants.thresholdRelaxationPerToken + } ?? ContextBiasingConstants.defaultMinSpotterScore + + let multipleDetections = ctcWordSpotMultiple( + logProbs: logProbs, + keywordTokens: ids, + minScore: adjustedThreshold, + mergeOverlap: true + ) + + for (score, start, end) in multipleDetections { + let startTime = TimeInterval(start) * frameDuration + let endTime = TimeInterval(end) * frameDuration + + let detection = KeywordDetection( + term: term, + score: score, + totalFrames: totalFrames, + startFrame: start, + endFrame: end, + startTime: startTime, + endTime: endTime + ) + results.append(detection) + } + } + + return SpotKeywordsResult( + detections: results, + logProbs: logProbs, + frameDuration: frameDuration, + totalFrames: totalFrames + ) + } + + // MARK: - Log-Probability Conversion + + /// Convert raw CTC logits to log-probabilities with temperature scaling and blank bias. + /// Use this to post-process raw logits from a unified Preprocessor before passing to + /// `spotKeywordsFromLogProbs` or `VocabularyRescorer.ctcTokenRescore`. + /// + /// - Parameters: + /// - rawLogits: Raw CTC logits [T, V] (before softmax). + /// - blankId: Index of the blank token in the vocabulary. + /// - temperature: Temperature for softmax scaling (default from ContextBiasingConstants). + /// - blankBias: Penalty applied to blank token log-probability (default from ContextBiasingConstants). + /// - Returns: Log-probabilities [T, V] after log-softmax, temperature, and blank bias. + public static func applyLogSoftmax( + rawLogits: [[Float]], + blankId: Int, + temperature: Float = ContextBiasingConstants.ctcTemperature, + blankBias: Float = ContextBiasingConstants.blankBias + ) -> [[Float]] { + var logProbs = [[Float]]() + logProbs.reserveCapacity(rawLogits.count) + + for logits in rawLogits { + guard !logits.isEmpty else { + logProbs.append([]) + continue + } + + // Temperature scaling + let scaled = temperature != 1.0 ? logits.map { $0 / temperature } : logits + + // Log-softmax + let maxVal = scaled.max() ?? 0 + var sumExp: Float = 0 + for v in scaled { sumExp += expf(v - maxVal) } + let logSumExp = logf(sumExp) + + var row = [Float](repeating: 0, count: scaled.count) + for i in 0.. = [ preprocessorFile, diff --git a/Sources/FluidAudioCLI/Commands/ASR/Parakeet/SlidingWindow/CtcEarningsBenchmark.swift b/Sources/FluidAudioCLI/Commands/ASR/Parakeet/SlidingWindow/CtcEarningsBenchmark.swift index ec2d3ec35..05d774b0e 100644 --- a/Sources/FluidAudioCLI/Commands/ASR/Parakeet/SlidingWindow/CtcEarningsBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/ASR/Parakeet/SlidingWindow/CtcEarningsBenchmark.swift @@ -90,8 +90,15 @@ public enum CtcEarningsBenchmark { } case "--tdt-version": if i + 1 < arguments.count { - if arguments[i + 1] == "v2" || arguments[i + 1] == "2" { + switch arguments[i + 1].lowercased() { + case "v2", "2": tdtVersion = .v2 + case "v3", "3": + tdtVersion = .v3 + case "110m", "ctc-110m", "tdt-ctc-110m": + tdtVersion = .tdtCtc110m + default: + break } i += 1 } @@ -144,7 +151,7 @@ public enum CtcEarningsBenchmark { print("Earnings Benchmark (TDT transcription + CTC keyword spotting)") print(" Data directory: \(dataDir ?? "not found")") print(" Output file: \(outputFile)") - print(" TDT version: \(tdtVersion == .v2 ? "v2" : "v3")") + print(" TDT version: \(tdtVersion == .v2 ? "v2" : tdtVersion == .tdtCtc110m ? "110m" : "v3")") print(" CTC variant: \(ctcVariant.displayName)") print(" CTC model: \(ctcModelPath ?? "not found")") print(" Keywords mode: \(keywordsMode.rawValue)") @@ -171,7 +178,9 @@ public enum CtcEarningsBenchmark { do { // Load TDT models for transcription - print("Loading TDT models (\(tdtVersion == .v2 ? "v2" : "v3")) for transcription...") + print( + "Loading TDT models (\(tdtVersion == .v2 ? "v2" : tdtVersion == .tdtCtc110m ? "110m" : "v3")) for transcription..." + ) let tdtModels = try await AsrModels.downloadAndLoad(version: tdtVersion) let asrManager = AsrManager(config: .default) try await asrManager.initialize(models: tdtModels) @@ -499,22 +508,29 @@ public enum CtcEarningsBenchmark { let customVocab = CustomVocabularyContext(terms: vocabTerms) // 3. CTC keyword spotting for high recall dictionary detection - let spotResult = try await spotter.spotKeywordsWithLogProbs( - audioSamples: samples, - customVocabulary: customVocab, - minScore: nil - ) - - // Debug: Show CTC detections with timestamps - if debugTimings && !spotResult.detections.isEmpty { - print(" CTC Detections:") - for detection in spotResult.detections { - print( - " [\(String(format: "%.2f", detection.startTime))-\(String(format: "%.2f", detection.endTime))s] \"\(detection.term.text)\" (score: \(String(format: "%.2f", detection.score)))" - ) - } + // Use cached CTC logits from unified Preprocessor if available (no separate encoder run needed) + let logProbs: [[Float]] + let frameDuration: Double + if let cached = await asrManager.getCachedCtcRawLogits() { + // Cached values are raw logits - apply log-softmax + temperature + blank bias + logProbs = CtcKeywordSpotter.applyLogSoftmax( + rawLogits: cached.rawLogits, + blankId: spotter.blankId + ) + frameDuration = cached.frameDuration + } else { + let spotResult = try await spotter.spotKeywordsWithLogProbs( + audioSamples: samples, + customVocabulary: customVocab, + minScore: nil + ) + logProbs = spotResult.logProbs + frameDuration = spotResult.frameDuration } + // Debug: Show CTC detections with timestamps (only available with separate spotter path) + // When using cached CTC logits, detections are not available + // 4. Post-process: Use VocabularyRescorer with timestamp-based matching (NeMo CTC-WS) // Set USE_TIMESTAMP_RESCORING=1 to use timestamp-based matching (default) // Set USE_TIMESTAMP_RESCORING=0 to use legacy string-similarity based matching @@ -558,8 +574,8 @@ public enum CtcEarningsBenchmark { let rescoreResult = rescorer.ctcTokenRescore( transcript: tdtResult.text, tokenTimings: tokenTimings, - logProbs: spotResult.logProbs, - frameDuration: spotResult.frameDuration, + logProbs: logProbs, + frameDuration: frameDuration, cbw: cbw, marginSeconds: 0.5, minSimilarity: minSimilarity @@ -602,19 +618,26 @@ public enum CtcEarningsBenchmark { let checkWordsLowerSet = Set(checkWords.map { $0.lowercased() }) // 1. CTC detections (deduplicate - only count each word once, only if in checkWords) + // Reuse pre-computed logProbs for keyword detection (avoids duplicate CTC inference) + let spotResult = spotter.spotKeywordsFromLogProbs( + logProbs: logProbs, + frameDuration: frameDuration, + customVocabulary: customVocab, + minScore: nil + ) + for detection in spotResult.detections { let detail: [String: Any] = [ "word": detection.term.text, "score": round(Double(detection.score) * 100) / 100, "startTime": round(detection.startTime * 100) / 100, "endTime": round(detection.endTime * 100) / 100, - "source": "ctc", + "source": await asrManager.hasCachedCtcLogits ? "ctc-head" : "ctc", ] detectionDetails.append(detail) if detection.score >= minCtcScore { let wordLower = detection.term.text.lowercased() - // Only count if word is in checkWords and not already counted if checkWordsLowerSet.contains(wordLower) && !ctcFoundWords.contains(wordLower) { dictFound += 1 ctcFoundWords.insert(wordLower)