From 72ea0d258835b76ca88c4576086f3bc89447c3db Mon Sep 17 00:00:00 2001 From: adamw Date: Thu, 5 Sep 2024 15:13:28 +0200 Subject: [PATCH] ChatProxy example --- build.sbt | 16 ++++ examples/src/main/resources/logback.xml | 12 +++ .../src/main/scala/examples/ChatProxy.scala | 89 +++++++++++++++++++ project/Dependencies.scala | 2 +- 4 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/resources/logback.xml create mode 100644 examples/src/main/scala/examples/ChatProxy.scala diff --git a/build.sbt b/build.sbt index 4e48cc2..e5e507c 100644 --- a/build.sbt +++ b/build.sbt @@ -23,6 +23,7 @@ lazy val allAgregates = core.projectRefs ++ pekko.projectRefs ++ akka.projectRefs ++ ox.projectRefs ++ + examples.projectRefs ++ docs.projectRefs lazy val core = (projectMatrix in file("core")) @@ -84,6 +85,21 @@ lazy val ox = (projectMatrix in file("streaming/ox")) ) .dependsOn(core % "compile->compile;test->test") +lazy val examples = (projectMatrix in file("examples")) + .jvmPlatform( + scalaVersions = scala3 + ) + .settings(commonSettings) + .settings( + libraryDependencies ++= Seq( + "com.softwaremill.sttp.tapir" %% "tapir-netty-server-sync" % "1.11.2", + "com.softwaremill.sttp.client4" %% "ox" % "4.0.0-M17", + "ch.qos.logback" % "logback-classic" % "1.5.6" + ), + publish / skip := true + ) + .dependsOn(ox) + val compileDocs: TaskKey[Unit] = taskKey[Unit]("Compiles docs module throwing away its output") compileDocs := { (docs.jvm(scala2.head) / mdoc).toTask(" --out target/sttp-openai-docs").value diff --git a/examples/src/main/resources/logback.xml b/examples/src/main/resources/logback.xml new file mode 100644 index 0000000..e6cee15 --- /dev/null +++ b/examples/src/main/resources/logback.xml @@ -0,0 +1,12 @@ + + + + + %date [%thread] %-5level %logger{36} - %msg%n + + + + + + + \ No newline at end of file diff --git a/examples/src/main/scala/examples/ChatProxy.scala b/examples/src/main/scala/examples/ChatProxy.scala new file mode 100644 index 0000000..0ffaa73 --- /dev/null +++ b/examples/src/main/scala/examples/ChatProxy.scala @@ -0,0 +1,89 @@ +package examples + +import org.slf4j.{Logger, LoggerFactory} +import ox.* +import ox.channels.Channel +import ox.either.orThrow +import sttp.client4.{DefaultSyncBackend, SyncBackend} +import sttp.openai.OpenAI +import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel} +import sttp.openai.requests.completions.chat.message.{Content, Message} +import sttp.openai.streaming.ox.* +import sttp.tapir.* +import sttp.tapir.CodecFormat.* +import sttp.tapir.server.netty.sync.{NettySyncServer, OxStreams} + +import scala.annotation.tailrec + +// + +val logger: Logger = LoggerFactory.getLogger("ChatProxy") + +// model for sending & receiving chat messages to/from the end-user +case class ChatMessage(message: String) +given Codec[String, ChatMessage, TextPlain] = Codec.string.map(ChatMessage(_))(_.message) + +// the description of the endpoint, that will be exposed: GET /chat -> WS(consuming & producing ChatMessage-s) +val chatProxyEndpoint = infallibleEndpoint.get + .in("chat") + .out(webSocketBody[ChatMessage, TextPlain, ChatMessage, TextPlain](OxStreams)) + +def chat(sttpBackend: SyncBackend, openAI: OpenAI)(using IO): OxStreams.Pipe[ChatMessage, ChatMessage] = + ox ?=> // running within a concurrency scope + incoming => { + val outgoing = Channel.bufferedDefault[ChatMessage] + + // incoming - messages sent by the end-user over the web socket + // outgoing - messages to be sent to the end user over the web socket + + // main processing loop: receives messages from the WS and queries OpenAI with the chat's history + @tailrec + def loop(history: Vector[Message]): Unit = { + val nextMessage = incoming.receive() + val nextHistory = history :+ Message.UserMessage(content = Content.TextContent(nextMessage.message)) + + // querying OpenAI with the entire chat history, as each request is stateless + val chatRequestBody: ChatBody = ChatBody( + model = ChatCompletionModel.GPT4oMini, + messages = nextHistory + ) + + // requesting a streaming completion, so that we can get back to the user as the answer is being generated + val source = openAI + .createStreamedChatCompletion(chatRequestBody) + .send(sttpBackend) + .body + .orThrow // there might be an OpenAI HTTP-error + + // a side-channel onto which we'll collect all the responses, to store it in history for subsequent messages + val gatherResponse = Channel.bufferedDefault[ChatMessage] + val gatherResponseFork = fork(gatherResponse.toList.map(_.message).mkString) // collecting the response in the background + + // extracting the response increments, sending to the outgoing channel, as well as to the side-channel + source + .mapAsView(_.orThrow.choices.head.delta.content) + .collectAsView { case Some(msg) => ChatMessage(msg) } + .alsoTo(gatherResponse) + .pipeTo(outgoing, propagateDone = false) + + val gatheredResponse = gatherResponseFork.join() + val nextNextHistory = nextHistory :+ Message.AssistantMessage(content = gatheredResponse) + + loop(nextNextHistory) + } + + // running the processing in the background, so that we can return the outgoing channel to the library ... + fork(loop(Vector.empty)) + + // ... so that the messages can be sent over the WS + outgoing + } + +object ChatProxy extends OxApp: + override def run(args: Vector[String])(using Ox, IO): ExitCode = + val openAI = new OpenAI(System.getenv("OPENAI_KEY")) + val sttpBackend = useCloseableInScope(DefaultSyncBackend()) + val chatProxyServerEndpoint = chatProxyEndpoint.handleSuccess(_ => chat(sttpBackend, openAI)) + val binding = NettySyncServer().addEndpoint(chatProxyServerEndpoint).start() + logger.info(s"Server started at ${binding.hostName}:${binding.port}") + never diff --git a/project/Dependencies.scala b/project/Dependencies.scala index ecddb0e..b4abd72 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -40,7 +40,7 @@ object Dependencies { val sttpClientOx = Seq( "com.softwaremill.sttp.client4" %% "ox" % V.sttpClient, - "com.softwaremill.ox" %% "core" % "0.3.5" + "com.softwaremill.ox" %% "core" % "0.3.6" ) val uPickle = "com.lihaoyi" %% "upickle" % V.uPickle