Skip to content

Commit

Permalink
MaxContentLength support for Netty pt 2 (#3337)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored Dec 4, 2023
1 parent efcbf9b commit abeb5d7
Show file tree
Hide file tree
Showing 68 changed files with 971 additions and 504 deletions.
20 changes: 20 additions & 0 deletions doc/endpoint/security.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ will show you a password prompt.
Optional and multiple authentication inputs have some additional rules as to how hey map to documentation, see the
["Authentication inputs and security requirements"](../docs/openapi.md) section in the OpenAPI docs for details.

## Limiting request body length

*Supported backends*:
Feature enabled only for Netty-based servers. More backends will be added in the near future.

Individual endpoints can be annotated with content length limit:

```scala mdoc:compile-only
import sttp.tapir._
import sttp.tapir.server.model.EndpointExtensions._

val limitedEndpoint = endpoint.maxRequestBodyLength(maxBytes = 163484L)
```

The `EndpointsExtensions` utility is available in `tapir-server` core module.
Such protection prevents loading all the input data if it exceeds the limit. Instead, it will result in a `HTTP 413`
response to the client.
Please note that in case of endpoints with `streamBody` input type, the server logic receives a reference to a lazily
evaluated stream, so actual length verification will happen only when the logic performs streams processing, not earlier.

## Next

Read on about [streaming support](streaming.md).
4 changes: 4 additions & 0 deletions doc/migrating.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Migrating

## From 1.9.3 to 1.9.4

- `NettyConfig.defaultNoStreaming` has been removed, use `NettyConfig.default`.

## From 1.4 to 1.5

- `badRequestOnPathErrorIfPathShapeMatches` and `badRequestOnPathInvalidIfPathShapeMatches` have been removed from `DefaultDecodeFailureHandler`. These flags were causing confusion and incosistencies caused by specifics of ZIO and Play backends. Before tapir 1.5, keeping defaults (`false` and `true` respectively for these flags) meant that some path segment decoding failures (specifically, errors - when an exception has been thrown during decoding, but not for e.g. enumeration mismatches) were translated to a "no-match", meaning that the next endpoint was attempted. From 1.5, tapir defaults to a 400 Bad Request response to be sent instead, on all path decoding failures.
Expand Down
6 changes: 3 additions & 3 deletions doc/server/netty.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ NettyFutureServer().port(9090).addEndpoints(???)
NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None).options)

// customise Netty config
NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256))
NettyFutureServer(NettyConfig.default.socketBacklog(256))
```

## Graceful shutdown
Expand All @@ -93,9 +93,9 @@ import sttp.tapir.server.netty.NettyConfig
import scala.concurrent.duration._

// adjust the waiting time to your needs
val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds)
val config = NettyConfig.default.withGracefulShutdownTimeout(5.seconds)
// or if you don't want the server to wait for in-flight requests
val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown
val config2 = NettyConfig.default.noGracefulShutdown
```

## Domain socket support
Expand Down
6 changes: 3 additions & 3 deletions generated-doc/out/server/netty.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ NettyFutureServer().port(9090).addEndpoints(???)
NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None).options)

// customise Netty config
NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256))
NettyFutureServer(NettyConfig.default.socketBacklog(256))
```

## Graceful shutdown
Expand All @@ -93,9 +93,9 @@ import sttp.tapir.server.netty.NettyConfig
import scala.concurrent.duration._

// adjust the waiting time to your needs
val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds)
val config = NettyConfig.default.withGracefulShutdownTimeout(5.seconds)
// or if you don't want the server to wait for in-flight requests
val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown
val config2 = NettyConfig.default.noGracefulShutdown
```

## Domain socket support
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object ZTapirTest extends ZIOSpecDefault with ZTapir {

private val exampleRequestBody = new RequestBody[TestEffect, RequestBodyType] {
override val streams: Streams[RequestBodyType] = null.asInstanceOf[Streams[RequestBodyType]]
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): TestEffect[RawValue[R]] = ???
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): TestEffect[RawValue[R]] = ???
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ???
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object ZTapirTest extends DefaultRunnableSpec with ZTapir {

private val exampleRequestBody = new RequestBody[TestEffect, RequestBodyType] {
override val streams: Streams[RequestBodyType] = null.asInstanceOf[Streams[RequestBodyType]]
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): TestEffect[RawValue[R]] = ???
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): TestEffect[RawValue[R]] = ???
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ???
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ private[akkagrpc] class AkkaGrpcRequestBody(serverOptions: AkkaHttpServerOptions
private val grpcProtocol = GrpcProtocolNative.newReader(Identity)

override val streams: AkkaStreams = AkkaStreams
override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] =
override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] =
toRawFromEntity(request, akkaRequestEntity(request), bodyType)

override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ???
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(im
ec: ExecutionContext
) extends RequestBody[Future, AkkaStreams] {
override val streams: AkkaStreams = AkkaStreams
override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] =
override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] =
toRawFromEntity(request, akkeRequestEntity(request), bodyType)
override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val stream = akkeRequestEntity(request).dataBytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private[armeria] final class ArmeriaRequestBody[F[_], S <: Streams[S]](
.asInstanceOf[streams.BinaryStream]
}

override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = {
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = {
val ctx = armeriaCtx(serverRequest)
val request = ctx.request()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import sttp.tapir.server.model.ValuedEndpointOutput
import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, ValidationError, Validator, server, _}

import scala.annotation.tailrec
import sttp.capabilities.StreamMaxLengthExceededException

trait DecodeFailureHandler[F[_]] {

Expand Down Expand Up @@ -122,11 +123,12 @@ object DefaultDecodeFailureHandler {
case (_: EndpointIO.Header[_], _) => respondBadRequest
case (fh: EndpointIO.FixedHeader[_], _: DecodeResult.Mismatch) if fh.h.name == HeaderNames.ContentType =>
respondUnsupportedMediaType
case (_: EndpointIO.FixedHeader[_], _) => respondBadRequest
case (_: EndpointIO.Headers[_], _) => respondBadRequest
case (_: EndpointIO.Body[_, _], _) => respondBadRequest
case (_: EndpointIO.OneOfBody[_, _], _: DecodeResult.Mismatch) => respondUnsupportedMediaType
case (_: EndpointIO.StreamBodyWrapper[_, _], _) => respondBadRequest
case (_: EndpointIO.FixedHeader[_], _) => respondBadRequest
case (_: EndpointIO.Headers[_], _) => respondBadRequest
case (_, DecodeResult.Error(_, _: StreamMaxLengthExceededException)) => respondPayloadTooLarge
case (_: EndpointIO.Body[_, _], _) => respondBadRequest
case (_: EndpointIO.OneOfBody[_, _], _: DecodeResult.Mismatch) => respondUnsupportedMediaType
case (_: EndpointIO.StreamBodyWrapper[_, _], _) => respondBadRequest
// we assume that the only decode failure that might happen during path segment decoding is an error
// a non-standard path decoder might return Missing/Multiple/Mismatch, but that would be indistinguishable from
// a path shape mismatch
Expand All @@ -143,6 +145,7 @@ object DefaultDecodeFailureHandler {
}
private val respondBadRequest = Some(onlyStatus(StatusCode.BadRequest))
private val respondUnsupportedMediaType = Some(onlyStatus(StatusCode.UnsupportedMediaType))
private val respondPayloadTooLarge = Some(onlyStatus(StatusCode.PayloadTooLarge))

def respondNotFoundIfHasAuth(
ctx: DecodeFailureContext,
Expand Down Expand Up @@ -224,10 +227,12 @@ object DefaultDecodeFailureHandler {
}
.mkString(", ")
)
case Missing => Some("missing")
case Multiple(_) => Some("multiple values")
case Mismatch(_, _) => Some("value mismatch")
case _ => None
case Missing => Some("missing")
case Multiple(_) => Some("multiple values")
case Mismatch(_, _) => Some("value mismatch")
case Error(_, StreamMaxLengthExceededException(maxBytes)) => Some(s"Content length limit: $maxBytes bytes")
case _: Error => None
case _: InvalidValue => None
}

def combineSourceAndDetail(source: String, detail: Option[String]): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@ package sttp.tapir.server.interpreter
import sttp.capabilities.Streams
import sttp.model.Part
import sttp.tapir.model.ServerRequest
import sttp.tapir.AttributeKey
import sttp.tapir.EndpointInfo
import sttp.tapir.{FileRange, RawBodyType, RawPart}

case class MaxContentLength(value: Long)

trait RequestBody[F[_], S] {
val streams: Streams[S]
def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]]
def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]]
def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream

}

case class RawValue[R](value: R, createdFiles: Seq[FileRange] = Nil)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package sttp.tapir.server.interpreter

import sttp.capabilities.StreamMaxLengthExceededException
import sttp.model.{Headers, StatusCode}
import sttp.monad.MonadError
import sttp.monad.syntax._
import sttp.tapir.internal.{Params, ParamsAsAny, RichOneOfBody}
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.{model, _}
import sttp.tapir.server.interceptor._
import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput}
import sttp.tapir.server.model.{MaxContentLength, ServerResponse, ValuedEndpointOutput}
import sttp.tapir.server.{model, _}
import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile}
import sttp.tapir.EndpointInfo
import sttp.tapir.AttributeKey
Expand Down Expand Up @@ -157,7 +158,7 @@ class ServerInterpreter[R, F[_], B, S](
values.bodyInputWithIndex match {
case Some((Left(oneOfBodyInput), _)) =>
oneOfBodyInput.chooseBodyToDecode(request.contentTypeParsed) match {
case Some(Left(body)) => decodeBody(request, values, body)
case Some(Left(body)) => decodeBody(request, values, body, maxBodyLength)
case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body, maxBodyLength)
case None => unsupportedInputMediaTypeResponse(request, oneOfBodyInput)
}
Expand All @@ -182,17 +183,23 @@ class ServerInterpreter[R, F[_], B, S](
private def decodeBody[RAW, T](
request: ServerRequest,
values: DecodeBasicInputsResult.Values,
bodyInput: EndpointIO.Body[RAW, T]
bodyInput: EndpointIO.Body[RAW, T],
maxBodyLength: Option[Long]
): F[DecodeBasicInputsResult] = {
requestBody.toRaw(request, bodyInput.bodyType).flatMap { v =>
bodyInput.codec.decode(v.value) match {
case DecodeResult.Value(bodyV) => (values.setBodyInputValue(bodyV): DecodeBasicInputsResult).unit
case failure: DecodeResult.Failure =>
v.createdFiles
.foldLeft(monad.unit(()))((u, f) => u.flatMap(_ => deleteFile(f.file)))
.map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult)
requestBody
.toRaw(request, bodyInput.bodyType, maxBodyLength)
.flatMap { v =>
bodyInput.codec.decode(v.value) match {
case DecodeResult.Value(bodyV) => (values.setBodyInputValue(bodyV): DecodeBasicInputsResult).unit
case failure: DecodeResult.Failure =>
v.createdFiles
.foldLeft(monad.unit(()))((u, f) => u.flatMap(_ => deleteFile(f.file)))
.map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult)
}
}
.handleError { case e: StreamMaxLengthExceededException =>
(DecodeBasicInputsResult.Failure(bodyInput, DecodeResult.Error("", e)): DecodeBasicInputsResult).unit
}
}
}

private def unsupportedInputMediaTypeResponse(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package sttp.tapir.server.model

import sttp.tapir.EndpointInfoOps
import sttp.tapir.AttributeKey

/** Can be used as an endpoint attribute.
* @example
* {{{
* endpoint.attribute(MaxContentLength.attributeKey, MaxContentLength(16384L))
* }}}
*/
case class MaxContentLength(value: Long) extends AnyVal

object MaxContentLength {
val attributeKey: AttributeKey[MaxContentLength] = AttributeKey[MaxContentLength]
}

object EndpointExtensions {

implicit class RichServerEndpoint[E <: EndpointInfoOps[_]](e: E) {

/** Enables checks that prevent loading full request body into memory if it exceeds given limit. Otherwise causes endpoint to reply with
* HTTP 413 Payload Too Loarge.
*
* Please refer to Tapir docs to ensure which backends are supported: https://tapir.softwaremill.com/en/latest/endpoint/security.html
* @example
* {{{
* endpoint.maxRequestBodyLength(16384L)
* }}}
* @param maxBytes
* maximum allowed size of request body in bytes.
*/
def maxRequestBodyLength(maxBytes: Long): E =
e.attribute(MaxContentLength.attributeKey, MaxContentLength(maxBytes)).asInstanceOf[E]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import scala.util.{Success, Try}
object TestUtil {
object TestRequestBody extends RequestBody[Id, NoStreams] {
override val streams: Streams[NoStreams] = NoStreams
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Id[RawValue[R]] = ???
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Id[RawValue[R]] = ???
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ???
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.collection.immutable.Seq
class FinatraRequestBody(serverOptions: FinatraServerOptions) extends RequestBody[Future, NoStreams] {
override val streams: NoStreams = NoStreams

override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = {
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = {
val request = finatraRequest(serverRequest)
toRaw(request, bodyType, request.content, request.charset.map(Charset.forName))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ private[http4s] class Http4sRequestBody[F[_]: Async](
serverOptions: Http4sServerOptions[F]
) extends RequestBody[F, Fs2Streams[F]] {
override val streams: Fs2Streams[F] = Fs2Streams[F]
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = {
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = {
val r = http4sRequest(serverRequest)
toRawFromStream(serverRequest, r.body, bodyType, r.charset)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile
extends RequestBody[Id, NoStreams] {
override val streams: capabilities.Streams[NoStreams] = NoStreams

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): RawValue[RAW] = {
override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = {
val request = jdkHttpRequest(serverRequest)
toRaw(serverRequest, bodyType, request.getRequestBody)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown),
new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown),
eventLoopGroup,
socketOverride
)
Expand Down Expand Up @@ -123,9 +123,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty

object NettyCatsServer {
def apply[F[_]: Async](dispatcher: Dispatcher[F]): NettyCatsServer[F] =
NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.defaultWithStreaming)
NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.default)
def apply[F[_]: Async](options: NettyCatsServerOptions[F]): NettyCatsServer[F] =
NettyCatsServer(Vector.empty, options, NettyConfig.defaultWithStreaming)
NettyCatsServer(Vector.empty, options, NettyConfig.default)
def apply[F[_]: Async](dispatcher: Dispatcher[F], config: NettyConfig): NettyCatsServer[F] =
NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), config)
def apply[F[_]: Async](options: NettyCatsServerOptions[F], config: NettyConfig): NettyCatsServer[F] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sttp.tapir.server.netty.cats

import cats.effect.Async
import cats.effect.std.Dispatcher
import internal.Fs2StreamCompatible
import sttp.capabilities.fs2.Fs2Streams
import sttp.monad.MonadError
import sttp.monad.syntax._
Expand All @@ -11,6 +12,7 @@ import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interceptor.reject.RejectInterceptor
import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter}
import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _}
import sttp.tapir.server.netty.cats.internal.NettyCatsRequestBody
import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route}

trait NettyCatsServerInterpreter[F[_]] {
Expand All @@ -31,8 +33,8 @@ trait NettyCatsServerInterpreter[F[_]] {

val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]](
FilterServerEndpoints(ses),
new NettyCatsRequestBody(createFile),
new NettyCatsToResponseBody(nettyServerOptions.dispatcher, delegate = new NettyToResponseBody),
new NettyCatsRequestBody(createFile, Fs2StreamCompatible[F](nettyServerOptions.dispatcher)),
new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)),
RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses),
deleteFile
)
Expand Down
Loading

0 comments on commit abeb5d7

Please sign in to comment.