From abeb5d72e4e7a16c4da3830a59eb58862dfda69b Mon Sep 17 00:00:00 2001 From: Krzysiek Ciesielski Date: Mon, 4 Dec 2023 16:51:10 +0100 Subject: [PATCH] MaxContentLength support for Netty pt 2 (#3337) --- doc/endpoint/security.md | 20 ++++ doc/migrating.md | 4 + doc/server/netty.md | 6 +- generated-doc/out/server/netty.md | 6 +- .../scala/sttp/tapir/ztapir/ZTapirTest.scala | 2 +- .../scala/sttp/tapir/ztapir/ZTapirTest.scala | 2 +- .../server/akkagrpc/AkkaGrpcRequestBody.scala | 2 +- .../server/akkahttp/AkkaRequestBody.scala | 2 +- .../server/armeria/ArmeriaRequestBody.scala | 2 +- .../decodefailure/DecodeFailureHandler.scala | 23 ++-- .../server/interpreter/RequestBody.scala | 7 +- .../interpreter/ServerInterpreter.scala | 31 +++--- .../server/model/EndpointExtensions.scala | 36 +++++++ .../scala/sttp/tapir/server/TestUtil.scala | 2 +- .../server/finatra/FinatraRequestBody.scala | 2 +- .../server/http4s/Http4sRequestBody.scala | 2 +- .../jdkhttp/internal/JdkHttpRequestBody.scala | 2 +- .../server/netty/cats/NettyCatsServer.scala | 6 +- .../cats/NettyCatsServerInterpreter.scala | 6 +- .../cats/internal/Fs2StreamCompatible.scala | 74 +++++++++++++ .../cats/internal/NettyCatsRequestBody.scala | 33 ++++++ .../netty/internal/NettyCatsRequestBody.scala | 65 ------------ .../internal/NettyCatsToResponseBody.scala | 91 ---------------- .../netty/cats/NettyCatsServerTest.scala | 2 +- .../cats/NettyCatsTestServerInterpreter.scala | 7 +- .../netty/loom/NettyIdRequestBody.scala | 30 ++++++ .../server/netty/loom/NettyIdServer.scala | 5 +- .../netty/loom/NettyIdServerInterpreter.scala | 5 +- .../server/netty/loom/NettyIdServerTest.scala | 6 +- .../loom/NettyIdTestServerInterpreter.scala | 2 +- .../sttp/tapir/server/netty/NettyConfig.scala | 30 +----- .../server/netty/NettyFutureServer.scala | 6 +- .../netty/NettyFutureServerInterpreter.scala | 5 +- .../internal/NettyFutureRequestBody.scala | 32 ++++++ .../netty/internal/NettyRequestBody.scala | 100 ++++++++++++------ .../netty/internal/NettyServerHandler.scala | 64 +++-------- .../internal/NettyServerInterpreter.scala | 10 +- .../internal/NettyStreamingRequestBody.scala | 24 +++++ .../netty/internal/NettyToResponseBody.scala | 67 +++++------- .../internal/NettyToStreamsResponseBody.scala | 41 +++---- .../netty/internal/StreamCompatible.scala | 24 +++-- .../reactivestreams/FileRangePublisher.scala | 95 +++++++++++++++++ .../FileWriterSubscriber.scala | 82 ++++++++++++++ .../InputStreamPublisher.scala | 85 +++++++++++++++ .../LimitedLengthSubscriber.scala | 38 +++++++ .../reactivestreams/PromisingSubscriber.scala | 9 ++ .../reactivestreams/SimpleSubscriber.scala | 67 ++++++++++++ .../server/netty/NettyFutureServerTest.scala | 5 +- .../NettyFutureTestServerInterpreter.scala | 2 +- .../netty/internal/NettyZioRequestBody.scala | 61 ----------- .../server/netty/zio/NettyZioServer.scala | 5 +- .../netty/zio/NettyZioServerInterpreter.scala | 6 +- .../zio/internal/NettyZioRequestBody.scala | 27 +++++ .../zio/internal/ZioStreamCompatible.scala | 26 +++-- .../server/netty/zio/NettyZioServerTest.scala | 9 +- .../zio/NettyZioTestServerInterpreter.scala | 2 +- .../nima/internal/NimaRequestBody.scala | 2 +- .../pekkogrpc/PekkoGrpcRequestBody.scala | 2 +- .../server/pekkohttp/PekkoRequestBody.scala | 2 +- .../tapir/server/play/PlayRequestBody.scala | 2 +- .../tapir/server/stub/SttpRequestBody.scala | 2 +- .../tapir/server/tests/AllServerTests.scala | 2 +- .../tapir/server/tests/ServerBasicTests.scala | 48 +++++++-- .../server/tests/ServerStreamingTests.scala | 2 +- .../vertx/decoders/VertxRequestBody.scala | 2 +- .../server/ziohttp/ZioHttpRequestBody.scala | 2 +- .../server/ziohttp/ZioHttpRequestBody.scala | 2 +- .../aws/lambda/AwsRequestBody.scala | 2 +- 68 files changed, 971 insertions(+), 504 deletions(-) create mode 100644 server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala create mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala create mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala delete mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala delete mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala create mode 100644 server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/PromisingSubscriber.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala delete mode 100644 server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala create mode 100644 server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala diff --git a/doc/endpoint/security.md b/doc/endpoint/security.md index e94f2c1fa2..cb677720bb 100644 --- a/doc/endpoint/security.md +++ b/doc/endpoint/security.md @@ -49,6 +49,26 @@ will show you a password prompt. Optional and multiple authentication inputs have some additional rules as to how hey map to documentation, see the ["Authentication inputs and security requirements"](../docs/openapi.md) section in the OpenAPI docs for details. +## Limiting request body length + +*Supported backends*: +Feature enabled only for Netty-based servers. More backends will be added in the near future. + +Individual endpoints can be annotated with content length limit: + +```scala mdoc:compile-only +import sttp.tapir._ +import sttp.tapir.server.model.EndpointExtensions._ + +val limitedEndpoint = endpoint.maxRequestBodyLength(maxBytes = 163484L) +``` + +The `EndpointsExtensions` utility is available in `tapir-server` core module. +Such protection prevents loading all the input data if it exceeds the limit. Instead, it will result in a `HTTP 413` +response to the client. +Please note that in case of endpoints with `streamBody` input type, the server logic receives a reference to a lazily +evaluated stream, so actual length verification will happen only when the logic performs streams processing, not earlier. + ## Next Read on about [streaming support](streaming.md). diff --git a/doc/migrating.md b/doc/migrating.md index 6d6a59210d..3d61874a17 100644 --- a/doc/migrating.md +++ b/doc/migrating.md @@ -1,5 +1,9 @@ # Migrating +## From 1.9.3 to 1.9.4 + +- `NettyConfig.defaultNoStreaming` has been removed, use `NettyConfig.default`. + ## From 1.4 to 1.5 - `badRequestOnPathErrorIfPathShapeMatches` and `badRequestOnPathInvalidIfPathShapeMatches` have been removed from `DefaultDecodeFailureHandler`. These flags were causing confusion and incosistencies caused by specifics of ZIO and Play backends. Before tapir 1.5, keeping defaults (`false` and `true` respectively for these flags) meant that some path segment decoding failures (specifically, errors - when an exception has been thrown during decoding, but not for e.g. enumeration mismatches) were translated to a "no-match", meaning that the next endpoint was attempted. From 1.5, tapir defaults to a 400 Bad Request response to be sent instead, on all path decoding failures. diff --git a/doc/server/netty.md b/doc/server/netty.md index ba3717d841..d3a350cdc6 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -80,7 +80,7 @@ NettyFutureServer().port(9090).addEndpoints(???) NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None).options) // customise Netty config -NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256)) +NettyFutureServer(NettyConfig.default.socketBacklog(256)) ``` ## Graceful shutdown @@ -93,9 +93,9 @@ import sttp.tapir.server.netty.NettyConfig import scala.concurrent.duration._ // adjust the waiting time to your needs -val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds) +val config = NettyConfig.default.withGracefulShutdownTimeout(5.seconds) // or if you don't want the server to wait for in-flight requests -val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown +val config2 = NettyConfig.default.noGracefulShutdown ``` ## Domain socket support diff --git a/generated-doc/out/server/netty.md b/generated-doc/out/server/netty.md index 89ff016370..08c17b5843 100644 --- a/generated-doc/out/server/netty.md +++ b/generated-doc/out/server/netty.md @@ -80,7 +80,7 @@ NettyFutureServer().port(9090).addEndpoints(???) NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None).options) // customise Netty config -NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256)) +NettyFutureServer(NettyConfig.default.socketBacklog(256)) ``` ## Graceful shutdown @@ -93,9 +93,9 @@ import sttp.tapir.server.netty.NettyConfig import scala.concurrent.duration._ // adjust the waiting time to your needs -val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds) +val config = NettyConfig.default.withGracefulShutdownTimeout(5.seconds) // or if you don't want the server to wait for in-flight requests -val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown +val config2 = NettyConfig.default.noGracefulShutdown ``` ## Domain socket support diff --git a/integrations/zio/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala b/integrations/zio/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala index 77dc35f54e..05fbba3cee 100644 --- a/integrations/zio/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala +++ b/integrations/zio/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala @@ -28,7 +28,7 @@ object ZTapirTest extends ZIOSpecDefault with ZTapir { private val exampleRequestBody = new RequestBody[TestEffect, RequestBodyType] { override val streams: Streams[RequestBodyType] = null.asInstanceOf[Streams[RequestBodyType]] - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): TestEffect[RawValue[R]] = ??? + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): TestEffect[RawValue[R]] = ??? override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? } diff --git a/integrations/zio1/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala b/integrations/zio1/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala index fce0ad61ad..d5dd2f81c2 100644 --- a/integrations/zio1/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala +++ b/integrations/zio1/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala @@ -30,7 +30,7 @@ object ZTapirTest extends DefaultRunnableSpec with ZTapir { private val exampleRequestBody = new RequestBody[TestEffect, RequestBodyType] { override val streams: Streams[RequestBodyType] = null.asInstanceOf[Streams[RequestBodyType]] - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): TestEffect[RawValue[R]] = ??? + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): TestEffect[RawValue[R]] = ??? override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? } diff --git a/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala b/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala index 315d7d8e71..9097491d57 100644 --- a/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala +++ b/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala @@ -22,7 +22,7 @@ private[akkagrpc] class AkkaGrpcRequestBody(serverOptions: AkkaHttpServerOptions private val grpcProtocol = GrpcProtocolNative.newReader(Identity) override val streams: AkkaStreams = AkkaStreams - override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = + override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = toRawFromEntity(request, akkaRequestEntity(request), bodyType) override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? diff --git a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala index 31231631b7..19654402f7 100644 --- a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala +++ b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala @@ -21,7 +21,7 @@ private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(im ec: ExecutionContext ) extends RequestBody[Future, AkkaStreams] { override val streams: AkkaStreams = AkkaStreams - override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = + override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = toRawFromEntity(request, akkeRequestEntity(request), bodyType) override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { val stream = akkeRequestEntity(request).dataBytes diff --git a/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/ArmeriaRequestBody.scala b/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/ArmeriaRequestBody.scala index 8a41ad4b86..637513d4a5 100644 --- a/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/ArmeriaRequestBody.scala +++ b/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/ArmeriaRequestBody.scala @@ -29,7 +29,7 @@ private[armeria] final class ArmeriaRequestBody[F[_], S <: Streams[S]]( .asInstanceOf[streams.BinaryStream] } - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val ctx = armeriaCtx(serverRequest) val request = ctx.request() diff --git a/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala b/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala index 72c3ac1e56..ab07eacc3c 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala @@ -10,6 +10,7 @@ import sttp.tapir.server.model.ValuedEndpointOutput import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, ValidationError, Validator, server, _} import scala.annotation.tailrec +import sttp.capabilities.StreamMaxLengthExceededException trait DecodeFailureHandler[F[_]] { @@ -122,11 +123,12 @@ object DefaultDecodeFailureHandler { case (_: EndpointIO.Header[_], _) => respondBadRequest case (fh: EndpointIO.FixedHeader[_], _: DecodeResult.Mismatch) if fh.h.name == HeaderNames.ContentType => respondUnsupportedMediaType - case (_: EndpointIO.FixedHeader[_], _) => respondBadRequest - case (_: EndpointIO.Headers[_], _) => respondBadRequest - case (_: EndpointIO.Body[_, _], _) => respondBadRequest - case (_: EndpointIO.OneOfBody[_, _], _: DecodeResult.Mismatch) => respondUnsupportedMediaType - case (_: EndpointIO.StreamBodyWrapper[_, _], _) => respondBadRequest + case (_: EndpointIO.FixedHeader[_], _) => respondBadRequest + case (_: EndpointIO.Headers[_], _) => respondBadRequest + case (_, DecodeResult.Error(_, _: StreamMaxLengthExceededException)) => respondPayloadTooLarge + case (_: EndpointIO.Body[_, _], _) => respondBadRequest + case (_: EndpointIO.OneOfBody[_, _], _: DecodeResult.Mismatch) => respondUnsupportedMediaType + case (_: EndpointIO.StreamBodyWrapper[_, _], _) => respondBadRequest // we assume that the only decode failure that might happen during path segment decoding is an error // a non-standard path decoder might return Missing/Multiple/Mismatch, but that would be indistinguishable from // a path shape mismatch @@ -143,6 +145,7 @@ object DefaultDecodeFailureHandler { } private val respondBadRequest = Some(onlyStatus(StatusCode.BadRequest)) private val respondUnsupportedMediaType = Some(onlyStatus(StatusCode.UnsupportedMediaType)) + private val respondPayloadTooLarge = Some(onlyStatus(StatusCode.PayloadTooLarge)) def respondNotFoundIfHasAuth( ctx: DecodeFailureContext, @@ -224,10 +227,12 @@ object DefaultDecodeFailureHandler { } .mkString(", ") ) - case Missing => Some("missing") - case Multiple(_) => Some("multiple values") - case Mismatch(_, _) => Some("value mismatch") - case _ => None + case Missing => Some("missing") + case Multiple(_) => Some("multiple values") + case Mismatch(_, _) => Some("value mismatch") + case Error(_, StreamMaxLengthExceededException(maxBytes)) => Some(s"Content length limit: $maxBytes bytes") + case _: Error => None + case _: InvalidValue => None } def combineSourceAndDetail(source: String, detail: Option[String]): String = diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala index fc712dbe24..7503040a36 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala @@ -3,17 +3,12 @@ package sttp.tapir.server.interpreter import sttp.capabilities.Streams import sttp.model.Part import sttp.tapir.model.ServerRequest -import sttp.tapir.AttributeKey -import sttp.tapir.EndpointInfo import sttp.tapir.{FileRange, RawBodyType, RawPart} -case class MaxContentLength(value: Long) - trait RequestBody[F[_], S] { val streams: Streams[S] - def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] + def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream - } case class RawValue[R](value: R, createdFiles: Seq[FileRange] = Nil) diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala index c20322d604..0656677c5c 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala @@ -1,13 +1,14 @@ package sttp.tapir.server.interpreter +import sttp.capabilities.StreamMaxLengthExceededException import sttp.model.{Headers, StatusCode} import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.internal.{Params, ParamsAsAny, RichOneOfBody} import sttp.tapir.model.ServerRequest -import sttp.tapir.server.{model, _} import sttp.tapir.server.interceptor._ -import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput} +import sttp.tapir.server.model.{MaxContentLength, ServerResponse, ValuedEndpointOutput} +import sttp.tapir.server.{model, _} import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile} import sttp.tapir.EndpointInfo import sttp.tapir.AttributeKey @@ -157,7 +158,7 @@ class ServerInterpreter[R, F[_], B, S]( values.bodyInputWithIndex match { case Some((Left(oneOfBodyInput), _)) => oneOfBodyInput.chooseBodyToDecode(request.contentTypeParsed) match { - case Some(Left(body)) => decodeBody(request, values, body) + case Some(Left(body)) => decodeBody(request, values, body, maxBodyLength) case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body, maxBodyLength) case None => unsupportedInputMediaTypeResponse(request, oneOfBodyInput) } @@ -182,17 +183,23 @@ class ServerInterpreter[R, F[_], B, S]( private def decodeBody[RAW, T]( request: ServerRequest, values: DecodeBasicInputsResult.Values, - bodyInput: EndpointIO.Body[RAW, T] + bodyInput: EndpointIO.Body[RAW, T], + maxBodyLength: Option[Long] ): F[DecodeBasicInputsResult] = { - requestBody.toRaw(request, bodyInput.bodyType).flatMap { v => - bodyInput.codec.decode(v.value) match { - case DecodeResult.Value(bodyV) => (values.setBodyInputValue(bodyV): DecodeBasicInputsResult).unit - case failure: DecodeResult.Failure => - v.createdFiles - .foldLeft(monad.unit(()))((u, f) => u.flatMap(_ => deleteFile(f.file))) - .map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult) + requestBody + .toRaw(request, bodyInput.bodyType, maxBodyLength) + .flatMap { v => + bodyInput.codec.decode(v.value) match { + case DecodeResult.Value(bodyV) => (values.setBodyInputValue(bodyV): DecodeBasicInputsResult).unit + case failure: DecodeResult.Failure => + v.createdFiles + .foldLeft(monad.unit(()))((u, f) => u.flatMap(_ => deleteFile(f.file))) + .map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult) + } + } + .handleError { case e: StreamMaxLengthExceededException => + (DecodeBasicInputsResult.Failure(bodyInput, DecodeResult.Error("", e)): DecodeBasicInputsResult).unit } - } } private def unsupportedInputMediaTypeResponse( diff --git a/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala b/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala new file mode 100644 index 0000000000..d2f0b31609 --- /dev/null +++ b/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala @@ -0,0 +1,36 @@ +package sttp.tapir.server.model + +import sttp.tapir.EndpointInfoOps +import sttp.tapir.AttributeKey + +/** Can be used as an endpoint attribute. + * @example + * {{{ + * endpoint.attribute(MaxContentLength.attributeKey, MaxContentLength(16384L)) + * }}} + */ +case class MaxContentLength(value: Long) extends AnyVal + +object MaxContentLength { + val attributeKey: AttributeKey[MaxContentLength] = AttributeKey[MaxContentLength] +} + +object EndpointExtensions { + + implicit class RichServerEndpoint[E <: EndpointInfoOps[_]](e: E) { + + /** Enables checks that prevent loading full request body into memory if it exceeds given limit. Otherwise causes endpoint to reply with + * HTTP 413 Payload Too Loarge. + * + * Please refer to Tapir docs to ensure which backends are supported: https://tapir.softwaremill.com/en/latest/endpoint/security.html + * @example + * {{{ + * endpoint.maxRequestBodyLength(16384L) + * }}} + * @param maxBytes + * maximum allowed size of request body in bytes. + */ + def maxRequestBodyLength(maxBytes: Long): E = + e.attribute(MaxContentLength.attributeKey, MaxContentLength(maxBytes)).asInstanceOf[E] + } +} diff --git a/server/core/src/test/scala/sttp/tapir/server/TestUtil.scala b/server/core/src/test/scala/sttp/tapir/server/TestUtil.scala index 58483a6074..4cb6f5b882 100644 --- a/server/core/src/test/scala/sttp/tapir/server/TestUtil.scala +++ b/server/core/src/test/scala/sttp/tapir/server/TestUtil.scala @@ -14,7 +14,7 @@ import scala.util.{Success, Try} object TestUtil { object TestRequestBody extends RequestBody[Id, NoStreams] { override val streams: Streams[NoStreams] = NoStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Id[RawValue[R]] = ??? + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Id[RawValue[R]] = ??? override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? } diff --git a/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala b/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala index 9a6071bc14..7e3d12fa6a 100644 --- a/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala +++ b/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala @@ -20,7 +20,7 @@ import scala.collection.immutable.Seq class FinatraRequestBody(serverOptions: FinatraServerOptions) extends RequestBody[Future, NoStreams] { override val streams: NoStreams = NoStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = { val request = finatraRequest(serverRequest) toRaw(request, bodyType, request.content, request.charset.map(Charset.forName)) } diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala index e7db5466a0..eba4433a89 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala +++ b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala @@ -18,7 +18,7 @@ private[http4s] class Http4sRequestBody[F[_]: Async]( serverOptions: Http4sServerOptions[F] ) extends RequestBody[F, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val r = http4sRequest(serverRequest) toRawFromStream(serverRequest, r.body, bodyType, r.charset) } diff --git a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala index 824d447394..5164a29f66 100644 --- a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala +++ b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala @@ -18,7 +18,7 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile extends RequestBody[Id, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): RawValue[RAW] = { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = { val request = jdkHttpRequest(serverRequest) toRaw(serverRequest, bodyType, request.getRequestBody) } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 340ebd76dc..b9e2b958a3 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -73,7 +73,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown), eventLoopGroup, socketOverride ) @@ -123,9 +123,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty object NettyCatsServer { def apply[F[_]: Async](dispatcher: Dispatcher[F]): NettyCatsServer[F] = - NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.defaultWithStreaming) + NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.default) def apply[F[_]: Async](options: NettyCatsServerOptions[F]): NettyCatsServer[F] = - NettyCatsServer(Vector.empty, options, NettyConfig.defaultWithStreaming) + NettyCatsServer(Vector.empty, options, NettyConfig.default) def apply[F[_]: Async](dispatcher: Dispatcher[F], config: NettyConfig): NettyCatsServer[F] = NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), config) def apply[F[_]: Async](options: NettyCatsServerOptions[F], config: NettyConfig): NettyCatsServer[F] = diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index d6409f0fd0..e9e223461c 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -2,6 +2,7 @@ package sttp.tapir.server.netty.cats import cats.effect.Async import cats.effect.std.Dispatcher +import internal.Fs2StreamCompatible import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.monad.syntax._ @@ -11,6 +12,7 @@ import sttp.tapir.server.interceptor.RequestResult import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} +import sttp.tapir.server.netty.cats.internal.NettyCatsRequestBody import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} trait NettyCatsServerInterpreter[F[_]] { @@ -31,8 +33,8 @@ trait NettyCatsServerInterpreter[F[_]] { val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), - new NettyCatsRequestBody(createFile), - new NettyCatsToResponseBody(nettyServerOptions.dispatcher, delegate = new NettyToResponseBody), + new NettyCatsRequestBody(createFile, Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), + new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala new file mode 100644 index 0000000000..6a2a1fe69b --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -0,0 +1,74 @@ +package sttp.tapir.server.netty.cats.internal + +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} +import org.reactivestreams.Publisher +import sttp.tapir.FileRange +import sttp.tapir.server.netty.internal._ + +import java.io.InputStream +import cats.effect.std.Dispatcher +import sttp.capabilities.fs2.Fs2Streams +import fs2.io.file.Path +import fs2.io.file.Files +import cats.effect.kernel.Async +import fs2.io.file.Flags +import fs2.interop.reactivestreams.StreamUnicastPublisher +import cats.effect.kernel.Sync +import fs2.Chunk +import fs2.interop.reactivestreams.StreamSubscriber + +object Fs2StreamCompatible { + +private[cats] def apply[F[_]: Async](dispatcher: Dispatcher[F]): StreamCompatible[Fs2Streams[F]] = { + new StreamCompatible[Fs2Streams[F]] { + override val streams: Fs2Streams[F] = Fs2Streams[F] + + override def fromFile(fileRange: FileRange, chunkSize: Int): streams.BinaryStream = { + val path = Path.fromNioPath(fileRange.file.toPath) + fileRange.range + .flatMap(r => + r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, chunkSize, s._1, s._2)) + ) + .getOrElse(Files[F](Files.forAsync[F]).readAll(path, chunkSize, Flags.Read)) + } + + override def fromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): streams.BinaryStream = + length match { + case Some(limitedLength) => inputStreamToFs2(is, chunkSize).take(limitedLength) + case None => inputStreamToFs2(is, chunkSize) + } + + override def asPublisher(stream: fs2.Stream[F, Byte]): Publisher[HttpContent] = + // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated + // dispatcher, which results in a Resource[], which is hard to afford here + StreamUnicastPublisher( + stream.mapChunks { chunk => + val bytes: Chunk.ArraySlice[Byte] = chunk.compact + Chunk.singleton(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length))) + }, + dispatcher + ) + + override def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { + val stream = fs2.Stream + .eval(StreamSubscriber[F, HttpContent](bufferSize = 2)) + .flatMap(s => s.sub.stream(Sync[F].delay(publisher.subscribe(s)))) + .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) + maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) + } + + override def failedStream(e: => Throwable): streams.BinaryStream = + fs2.Stream.raiseError(e) + + override def emptyStream: streams.BinaryStream = + fs2.Stream.empty + + private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) = + fs2.io.readInputStream( + Sync[F].blocking(inputStream()), + chunkSize + ) + } + } +} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala new file mode 100644 index 0000000000..e1a762ae70 --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala @@ -0,0 +1,33 @@ +package sttp.tapir.server.netty.cats.internal + +import cats.effect.Async +import cats.syntax.all._ +import fs2.Chunk +import fs2.io.file.{Files, Path} +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher +import sttp.capabilities.fs2.Fs2Streams +import sttp.monad.MonadError +import sttp.tapir.TapirFile +import sttp.tapir.integ.cats.effect.CatsMonadError +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} + +private[cats] class NettyCatsRequestBody[F[_]: Async]( + val createFile: ServerRequest => F[TapirFile], + val streamCompatible: StreamCompatible[Fs2Streams[F]] +) extends NettyStreamingRequestBody[F, Fs2Streams[F]] { + + override implicit val monad: MonadError[F] = new CatsMonadError() + + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] = + streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) + + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = + (toStream(serverRequest, maxBytes).asInstanceOf[streamCompatible.streams.BinaryStream]) + .through( + Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath)) + ) + .compile + .drain +} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala deleted file mode 100644 index 3c3eef738d..0000000000 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ /dev/null @@ -1,65 +0,0 @@ -package sttp.tapir.server.netty.internal - -import cats.effect.{Async, Sync} -import cats.syntax.all._ -import org.playframework.netty.http.StreamedHttpRequest -import fs2.Chunk -import fs2.interop.reactivestreams.StreamSubscriber -import fs2.io.file.{Files, Path} -import io.netty.buffer.ByteBufUtil -import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} -import sttp.capabilities.fs2.Fs2Streams -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} - -import java.io.ByteArrayInputStream -import java.nio.ByteBuffer - -private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) - extends RequestBody[F, Fs2Streams[F]] { - - override val streams: Fs2Streams[F] = Fs2Streams[F] - - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { - - bodyType match { - case RawBodyType.StringBody(charset) => nettyRequestBytes(serverRequest).map(bs => RawValue(new String(bs, charset))) - case RawBodyType.ByteArrayBody => - nettyRequestBytes(serverRequest).map(RawValue(_)) - case RawBodyType.ByteBufferBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(ByteBuffer.wrap(bs))) - case RawBodyType.InputStreamBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(new ByteArrayInputStream(bs))) - case RawBodyType.InputStreamRangeBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) - case RawBodyType.FileBody => - createFile(serverRequest) - .flatMap(tapirFile => { - toStream(serverRequest, maxBytes = None) // TODO - .through( - Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(tapirFile.toPath)) - ) - .compile - .drain - .map(_ => RawValue(FileRange(tapirFile), Seq(FileRange(tapirFile)))) - }) - case _: RawBodyType.MultipartBody => ??? - } - } - - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { - val nettyRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] - val stream = fs2.Stream - .eval(StreamSubscriber[F, HttpContent](NettyRequestBody.DefaultChunkSize)) - .flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s)))) - .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) - maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) - } - - private def nettyRequestBytes(serverRequest: ServerRequest): F[Array[Byte]] = serverRequest.underlying match { - case req: FullHttpRequest => monad.delay(ByteBufUtil.getBytes(req.content())) - case _: StreamedHttpRequest => toStream(serverRequest, maxBytes = None).compile.to(Chunk).map(_.toArray[Byte]) // TODO - case other => monad.raiseError(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) - } -} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala deleted file mode 100644 index 757053ae36..0000000000 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ /dev/null @@ -1,91 +0,0 @@ -package sttp.tapir.server.netty.internal - -import cats.effect.kernel.{Async, Sync} -import cats.effect.std.Dispatcher -import fs2.Chunk -import fs2.interop.reactivestreams._ -import fs2.io.file.{Files, Flags, Path} -import io.netty.buffer.Unpooled -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import org.reactivestreams.Publisher -import sttp.capabilities.fs2.Fs2Streams -import sttp.model.HasHeaders -import sttp.tapir.server.interpreter.ToResponseBody -import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.NettyResponseContent._ -import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} - -import java.io.InputStream -import java.nio.charset.Charset - -class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], delegate: NettyToResponseBody) - extends ToResponseBody[NettyResponse, Fs2Streams[F]] { - override val streams: Fs2Streams[F] = Fs2Streams[F] - - override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { - bodyType match { - - case RawBodyType.InputStreamBody => - val stream = inputStreamToFs2(() => v) - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) - - case RawBodyType.InputStreamRangeBody => - val stream = v.range - .map(range => inputStreamToFs2(v.inputStreamFromRangeStart).take(range.contentLength)) - .getOrElse(inputStreamToFs2(v.inputStream)) - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) - - case RawBodyType.FileBody => - val tapirFile = v - val path = Path.fromNioPath(tapirFile.file.toPath) - val stream = tapirFile.range - .flatMap(r => - r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, NettyToResponseBody.DefaultChunkSize, s._1, s._2)) - ) - .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyToResponseBody.DefaultChunkSize, Flags.Read)) - - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) - - case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException - - case _ => delegate.fromRawValue(v, headers, format, bodyType) - } - } - - private def inputStreamToFs2(inputStream: () => InputStream) = - fs2.io.readInputStream( - Sync[F].blocking(inputStream()), - NettyToResponseBody.DefaultChunkSize - ) - - private def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { - // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated - // dispatcher, which results in a Resource[], which is hard to afford here - StreamUnicastPublisher( - stream - .chunkLimit(NettyToResponseBody.DefaultChunkSize) - .map { chunk => - val bytes: Chunk.ArraySlice[Byte] = chunk.compact - - new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length)) - }, - dispatcher - ) - } - - override def fromStreamValue( - v: streams.BinaryStream, - headers: HasHeaders, - format: CodecFormat, - charset: Option[Charset] - ): NettyResponse = - (ctx: ChannelHandlerContext) => { - new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(v)) - } - - override def fromWebSocketPipe[REQ, RESP]( - pipe: streams.Pipe[REQ, RESP], - o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]] - ): NettyResponse = throw new UnsupportedOperationException -} diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index dd9c9f865e..52d86ae1d4 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -35,7 +35,7 @@ class NettyCatsServerTest extends TestSuite with EitherValues { interpreter, backend, multipart = false, - maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength) + maxContentLength = true ) .tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++ diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index 521e6a342d..61be7e6f4d 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -24,11 +24,10 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch routes: NonEmptyList[Route[IO]], gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, KillSwitch)] = { - val config = NettyConfig.defaultWithStreaming + val config = NettyConfig.default .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose - .maxContentLength(NettyCatsTestServerInterpreter.maxContentLength) .noGracefulShutdown val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) @@ -39,7 +38,3 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch .make(bind.map(b => (b.port, b.stop()))) { case (_, release) => release } } } - -object NettyCatsTestServerInterpreter { - val maxContentLength = 10000 -} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala new file mode 100644 index 0000000000..5b1aaf8980 --- /dev/null +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala @@ -0,0 +1,30 @@ +package sttp.tapir.server.netty.loom + +import io.netty.handler.codec.http.HttpContent +import org.playframework.netty.http.StreamedHttpRequest +import org.reactivestreams.Publisher +import sttp.capabilities +import sttp.monad.MonadError +import sttp.tapir.TapirFile +import sttp.tapir.capabilities.NoStreams +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.NettyRequestBody +import sttp.tapir.server.netty.internal.reactivestreams.{FileWriterSubscriber, SimpleSubscriber} + +private[netty] class NettyIdRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Id, NoStreams] { + + override implicit val monad: MonadError[Id] = idMonad + override val streams: capabilities.Streams[NoStreams] = NoStreams + + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = + SimpleSubscriber.processAllBlocking(publisher, maxBytes) + + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = + serverRequest.underlying match { + case r: StreamedHttpRequest => FileWriterSubscriber.processAllBlocking(r, file.toPath, maxBytes) + case _ => () // Empty request + } + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException() +} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala index 8609b49a36..aef0b99bdc 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala @@ -95,7 +95,6 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, new NettyServerHandler( route, unsafeRunF, - config.maxContentLength, channelGroup, isShuttingDown ), @@ -142,10 +141,10 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, } object NettyIdServer { - def apply(): NettyIdServer = NettyIdServer(Vector.empty, NettyIdServerOptions.default, NettyConfig.defaultNoStreaming) + def apply(): NettyIdServer = NettyIdServer(Vector.empty, NettyIdServerOptions.default, NettyConfig.default) def apply(serverOptions: NettyIdServerOptions): NettyIdServer = - NettyIdServer(Vector.empty, serverOptions, NettyConfig.defaultNoStreaming) + NettyIdServer(Vector.empty, serverOptions, NettyConfig.default) def apply(config: NettyConfig): NettyIdServer = NettyIdServer(Vector.empty, NettyIdServerOptions.default, config) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala index 1f89ab6c09..4cd225e0eb 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala @@ -1,7 +1,7 @@ package sttp.tapir.server.netty.loom import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} +import sttp.tapir.server.netty.internal.{NettyToResponseBody, NettyServerInterpreter, RunAsync} trait NettyIdServerInterpreter { def nettyServerOptions: NettyIdServerOptions @@ -12,7 +12,8 @@ trait NettyIdServerInterpreter { NettyServerInterpreter.toRoute[Id]( ses, nettyServerOptions.interceptors, - nettyServerOptions.createFile, + new NettyIdRequestBody(nettyServerOptions.createFile), + new NettyToResponseBody[Id], nettyServerOptions.deleteFile, new RunAsync[Id] { override def apply[T](f: => Id[T]): Unit = { diff --git a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala index 339a30d8a4..73fe2cdd20 100644 --- a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala +++ b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala @@ -21,8 +21,10 @@ class NettyIdServerTest extends TestSuite with EitherValues { val createServerTest = new DefaultCreateServerTest(backend, interpreter) val sleeper: Sleeper[Id] = (duration: FiniteDuration) => Thread.sleep(duration.toMillis) - val tests = new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++ - new ServerGracefulShutdownTests(createServerTest, sleeper).tests() + val tests = + new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false, maxContentLength = true) + .tests() ++ + new ServerGracefulShutdownTests(createServerTest, sleeper).tests() (tests, eventLoopGroup) }) { case (_, eventLoopGroup) => diff --git a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdTestServerInterpreter.scala b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdTestServerInterpreter.scala index 8d6e940a68..c4535bbc61 100644 --- a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdTestServerInterpreter.scala +++ b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdTestServerInterpreter.scala @@ -22,7 +22,7 @@ class NettyIdTestServerInterpreter(eventLoopGroup: NioEventLoopGroup) gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, IO[Unit])] = { val config = - NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown + NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) val options = NettyIdServerOptions.default val bind = IO.blocking(NettyIdServer(options, customizedConfig).addRoutes(routes.toList).start()) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index fbadd899fe..8ae9ac0523 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -1,15 +1,14 @@ package sttp.tapir.server.netty -import org.playframework.netty.http.HttpStreamsServerHandler import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.kqueue.{KQueue, KQueueEventLoopGroup, KQueueServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.{ChannelHandler, ChannelPipeline, EventLoopGroup, ServerChannel} -import io.netty.handler.codec.http.{HttpObjectAggregator, HttpServerCodec} +import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.logging.LoggingHandler import io.netty.handler.ssl.SslContext -import io.netty.handler.stream.ChunkedWriteHandler +import org.playframework.netty.http.HttpStreamsServerHandler import sttp.tapir.server.netty.NettyConfig.EventLoopConfig import scala.concurrent.duration._ @@ -17,9 +16,6 @@ import scala.concurrent.duration._ /** Netty configuration, used by [[NettyFutureServer]] and other server implementations to configure the networking layer, the Netty * processing pipeline, and start & stop the server. * - * @param maxContentLength - * The max content length passed to the [[io.netty.handler.codec.http.HttpObjectAggregator]] handler. - * * @param maxConnections * The maximum number of concurrent connections allowed by the server. Any connections above this limit will be closed right after they * are opened. @@ -56,7 +52,6 @@ case class NettyConfig( host: String, port: Int, shutdownEventLoopGroupOnClose: Boolean, - maxContentLength: Option[Int], maxConnections: Option[Int], socketBacklog: Int, requestTimeout: Option[FiniteDuration], @@ -79,9 +74,6 @@ case class NettyConfig( def withShutdownEventLoopGroupOnClose: NettyConfig = copy(shutdownEventLoopGroupOnClose = true) def withDontShutdownEventLoopGroupOnClose: NettyConfig = copy(shutdownEventLoopGroupOnClose = false) - def maxContentLength(m: Int): NettyConfig = copy(maxContentLength = Some(m)) - def noMaxContentLength: NettyConfig = copy(maxContentLength = None) - def maxConnections(m: Int): NettyConfig = copy(maxConnections = Some(m)) def socketBacklog(s: Int): NettyConfig = copy(socketBacklog = s) @@ -113,7 +105,7 @@ case class NettyConfig( } object NettyConfig { - def defaultNoStreaming: NettyConfig = NettyConfig( + def default: NettyConfig = NettyConfig( host = "localhost", port = 8080, shutdownEventLoopGroupOnClose = true, @@ -124,25 +116,15 @@ object NettyConfig { socketTimeout = Some(60.seconds), lingerTimeout = Some(60.seconds), gracefulShutdownTimeout = Some(10.seconds), - maxContentLength = None, maxConnections = None, addLoggingHandler = false, sslContext = None, eventLoopConfig = EventLoopConfig.auto, socketConfig = NettySocketConfig.default, - initPipeline = cfg => defaultInitPipelineNoStreaming(cfg)(_, _) + initPipeline = cfg => defaultInitPipeline(cfg)(_, _) ) - def defaultInitPipelineNoStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { - cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) - pipeline.addLast(new HttpServerCodec()) - pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength.getOrElse(Integer.MAX_VALUE))) - pipeline.addLast(new ChunkedWriteHandler()) - pipeline.addLast(handler) - () - } - - def defaultInitPipelineStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(new HttpServerCodec()) pipeline.addLast(new HttpStreamsServerHandler()) @@ -151,8 +133,6 @@ object NettyConfig { () } - def defaultWithStreaming: NettyConfig = defaultNoStreaming.copy(initPipeline = cfg => defaultInitPipelineStreaming(cfg)(_, _)) - case class EventLoopConfig(initEventLoopGroup: () => EventLoopGroup, serverChannel: Class[_ <: ServerChannel]) object EventLoopConfig { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index 7699e20d90..0b6a4fec98 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -70,7 +70,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown), eventLoopGroup, socketOverride ) @@ -121,10 +121,10 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe object NettyFutureServer { def apply()(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.defaultNoStreaming) + NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.default) def apply(serverOptions: NettyFutureServerOptions)(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, serverOptions, NettyConfig.defaultNoStreaming) + NettyFutureServer(Vector.empty, serverOptions, NettyConfig.default) def apply(config: NettyConfig)(implicit ec: ExecutionContext): NettyFutureServer = NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, config) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala index 7c2ef53f7e..a2255216dd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala @@ -3,7 +3,7 @@ package sttp.tapir.server.netty import sttp.monad.FutureMonad import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.NettyFutureServerInterpreter.FutureRunAsync -import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} +import sttp.tapir.server.netty.internal.{NettyFutureRequestBody, NettyServerInterpreter, NettyToResponseBody, RunAsync} import scala.concurrent.{ExecutionContext, Future} @@ -21,7 +21,8 @@ trait NettyFutureServerInterpreter { NettyServerInterpreter.toRoute( ses, nettyServerOptions.interceptors, - nettyServerOptions.createFile, + new NettyFutureRequestBody(nettyServerOptions.createFile), + new NettyToResponseBody[Future](), nettyServerOptions.deleteFile, FutureRunAsync ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala new file mode 100644 index 0000000000..c6dcbf0a9d --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -0,0 +1,32 @@ +package sttp.tapir.server.netty.internal + +import io.netty.handler.codec.http.HttpContent +import org.playframework.netty.http.StreamedHttpRequest +import org.reactivestreams.Publisher +import sttp.capabilities +import sttp.monad.{FutureMonad, MonadError} +import sttp.tapir.TapirFile +import sttp.tapir.capabilities.NoStreams +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.reactivestreams._ + +import scala.concurrent.{ExecutionContext, Future} + +private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) + extends NettyRequestBody[Future, NoStreams] { + + override val streams: capabilities.Streams[NoStreams] = NoStreams + override implicit val monad: MonadError[Future] = new FutureMonad() + + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] = + SimpleSubscriber.processAll(publisher, maxBytes) + + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Future[Unit] = + serverRequest.underlying match { + case r: StreamedHttpRequest => FileWriterSubscriber.processAll(r, file.toPath, maxBytes) + case _ => monad.unit(()) // Empty request + } + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException() +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 41c99a99ec..9d1375e7a5 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -1,52 +1,82 @@ package sttp.tapir.server.netty.internal -import io.netty.buffer.{ByteBufInputStream, ByteBufUtil} -import io.netty.handler.codec.http.FullHttpRequest -import sttp.capabilities +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} +import org.playframework.netty.http.StreamedHttpRequest +import org.reactivestreams.Publisher import sttp.monad.MonadError -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} -import sttp.tapir.model.ServerRequest import sttp.monad.syntax._ -import sttp.tapir.capabilities.NoStreams -import sttp.tapir.server.interpreter.{RawValue, RequestBody} - +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.interpreter.RequestBody +import sttp.tapir.RawBodyType +import sttp.tapir.TapirFile +import sttp.tapir.server.interpreter.RawValue +import sttp.tapir.FileRange +import sttp.tapir.InputStreamRange +import java.io.ByteArrayInputStream import java.nio.ByteBuffer -import java.nio.file.Files +import sttp.capabilities.Streams + +/** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */ +private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { -class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit - monadError: MonadError[F] -) extends RequestBody[F, NoStreams] { + implicit def monad: MonadError[F] - override val streams: capabilities.Streams[NoStreams] = NoStreams + /** Backend-specific implementation for creating a file. */ + def createFile: ServerRequest => F[TapirFile] - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): F[RawValue[RAW]] = { + /** Backend-specific way to process all elements emitted by a Publisher[HttpContent] into a raw array of bytes. + * + * @param publisher + * reactive publisher emitting byte chunks. + * @param maxBytes + * optional request length limit. If exceeded, The effect `F` is failed with a [[sttp.capabilities.StreamMaxLengthExceededException]] + * @return + * An effect which finishes with a single array of all collected bytes. + */ + def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] - /** [[ByteBufUtil.getBytes(io.netty.buffer.ByteBuf)]] copies buffer without affecting reader index of the original. */ - def requestContentAsByteArray = ByteBufUtil.getBytes(nettyRequest(serverRequest).content()) + /** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file. + * + * @param serverRequest + * can have underlying `Publisher[HttpContent]` or an empty `FullHttpRequest` + * @param file + * an empty file where bytes should be stored. + * @param maxBytes + * optional request length limit. If exceeded, The effect `F` is failed with a [[sttp.capabilities.StreamMaxLengthExceededException]] + * @return + * an effect which finishes when all data is written to the file. + */ + def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = { bodyType match { - case RawBodyType.StringBody(charset) => monadError.unit(RawValue(nettyRequest(serverRequest).content().toString(charset))) - case RawBodyType.ByteArrayBody => monadError.unit(RawValue(requestContentAsByteArray)) - case RawBodyType.ByteBufferBody => monadError.unit(RawValue(ByteBuffer.wrap(requestContentAsByteArray))) - case RawBodyType.InputStreamBody => monadError.unit(RawValue(new ByteBufInputStream(nettyRequest(serverRequest).content()))) + case RawBodyType.StringBody(charset) => readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new String(bs, charset))) + case RawBodyType.ByteArrayBody => + readAllBytes(serverRequest, maxBytes).map(RawValue(_)) + case RawBodyType.ByteBufferBody => + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs))) + case RawBodyType.InputStreamBody => + // Possibly can be optimized to avoid loading all data eagerly into memory + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs))) case RawBodyType.InputStreamRangeBody => - monadError.unit(RawValue(InputStreamRange(() => new ByteBufInputStream(nettyRequest(serverRequest).content())))) + // Possibly can be optimized to avoid loading all data eagerly into memory + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) case RawBodyType.FileBody => - createFile(serverRequest) - .map(file => { - Files.write(file.toPath, requestContentAsByteArray) - RawValue(FileRange(file), Seq(FileRange(file))) - }) - case _: RawBodyType.MultipartBody => ??? + for { + file <- createFile(serverRequest) + _ <- writeToFile(serverRequest, file, maxBytes) + } yield RawValue(FileRange(file), Seq(FileRange(file))) + case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException()) } } - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = - throw new UnsupportedOperationException() - - private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] -} - -private[internal] object NettyRequestBody { - val DefaultChunkSize = 8192 + private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = + serverRequest.underlying match { + case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request + monad.unit(Array.empty[Byte]) + case req: StreamedHttpRequest => + publisherToBytes(req, maxBytes) + case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) + } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index bef744142d..c7e32baf9b 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -4,7 +4,6 @@ import com.typesafe.scalalogging.Logger import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ import io.netty.channel.group.ChannelGroup -import io.netty.handler.codec.http.HttpHeaderNames.{CONNECTION, CONTENT_LENGTH} import io.netty.handler.codec.http._ import io.netty.handler.stream.{ChunkedFile, ChunkedStream} import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} @@ -35,7 +34,6 @@ import scala.util.{Failure, Success} class NettyServerHandler[F[_]]( route: Route[F], unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), - maxContentLength: Option[Int], channelGroup: ChannelGroup, isShuttingDown: AtomicBoolean )(implicit @@ -64,19 +62,6 @@ class NettyServerHandler[F[_]]( private val logger = Logger[NettyServerHandler[F]] - private val EntityTooLarge: FullHttpResponse = { - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) - res.headers().set(CONTENT_LENGTH, 0) - res - } - - private val EntityTooLargeClose: FullHttpResponse = { - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) - res.headers().set(CONTENT_LENGTH, 0) - res.headers().set(CONNECTION, HttpHeaderValues.CLOSE) - res - } - override def handlerAdded(ctx: ChannelHandlerContext): Unit = if (ctx.channel.isActive) { initHandler(ctx) @@ -177,16 +162,11 @@ class NettyServerHandler[F[_]]( serverResponse.handle( ctx = ctx, byteBufHandler = (channelPromise, byteBuf) => { - - if (maxContentLength.exists(_ < byteBuf.readableBytes)) - writeEntityTooLargeResponse(ctx, req) - else { - val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) - res.setHeadersFrom(serverResponse) - res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) - res.handleCloseAndKeepAliveHeaders(req) - ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) - } + val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) + res.setHeadersFrom(serverResponse) + res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) }, chunkedStreamHandler = (channelPromise, chunkedStream) => { val resHeader: DefaultHttpResponse = @@ -216,6 +196,14 @@ class NettyServerHandler[F[_]]( res.setHeadersFrom(serverResponse) res.handleCloseAndKeepAliveHeaders(req) + + channelPromise.addListener((future: ChannelFuture) => { + // A reactive publisher silently closes the channel and fails the channel promise, so we need + // to listen on it and log failure details + if (!future.isSuccess()) { + logger.error("Error when streaming HTTP response", future.cause()) + } + }) ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) }, @@ -234,32 +222,6 @@ class NettyServerHandler[F[_]]( } ) - private def writeEntityTooLargeResponse(ctx: ChannelHandlerContext, req: HttpRequest): Unit = { - - if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) { - val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate()) - val _ = future.addListener(new ChannelFutureListener() { - override def operationComplete(future: ChannelFuture) = { - if (!future.isSuccess()) { - logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) - } - val _ = ctx.close() - } - }) - } else { - val _ = ctx - .writeAndFlush(EntityTooLarge.retainedDuplicate()) - .addListener(new ChannelFutureListener() { - override def operationComplete(future: ChannelFuture) = { - if (!future.isSuccess()) { - logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) - val _ = ctx.close() - } - } - }) - } - } - private implicit class RichServerNettyResponse(val r: ServerResponse[NettyResponse]) { def handle( ctx: ChannelHandlerContext, diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala index eb3dd3d02f..8b87f78f47 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala @@ -4,26 +4,28 @@ import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.TapirFile import sttp.tapir.capabilities.NoStreams -import sttp.tapir.model.ServerRequest import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interceptor.{Interceptor, RequestResult} import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} +import sttp.tapir.server.interpreter.RequestBody +import sttp.tapir.server.interpreter.ToResponseBody object NettyServerInterpreter { def toRoute[F[_]: MonadError]( ses: List[ServerEndpoint[Any, F]], interceptors: List[Interceptor[F]], - createFile: ServerRequest => F[TapirFile], + requestBody: RequestBody[F, NoStreams], + toResponseBody: ToResponseBody[NettyResponse, NoStreams], deleteFile: TapirFile => F[Unit], runAsync: RunAsync[F] ): Route[F] = { implicit val bodyListener: BodyListener[F, NettyResponse] = new NettyBodyListener(runAsync) val serverInterpreter = new ServerInterpreter[Any, F, NettyResponse, NoStreams]( FilterServerEndpoints(ses), - new NettyRequestBody(createFile), - new NettyToResponseBody, + requestBody, + toResponseBody, RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala new file mode 100644 index 0000000000..cccb1a0fce --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala @@ -0,0 +1,24 @@ +package sttp.tapir.server.netty.internal + +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.FullHttpRequest +import org.playframework.netty.http.StreamedHttpRequest +import sttp.capabilities.Streams +import sttp.tapir.model.ServerRequest + +/** Common logic for processing streaming request body in all Netty backends which support streaming. */ +private[netty] trait NettyStreamingRequestBody[F[_], S <: Streams[S]] extends NettyRequestBody[F, S] { + + val streamCompatible: StreamCompatible[S] + override val streams = streamCompatible.streams + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + (serverRequest.underlying match { + case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // means EmptyHttpRequest, but that class is not public + streamCompatible.emptyStream + case publisher: StreamedHttpRequest => + streamCompatible.fromPublisher(publisher, maxBytes) + case other => + streamCompatible.failedStream(new UnsupportedOperationException(s"Unexpected Netty request of type: ${other.getClass().getName()}")) + }).asInstanceOf[streams.BinaryStream] // Scala can't figure out that it's the same type as streamCompatible.streams.BinaryStream +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala index dd39ebef31..e0b2c0b35e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala @@ -2,30 +2,28 @@ package sttp.tapir.server.netty.internal import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext -import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher import sttp.capabilities import sttp.model.HasHeaders +import sttp.monad.MonadError import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.NettyResponseContent.{ - ByteBufNettyResponseContent, - ChunkedFileNettyResponseContent, - ChunkedStreamNettyResponseContent -} +import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} +import sttp.tapir.server.netty.internal.NettyToResponseBody.DefaultChunkSize +import sttp.tapir.server.netty.internal.reactivestreams.{FileRangePublisher, InputStreamPublisher} import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput} -import java.io.{InputStream, RandomAccessFile} +import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset -private[internal] class RangedChunkedStream(raw: InputStream, length: Long) extends ChunkedStream(raw) { - - override def isEndOfInput(): Boolean = - super.isEndOfInput || transferredBytes == length -} - -class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { +/** Common logic for producing response body from responses in all Netty backends that don't support streaming. These backends use our custom reactive + * Publishers to integrate responses like InputStreamBody, InputStreamRangeBody or FileBody with Netty reactive extensions. Other kinds of + * raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. + */ +private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) extends ToResponseBody[NettyResponse, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { @@ -43,43 +41,28 @@ class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) case RawBodyType.InputStreamBody => - (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) + (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) case RawBodyType.InputStreamRangeBody => - (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) - - case RawBodyType.FileBody => - (ctx: ChannelHandlerContext) => ChunkedFileNettyResponseContent(ctx.newPromise(), wrap(v)) + (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) + case RawBodyType.FileBody => { (ctx: ChannelHandlerContext) => + ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) + } case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException } } - private def wrap(streamRange: InputStreamRange): ChunkedStream = { - streamRange.range - .map(r => new RangedChunkedStream(streamRange.inputStreamFromRangeStart(), r.contentLength)) - .getOrElse(new ChunkedStream(streamRange.inputStream())) + private def wrap(streamRange: InputStreamRange): Publisher[HttpContent] = { + new InputStreamPublisher[F](streamRange, DefaultChunkSize) } - private def wrap(content: InputStream): ChunkedStream = { - new ChunkedStream(content) + private def wrap(fileRange: FileRange): Publisher[HttpContent] = { + new FileRangePublisher(fileRange, DefaultChunkSize) } - private def wrap(content: FileRange): ChunkedFile = { - val file = content.file - val maybeRange = for { - range <- content.range - start <- range.start - end <- range.end - } yield (start, end + NettyToResponseBody.IncludingLastOffset) - - maybeRange match { - case Some((start, end)) => { - val randomAccessFile = new RandomAccessFile(file, NettyToResponseBody.ReadOnlyAccessMode) - new ChunkedFile(randomAccessFile, start, end - start, NettyToResponseBody.DefaultChunkSize) - } - case None => new ChunkedFile(file) - } + private def wrap(content: InputStream): Publisher[HttpContent] = { + wrap(InputStreamRange(() => content, range = None)) } override def fromStreamValue( @@ -95,8 +78,6 @@ class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { ): NettyResponse = throw new UnsupportedOperationException } -private[internal] object NettyToResponseBody { +private[netty] object NettyToResponseBody { val DefaultChunkSize = 8192 - val IncludingLastOffset = 1 - val ReadOnlyAccessMode = "r" } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index 8c45f8285c..0f335a7b14 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -2,49 +2,54 @@ package sttp.tapir.server.netty.internal import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext -import io.netty.handler.stream.{ChunkedFile, ChunkedStream} -import sttp.capabilities import sttp.capabilities.Streams import sttp.model.HasHeaders -import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.NettyResponseContent.{ - ByteBufNettyResponseContent, - ChunkedFileNettyResponseContent, - ChunkedStreamNettyResponseContent, - ReactivePublisherNettyResponseContent -} -import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput} +import sttp.tapir.server.netty.internal.NettyToResponseBody._ +import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} +import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} -import java.io.{InputStream, RandomAccessFile} import java.nio.ByteBuffer import java.nio.charset.Charset -class NettyToStreamsResponseBody[S <: Streams[S]](delegate: NettyToResponseBody, streamCompatible: StreamCompatible[S]) - extends ToResponseBody[NettyResponse, S] { +/** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries + * like fs2 or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. + * Other kinds of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. + */ +private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) extends ToResponseBody[NettyResponse, S] { override val streams: S = streamCompatible.streams override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { bodyType match { + case RawBodyType.StringBody(charset) => + val bytes = v.asInstanceOf[String].getBytes(charset) + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) + + case RawBodyType.ByteArrayBody => + val bytes = v.asInstanceOf[Array[Byte]] + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) + + case RawBodyType.ByteBufferBody => + val byteBuffer = v.asInstanceOf[ByteBuffer] + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) + case RawBodyType.InputStreamBody => (ctx: ChannelHandlerContext) => - new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromInputStream(() => v, length = None)) + new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromInputStream(() => v, DefaultChunkSize, length = None)) case RawBodyType.InputStreamRangeBody => (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent( ctx.newPromise(), - streamCompatible.publisherFromInputStream(v.inputStreamFromRangeStart, length = v.range.map(_.contentLength)) + streamCompatible.publisherFromInputStream(v.inputStreamFromRangeStart, DefaultChunkSize, length = v.range.map(_.contentLength)) ) case RawBodyType.FileBody => - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v)) + (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v, DefaultChunkSize)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException - - case _ => delegate.fromRawValue(v, headers, format, bodyType) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index a64d92dfca..6d5da177bd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -3,20 +3,28 @@ package sttp.tapir.server.netty.internal import io.netty.handler.codec.http.HttpContent import org.reactivestreams.Publisher import sttp.capabilities.Streams -import sttp.tapir.{FileRange, TapirFile} +import sttp.tapir.FileRange import java.io.InputStream +/** + * Operations on streams that have to be implemented for each streaming integration (fs2, zio-streams, etc) used by Netty backends. + * This includes conversions like building a stream from a `File`, an `InputStream`, or a reactive `Publisher`. + * We also need implementation of a failed (errored) stream, as well as an empty stream (for handling empty requests). + */ private[netty] trait StreamCompatible[S <: Streams[S]] { val streams: S - def fromFile(file: FileRange): streams.BinaryStream - def fromInputStream(is: () => InputStream, length: Option[Long]): streams.BinaryStream - def fromNettyStream(s: Publisher[HttpContent]): streams.BinaryStream + def fromFile(file: FileRange, chunkSize: Int): streams.BinaryStream + def fromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): streams.BinaryStream + def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream def asPublisher(s: streams.BinaryStream): Publisher[HttpContent] - def publisherFromFile(file: FileRange): Publisher[HttpContent] = - asPublisher(fromFile(file)) + def failedStream(e: => Throwable): streams.BinaryStream + def emptyStream: streams.BinaryStream - def publisherFromInputStream(is: () => InputStream, length: Option[Long]): Publisher[HttpContent] = - asPublisher(fromInputStream(is, length)) + def publisherFromFile(file: FileRange, chunkSize: Int): Publisher[HttpContent] = + asPublisher(fromFile(file, chunkSize)) + + def publisherFromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): Publisher[HttpContent] = + asPublisher(fromInputStream(is, chunkSize, length)) } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala new file mode 100644 index 0000000000..9e84324b63 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala @@ -0,0 +1,95 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} +import org.reactivestreams.{Publisher, Subscriber, Subscription} +import sttp.tapir.FileRange + +import java.nio.ByteBuffer +import java.nio.channels.{AsynchronousFileChannel, CompletionHandler} +import java.nio.file.StandardOpenOption +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +/** A Reactive Streams publisher which emits chunks of HttpContent read from a given file. + */ +class FileRangePublisher(fileRange: FileRange, chunkSize: Int) extends Publisher[HttpContent] { + override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { + if (subscriber == null) throw new NullPointerException("Subscriber cannot be null") + val subscription = new FileRangeSubscription(subscriber, fileRange, chunkSize) + subscriber.onSubscribe(subscription) + } + + private class FileRangeSubscription(subscriber: Subscriber[_ >: HttpContent], fileRange: FileRange, chunkSize: Int) extends Subscription { + private lazy val channel: AsynchronousFileChannel = AsynchronousFileChannel.open(fileRange.file.toPath(), StandardOpenOption.READ) + private val demand = new AtomicLong(0L) + private val position = new AtomicLong(fileRange.range.flatMap(_.start).getOrElse(0L)) + private val buffer: ByteBuffer = ByteBuffer.allocate(chunkSize) + private val isCompleted = new AtomicBoolean(false) + private val readingInProgress = new AtomicBoolean(false) + + override def request(n: Long): Unit = { + if (n <= 0) subscriber.onError(new IllegalArgumentException("§3.9: n must be greater than 0")) + else { + demand.addAndGet(n) + readNextChunkIfNeeded() + } + } + + /** Can be called multiple times by request(n), or concurrently by channel.read() callback. The readingInProgress check ensures that + * calls are serialized. A channel.read() operation will be started only if another isn't running. This method is non-blocking. + */ + private def readNextChunkIfNeeded(): Unit = { + if (demand.get() > 0 && !isCompleted.get() && readingInProgress.compareAndSet(false, true)) { + val pos = position.get() + val expectedBytes: Int = fileRange.range.flatMap(_.end) match { + case Some(endPos) if pos + chunkSize > endPos => (endPos - pos + 1).toInt + case _ => chunkSize + } + buffer.clear() + // Async call, so readNextChunkIfNeeded() finishes immediately after firing this + channel.read( + buffer, + pos, + null, + new CompletionHandler[Integer, Void] { + override def completed(bytesRead: Integer, attachment: Void): Unit = { + if (bytesRead == -1) { + cancel() + subscriber.onComplete() + } else { + val bytesToRead = Math.min(bytesRead, expectedBytes) + // The buffer is modified only by one thread at a time, because only one channel.read() + // is running at a time, and because buffer.clear() calls before the read are guarded + // by readingInProgress.compareAndSet. + buffer.flip() + val bytes = new Array[Byte](bytesToRead) + buffer.get(bytes) + position.addAndGet(bytesToRead.toLong) + subscriber.onNext(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes))) + if (bytesToRead < expectedBytes) { + cancel() + subscriber.onComplete() + } else { + demand.decrementAndGet() + readingInProgress.set(false) + // Either this call, or a call from request(n) will win the race to + // actually start a new read. + readNextChunkIfNeeded() + } + } + } + + override def failed(exc: Throwable, attachment: Void): Unit = { + subscriber.onError(exc) + } + } + ) + } + } + + override def cancel(): Unit = { + isCompleted.set(true) + channel.close() + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala new file mode 100644 index 0000000000..e7c4ca0479 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -0,0 +1,82 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.{Publisher, Subscription} + +import java.nio.channels.AsynchronousFileChannel +import java.nio.file.{Path, StandardOpenOption} +import scala.concurrent.{Future, Promise} +import java.util.concurrent.LinkedBlockingQueue + +/** A Reactive Streams subscriber which receives chunks of bytes and writes them to a file. + */ +class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpContent] { + private var subscription: Subscription = _ + + /** JDK interface to write asynchronously to a file */ + private var fileChannel: AsynchronousFileChannel = _ + + /** Current position in the file */ + @volatile private var position: Long = 0 + + /** Used to signal completion, so that external code can represent writing to a file as Future[Unit] */ + private val resultPromise = Promise[Unit]() + + /** An alternative way to signal completion, so that non-effectful servers can await on the response (like netty-loom) */ + private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Unit]]() + + override def future: Future[Unit] = resultPromise.future + private def waitForResultBlocking(): Either[Throwable, Unit] = resultBlockingQueue.take() + + override def onSubscribe(s: Subscription): Unit = { + this.subscription = s + fileChannel = AsynchronousFileChannel.open(path, StandardOpenOption.WRITE, StandardOpenOption.CREATE) + s.request(1) + } + + override def onNext(httpContent: HttpContent): Unit = { + val byteBuffer = httpContent.content().nioBuffer() + fileChannel.write( + byteBuffer, + position, + (), + new java.nio.channels.CompletionHandler[Integer, Unit] { + override def completed(result: Integer, attachment: Unit): Unit = { + position += result + subscription.request(1) + } + + override def failed(exc: Throwable, attachment: Unit): Unit = { + subscription.cancel() + onError(exc) + } + } + ) + } + + override def onError(t: Throwable): Unit = { + fileChannel.close() + resultBlockingQueue.add(Left(t)) + resultPromise.failure(t) + } + + override def onComplete(): Unit = { + fileChannel.close() + val _ = resultBlockingQueue.add(Right(())) + resultPromise.success(()) + } +} + +object FileWriterSubscriber { + def processAll(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Future[Unit] = { + val subscriber = new FileWriterSubscriber(path) + publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)) + subscriber.future + } + + def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit = { + val subscriber = new FileWriterSubscriber(path) + publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)) + subscriber.waitForResultBlocking().left.foreach(e => throw e) + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala new file mode 100644 index 0000000000..7f16c0a108 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala @@ -0,0 +1,85 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} +import org.reactivestreams.{Publisher, Subscriber, Subscription} +import sttp.tapir.InputStreamRange + +import java.io.InputStream +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.util.Try +import sttp.monad.MonadError +import sttp.monad.syntax._ + +class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implicit monad: MonadError[F]) extends Publisher[HttpContent] { + override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { + if (subscriber == null) throw new NullPointerException("Subscriber cannot be null") + val subscription = new InputStreamSubscription(subscriber, range, chunkSize) + subscriber.onSubscribe(subscription) + } + + private class InputStreamSubscription(subscriber: Subscriber[_ >: HttpContent], range: InputStreamRange, chunkSize: Int) + extends Subscription { + private lazy val stream: InputStream = range.inputStreamFromRangeStart() + private val demand = new AtomicLong(0L) + private val position = new AtomicLong(range.range.flatMap(_.start).getOrElse(0L)) + private val isCompleted = new AtomicBoolean(false) + private val readingInProgress = new AtomicBoolean(false) + + override def request(n: Long): Unit = { + if (n <= 0) subscriber.onError(new IllegalArgumentException("§3.9: n must be greater than 0")) + else { + demand.addAndGet(n) + readNextChunkIfNeeded() + } + } + + /** Non-blocking by itself, starts an asynchronous operation with blocking stream.readNBytes. Can be called multiple times by + * request(n), or concurrently by onComplete callback. The readingInProgress check ensures that calls are serialized. A + * stream.readNBytes operation will be started only if another isn't running. + */ + private def readNextChunkIfNeeded(): Unit = { + if (demand.get() > 0 && !isCompleted.get() && readingInProgress.compareAndSet(false, true)) { + val pos = position.get() + val expectedBytes: Int = range.range.flatMap(_.end) match { + case Some(endPos) if pos + chunkSize > endPos => (endPos - pos + 1).toInt + case _ => chunkSize + } + + val _ = monad + .blocking( + stream.readNBytes(expectedBytes) + ) + .map { bytes => + val bytesRead = bytes.length + if (bytesRead == 0) { + cancel() + subscriber.onComplete() + } else { + position.addAndGet(bytesRead.toLong) + subscriber.onNext(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes))) + if (bytesRead < expectedBytes) { + cancel() + subscriber.onComplete() + } else { + demand.decrementAndGet() + readingInProgress.set(false) + readNextChunkIfNeeded() + } + } + } + .handleError { + case e => { + val _ = Try(stream.close()) + monad.unit(subscriber.onError(e)) + } + } + } + } + + override def cancel(): Unit = { + isCompleted.set(true) + val _ = Try(stream.close()) + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala new file mode 100644 index 0000000000..7a670ebf9e --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala @@ -0,0 +1,38 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.{Subscriber, Subscription} +import sttp.capabilities.StreamMaxLengthExceededException + +import scala.collection.JavaConverters._ + +// based on org.asynchttpclient.request.body.generator.ReactiveStreamsBodyGenerator.SimpleSubscriber +private[netty] class LimitedLengthSubscriber[R](maxBytes: Long, delegate: Subscriber[HttpContent]) extends Subscriber[HttpContent] { + private var subscription: Subscription = _ + private var bytesReadSoFar = 0L + + override def onSubscribe(s: Subscription): Unit = { + subscription = s + delegate.onSubscribe(s) + } + + override def onNext(content: HttpContent): Unit = { + bytesReadSoFar = bytesReadSoFar + content.content.readableBytes() + if (bytesReadSoFar > maxBytes) { + subscription.cancel() + onError(StreamMaxLengthExceededException(maxBytes)) + subscription = null + } else + delegate.onNext(content) + } + + override def onError(t: Throwable): Unit = { + if (subscription != null) + delegate.onError(t) + } + + override def onComplete(): Unit = { + if (subscription != null) + delegate.onComplete() + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/PromisingSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/PromisingSubscriber.scala new file mode 100644 index 0000000000..5c0bef2545 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/PromisingSubscriber.scala @@ -0,0 +1,9 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import org.reactivestreams.Subscriber + +import scala.concurrent.Future + +trait PromisingSubscriber[R, A] extends Subscriber[A] { + def future: Future[R] +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala new file mode 100644 index 0000000000..40138b3614 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -0,0 +1,67 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.buffer.ByteBufUtil +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.{Publisher, Subscription} + +import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.JavaConverters._ +import scala.concurrent.{Future, Promise} +import java.util.concurrent.LinkedBlockingQueue + +private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], HttpContent] { + private var subscription: Subscription = _ + private val chunks = new ConcurrentLinkedQueue[Array[Byte]]() + private var size = 0 + private val resultPromise = Promise[Array[Byte]]() + private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Array[Byte]]]() + + override def future: Future[Array[Byte]] = resultPromise.future + def resultBlocking(): Either[Throwable, Array[Byte]] = resultBlockingQueue.take() + + override def onSubscribe(s: Subscription): Unit = { + subscription = s + s.request(1) + } + + override def onNext(content: HttpContent): Unit = { + val a = ByteBufUtil.getBytes(content.content()) + size += a.length + chunks.add(a) + subscription.request(1) + } + + override def onError(t: Throwable): Unit = { + chunks.clear() + resultBlockingQueue.add(Left(t)) + resultPromise.failure(t) + } + + override def onComplete(): Unit = { + val result = new Array[Byte](size) + val _ = chunks.asScala.foldLeft(0)((currentPosition, array) => { + System.arraycopy(array, 0, result, currentPosition, array.length) + currentPosition + array.length + }) + chunks.clear() + resultBlockingQueue.add(Right(result)) + resultPromise.success(result) + } +} + +object SimpleSubscriber { + def processAll(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] = { + val subscriber = new SimpleSubscriber() + publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) + subscriber.future + } + + def processAllBlocking(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = { + val subscriber = new SimpleSubscriber() + publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) + subscriber.resultBlocking() match { + case Right(result) => result + case Left(e) => throw e + } + } +} diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala index b7d86c5e4a..5125a86532 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala @@ -21,8 +21,9 @@ class NettyFutureServerTest extends TestSuite with EitherValues { val interpreter = new NettyFutureTestServerInterpreter(eventLoopGroup) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ - new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() + val tests = + new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = true).tests() ++ + new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() (tests, eventLoopGroup) }) { case (_, eventLoopGroup) => diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala index 0073136a5e..7eb0867ff1 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala @@ -24,7 +24,7 @@ class NettyFutureTestServerInterpreter(eventLoopGroup: NioEventLoopGroup)(implic gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, KillSwitch)] = { val config = - NettyConfig.defaultNoStreaming + NettyConfig.default .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala deleted file mode 100644 index a28da7007a..0000000000 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala +++ /dev/null @@ -1,61 +0,0 @@ -package sttp.tapir.server.netty.internal - -import org.playframework.netty.http.StreamedHttpRequest -import io.netty.buffer.ByteBufUtil -import io.netty.handler.codec.http.FullHttpRequest -import sttp.capabilities.zio.ZioStreams -import sttp.tapir.RawBodyType._ -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} -import zio.interop.reactivestreams._ -import zio.stream.{ZStream, _} -import zio.{Chunk, RIO, ZIO} - -import java.io.ByteArrayInputStream -import java.nio.ByteBuffer - -private[netty] class NettyZioRequestBody[Env](createFile: ServerRequest => RIO[Env, TapirFile]) - extends RequestBody[RIO[Env, *], ZioStreams] { - - override val streams: ZioStreams = ZioStreams - - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): RIO[Env, RawValue[R]] = { - bodyType match { - case StringBody(charset) => nettyRequestBytes(serverRequest).map(bs => RawValue(new String(bs, charset))) - - case ByteArrayBody => - nettyRequestBytes(serverRequest).map(RawValue(_)) - case ByteBufferBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(ByteBuffer.wrap(bs))) - case InputStreamBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(new ByteArrayInputStream(bs))) - case InputStreamRangeBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) - case FileBody => - createFile(serverRequest) - .flatMap(tapirFile => { - toStream(serverRequest, maxBytes = None) // TODO createFile() should also have maxBytes - .run(ZSink.fromFile(tapirFile)) - .map(_ => RawValue(FileRange(tapirFile), Seq(FileRange(tapirFile)))) - }) - case MultipartBody(partTypes, defaultType) => - throw new java.lang.UnsupportedOperationException() - } - } - - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { - val stream = serverRequest.underlying - .asInstanceOf[StreamedHttpRequest] - .toZIOStream() - .flatMap(httpContent => ZStream.fromChunk(Chunk.fromByteBuffer(httpContent.content.nioBuffer()))) - maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) - } - - private def nettyRequestBytes(serverRequest: ServerRequest): RIO[Env, Array[Byte]] = serverRequest.underlying match { - case req: FullHttpRequest => ZIO.succeed(ByteBufUtil.getBytes(req.content())) - case _: StreamedHttpRequest => toStream(serverRequest, maxBytes = None).run(ZSink.collectAll[Byte]).map(_.toArray) // TODO - case other => ZIO.fail(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) - } - -} diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 968e0d1013..be553f0ccf 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -89,7 +89,6 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: new NettyServerHandler[RIO[R, *]]( route, unsafeRunAsync(runtime), - config.maxContentLength, channelGroup, isShuttingDown ), @@ -141,9 +140,9 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: } object NettyZioServer { - def apply[R](): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], NettyConfig.defaultWithStreaming) + def apply[R](): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], NettyConfig.default) def apply[R](options: NettyZioServerOptions[R]): NettyZioServer[R] = - NettyZioServer(Vector.empty, options, NettyConfig.defaultWithStreaming) + NettyZioServer(Vector.empty, options, NettyConfig.default) def apply[R](config: NettyConfig): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], config) def apply[R](options: NettyZioServerOptions[R], config: NettyConfig): NettyZioServer[R] = NettyZioServer(Vector.empty, options, config) } diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala index 56a16bc622..64a91e3cff 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala @@ -6,7 +6,7 @@ import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} import sttp.tapir.server.netty.zio.NettyZioServerInterpreter.ZioRunAsync -import sttp.tapir.server.netty.zio.internal.ZioStreamCompatible +import sttp.tapir.server.netty.zio.internal.{NettyZioRequestBody, ZioStreamCompatible} import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint, _} import zio._ @@ -26,8 +26,8 @@ trait NettyZioServerInterpreter[R] { implicit val bodyListener: BodyListener[F, NettyResponse] = new NettyBodyListener(runAsync) val serverInterpreter = new ServerInterpreter[ZioStreams, F, NettyResponse, ZioStreams]( FilterServerEndpoints(widenedSes), - new NettyZioRequestBody(widenedServerOptions.createFile), - new NettyToStreamsResponseBody[ZioStreams](delegate = new NettyToResponseBody(), ZioStreamCompatible(runtime)), + new NettyZioRequestBody(widenedServerOptions.createFile, ZioStreamCompatible(runtime)), + new NettyToStreamsResponseBody[ZioStreams](ZioStreamCompatible(runtime)), RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes), widenedServerOptions.deleteFile ) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala new file mode 100644 index 0000000000..2e551cad81 --- /dev/null +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala @@ -0,0 +1,27 @@ +package sttp.tapir.server.netty.zio.internal + +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher +import sttp.capabilities.zio.ZioStreams +import sttp.monad.MonadError +import sttp.tapir.TapirFile +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} +import sttp.tapir.ztapir.RIOMonadError +import zio.RIO +import zio.stream._ + +private[zio] class NettyZioRequestBody[Env]( + val createFile: ServerRequest => RIO[Env, TapirFile], + val streamCompatible: StreamCompatible[ZioStreams] +) extends NettyStreamingRequestBody[RIO[Env, *], ZioStreams] { + + override val streams: ZioStreams = ZioStreams + override implicit val monad: MonadError[RIO[Env, *]] = new RIOMonadError[Env] + + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): RIO[Env, Array[Byte]] = + streamCompatible.fromPublisher(publisher, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray) + + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): RIO[Env, Unit] = + toStream(serverRequest, maxBytes).run(ZSink.fromFile(file)).map(_ => ()) +} diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala index c5ecd41fde..7ec3ae4fa9 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala @@ -18,12 +18,12 @@ private[zio] object ZioStreamCompatible { new StreamCompatible[ZioStreams] { override val streams: ZioStreams = ZioStreams - override def fromFile(fileRange: FileRange): streams.BinaryStream = { + override def fromFile(fileRange: FileRange, chunkSize: Int): streams.BinaryStream = { fileRange.range .flatMap(r => r.startAndEnd.map { case (fStart, _) => ZStream - .fromPath(fileRange.file.toPath) + .fromPath(fileRange.file.toPath, chunkSize) .drop(fStart.toInt) .take(r.contentLength) } @@ -33,10 +33,10 @@ private[zio] object ZioStreamCompatible { ) } - override def fromInputStream(is: () => InputStream, length: Option[Long]): streams.BinaryStream = + override def fromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): streams.BinaryStream = length match { - case Some(limitedLength) => ZStream.fromInputStream(is()).take(limitedLength.toInt) - case None => ZStream.fromInputStream(is()) + case Some(limitedLength) => ZStream.fromInputStream(is(), chunkSize).take(limitedLength.toInt) + case None => ZStream.fromInputStream(is(), chunkSize) } override def asPublisher(stream: Stream[Throwable, Byte]): Publisher[HttpContent] = @@ -46,8 +46,20 @@ private[zio] object ZioStreamCompatible { .getOrThrowFiberFailure() ) - override def fromNettyStream(publisher: Publisher[HttpContent]): Stream[Throwable, Byte] = - publisher.toZIOStream().mapConcatChunk(httpContent => Chunk.fromByteBuffer(httpContent.content.nioBuffer())) + override def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { + val stream = + Adapters + .publisherToStream(publisher, bufferSize = 2) + .map(httpContent => Chunk.fromByteBuffer(httpContent.content.nioBuffer())) + .flattenChunks + maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) + } + + override def failedStream(e: => Throwable): streams.BinaryStream = + ZStream.fail(e) + + override def emptyStream: streams.BinaryStream = + ZStream.empty } } } diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala index e5f65a8173..134de29376 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala @@ -34,7 +34,14 @@ class NettyZioServerTest extends TestSuite with EitherValues { } val tests = - new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++ + new AllServerTests( + createServerTest, + interpreter, + backend, + staticContent = false, + multipart = false, + maxContentLength = true + ).tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests() ++ new ServerGracefulShutdownTests(createServerTest, zioSleeper).tests() diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala index 2f17fdefaf..ed6ad1cb7c 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala @@ -25,7 +25,7 @@ class NettyZioTestServerInterpreter[R](eventLoopGroup: NioEventLoopGroup) routes: NonEmptyList[Task[Route[Task]]], gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, KillSwitch)] = { - val config = NettyConfig.defaultWithStreaming + val config = NettyConfig.default .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose diff --git a/server/nima-server/src/main/scala/sttp/tapir/server/nima/internal/NimaRequestBody.scala b/server/nima-server/src/main/scala/sttp/tapir/server/nima/internal/NimaRequestBody.scala index 8d2fd4972e..a9d31f11c2 100644 --- a/server/nima-server/src/main/scala/sttp/tapir/server/nima/internal/NimaRequestBody.scala +++ b/server/nima-server/src/main/scala/sttp/tapir/server/nima/internal/NimaRequestBody.scala @@ -14,7 +14,7 @@ import java.nio.file.{Files, StandardCopyOption} private[nima] class NimaRequestBody(createFile: ServerRequest => TapirFile) extends RequestBody[Id, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): RawValue[RAW] = { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = { def asInputStream = nimaRequest(serverRequest).content().inputStream() def asByteArray = asInputStream.readAllBytes() diff --git a/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala b/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala index 206aa429f7..3dac104b2e 100644 --- a/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala +++ b/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala @@ -22,7 +22,7 @@ private[pekkogrpc] class PekkoGrpcRequestBody(serverOptions: PekkoHttpServerOpti private val grpcProtocol = GrpcProtocolNative.newReader(Identity) override val streams: PekkoStreams = PekkoStreams - override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = + override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = toRawFromEntity(request, akkaRequestEntity(request), bodyType) override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? diff --git a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala index 629ff26467..2f37d6e28c 100644 --- a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala +++ b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala @@ -21,7 +21,7 @@ private[pekkohttp] class PekkoRequestBody(serverOptions: PekkoHttpServerOptions) ec: ExecutionContext ) extends RequestBody[Future, PekkoStreams] { override val streams: PekkoStreams = PekkoStreams - override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = + override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = toRawFromEntity(request, akkeRequestEntity(request), bodyType) override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { val stream = akkeRequestEntity(request).dataBytes diff --git a/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala b/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala index 56614ec266..e3feeaaf05 100644 --- a/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala +++ b/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala @@ -23,7 +23,7 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit override val streams: PekkoStreams = PekkoStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = { import mat.executionContext val request = playRequest(serverRequest) val charset = request.charset.map(Charset.forName) diff --git a/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpRequestBody.scala b/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpRequestBody.scala index 24743fb797..c8adc730d1 100644 --- a/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpRequestBody.scala +++ b/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpRequestBody.scala @@ -14,7 +14,7 @@ import scala.annotation.tailrec class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, AnyStreams] { override val streams: AnyStreams = AnyStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = body(serverRequest) match { case Left(bytes) => bodyType match { diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala index 3e6b762346..922e6973eb 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala @@ -28,7 +28,7 @@ class AllServerTests[F[_], OPTIONS, ROUTE]( oneOfBody: Boolean = true, cors: Boolean = true, options: Boolean = true, - maxContentLength: Option[Int] = None + maxContentLength: Boolean = false // TODO let's work towards making this true by default )(implicit m: MonadError[F] ) { diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index ecda573c9c..10d24d60f7 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -15,6 +15,8 @@ import sttp.tapir.codec.enumeratum.TapirCodecEnumeratum import sttp.tapir.generic.auto._ import sttp.tapir.json.circe._ import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.model.EndpointExtensions._ +import sttp.tapir.server.model._ import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler import sttp.tapir.tests.Basic._ import sttp.tapir.tests.TestUtil._ @@ -23,6 +25,7 @@ import sttp.tapir.tests.data.{FruitAmount, FruitError} import java.io.{ByteArrayInputStream, InputStream} import java.nio.ByteBuffer +import sttp.tapir.tests.Files.in_file_out_file class ServerBasicTests[F[_], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE], @@ -32,7 +35,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( supportsUrlEncodedPathSegments: Boolean = true, supportsMultipleSetCookieHeaders: Boolean = true, invulnerableToUnsanitizedHeaders: Boolean = true, - maxContentLength: Option[Int] = None + maxContentLength: Boolean = false )(implicit m: MonadError[F] ) { @@ -47,7 +50,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( customiseDecodeFailureHandlerTests() ++ serverSecurityLogicTests() ++ (if (inputStreamSupport) inputStreamTests() else Nil) ++ - (if (maxContentLength.nonEmpty) maxContentLengthTests() else Nil) ++ + (if (maxContentLength) maxContentLengthTests else Nil) ++ exceptionTests() def basicTests(): List[Test] = List( @@ -744,11 +747,42 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } ) - def maxContentLengthTests(): List[Test] = List( - testServer(in_string_out_string, "returns 413 on exceeded max content length")(_ => - pureResult(List.fill(maxContentLength.getOrElse(0) + 1)('x').mkString.asRight[Unit]) - ) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("irrelevant").send(backend).map(_.code.code shouldBe 413) } - ) + def testPayloadTooLarge[I]( + testedEndpoint: PublicEndpoint[I, Unit, I, Any], + maxLength: Int + ) = testServer( + testedEndpoint.maxRequestBodyLength(maxLength.toLong), + "returns 413 on exceeded max content length (request)" + )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => + val tooLargeBody: String = List.fill(maxLength + 1)('x').mkString + basicRequest.post(uri"$baseUri/api/echo").body(tooLargeBody).send(backend).map(_.code shouldBe StatusCode.PayloadTooLarge) + } + def testPayloadWithinLimit[I]( + testedEndpoint: PublicEndpoint[I, Unit, I, Any], + maxLength: Int + ) = testServer( + testedEndpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxLength.toLong)), + "returns OK on content length below or equal max (request)" + )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => + val fineBody: String = List.fill(maxLength)('x').mkString + basicRequest.post(uri"$baseUri/api/echo").body(fineBody).send(backend).map(_.code shouldBe StatusCode.Ok) + } + + def maxContentLengthTests: List[Test] = { + val maxLength = 17000 // To generate a few chunks of default size 8192 + some extra bytes + List( + testPayloadTooLarge(in_string_out_string, maxLength), + testPayloadTooLarge(in_byte_array_out_byte_array, maxLength), + testPayloadTooLarge(in_file_out_file, maxLength), + testPayloadTooLarge(in_input_stream_out_input_stream, maxLength), + testPayloadTooLarge(in_byte_buffer_out_byte_buffer, maxLength), + testPayloadWithinLimit(in_string_out_string, maxLength), + testPayloadWithinLimit(in_input_stream_out_input_stream, maxLength), + testPayloadWithinLimit(in_byte_array_out_byte_array, maxLength), + testPayloadWithinLimit(in_file_out_file, maxLength), + testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength) + ) + } def exceptionTests(): List[Test] = List( testServer(endpoint, "handle exceptions")(_ => throw new RuntimeException()) { (backend, baseUri) => diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index 1d0043188e..2a66e8326e 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -9,7 +9,7 @@ import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.tests.Test import sttp.tapir.tests.Streaming._ -import sttp.tapir.server.interpreter.MaxContentLength +import sttp.tapir.server.model.MaxContentLength import sttp.tapir.AttributeKey import cats.effect.IO import sttp.capabilities.fs2.Fs2Streams diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala index 5d551b9109..a2f10720e7 100644 --- a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala @@ -26,7 +26,7 @@ class VertxRequestBody[F[_], S <: Streams[S]]( extends RequestBody[F, S] { override val streams: Streams[S] = readStreamCompatible.streams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val rc = routingContext(serverRequest) fromVFuture(bodyType match { case RawBodyType.StringBody(defaultCharset) => diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala index 7fbfb1a01e..0631d4a0b9 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala @@ -17,7 +17,7 @@ import java.nio.ByteBuffer class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] { override val streams: capabilities.Streams[ZioStreams] = ZioStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): Task[RawValue[RAW]] = bodyType match { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = bodyType match { case RawBodyType.StringBody(defaultCharset) => asByteArray(serverRequest).map(new String(_, defaultCharset)).map(RawValue(_)) case RawBodyType.ByteArrayBody => asByteArray(serverRequest).map(RawValue(_)) case RawBodyType.ByteBufferBody => asByteArray(serverRequest).map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) diff --git a/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala b/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala index 2797590752..4b178e546a 100644 --- a/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala +++ b/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala @@ -18,7 +18,7 @@ import java.nio.ByteBuffer class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] { override val streams: capabilities.Streams[ZioStreams] = ZioStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): Task[RawValue[RAW]] = bodyType match { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = bodyType match { case RawBodyType.StringBody(defaultCharset) => asByteArray(serverRequest).map(new String(_, defaultCharset)).map(RawValue(_)) case RawBodyType.ByteArrayBody => asByteArray(serverRequest).map(RawValue(_)) case RawBodyType.ByteBufferBody => asByteArray(serverRequest).map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) diff --git a/serverless/aws/lambda-core/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala b/serverless/aws/lambda-core/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala index b498fa60ba..077229f5ce 100644 --- a/serverless/aws/lambda-core/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala +++ b/serverless/aws/lambda-core/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala @@ -15,7 +15,7 @@ import java.util.Base64 private[lambda] class AwsRequestBody[F[_]: MonadError]() extends RequestBody[F, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val request = awsRequest(serverRequest) val decoded = if (request.isBase64Encoded) Left(Base64.getDecoder.decode(request.body.getOrElse(""))) else Right(request.body.getOrElse(""))