Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions speaktype/Services/ModelDownloadService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ModelDownloadService: ObservableObject {
@Published var isDownloading: [String: Bool] = [:]

private var activeTasks: [String: Task<Void, Never>] = [:] // Track running download tasks
private let selectedModelKey = "selectedModelVariant"

private init() {
// Force a custom cache directory to avoid "Multiple models found" conflicts
Expand All @@ -18,7 +19,6 @@ class ModelDownloadService: ObservableObject {
// Check for already-downloaded models on launch
Task { @MainActor in
await refreshDownloadedModels()
// Don't auto-select - let user explicitly pick a model which will load it
}
}

Expand Down Expand Up @@ -99,15 +99,19 @@ class ModelDownloadService: ObservableObject {
}
}

let orderedDownloadedVariants = orderedModelVariants(from: foundModels)

await MainActor.run {
// Clear all previous progress
self.downloadProgress.removeAll()

// Only mark models that actually exist
for variant in foundModels {
for variant in orderedDownloadedVariants {
self.downloadProgress[variant] = 1.0
print("✅ Marked as downloaded: \(variant)")
}

self.synchronizeSelectedModelSelection(with: orderedDownloadedVariants)

if foundModels.isEmpty {
print("❌ No models found - all will show as 'Download' buttons")
Expand Down Expand Up @@ -158,6 +162,7 @@ class ModelDownloadService: ObservableObject {
DispatchQueue.main.async {
self.isDownloading[variant] = false
self.downloadProgress[variant] = 1.0
self.synchronizeSelectedModelSelection(preferredVariant: variant)
self.activeTasks[variant] = nil // Cleanup task
}
} catch {
Expand Down Expand Up @@ -203,6 +208,7 @@ class ModelDownloadService: ObservableObject {
self.isDownloading[variant] = false
self.downloadProgress[variant] = 1.0
self.downloadError[variant] = nil
self.synchronizeSelectedModelSelection(preferredVariant: variant)
self.activeTasks[variant] = nil
}
} catch {
Expand Down Expand Up @@ -297,6 +303,7 @@ class ModelDownloadService: ObservableObject {
await MainActor.run {
self.downloadProgress[variant] = 0.0
self.isDownloading[variant] = false
self.synchronizeSelectedModelSelection()
}
return "Deleted \(deletedCount) items"
} else {
Expand Down Expand Up @@ -382,4 +389,42 @@ class ModelDownloadService: ObservableObject {
formatter.countStyle = .file
return formatter.string(fromByteCount: bytes)
}

private func orderedModelVariants(from variants: Set<String>) -> [String] {
AIModel.availableModels.map(\.variant).filter { variants.contains($0) }
}

private func synchronizeSelectedModelSelection(with downloadedVariants: [String]? = nil, preferredVariant: String? = nil) {
let resolvedDownloadedVariants: [String]
if let downloadedVariants {
resolvedDownloadedVariants = downloadedVariants
} else {
let downloadedSet = Set(
downloadProgress.compactMap { variant, progress in
progress >= 1.0 ? variant : nil
}
)
resolvedDownloadedVariants = orderedModelVariants(from: downloadedSet)
}

let currentSelection = UserDefaults.standard.string(forKey: selectedModelKey) ?? ""
let preferredSelection = preferredVariant.flatMap { variant in
resolvedDownloadedVariants.contains(variant) ? variant : nil
}

let nextSelection = preferredSelection
?? (resolvedDownloadedVariants.contains(currentSelection) ? currentSelection : resolvedDownloadedVariants.first)

if let nextSelection {
if currentSelection != nextSelection {
print("🔁 Syncing selected model to \(nextSelection)")
}
UserDefaults.standard.set(nextSelection, forKey: selectedModelKey)
} else {
if !currentSelection.isEmpty {
print("⚠️ Clearing selected model because no downloaded models are available")
}
UserDefaults.standard.removeObject(forKey: selectedModelKey)
}
}
}
10 changes: 10 additions & 0 deletions speaktype/Services/WhisperService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class WhisperService {
case fileNotFound
case alreadyLoading
case loadingTimeout
case invalidModelSelection(String)

var errorDescription: String? {
switch self {
Expand All @@ -71,6 +72,11 @@ class WhisperService {
case .alreadyLoading: return "Model loading already in progress"
case .loadingTimeout:
return "Model loading timed out — your Mac may not have enough RAM for this model"
case .invalidModelSelection(let variant):
if variant.isEmpty {
return "No AI model selected. Download or choose a model in AI Models first."
}
return "Selected AI model '\(variant)' is no longer available. Choose a downloaded model in AI Models."
}
}
}
Expand All @@ -85,6 +91,10 @@ class WhisperService {

// Dynamic model loading with optimized WhisperKitConfig
func loadModel(variant: String) async throws {
guard AIModel.availableModels.contains(where: { $0.variant == variant }) else {
throw TranscriptionError.invalidModelSelection(variant)
}

// Already loaded this exact model
if isInitialized && variant == currentModelVariant && pipe != nil {
print("✅ Model \(variant) already loaded, skipping")
Expand Down
2 changes: 1 addition & 1 deletion speaktype/Views/Screens/Dashboard/DashboardView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct DashboardView: View {
@EnvironmentObject var trialManager: TrialManager
@EnvironmentObject var licenseManager: LicenseManager

@AppStorage("selectedModelVariant") private var selectedModel: String = "openai_whisper-base"
@AppStorage("selectedModelVariant") private var selectedModel: String = ""
@AppStorage("transcriptionLanguage") private var transcriptionLanguage: String = "auto"
@State private var showFileImporter = false
@State private var isTranscribing = false
Expand Down
21 changes: 0 additions & 21 deletions speaktype/Views/Screens/Settings/AIModelsView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,6 @@ struct AIModelsView: View {
// Refresh model download status when view appears
Task {
await downloadService.refreshDownloadedModels()

// Auto-fallback: If selected model isn't downloaded, switch to first available
if !selectedModel.isEmpty {
let isSelectedModelDownloaded =
downloadService.downloadProgress[selectedModel] ?? 0.0 >= 1.0

if !isSelectedModelDownloaded {
// Find first downloaded model
if let firstDownloaded = downloadService.downloadProgress.first(where: {
$0.value >= 1.0
})?.key {
print(
"⚠️ Selected model '\(selectedModel)' not found. Auto-switching to '\(firstDownloaded)'"
)
selectedModel = firstDownloaded
} else {
print("⚠️ No models downloaded. Please download a model to use the app.")
selectedModel = "" // Clear invalid selection
}
}
}
}
}
}
Expand Down