Skip to content

Commit

Permalink
Direct-style streaming using ox (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Sep 4, 2024
1 parent 1cc5e2c commit 80fd405
Show file tree
Hide file tree
Showing 11 changed files with 341 additions and 23 deletions.
98 changes: 87 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
![sttp-model](https://github.com/softwaremill/sttp-openai/raw/master/banner.jpg)

![sttp-openai](https://github.com/softwaremill/sttp-openai/raw/master/banner.jpg)

[![Ideas, suggestions, problems, questions](https://img.shields.io/badge/Discourse-ask%20question-blue)](https://softwaremill.community/c/tapir)
[![CI](https://github.com/softwaremill/sttp-openai/workflows/CI/badge.svg)](https://github.com/softwaremill/sttp-openai/actions?query=workflow%3ACI+branch%3Amaster)

[//]: # ([![Maven Central](https://maven-badges.herokuapp.com/maven-central/com.softwaremill.sttp.openai.svg)(https://maven-badges.herokuapp.com/maven-central/com.softwaremill.sttp.openai))

sttp is a family of Scala HTTP-related projects, and currently includes:

* [sttp client](https://github.com/softwaremill/sttp): The Scala HTTP client you always wanted!
* [sttp tapir](https://github.com/softwaremill/tapir): Typed API descRiptions
* sttp openai: this project. Non-official Scala client wrapper for OpenAI (and OpenAI-compatible) API. Use the power of ChatGPT inside your code!

## Intro
Sttp-openai uses sttp client to describe requests and responses used in OpenAI (and OpenAI-compatible) endpoints.

sttp-openai uses sttp client to describe requests and responses used in OpenAI (and OpenAI-compatible) endpoints.

## Quickstart with sbt

Expand All @@ -22,17 +23,21 @@ Add the following dependency:
"com.softwaremill.sttp.openai" %% "core" % "0.2.1"
```

sttp openai is available for Scala 2.13 and Scala 3
sttp-openai is available for Scala 2.13 and Scala 3

## Project content

OpenAI API Official Documentation https://platform.openai.com/docs/api-reference/completions

## Example

Examples are runnable using [scala-cli](https://scala-cli.virtuslab.org).

### To use ChatGPT

```scala mdoc:compile-only
//> using dep com.softwaremill.sttp.openai::core:0.2.1

import sttp.openai.OpenAISyncClient
import sttp.openai.requests.completions.chat.ChatRequestResponseData.ChatResponse
import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel}
Expand Down Expand Up @@ -82,7 +87,9 @@ object Main extends App {

Ollama with sync backend:

```scala mdoc:compile-only
```scala mdoc:compile-only
//> using dep com.softwaremill.sttp.openai::core:0.2.1

import sttp.model.Uri._
import sttp.openai.OpenAISyncClient
import sttp.openai.requests.completions.chat.ChatRequestResponseData.ChatResponse
Expand Down Expand Up @@ -134,7 +141,10 @@ object Main extends App {

Grok with cats-effect based backend:

```scala mdoc:compile-only
```scala mdoc:compile-only
//> using dep com.softwaremill.sttp.openai::core:0.2.1
//> using dep com.softwaremill.sttp.client4::cats:4.0.0-M17

import cats.effect.{ExitCode, IO, IOApp}
import sttp.client4.httpclient.cats.HttpClientCatsBackend

Expand Down Expand Up @@ -208,7 +218,10 @@ If you want to make use of other effects, you have to use `OpenAI` and pass the
Example below uses `HttpClientCatsBackend` as a backend, make sure to [add it to the dependencies](https://sttp.softwaremill.com/en/latest/backends/catseffect.html)
or use backend of your choice.

```scala mdoc:compile-only
```scala mdoc:compile-only
//> using dep com.softwaremill.sttp.openai::core:0.2.1
//> using dep com.softwaremill.sttp.client4::cats:4.0.0-M17

import cats.effect.{ExitCode, IO, IOApp}
import sttp.client4.httpclient.cats.HttpClientCatsBackend

Expand Down Expand Up @@ -271,17 +284,23 @@ object Main extends IOApp {
#### Create completion with streaming:

To enable streaming support for the Chat Completion API using server-sent events, you must include the appropriate
dependency for your chosen streaming library. We provide support for the following libraries: _Fs2_, _ZIO_, _Akka / Pekko Streams_
dependency for your chosen streaming library. We provide support for the following libraries: _fs2_, _ZIO_, _Akka / Pekko Streams_ and _Ox_.

For example, to use `Fs2` add the following import:
For example, to use `fs2` add the following dependency & import:

```scala
// sbt dependency
"com.softwaremill.sttp.openai" %% "fs2" % "0.2.1"

// import
import sttp.openai.streaming.fs2._
```

Example below uses `HttpClientFs2Backend` as a backend.
Example below uses `HttpClientFs2Backend` as a backend:

```scala mdoc:compile-only
//> using dep com.softwaremill.sttp.openai::fs2:0.2.1

import cats.effect.{ExitCode, IO, IOApp}
import fs2.Stream
import sttp.client4.httpclient.fs2.HttpClientFs2Backend
Expand Down Expand Up @@ -361,6 +380,63 @@ object Main extends IOApp {
}
```

To use direct-style streaming (requires Scala 3) add the following dependency & import:

```scala
// sbt dependency
"com.softwaremill.sttp.openai" %% "ox" % "0.2.1"

// import
import sttp.openai.streaming.ox.*
```

Example code:

```scala
//> using dep com.softwaremill.sttp.openai::ox:0.2.1
//> using dep com.softwaremill.ox::core:0.3.5

import ox.*
import ox.either.orThrow
import sttp.client4.DefaultSyncBackend
import sttp.openai.OpenAI
import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel}
import sttp.openai.requests.completions.chat.message.*
import sttp.openai.streaming.ox.*

object Main extends OxApp:
override def run(args: Vector[String])(using Ox, IO): ExitCode =
// Read your API secret-key from env variables
val apiKey = System.getenv("openai-key")

// Create an instance of OpenAISyncClient providing your API secret-key
val openAI: OpenAI = new OpenAI(apiKey)

val bodyMessages: Seq[Message] = Seq(
Message.UserMessage(
content = Content.TextContent("Hello!")
)
)

val chatRequestBody: ChatBody = ChatBody(
model = ChatCompletionModel.GPT35Turbo,
messages = bodyMessages
)

val backend = useCloseableInScope(DefaultSyncBackend())
supervised {
val source = openAI
.createStreamedChatCompletion(chatRequestBody)
.send(backend)
.body // this gives us an Either[OpenAIException, Source[ChatChunkResponse]]
.orThrow // we choose to throw any exceptions and fail the whole

source.foreach(el => println(el.orThrow))
}

ExitCode.Success
```

## Contributing

If you have a question, or hit a problem, feel free to post on our community https://softwaremill.community/c/open-source/
Expand All @@ -373,4 +449,4 @@ We offer commercial support for sttp and related technologies, as well as develo

## Copyright

Copyright (C) 2023 SoftwareMill [https://softwaremill.com](https://softwaremill.com).
Copyright (C) 2023-2024 SoftwareMill [https://softwaremill.com](https://softwaremill.com).
11 changes: 11 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ lazy val allAgregates = core.projectRefs ++
zio.projectRefs ++
pekko.projectRefs ++
akka.projectRefs ++
ox.projectRefs ++
docs.projectRefs

lazy val core = (projectMatrix in file("core"))
Expand Down Expand Up @@ -73,6 +74,16 @@ lazy val akka = (projectMatrix in file("streaming/akka"))
)
.dependsOn(core % "compile->compile;test->test")

lazy val ox = (projectMatrix in file("streaming/ox"))
.jvmPlatform(
scalaVersions = scala3
)
.settings(commonSettings)
.settings(
libraryDependencies ++= Libraries.sttpClientOx
)
.dependsOn(core % "compile->compile;test->test")

val compileDocs: TaskKey[Unit] = taskKey[Unit]("Compiles docs module throwing away its output")
compileDocs := {
(docs.jvm(scala2.head) / mdoc).toTask(" --out target/sttp-openai-docs").value
Expand Down
32 changes: 28 additions & 4 deletions core/src/main/scala/sttp/openai/OpenAI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ package sttp.openai
import sttp.client4._
import sttp.model.{Header, Uri}
import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai.json.SttpUpickleApiExtension.{asJsonSnake, asStreamSnake, asStringEither, upickleBodySerializer}
import sttp.openai.json.SttpUpickleApiExtension.{
asInputStreamStreamSnake,
asJsonSnake,
asStreamSnake,
asStringEither,
upickleBodySerializer
}
import sttp.openai.requests.assistants.AssistantsRequestBody.{CreateAssistantBody, ModifyAssistantBody}
import sttp.openai.requests.assistants.AssistantsResponseData.{AssistantData, DeleteAssistantResponse, ListAssistantsResponse}
import sttp.openai.requests.completions.CompletionsRequestBody.CompletionsBody
Expand Down Expand Up @@ -56,7 +62,7 @@ import sttp.openai.requests.vectorstore.file.VectorStoreFileResponseData.{
VectorStoreFile
}

import java.io.File
import java.io.{File, InputStream}
import java.nio.file.Paths

class OpenAI(authToken: String, baseUri: Uri = OpenAIUris.OpenAIBaseUri) {
Expand Down Expand Up @@ -265,7 +271,10 @@ class OpenAI(authToken: String, baseUri: Uri = OpenAIUris.OpenAIBaseUri) {
.body(chatBody)
.response(asJsonSnake[ChatResponse])

/** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody.
/** Creates a model response for the given chat conversation defined in chatBody.
*
* The response is streamed in chunks as server-sent events, which are returned unparsed as a binary stream, using the given streams
* implementation.
*
* [[https://platform.openai.com/docs/api-reference/chat/create]]
*
Expand All @@ -274,12 +283,27 @@ class OpenAI(authToken: String, baseUri: Uri = OpenAIUris.OpenAIBaseUri) {
* @param chatBody
* Chat request body.
*/
def createChatCompletion[S](s: Streams[S], chatBody: ChatBody): StreamRequest[Either[OpenAIException, s.BinaryStream], S] =
def createChatCompletionAsBinaryStream[S](s: Streams[S], chatBody: ChatBody): StreamRequest[Either[OpenAIException, s.BinaryStream], S] =
openAIAuthRequest
.post(openAIUris.ChatCompletions)
.body(ChatBody.withStreaming(chatBody))
.response(asStreamSnake(s))

/** Creates a model response for the given chat conversation defined in chatBody.
*
* The response is streamed in chunks as server-sent events, which are returned unparsed as a [[InputStream]].
*
* [[https://platform.openai.com/docs/api-reference/chat/create]]
*
* @param chatBody
* Chat request body.
*/
def createChatCompletionAsInputStream(chatBody: ChatBody): Request[Either[OpenAIException, InputStream]] =
openAIAuthRequest
.post(openAIUris.ChatCompletions)
.body(ChatBody.withStreaming(chatBody))
.response(asInputStreamStreamSnake)

/** Returns a list of files that belong to the user's organization.
*
* [[https://platform.openai.com/docs/api-reference/files]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai.OpenAIExceptions.OpenAIException._
import sttp.capabilities.Streams

import java.io.InputStream

/** An sttp upickle api extension that deserializes JSON with snake_case keys into case classes with fields corresponding to keys in
* camelCase and maps errors to OpenAIException subclasses.
*/
Expand All @@ -20,6 +22,11 @@ object SttpUpickleApiExtension extends SttpUpickleApi {
body.left.map(errorBody => httpToOpenAIError(HttpError(errorBody, meta.code)))
}

def asInputStreamStreamSnake: ResponseAs[Either[OpenAIException, InputStream]] =
asInputStreamUnsafe.mapWithMetadata { (body, meta) =>
body.left.map(errorBody => httpToOpenAIError(HttpError(errorBody, meta.code)))
}

def asJsonSnake[B: upickleApi.Reader: IsOption]: ResponseAs[Either[OpenAIException, B]] =
asString.mapWithMetadata(deserializeRightWithMappedExceptions(deserializeJsonSnake)).showAsJson

Expand Down
5 changes: 5 additions & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ object Dependencies {
"com.typesafe.akka" %% "akka-stream" % V.akkaStreams
)

val sttpClientOx = Seq(
"com.softwaremill.sttp.client4" %% "ox" % V.sttpClient,
"com.softwaremill.ox" %% "core" % "0.3.5"
)

val uPickle = "com.lihaoyi" %% "upickle" % V.uPickle

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ package object akka {

implicit class extension(val client: OpenAI) {

/** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody.
/** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody. The request will complete
* and the connection close only once the source is fully consumed.
*
* [[https://platform.openai.com/docs/api-reference/chat/create]]
*
Expand All @@ -28,7 +29,7 @@ package object akka {
chatBody: ChatBody
): StreamRequest[Either[OpenAIException, Source[ChatChunkResponse, Any]], AkkaStreams] =
client
.createChatCompletion(AkkaStreams, chatBody)
.createChatCompletionAsBinaryStream(AkkaStreams, chatBody)
.mapResponse(mapEventToResponse)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ package object fs2 {

implicit class extension(val client: OpenAI) {

/** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody.
/** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody. The request will complete
* and the connection close only once the source is fully consumed.
*
* [[https://platform.openai.com/docs/api-reference/chat/create]]
*
Expand All @@ -27,7 +28,7 @@ package object fs2 {
chatBody: ChatBody
): StreamRequest[Either[OpenAIException, Stream[F, ChatChunkResponse]], Fs2Streams[F]] =
client
.createChatCompletion(Fs2Streams[F], chatBody)
.createChatCompletionAsBinaryStream(Fs2Streams[F], chatBody)
.mapResponse(mapEventToResponse[F])
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package sttp.openai.streaming.ox

import ox.channels.Source
import ox.{IO, Ox}
import sttp.client4.Request
import sttp.client4.impl.ox.sse.OxServerSentEvents
import sttp.model.sse.ServerSentEvent
import sttp.openai.OpenAI
import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai.OpenAIExceptions.OpenAIException.DeserializationOpenAIException
import sttp.openai.json.SttpUpickleApiExtension.deserializeJsonSnake
import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse
import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse.DoneEvent
import sttp.openai.requests.completions.chat.ChatRequestBody.ChatBody

import java.io.InputStream

extension (client: OpenAI)
/** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody.
*
* The chunk [[Source]] can be obtained from the response within a concurrency scope (e.g. [[ox.supervised]]), and the [[IO]] capability
* must be provided. The request will complete and the connection close only once the source is fully consumed.
*
* [[https://platform.openai.com/docs/api-reference/chat/create]]
*
* @param chatBody
* Chat request body.
*/
def createStreamedChatCompletion(
chatBody: ChatBody
): Request[Either[OpenAIException, Ox ?=> IO ?=> Source[Either[DeserializationOpenAIException, ChatChunkResponse]]]] =
client
.createChatCompletionAsInputStream(chatBody)
.mapResponse(mapEventToResponse)

private def mapEventToResponse(
response: Either[OpenAIException, InputStream]
): Either[OpenAIException, Ox ?=> IO ?=> Source[Either[DeserializationOpenAIException, ChatChunkResponse]]] =
response.map(s =>
OxServerSentEvents
.parse(s)
.transform {
_.takeWhile(_ != DoneEvent)
.collect { case ServerSentEvent(Some(data), _, _, _) =>
deserializeJsonSnake[ChatChunkResponse].apply(data)
}
}
)
Loading

0 comments on commit 80fd405

Please sign in to comment.