diff --git a/speaktype/Services/ModelDownloadService.swift b/speaktype/Services/ModelDownloadService.swift index 2ae73ea..c27e21f 100644 --- a/speaktype/Services/ModelDownloadService.swift +++ b/speaktype/Services/ModelDownloadService.swift @@ -10,6 +10,7 @@ class ModelDownloadService: ObservableObject { @Published var isDownloading: [String: Bool] = [:] private var activeTasks: [String: Task] = [:] // Track running download tasks + private let selectedModelKey = "selectedModelVariant" private init() { // Force a custom cache directory to avoid "Multiple models found" conflicts @@ -18,7 +19,6 @@ class ModelDownloadService: ObservableObject { // Check for already-downloaded models on launch Task { @MainActor in await refreshDownloadedModels() - // Don't auto-select - let user explicitly pick a model which will load it } } @@ -99,15 +99,19 @@ class ModelDownloadService: ObservableObject { } } + let orderedDownloadedVariants = orderedModelVariants(from: foundModels) + await MainActor.run { // Clear all previous progress self.downloadProgress.removeAll() // Only mark models that actually exist - for variant in foundModels { + for variant in orderedDownloadedVariants { self.downloadProgress[variant] = 1.0 print("✅ Marked as downloaded: \(variant)") } + + self.synchronizeSelectedModelSelection(with: orderedDownloadedVariants) if foundModels.isEmpty { print("❌ No models found - all will show as 'Download' buttons") @@ -158,6 +162,7 @@ class ModelDownloadService: ObservableObject { DispatchQueue.main.async { self.isDownloading[variant] = false self.downloadProgress[variant] = 1.0 + self.synchronizeSelectedModelSelection(preferredVariant: variant) self.activeTasks[variant] = nil // Cleanup task } } catch { @@ -203,6 +208,7 @@ class ModelDownloadService: ObservableObject { self.isDownloading[variant] = false self.downloadProgress[variant] = 1.0 self.downloadError[variant] = nil + self.synchronizeSelectedModelSelection(preferredVariant: variant) self.activeTasks[variant] = nil } } catch { @@ -297,6 +303,7 @@ class ModelDownloadService: ObservableObject { await MainActor.run { self.downloadProgress[variant] = 0.0 self.isDownloading[variant] = false + self.synchronizeSelectedModelSelection() } return "Deleted \(deletedCount) items" } else { @@ -382,4 +389,42 @@ class ModelDownloadService: ObservableObject { formatter.countStyle = .file return formatter.string(fromByteCount: bytes) } + + private func orderedModelVariants(from variants: Set) -> [String] { + AIModel.availableModels.map(\.variant).filter { variants.contains($0) } + } + + private func synchronizeSelectedModelSelection(with downloadedVariants: [String]? = nil, preferredVariant: String? = nil) { + let resolvedDownloadedVariants: [String] + if let downloadedVariants { + resolvedDownloadedVariants = downloadedVariants + } else { + let downloadedSet = Set( + downloadProgress.compactMap { variant, progress in + progress >= 1.0 ? variant : nil + } + ) + resolvedDownloadedVariants = orderedModelVariants(from: downloadedSet) + } + + let currentSelection = UserDefaults.standard.string(forKey: selectedModelKey) ?? "" + let preferredSelection = preferredVariant.flatMap { variant in + resolvedDownloadedVariants.contains(variant) ? variant : nil + } + + let nextSelection = preferredSelection + ?? (resolvedDownloadedVariants.contains(currentSelection) ? currentSelection : resolvedDownloadedVariants.first) + + if let nextSelection { + if currentSelection != nextSelection { + print("🔁 Syncing selected model to \(nextSelection)") + } + UserDefaults.standard.set(nextSelection, forKey: selectedModelKey) + } else { + if !currentSelection.isEmpty { + print("⚠️ Clearing selected model because no downloaded models are available") + } + UserDefaults.standard.removeObject(forKey: selectedModelKey) + } + } } diff --git a/speaktype/Services/WhisperService.swift b/speaktype/Services/WhisperService.swift index f19659f..1f36c8a 100644 --- a/speaktype/Services/WhisperService.swift +++ b/speaktype/Services/WhisperService.swift @@ -63,6 +63,7 @@ class WhisperService { case fileNotFound case alreadyLoading case loadingTimeout + case invalidModelSelection(String) var errorDescription: String? { switch self { @@ -71,6 +72,11 @@ class WhisperService { case .alreadyLoading: return "Model loading already in progress" case .loadingTimeout: return "Model loading timed out — your Mac may not have enough RAM for this model" + case .invalidModelSelection(let variant): + if variant.isEmpty { + return "No AI model selected. Download or choose a model in AI Models first." + } + return "Selected AI model '\(variant)' is no longer available. Choose a downloaded model in AI Models." } } } @@ -85,6 +91,10 @@ class WhisperService { // Dynamic model loading with optimized WhisperKitConfig func loadModel(variant: String) async throws { + guard AIModel.availableModels.contains(where: { $0.variant == variant }) else { + throw TranscriptionError.invalidModelSelection(variant) + } + // Already loaded this exact model if isInitialized && variant == currentModelVariant && pipe != nil { print("✅ Model \(variant) already loaded, skipping") diff --git a/speaktype/Views/Screens/Dashboard/DashboardView.swift b/speaktype/Views/Screens/Dashboard/DashboardView.swift index 6255905..08cc8b4 100644 --- a/speaktype/Views/Screens/Dashboard/DashboardView.swift +++ b/speaktype/Views/Screens/Dashboard/DashboardView.swift @@ -15,7 +15,7 @@ struct DashboardView: View { @EnvironmentObject var trialManager: TrialManager @EnvironmentObject var licenseManager: LicenseManager - @AppStorage("selectedModelVariant") private var selectedModel: String = "openai_whisper-base" + @AppStorage("selectedModelVariant") private var selectedModel: String = "" @AppStorage("transcriptionLanguage") private var transcriptionLanguage: String = "auto" @State private var showFileImporter = false @State private var isTranscribing = false diff --git a/speaktype/Views/Screens/Settings/AIModelsView.swift b/speaktype/Views/Screens/Settings/AIModelsView.swift index 05e00b6..a63d2e0 100644 --- a/speaktype/Views/Screens/Settings/AIModelsView.swift +++ b/speaktype/Views/Screens/Settings/AIModelsView.swift @@ -42,27 +42,6 @@ struct AIModelsView: View { // Refresh model download status when view appears Task { await downloadService.refreshDownloadedModels() - - // Auto-fallback: If selected model isn't downloaded, switch to first available - if !selectedModel.isEmpty { - let isSelectedModelDownloaded = - downloadService.downloadProgress[selectedModel] ?? 0.0 >= 1.0 - - if !isSelectedModelDownloaded { - // Find first downloaded model - if let firstDownloaded = downloadService.downloadProgress.first(where: { - $0.value >= 1.0 - })?.key { - print( - "⚠️ Selected model '\(selectedModel)' not found. Auto-switching to '\(firstDownloaded)'" - ) - selectedModel = firstDownloaded - } else { - print("⚠️ No models downloaded. Please download a model to use the app.") - selectedModel = "" // Clear invalid selection - } - } - } } } }