Skip to content

Commit

Permalink
Add possibility to customize requests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Oct 14, 2024
1 parent 7853874 commit dec486d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 13 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ object Main extends IOApp {

If you want to make use of other effects, you have to use `OpenAI` and pass the chosen backend directly to `request.send(backend)` function.

To customize a request when using the `OpenAISyncClient`, e.g. by adding a header, or changing the timeout (via request options), you can use the `.customizeRequest` method on the client.

Example below uses `HttpClientCatsBackend` as a backend, make sure to [add it to the dependencies](https://sttp.softwaremill.com/en/latest/backends/catseffect.html)
or use backend of your choice.

Expand Down
46 changes: 38 additions & 8 deletions core/src/main/scala/sttp/openai/OpenAISyncClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ import sttp.openai.requests.vectorstore.file.VectorStoreFileResponseData.{

import java.io.File

class OpenAISyncClient private (authToken: String, backend: SyncBackend, closeClient: Boolean, baseUri: Uri) {
class OpenAISyncClient private (
authToken: String,
backend: SyncBackend,
closeClient: Boolean,
baseUri: Uri,
customizeRequest: CustomizeOpenAIRequest
) {

private val openAI = new OpenAI(authToken, baseUri)

Expand Down Expand Up @@ -801,20 +807,44 @@ class OpenAISyncClient private (authToken: String, backend: SyncBackend, closeCl
def deleteVectorStoreFile(vectorStoreId: String, fileId: String): DeleteVectorStoreFileResponse =
sendOrThrow(openAI.deleteVectorStoreFile(vectorStoreId, fileId))

/** Closes and releases resources of http client if was not provided explicitly, otherwise works no-op.
*/
/** Closes and releases resources of http client if was not provided explicitly, otherwise works no-op. */
def close(): Unit = if (closeClient) backend.close() else ()

/** Specifies a function, which will be applied to the generated request before sending it. If a function has been specified before, it
* will be applied before the given one.
*/
def customizeRequest(customize: CustomizeOpenAIRequest): OpenAISyncClient =
new OpenAISyncClient(authToken, backend, closeClient, baseUri, customizeRequest.andThen(customize))

private def sendOrThrow[A](request: Request[Either[OpenAIException, A]]): A =
request.send(backend).body match {
customizeRequest.apply(request).send(backend).body match {
case Right(value) => value
case Left(exception) => throw exception
}
}

object OpenAISyncClient {
def apply(authToken: String) = new OpenAISyncClient(authToken, DefaultSyncBackend(), true, OpenAIUris.OpenAIBaseUri)
def apply(authToken: String, backend: SyncBackend) = new OpenAISyncClient(authToken, backend, false, OpenAIUris.OpenAIBaseUri)
def apply(authToken: String, backend: SyncBackend, baseUrl: Uri) = new OpenAISyncClient(authToken, backend, false, baseUrl)
def apply(authToken: String, baseUrl: Uri) = new OpenAISyncClient(authToken, DefaultSyncBackend(), true, baseUrl)
def apply(authToken: String) =
new OpenAISyncClient(authToken, DefaultSyncBackend(), true, OpenAIUris.OpenAIBaseUri, CustomizeOpenAIRequest.Identity)
def apply(authToken: String, backend: SyncBackend) =
new OpenAISyncClient(authToken, backend, false, OpenAIUris.OpenAIBaseUri, CustomizeOpenAIRequest.Identity)
def apply(authToken: String, backend: SyncBackend, baseUrl: Uri) =
new OpenAISyncClient(authToken, backend, false, baseUrl, CustomizeOpenAIRequest.Identity)
def apply(authToken: String, baseUrl: Uri) =
new OpenAISyncClient(authToken, DefaultSyncBackend(), true, baseUrl, CustomizeOpenAIRequest.Identity)
}

trait CustomizeOpenAIRequest {
def apply[A](request: Request[Either[OpenAIException, A]]): Request[Either[OpenAIException, A]]

def andThen(customize: CustomizeOpenAIRequest): CustomizeOpenAIRequest = new CustomizeOpenAIRequest {
override def apply[A](request: Request[Either[OpenAIException, A]]): Request[Either[OpenAIException, A]] =
customize.apply(CustomizeOpenAIRequest.this(request))
}
}

object CustomizeOpenAIRequest {
val Identity: CustomizeOpenAIRequest = new CustomizeOpenAIRequest {
override def apply[A](request: Request[Either[OpenAIException, A]]): Request[Either[OpenAIException, A]] = request
}
}
39 changes: 34 additions & 5 deletions core/src/test/scala/sttp/openai/client/SyncClientSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai.OpenAISyncClient
import sttp.openai.fixtures.ErrorFixture
import sttp.openai.requests.models.ModelsResponseData._
import sttp.openai.CustomizeOpenAIRequest
import java.util.concurrent.atomic.AtomicReference
import sttp.client4.testing.ResponseStub

class SyncClientSpec extends AnyFlatSpec with Matchers with EitherValues {
for ((statusCode, expectedError) <- ErrorFixture.testData)
Expand All @@ -21,11 +24,11 @@ class SyncClientSpec extends AnyFlatSpec with Matchers with EitherValues {
val caught = intercept[OpenAIException](syncClient.getModels)

// then
caught.getClass shouldBe expectedError.getClass
caught.message shouldBe expectedError.message
caught.cause shouldBe expectedError.cause
caught.code shouldBe expectedError.code
caught.param shouldBe expectedError.param
caught.getClass shouldBe expectedError.getClass: Unit
caught.message shouldBe expectedError.message: Unit
caught.cause shouldBe expectedError.cause: Unit
caught.code shouldBe expectedError.code: Unit
caught.param shouldBe expectedError.param: Unit
caught.`type` shouldBe expectedError.`type`
}

Expand Down Expand Up @@ -67,4 +70,30 @@ class SyncClientSpec extends AnyFlatSpec with Matchers with EitherValues {
// when & then
syncClient.getModels shouldBe deserializedModels
}

"Customizing the request" should "be additive" in {
// given
val capturedRequest = new AtomicReference[GenericRequest[_, _]](null)
val syncBackendStub = DefaultSyncBackend.stub.whenAnyRequest.thenRespondF { request =>
capturedRequest.set(request)
ResponseStub.ok(sttp.openai.fixtures.ModelsGetResponse.singleModelResponse)
}
val syncClient = OpenAISyncClient(authToken = "test-token", backend = syncBackendStub)

// when
syncClient
.customizeRequest(new CustomizeOpenAIRequest {
override def apply[A](request: Request[Either[OpenAIException, A]]): Request[Either[OpenAIException, A]] =
request.header("X-Test", "test")
})
.customizeRequest(new CustomizeOpenAIRequest {
override def apply[A](request: Request[Either[OpenAIException, A]]): Request[Either[OpenAIException, A]] =
request.header("X-Test-2", "test-2")
})
.getModels: Unit

// then
capturedRequest.get().headers.find(_.is("X-Test")).map(_.value) shouldBe Some("test"): Unit
capturedRequest.get().headers.find(_.is("X-Test-2")).map(_.value) shouldBe Some("test-2")
}
}

0 comments on commit dec486d

Please sign in to comment.