Skip to content

Commit

Permalink
View state and Doc improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
philippzagar committed Nov 28, 2023
1 parent c845d6b commit 27f631a
Show file tree
Hide file tree
Showing 22 changed files with 354 additions and 56 deletions.
10 changes: 7 additions & 3 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ let package = Package(
.package(url: "https://github.com/StanfordSpezi/SpeziStorage", .upToNextMinor(from: "0.5.0")),
.package(url: "https://github.com/StanfordSpezi/SpeziOnboarding", .upToNextMinor(from: "0.7.0")),
.package(url: "https://github.com/StanfordSpezi/SpeziSpeech", .upToNextMinor(from: "0.1.1")),
.package(url: "https://github.com/StanfordSpezi/SpeziChat", .upToNextMinor(from: "0.1.1"))
.package(url: "https://github.com/StanfordSpezi/SpeziChat", .upToNextMinor(from: "0.1.1")),
// .package(url: "https://github.com/StanfordSpezi/SpeziViews", .upToNextMinor(from: "0.6.2"))
.package(url: "https://github.com/StanfordSpezi/SpeziViews", branch: "feature/view-state-mapper")
],
targets: [
.target(
name: "SpeziLLM",
dependencies: [
.product(name: "Spezi", package: "Spezi"),
.product(name: "SpeziChat", package: "SpeziChat")
.product(name: "SpeziChat", package: "SpeziChat"),
.product(name: "SpeziViews", package: "SpeziViews")
]
),
.target(
Expand Down Expand Up @@ -67,7 +70,8 @@ let package = Package(
.target(
name: "SpeziLLMLocalDownload",
dependencies: [
.product(name: "SpeziOnboarding", package: "SpeziOnboarding")
.product(name: "SpeziOnboarding", package: "SpeziOnboarding"),
.product(name: "SpeziViews", package: "SpeziViews")
]
),
.target(
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ SPDX-License-Identifier: MIT

[![Build and Test](https://github.com/StanfordSpezi/SpeziLLM/actions/workflows/build-and-test.yml/badge.svg)](https://github.com/StanfordSpezi/SpeziLLM/actions/workflows/build-and-test.yml)
[![codecov](https://codecov.io/gh/StanfordSpezi/SpeziLLM/branch/main/graph/badge.svg?token=pptLyqtoNR)](https://codecov.io/gh/StanfordSpezi/SpeziLLM)
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7954213.svg)](https://doi.org/10.5281/zenodo.7954213)
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7954213.svg)](https://doi.org/10.5281/zenodo.7954213) <!-- TODO: Is this DOI still valid after the naming switch to SpeziLLM? -->
[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FStanfordSpezi%2FSpeziLLM%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/StanfordSpezi/SpeziLLM)
[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FStanfordSpezi%2FSpeziLLM%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/StanfordSpezi/SpeziLLM)

Expand Down
2 changes: 1 addition & 1 deletion Sources/SpeziLLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public protocol LLM {
/// The type of the ``LLM`` as represented by the ``LLMHostingType``.
var type: LLMHostingType { get async }
/// The state of the ``LLM`` indicated by the ``LLMState``.
var state: LLMState { get async }
@MainActor var state: LLMState { get }


/// Performs any setup-related actions for the ``LLM``.
Expand Down
2 changes: 1 addition & 1 deletion Sources/SpeziLLM/LLMError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import Foundation


/// The ``LLMError`` describes possible errors that occure during the execution of the ``LLM`` via the ``LLMRunner``.
/// The ``LLMError`` describes possible errors that occur during the execution of the ``LLM`` via the ``LLMRunner``.
public enum LLMError: LocalizedError {
/// Indicates that the local model file is not found.
case modelNotFound
Expand Down
26 changes: 26 additions & 0 deletions Sources/SpeziLLM/LLMState+OperationState.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//
// This source file is part of the Stanford Spezi open source project
//
// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md)
//
// SPDX-License-Identifier: MIT
//

import Foundation
import SpeziViews

// Needs to be in a separate file as an extension in the file of the ``LLMState`` will lead to
// the "Circular reference resolving attached macro 'Observable'" error during compiling (see https://github.com/apple/swift/issues/66450)
/// Maps the ``LLMState`` to the SpeziViews `ViewState` via the conformance to the SpeziViews `OperationState` protocol.
extension LLMState: OperationState {
public var viewState: ViewState {
switch self {
case .uninitialized, .ready:
.idle
case .generating, .loading:
.processing
case .error(let error):
.error(error)
}
}
}
7 changes: 4 additions & 3 deletions Sources/SpeziLLM/Mock/LLMMock.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ import Foundation
/// A mock SpeziLLM ``LLM`` that is used for testing and preview purposes.
public actor LLMMock: LLM {
public let type: LLMHostingType = .local
public var state: LLMState = .uninitialized
@MainActor public var state: LLMState = .uninitialized


public init() {}


public func setup(runnerConfig: LLMRunnerConfiguration) async throws {
/// Set ``LLMState`` to ready
self.state = .ready
await MainActor.run {
self.state = .ready
}
}

public func generate(prompt: String, continuation: AsyncThrowingStream<String, Error>.Continuation) async {
Expand Down
23 changes: 23 additions & 0 deletions Sources/SpeziLLM/SpeziLLM.docc/SpeziLLM.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ``SpeziLLM``

<!--
#
# This source file is part of the Stanford Spezi open source project
#
# SPDX-FileCopyrightText: 2023 Stanford University and the project authors (see CONTRIBUTORS.md)
#
# SPDX-License-Identifier: MIT
#
-->

Provides base LLM execution capabilities within the Spezi ecosystem.

## Overview

<!--@START_MENU_TOKEN@-->Text<!--@END_MENU_TOKEN@-->

## Topics

### <!--@START_MENU_TOKEN@-->Group<!--@END_MENU_TOKEN@-->

- <!--@START_MENU_TOKEN@-->``Symbol``<!--@END_MENU_TOKEN@-->
8 changes: 6 additions & 2 deletions Sources/SpeziLLM/Views/LLMChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//

import SpeziChat
import SpeziViews
import SwiftUI


Expand All @@ -16,9 +17,10 @@ public struct LLMChatView: View {
@Environment(LLMRunner.self) private var runner
/// Represents the chat content that is displayed.
@State private var chat: Chat = []
/// Indicates if the input field is disabled
/// Indicates if the input field is disabled.
@State private var inputDisabled = false

/// Indicates the state of the view, get's derived from the ``LLM/state``.
@State private var viewState: ViewState = .idle

/// A SpeziLLM ``LLM`` that is used for the text generation within the chat view
private let model: any LLM
Expand Down Expand Up @@ -46,6 +48,8 @@ public struct LLMChatView: View {
}
}
}
.map(state: model.state, to: $viewState)
.viewStateAlert(state: $viewState)
}


Expand Down
18 changes: 13 additions & 5 deletions Sources/SpeziLLMLocal/LLMLlama+Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ extension LLMLlama {
func _generate( // swiftlint:disable:this identifier_name function_body_length
prompt: String,
continuation: AsyncThrowingStream<String, Error>.Continuation
) {
self.state = .generating
) async {
await MainActor.run {
self.state = .generating
}

// Log the most important parameters of the LLM
Self.logger.debug("n_length = \(self.parameters.maxOutputLength, privacy: .public), n_ctx = \(self.contextParameters.contextWindowSize, privacy: .public), n_batch = \(self.contextParameters.batchSize, privacy: .public), n_kv_req = \(self.parameters.maxOutputLength, privacy: .public)")
Expand Down Expand Up @@ -115,7 +117,9 @@ extension LLMLlama {
if nextTokenId == llama_token_eos(self.model) || batchTokenIndex == self.parameters.maxOutputLength {
self.generatedText.append(self.EOS)
continuation.finish()
self.state = .ready
await MainActor.run {
self.state = .ready
}
return
}

Expand All @@ -136,7 +140,9 @@ extension LLMLlama {
let decodeOutput = llama_decode(self.context, batch)
if decodeOutput != 0 { // = 0 Success, > 0 Warning, < 0 Error
Self.logger.error("Decoding of generated output failed. Output: \(decodeOutput, privacy: .public)")
self.state = .error(error: .generationError)
await MainActor.run {
self.state = .error(error: .generationError)
}
continuation.finish(throwing: LLMError.generationError)
return
}
Expand All @@ -149,6 +155,8 @@ extension LLMLlama {
llama_print_timings(self.context)

continuation.finish()
self.state = .ready
await MainActor.run {
self.state = .ready
}
}
}
20 changes: 14 additions & 6 deletions Sources/SpeziLLMLocal/LLMLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public actor LLMLlama: LLM {
/// A Swift Logger that logs important information from the ``LLMLlama``.
static let logger = Logger(subsystem: "edu.stanford.spezi", category: "SpeziLLM")
public let type: LLMHostingType = .local
public var state: LLMState = .uninitialized
@MainActor public var state: LLMState = .uninitialized

/// Parameters of the llama.cpp ``LLM``.
let parameters: LLMParameters
Expand Down Expand Up @@ -60,10 +60,14 @@ public actor LLMLlama: LLM {


public func setup(runnerConfig: LLMRunnerConfiguration) async throws {
self.state = .loading
await MainActor.run {
self.state = .loading
}

guard let model = llama_load_model_from_file(modelPath.path().cString(using: .utf8), parameters.llamaCppRepresentation) else {
self.state = .error(error: LLMError.modelNotFound)
await MainActor.run {
self.state = .error(error: LLMError.modelNotFound)
}
throw LLMError.modelNotFound
}
self.model = model
Expand All @@ -72,15 +76,19 @@ public actor LLMLlama: LLM {
let trainingContextWindow = llama_n_ctx_train(model)
guard self.contextParameters.contextWindowSize <= trainingContextWindow else {
Self.logger.warning("Model was trained on only \(trainingContextWindow, privacy: .public) context tokens, not the configured \(self.contextParameters.contextWindowSize, privacy: .public) context tokens")
self.state = .error(error: LLMError.generationError)
await MainActor.run {
self.state = .error(error: LLMError.generationError)
}
throw LLMError.modelNotFound
}

self.state = .ready
await MainActor.run {
self.state = .ready
}
}

public func generate(prompt: String, continuation: AsyncThrowingStream<String, Error>.Continuation) async {
_generate(prompt: prompt, continuation: continuation)
await _generate(prompt: prompt, continuation: continuation)
}


Expand Down
23 changes: 23 additions & 0 deletions Sources/SpeziLLMLocal/SpeziLLMLocal.docc/SpeziLLMLocal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ``SpeziLLMLocal``

<!--
#
# This source file is part of the Stanford Spezi open source project
#
# SPDX-FileCopyrightText: 2023 Stanford University and the project authors (see CONTRIBUTORS.md)
#
# SPDX-License-Identifier: MIT
#
-->

Provides Large Language Model (LLM) execution capabilities on the local device.

## Overview

<!--@START_MENU_TOKEN@-->Text<!--@END_MENU_TOKEN@-->

## Topics

### <!--@START_MENU_TOKEN@-->Group<!--@END_MENU_TOKEN@-->

- <!--@START_MENU_TOKEN@-->``Symbol``<!--@END_MENU_TOKEN@-->
29 changes: 29 additions & 0 deletions Sources/SpeziLLMLocalDownload/LLMLocalDownloadError.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//
// This source file is part of the Stanford Spezi open source project
//
// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md)
//
// SPDX-License-Identifier: MIT
//

import Foundation


/// The ``LLMLocalDownloadError`` describes possible errors that occur during downloading models via the ``LLMLocalDownloadManager``.
public enum LLMLocalDownloadError: LocalizedError {
/// Indicates an unknown error during downloading the model
case unknownError


public var errorDescription: String? {
String(localized: LocalizedStringResource("LLM_DOWNLOAD_ERROR_DESCRIPTION", bundle: .atURL(from: .module)))
}

public var recoverySuggestion: String? {
String(localized: LocalizedStringResource("LLM_DOWNLOAD_ERROR_RECOVERY_SUGGESTION", bundle: .atURL(from: .module)))
}

public var failureReason: String? {
String(localized: LocalizedStringResource("LLM_DOWNLOAD_ERROR_FAILURE_REASON", bundle: .atURL(from: .module)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//
// This source file is part of the Stanford Spezi open source project
//
// SPDX-FileCopyrightText: 2022 Stanford University and the project authors (see CONTRIBUTORS.md)
//
// SPDX-License-Identifier: MIT
//

import Foundation
import SpeziViews

// Needs to be in a separate file as an extension in the file of the ``LLMLocalDownloadManager`` will lead to
// the "Circular reference resolving attached macro 'Observable'" error during compiling (see https://github.com/apple/swift/issues/66450)
/// Maps the ``LLMLocalDownloadManager/DownloadState`` to the SpeziViews `ViewState` via the conformance to the SpeziViews `OperationState` protocol.
extension LLMLocalDownloadManager.DownloadState: OperationState {
public var viewState: ViewState {
switch self {
case .idle, .downloaded:
.idle
case .downloading:
.processing
case .error(let error):
.error(error)
}
}
}
30 changes: 23 additions & 7 deletions Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@
//

import Foundation
import Observation
import SpeziViews


/// Manages the download of an LLM to the local device.
public final class LLMLocalDownloadManager: NSObject, ObservableObject {
/// The ``LLMLocalDownloadManager`` manages the download and storage of Large Language Models (LLM) to the local device.
///
/// One configures the ``LLMLocalDownloadManager`` via the ``LLMLocalDownloadManager/init(llmDownloadUrl:llmStorageUrl:)`` initializer,
/// passing a download `URL` as well as a storage `URL` to the ``LLMLocalDownloadManager``.
/// The download of a model is started via ``LLMLocalDownloadManager/startDownload()`` and can be cancelled (early) via ``LLMLocalDownloadManager/cancelDownload()``.
///
/// The current state of the ``LLMLocalDownloadManager`` is exposed via the ``LLMLocalDownloadManager/state`` property which
/// is of type ``LLMLocalDownloadManager/DownloadState``, containing cases such as ``LLMLocalDownloadManager/DownloadState/downloading(progress:)``
/// which includes the progress of the download or ``LLMLocalDownloadManager/DownloadState/downloaded(storageUrl:)`` which indicates that the download has finished.
@Observable
public final class LLMLocalDownloadManager: NSObject {
/// Defaults of possible LLMs to download via the ``LLMLocalDownloadManager``.
public enum LLMUrlDefaults {
/// LLama 2 7B model in its chat variation (~3.5GB)
Expand Down Expand Up @@ -51,8 +62,8 @@ public final class LLMLocalDownloadManager: NSObject, ObservableObject {
public enum DownloadState: Equatable {
case idle
case downloading(progress: Double)
case downloaded
case error(Error?)
case downloaded(storageUrl: URL)
case error(LocalizedError)


public static func == (lhs: LLMLocalDownloadManager.DownloadState, rhs: LLMLocalDownloadManager.DownloadState) -> Bool {
Expand All @@ -67,15 +78,15 @@ public final class LLMLocalDownloadManager: NSObject, ObservableObject {
}

/// The delegate handling the download manager tasks.
private var downloadDelegate: LLMLocalDownloadManagerDelegate? // swiftlint:disable:this weak_delegate
@ObservationIgnored private var downloadDelegate: LLMLocalDownloadManagerDelegate? // swiftlint:disable:this weak_delegate
/// The `URLSessionDownloadTask` that handles the download of the model.
private var downloadTask: URLSessionDownloadTask?
@ObservationIgnored private var downloadTask: URLSessionDownloadTask?
/// Remote `URL` from where the LLM file should be downloaded.
private let llmDownloadUrl: URL
/// Local `URL` where the downloaded model is stored.
let llmStorageUrl: URL
/// Indicates the current state of the ``LLMLocalDownloadManager``.
@MainActor @Published public var state: DownloadState = .idle
@MainActor public var state: DownloadState = .idle


/// Creates a ``LLMLocalDownloadManager`` that helps with downloading LLM files from remote servers.
Expand All @@ -102,4 +113,9 @@ public final class LLMLocalDownloadManager: NSObject, ObservableObject {

downloadTask?.resume()
}

/// Cancels the download of a specified model via a `URLSessionDownloadTask`.
public func cancelDownload() {
downloadTask?.cancel()
}
}
Loading

0 comments on commit 27f631a

Please sign in to comment.