Skip to content

Commit

Permalink
Make LLMLocalDownloadManager @unchecked Sendable
Browse files Browse the repository at this point in the history
  • Loading branch information
jdisho committed Dec 22, 2024
1 parent 451ea40 commit ae6eacd
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import SpeziViews
/// is of type ``LLMLocalDownloadManager/DownloadState``, containing states such as ``LLMLocalDownloadManager/DownloadState/downloading(progress:)``
/// which includes the progress of the download or ``LLMLocalDownloadManager/DownloadState/downloaded(storageUrl:)`` which indicates that the download has finished.
@Observable
public final class LLMLocalDownloadManager: NSObject {
public final class LLMLocalDownloadManager: NSObject, @unchecked Sendable {
/// An enum containing all possible states of the ``LLMLocalDownloadManager``.
public enum DownloadState: Equatable {
case idle
Expand Down Expand Up @@ -79,49 +79,47 @@ public final class LLMLocalDownloadManager: NSObject {
}

/// Starts a `URLSessionDownloadTask` to download the specified model.
@MainActor
public func startDownload() async {
if modelExist {
Task { @MainActor in
self.state = .downloaded
}
state = .downloaded
return
}

await cancelDownload()
downloadTask = Task(priority: .userInitiated) {
do {
try await downloadWithHub()
await MainActor.run {
self.state = .downloaded
}
state = .downloaded
} catch {
await MainActor.run {
self.state = .error(
AnyLocalizedError(
error: error,
defaultErrorDescription: LocalizedStringResource("LLM_DOWNLOAD_FAILED_ERROR", bundle: .atURL(from: .module))
state = .error(
AnyLocalizedError(
error: error,
defaultErrorDescription: LocalizedStringResource(
"LLM_DOWNLOAD_FAILED_ERROR",
bundle: .atURL(from: .module)
)
)
}
)
}
}
}

/// Cancels the download of a specified model via a `URLSessionDownloadTask`.
@MainActor
public func cancelDownload() async {
downloadTask?.cancel()
await MainActor.run {
self.state = .idle
}
state = .idle
}

@MainActor

private func downloadWithHub() async throws {
let repo = Hub.Repo(id: model.hubID)
let modelFiles = ["*.safetensors", "config.json"]

try await HubApi.shared.snapshot(from: repo, matching: modelFiles) { progress in
self.state = .downloading(progress: progress)
Task { @MainActor in
self.state = .downloading(progress: progress)
}
}
}
}

0 comments on commit ae6eacd

Please sign in to comment.