diff --git a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift index a022e2b..e2e2415 100644 --- a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift +++ b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift @@ -23,7 +23,7 @@ import SpeziViews /// is of type ``LLMLocalDownloadManager/DownloadState``, containing states such as ``LLMLocalDownloadManager/DownloadState/downloading(progress:)`` /// which includes the progress of the download or ``LLMLocalDownloadManager/DownloadState/downloaded(storageUrl:)`` which indicates that the download has finished. @Observable -public final class LLMLocalDownloadManager: NSObject { +public final class LLMLocalDownloadManager: NSObject, @unchecked Sendable { /// An enum containing all possible states of the ``LLMLocalDownloadManager``. public enum DownloadState: Equatable { case idle @@ -79,11 +79,10 @@ public final class LLMLocalDownloadManager: NSObject { } /// Starts a `URLSessionDownloadTask` to download the specified model. + @MainActor public func startDownload() async { if modelExist { - Task { @MainActor in - self.state = .downloaded - } + state = .downloaded return } @@ -91,37 +90,36 @@ public final class LLMLocalDownloadManager: NSObject { downloadTask = Task(priority: .userInitiated) { do { try await downloadWithHub() - await MainActor.run { - self.state = .downloaded - } + state = .downloaded } catch { - await MainActor.run { - self.state = .error( - AnyLocalizedError( - error: error, - defaultErrorDescription: LocalizedStringResource("LLM_DOWNLOAD_FAILED_ERROR", bundle: .atURL(from: .module)) + state = .error( + AnyLocalizedError( + error: error, + defaultErrorDescription: LocalizedStringResource( + "LLM_DOWNLOAD_FAILED_ERROR", + bundle: .atURL(from: .module) ) ) - } + ) } } } /// Cancels the download of a specified model via a `URLSessionDownloadTask`. + @MainActor public func cancelDownload() async { downloadTask?.cancel() - await MainActor.run { - self.state = .idle - } + state = .idle } - - @MainActor + private func downloadWithHub() async throws { let repo = Hub.Repo(id: model.hubID) let modelFiles = ["*.safetensors", "config.json"] try await HubApi.shared.snapshot(from: repo, matching: modelFiles) { progress in - self.state = .downloading(progress: progress) + Task { @MainActor in + self.state = .downloading(progress: progress) + } } } }