diff --git a/speaktype/Services/ModelDownloadService.swift b/speaktype/Services/ModelDownloadService.swift index 2ae73ea..76a3852 100644 --- a/speaktype/Services/ModelDownloadService.swift +++ b/speaktype/Services/ModelDownloadService.swift @@ -2,6 +2,153 @@ import Foundation import Combine import WhisperKit +struct ModelCacheLocations { + let documentsDirectory: URL? + let applicationSupportDirectory: URL? + let cachesDirectory: URL? + let homeDirectory: URL + let temporaryDirectory: URL + + static func systemDefault(fileManager: FileManager = .default) -> ModelCacheLocations { + ModelCacheLocations( + documentsDirectory: fileManager.urls(for: .documentDirectory, in: .userDomainMask).first, + applicationSupportDirectory: fileManager.urls( + for: .applicationSupportDirectory, + in: .userDomainMask + ).first, + cachesDirectory: fileManager.urls(for: .cachesDirectory, in: .userDomainMask).first, + homeDirectory: fileManager.homeDirectoryForCurrentUser, + temporaryDirectory: fileManager.temporaryDirectory + ) + } +} + +struct ModelCacheCleanupReport { + let deletedPaths: [URL] + let checkedPaths: [URL] +} + +enum ModelCachePathResolver { + private static let directRepoPath = "huggingface/models/argmaxinc/whisperkit-coreml" + private static let hubRepoDirectoryName = "models--argmaxinc--whisperkit-coreml" + + static func candidatePaths( + for variant: String, + locations: ModelCacheLocations, + fileManager: FileManager = .default + ) -> [URL] { + var candidates = Set() + + for root in directModelRoots(from: locations) { + candidates.insert(root.appendingPathComponent(variant, isDirectory: true)) + } + + for root in hubRepoRoots(from: locations) { + for match in exactVariantDirectories(named: variant, under: root, fileManager: fileManager) + { + candidates.insert(match) + } + } + + return candidates.sorted { $0.path < $1.path } + } + + static func removeVariantDirectories( + for variant: String, + locations: ModelCacheLocations, + fileManager: FileManager = .default, + log: ((String) -> Void)? = nil + ) -> ModelCacheCleanupReport { + let candidates = candidatePaths(for: variant, locations: locations, fileManager: fileManager) + var deletedPaths: [URL] = [] + + for candidate in candidates where fileManager.fileExists(atPath: candidate.path) { + do { + try fileManager.removeItem(at: candidate) + deletedPaths.append(candidate) + log?("✅ Deleted cache: \(candidate.path)") + } catch { + log?("❌ Failed to delete \(candidate.path): \(error)") + } + } + + return ModelCacheCleanupReport(deletedPaths: deletedPaths, checkedPaths: candidates) + } + + private static func directModelRoots(from locations: ModelCacheLocations) -> [URL] { + let baseDirectories = [ + locations.documentsDirectory, + locations.applicationSupportDirectory, + locations.cachesDirectory, + ].compactMap { $0 } + + var roots = baseDirectories.map { + $0.appendingPathComponent(directRepoPath, isDirectory: true) + } + roots.append( + locations.homeDirectory + .appendingPathComponent(".cache", isDirectory: true) + .appendingPathComponent(directRepoPath, isDirectory: true) + ) + roots.append( + locations.temporaryDirectory + .appendingPathComponent(directRepoPath, isDirectory: true) + ) + return roots + } + + private static func hubRepoRoots(from locations: ModelCacheLocations) -> [URL] { + let baseDirectories = [ + locations.documentsDirectory, + locations.applicationSupportDirectory, + locations.cachesDirectory, + ].compactMap { $0 } + + var roots = baseDirectories.map { + $0.appendingPathComponent("huggingface/hub/\(hubRepoDirectoryName)", isDirectory: true) + } + roots.append( + locations.homeDirectory + .appendingPathComponent(".cache/huggingface/hub/\(hubRepoDirectoryName)", isDirectory: true) + ) + roots.append( + locations.temporaryDirectory + .appendingPathComponent("huggingface/hub/\(hubRepoDirectoryName)", isDirectory: true) + ) + return roots + } + + private static func exactVariantDirectories( + named variant: String, + under root: URL, + fileManager: FileManager = .default + ) -> [URL] { + guard fileManager.fileExists(atPath: root.path) else { return [] } + + guard + let enumerator = fileManager.enumerator( + at: root, + includingPropertiesForKeys: [.isDirectoryKey], + options: [.skipsHiddenFiles] + ) + else { + return [] + } + + var matches: [URL] = [] + + for case let url as URL in enumerator { + guard url.lastPathComponent == variant else { continue } + let values = try? url.resourceValues(forKeys: [.isDirectoryKey]) + guard values?.isDirectory == true else { continue } + matches.append(url) + enumerator.skipDescendants() + } + + return matches + } +} + class ModelDownloadService: ObservableObject { static let shared = ModelDownloadService() @@ -233,66 +380,19 @@ class ModelDownloadService: ObservableObject { // Aggressively deletes any potential cache for this variant func deleteModel(variant: String) async -> String { let fileManager = FileManager.default - let searchDirs: [FileManager.SearchPathDirectory] = [.documentDirectory, .applicationSupportDirectory, .cachesDirectory] - - // Parse variant: "openai/whisper-medium" or "openai_whisper-medium" - let variantParts = variant.split(separator: "/") - let modelName = variantParts.last ?? Substring(variant) - - // Also search for underscore version: openai_whisper-medium - let underscoreVariant = variant.replacingOccurrences(of: "/", with: "_") - - var deletedCount = 0 - var checkedPaths: [String] = [] - - print("🗑️ Searching for model caches matching: '\(modelName)' or '\(underscoreVariant)'") - - // 1. Check Standard macOS Paths - for searchDir in searchDirs { - guard let baseDir = fileManager.urls(for: searchDir, in: .userDomainMask).first else { continue } - - // Check ./huggingface/models (HuggingFace cache) - let hfModelsDir = baseDir.appendingPathComponent("huggingface/models") - checkedPaths.append(hfModelsDir.path) - deletedCount += cleanupDirectory(hfModelsDir, matchAny: [String(modelName), underscoreVariant]) - - // Check ./huggingface/hub (Alternative HF structure) - let hfHubDir = baseDir.appendingPathComponent("huggingface/hub") - checkedPaths.append(hfHubDir.path) - deletedCount += cleanupDirectory(hfHubDir, matchAny: [String(modelName), underscoreVariant]) - - // Skip the old SpeakType-specific directory (no longer used) - - // Check root directory (sometimes models are here) - deletedCount += cleanupDirectory(baseDir, matchAny: [String(modelName), underscoreVariant]) - } - - // 2. Check ~/.cache (Common for Python/Unix HF tools) - let homeDir = fileManager.homeDirectoryForCurrentUser - let dotCacheModels = homeDir.appendingPathComponent(".cache/huggingface/models") - checkedPaths.append(dotCacheModels.path) - deletedCount += cleanupDirectory(dotCacheModels, matchAny: [String(modelName), underscoreVariant]) - - let dotCacheHub = homeDir.appendingPathComponent(".cache/huggingface/hub") - checkedPaths.append(dotCacheHub.path) - deletedCount += cleanupDirectory(dotCacheHub, matchAny: [String(modelName), underscoreVariant]) - - // 3. Check Temporary Directory - let tempDir = fileManager.temporaryDirectory - let tempHf = tempDir.appendingPathComponent("huggingface") - checkedPaths.append(tempHf.path) - deletedCount += cleanupDirectory(tempHf, matchAny: [String(modelName), underscoreVariant]) - deletedCount += cleanupDirectory(tempDir, matchAny: [String(modelName), underscoreVariant]) - - // 4. Check Documents/huggingface/models/argmaxinc/whisperkit-coreml (standard location) - if let documentsDir = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first { - let whisperKitModels = documentsDir.appendingPathComponent("huggingface/models/argmaxinc/whisperkit-coreml") - checkedPaths.append(whisperKitModels.path) - deletedCount += cleanupDirectory(whisperKitModels, matchAny: [String(modelName), underscoreVariant]) - } - - print("🗑️ Cleanup complete. Deleted \(deletedCount) items from \(checkedPaths.count) locations") - + let locations = ModelCacheLocations.systemDefault(fileManager: fileManager) + let cleanupReport = ModelCachePathResolver.removeVariantDirectories( + for: variant, + locations: locations, + fileManager: fileManager, + log: { print($0) } + ) + + let deletedCount = cleanupReport.deletedPaths.count + let checkedPaths = cleanupReport.checkedPaths.map(\.path) + + print("🗑️ Cleanup complete. Deleted \(deletedCount) repo-owned model caches") + if deletedCount > 0 { await MainActor.run { self.downloadProgress[variant] = 0.0 @@ -304,34 +404,11 @@ class ModelDownloadService: ObservableObject { self.downloadProgress[variant] = 0.0 self.isDownloading[variant] = false } - return "No match for '\(modelName)' in \(checkedPaths.count) locations. checked: \(checkedPaths.map { $0.replacingOccurrences(of: homeDir.path, with: "~") }.joined(separator: ", "))" + let homePath = locations.homeDirectory.path + return "No repo-owned cache found for '\(variant)'. Checked: \(checkedPaths.map { $0.replacingOccurrences(of: homePath, with: "~") }.joined(separator: ", "))" } } - - private func cleanupDirectory(_ dir: URL, matchAny patterns: [String]) -> Int { - let fileManager = FileManager.default - guard let contents = try? fileManager.contentsOfDirectory(at: dir, includingPropertiesForKeys: nil) else { return 0 } - - var count = 0 - for url in contents { - let fileName = url.lastPathComponent - // Check if any pattern matches - let matches = patterns.contains { pattern in - fileName.contains(pattern) || fileName.contains(pattern.replacingOccurrences(of: "/", with: "--")) - } - - if matches { - do { - try fileManager.removeItem(at: url) - print("✅ Deleted cache: \(url.lastPathComponent)") - count += 1 - } catch { - print("❌ Failed to delete \(url.lastPathComponent): \(error)") - } - } - } - return count - } + func cancelDownload(for variant: String) { if let task = activeTasks[variant] { task.cancel() diff --git a/speaktypeTests/ModelDownloadServiceTests.swift b/speaktypeTests/ModelDownloadServiceTests.swift index 5e5b92a..08001bd 100644 --- a/speaktypeTests/ModelDownloadServiceTests.swift +++ b/speaktypeTests/ModelDownloadServiceTests.swift @@ -2,6 +2,19 @@ import XCTest @testable import speaktype final class ModelDownloadServiceTests: XCTestCase { + + private var tempRoot: URL! + + override func setUpWithError() throws { + tempRoot = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString) + try FileManager.default.createDirectory(at: tempRoot, withIntermediateDirectories: true) + } + + override func tearDownWithError() throws { + if let tempRoot { + try? FileManager.default.removeItem(at: tempRoot) + } + } func testInitialState() { let service = ModelDownloadService.shared @@ -13,7 +26,120 @@ final class ModelDownloadServiceTests: XCTestCase { XCTAssertNotNil(service.downloadProgress) XCTAssertNotNil(service.isDownloading) } - - // Real download tests require mocking backend connections or WhisperKit, which is out of scope for basic unit tests without dependency injection. - // We verified the model IDs in the previous verification step. + + func testCandidatePathsStayWithinRepoOwnedRoots() throws { + let locations = makeLocations() + let variant = "openai_whisper-medium" + + let directPath = try createDirectory( + at: locations.documentsDirectory! + .appendingPathComponent("huggingface/models/argmaxinc/whisperkit-coreml/\(variant)") + ) + let hubPath = try createDirectory( + at: locations.cachesDirectory! + .appendingPathComponent( + "huggingface/hub/models--argmaxinc--whisperkit-coreml/snapshots/123/\(variant)" + ) + ) + _ = try createDirectory(at: locations.documentsDirectory!.appendingPathComponent(variant)) + _ = try createDirectory( + at: locations.documentsDirectory! + .appendingPathComponent("huggingface/models/random-repo/\(variant)") + ) + + let candidatePaths = ModelCachePathResolver.candidatePaths( + for: variant, + locations: locations + ) + let normalizedCandidatePaths = Set(candidatePaths.map(normalizedPath)) + + XCTAssertTrue(normalizedCandidatePaths.contains(normalizedPath(directPath))) + XCTAssertTrue(normalizedCandidatePaths.contains(normalizedPath(hubPath))) + XCTAssertFalse( + normalizedCandidatePaths.contains( + normalizedPath(locations.documentsDirectory!.appendingPathComponent(variant)) + ) + ) + XCTAssertFalse( + normalizedCandidatePaths.contains( + normalizedPath( + locations.documentsDirectory! + .appendingPathComponent("huggingface/models/random-repo/\(variant)") + ) + ) + ) + } + + func testRemoveVariantDirectoriesOnlyDeletesExactRepoOwnedMatches() throws { + let locations = makeLocations() + let variant = "openai_whisper-medium" + + let directPath = try createDirectory( + at: locations.documentsDirectory! + .appendingPathComponent("huggingface/models/argmaxinc/whisperkit-coreml/\(variant)") + ) + let hubPath = try createDirectory( + at: locations.cachesDirectory! + .appendingPathComponent( + "huggingface/hub/models--argmaxinc--whisperkit-coreml/snapshots/abc123/\(variant)" + ) + ) + let backupPath = try createDirectory( + at: locations.documentsDirectory! + .appendingPathComponent( + "huggingface/models/argmaxinc/whisperkit-coreml/\(variant)-backup" + ) + ) + let unrelatedPath = try createDirectory( + at: locations.documentsDirectory!.appendingPathComponent("\(variant)-notes") + ) + + let report = ModelCachePathResolver.removeVariantDirectories( + for: variant, + locations: locations + ) + + XCTAssertEqual( + Set(report.deletedPaths.map(normalizedPath)), + Set([directPath, hubPath].map(normalizedPath)) + ) + XCTAssertFalse(FileManager.default.fileExists(atPath: directPath.path)) + XCTAssertFalse(FileManager.default.fileExists(atPath: hubPath.path)) + XCTAssertTrue(FileManager.default.fileExists(atPath: backupPath.path)) + XCTAssertTrue(FileManager.default.fileExists(atPath: unrelatedPath.path)) + } + + private func makeLocations() -> ModelCacheLocations { + let documents = tempRoot.appendingPathComponent("Documents", isDirectory: true) + let appSupport = tempRoot.appendingPathComponent("Application Support", isDirectory: true) + let caches = tempRoot.appendingPathComponent("Caches", isDirectory: true) + let home = tempRoot.appendingPathComponent("Home", isDirectory: true) + let temp = tempRoot.appendingPathComponent("Temp", isDirectory: true) + + return ModelCacheLocations( + documentsDirectory: documents, + applicationSupportDirectory: appSupport, + cachesDirectory: caches, + homeDirectory: home, + temporaryDirectory: temp + ) + } + + @discardableResult + private func createDirectory(at url: URL) throws -> URL { + try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true) + return url + } + + private func normalizedPath(_ url: URL) -> String { + let path = url.standardizedFileURL.path + if path.hasPrefix("/private/var/") { + return path.replacingOccurrences( + of: "/private/var/", + with: "/var/", + options: [.anchored] + ) + } + return path + } }