Skip to content

Commit

Permalink
ContentLengthVerifier Update (#369)
Browse files Browse the repository at this point in the history
Motivation:
Currently http2 ignores content length when handling a response to a HEAD request and handling a 304 response.

Modifications:
Changing ContentLengthVerifier and any functions called

---------

Co-authored-by: George Barnett <[email protected]>
Co-authored-by: Cory Benfield <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2023
1 parent 5485390 commit b408ca7
Show file tree
Hide file tree
Showing 10 changed files with 294 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ struct ConnectionStreamState {
/// - parameters:
/// - streamID: The ID of the pushed stream.
/// - remoteInitialWindowSize: The initial window size of the remote peer.
/// - requestVerb: the HTTP method used on the request
/// - throws: If the stream ID is invalid.
mutating func createRemotelyPushedStream(streamID: HTTP2StreamID, remoteInitialWindowSize: UInt32) throws {
mutating func createRemotelyPushedStream(streamID: HTTP2StreamID, remoteInitialWindowSize: UInt32, requestVerb: String?) throws {
try self.reserveServerStreamID(streamID)
let streamState = HTTP2StreamStateMachine(receivedPushPromiseCreatingStreamID: streamID, remoteInitialWindowSize: remoteInitialWindowSize)
let streamState = HTTP2StreamStateMachine(receivedPushPromiseCreatingStreamID: streamID, remoteInitialWindowSize: remoteInitialWindowSize, requestVerb: requestVerb)
self.activeStreams.insert(streamState)
}

Expand All @@ -93,9 +94,9 @@ struct ConnectionStreamState {
/// - streamID: The ID of the pushed stream.
/// - localInitialWindowSize: Our initial window size..
/// - throws: If the stream ID is invalid.
mutating func createLocallyPushedStream(streamID: HTTP2StreamID, localInitialWindowSize: UInt32) throws {
mutating func createLocallyPushedStream(streamID: HTTP2StreamID, localInitialWindowSize: UInt32, requestVerb: String?) throws {
try self.reserveServerStreamID(streamID)
let streamState = HTTP2StreamStateMachine(sentPushPromiseCreatingStreamID: streamID, localInitialWindowSize: localInitialWindowSize)
let streamState = HTTP2StreamStateMachine(sentPushPromiseCreatingStreamID: streamID, localInitialWindowSize: localInitialWindowSize, requestVerb: requestVerb)
self.activeStreams.insert(streamState)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ extension ReceivingPushPromiseState {
}

let validateHeaderBlock = self.headerBlockValidation == .enabled

let requestVerb = headers.first(name: ":method")
do {
try self.streamState.createRemotelyPushedStream(streamID: childStreamID,
remoteInitialWindowSize: self.remoteInitialWindowSize)
remoteInitialWindowSize: self.remoteInitialWindowSize, requestVerb: requestVerb)

let result = self.streamState.modifyStreamState(streamID: originalStreamID, ignoreRecentlyReset: true) {
$0.receivePushPromise(headers: headers, validateHeaderBlock: validateHeaderBlock)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ extension SendingPushPromiseState {
guard case .succeed = result.result else {
return result
}

try self.streamState.createLocallyPushedStream(streamID: childStreamID, localInitialWindowSize: self.localInitialWindowSize)
let requestVerb = headers.first(name: ":method")
try self.streamState.createLocallyPushedStream(streamID: childStreamID, localInitialWindowSize: self.localInitialWindowSize, requestVerb: requestVerb)
return result
} catch {
return StateMachineResultWithEffect(result: .connectionError(underlyingError: error, type: .protocolError), effect: nil)
Expand Down
12 changes: 11 additions & 1 deletion Sources/NIOHTTP2/ContentLengthVerifier.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,16 @@ extension ContentLengthVerifier {
}

extension ContentLengthVerifier {
internal init(_ headers: HPACKHeaders) throws {
internal init(_ headers: HPACKHeaders, requestMethod: String?) throws {
if let requestMethod = requestMethod {
if let status = headers.first(name: ":status"), status == "304" {
self.expectedContentLength = 0
return
} else if requestMethod == "HEAD" {
self.expectedContentLength = 0
return
}
}
let contentLengths = headers.values(forHeader: "content-length", canonicalForm: true)
var iterator = contentLengths.makeIterator()
guard let first = iterator.next() else {
Expand Down Expand Up @@ -82,3 +91,4 @@ extension ContentLengthVerifier: CustomStringConvertible {
return "ContentLengthVerifier(length: \(String(describing: self.expectedContentLength)))"
}
}

147 changes: 74 additions & 73 deletions Sources/NIOHTTP2/StreamStateMachine.swift

Large diffs are not rendered by default.

23 changes: 19 additions & 4 deletions Sources/NIOHTTP2Server/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,23 @@ final class HTTP1TestServer: ChannelInboundHandler {
public typealias InboundIn = HTTPServerRequestPart
public typealias OutboundOut = HTTPServerResponsePart

private var head: HTTPRequestHead? = nil

public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
guard case .end = self.unwrapInboundIn(data) else {
switch self.unwrapInboundIn(data) {
case .head(let head):
self.head = head
return
case .body:
return
case .end:
// Deliberate fallthrough
()
}

let requestHead = self.head!
self.head = nil

// Insert an event loop tick here. This more accurately represents real workloads in SwiftNIO, which will not
// re-entrantly write their response frames.
context.eventLoop.execute {
Expand All @@ -39,9 +51,12 @@ final class HTTP1TestServer: ChannelInboundHandler {
headers.add(name: "x-stream-id", value: String(Int(streamID)))
context.channel.write(self.wrapOutboundOut(HTTPServerResponsePart.head(HTTPResponseHead(version: .init(major: 2, minor: 0), status: .ok, headers: headers))), promise: nil)

var buffer = context.channel.allocator.buffer(capacity: 12)
buffer.writeStaticString("hello")
context.channel.write(self.wrapOutboundOut(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil)
if requestHead.method != .HEAD {
var buffer = context.channel.allocator.buffer(capacity: 12)
buffer.writeStaticString("hello")
context.channel.write(self.wrapOutboundOut(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil)
}

return context.channel.writeAndFlush(self.wrapOutboundOut(HTTPServerResponsePart.end(nil)))
}.whenComplete { _ in
context.close(promise: nil)
Expand Down
6 changes: 6 additions & 0 deletions Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ extension ConnectionStateMachineTests {
("testWeTolerateOneStreamBeingResetTwice", testWeTolerateOneStreamBeingResetTwice),
("testReceivedAltServiceFramesAreIgnored", testReceivedAltServiceFramesAreIgnored),
("testReceivedOriginFramesAreIgnored", testReceivedOriginFramesAreIgnored),
("testContentLengthForStatus304", testContentLengthForStatus304),
("testContentLengthForStatus304Failure", testContentLengthForStatus304Failure),
("testContentLengthForMethodHead", testContentLengthForMethodHead),
("testContentLengthForHeadFailure", testContentLengthForHeadFailure),
("testPushHeadRequestFailure", testPushHeadRequestFailure),
("testPushHeadRequest", testPushHeadRequest),
]
}
}
Expand Down
143 changes: 143 additions & 0 deletions Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3001,6 +3001,149 @@ class ConnectionStateMachineTests: XCTestCase {
assertIgnored(self.client.receiveOrigin(origins: ["one", "two"]))
assertIgnored(self.server.receiveOrigin(origins: ["one", "two"]))
}

func testContentLengthForStatus304() {
let streamOne = HTTP2StreamID(1)

self.server = .init(role: .server)
self.client = .init(role: .client)

self.exchangePreamble()

let responseHeaders = HPACKHeaders([(":status", "304"), ("content-length", "25")])

// Set up the connection
assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true))
assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true))

// The server responds
assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false))
assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false))

// Send in 0 bytes over two sets
assertSucceeds(self.server.sendData(streamID: streamOne, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true))
assertSucceeds(self.client.receiveData(streamID: streamOne, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true))
}

func testContentLengthForStatus304Failure() {
let streamOne = HTTP2StreamID(1)

self.server = .init(role: .server)
self.client = .init(role: .client)

self.exchangePreamble()

let responseHeaders = HPACKHeaders([(":status", "304"), ("content-length", "25")])

// Set up the connection
assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true))
assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true))

// The server responds
assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false))
assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false))

// Send in 1 byte over one frame
assertStreamError(type: HTTP2ErrorCode.protocolError, self.server.sendData(streamID: streamOne, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true))
assertStreamError(type: HTTP2ErrorCode.protocolError, self.client.receiveData(streamID: streamOne, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true))
}

func testContentLengthForMethodHead() {
let streamOne = HTTP2StreamID(1)

self.server = .init(role: .server)
self.client = .init(role: .client)

self.exchangePreamble()

let requestHeaders = HPACKHeaders([(":method", "HEAD"), (":authority", "localhost"), (":scheme", "https"), (":path", "/"), ("user-agent", "test")])
let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "25")])

// Set up the connection
assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: requestHeaders, isEndStreamSet: true))
assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: requestHeaders, isEndStreamSet: true))

// The server responds
assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false))
assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false))

// Send in 0 bytes over one frame
assertSucceeds(self.server.sendData(streamID: streamOne, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true))
assertSucceeds(self.client.receiveData(streamID: streamOne, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true))
}

func testContentLengthForHeadFailure() {
let streamOne = HTTP2StreamID(1)

self.server = .init(role: .server)
self.client = .init(role: .client)

self.exchangePreamble()

let requestHeaders = HPACKHeaders([(":method", "HEAD"), (":authority", "localhost"), (":scheme", "https"), (":path", "/"), ("user-agent", "test")])
let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "25")])

// Set up the connection
assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: requestHeaders, isEndStreamSet: true))
assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: requestHeaders, isEndStreamSet: true))

// The server responds
assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false))
assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false))

// Send in 1 byte over 1 frame
assertStreamError(type: HTTP2ErrorCode.protocolError, self.server.sendData(streamID: streamOne, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true))
assertStreamError(type: HTTP2ErrorCode.protocolError, self.client.receiveData(streamID: streamOne, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true))
}

func testPushHeadRequestFailure() {
let streamOne = HTTP2StreamID(1)
let streamTwo = HTTP2StreamID(2)

self.exchangePreamble()

let requestHeaders = HPACKHeaders([(":method", "HEAD"), (":authority", "localhost"), (":scheme", "https"), (":path", "/"), ("user-agent", "test")])
let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "25")])

assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true))
assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true))

// Server can push right away
assertSucceeds(self.server.sendPushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: requestHeaders))
assertSucceeds(self.client.receivePushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: requestHeaders))

// The server responds
assertSucceeds(self.server.sendHeaders(streamID: streamTwo, headers: responseHeaders, isEndStreamSet: false))
assertSucceeds(self.client.receiveHeaders(streamID: streamTwo, headers: responseHeaders, isEndStreamSet: false))

// Send in 1 byte over one frame
assertStreamError(type: HTTP2ErrorCode.protocolError, self.server.sendData(streamID: streamTwo, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true))
assertStreamError(type: HTTP2ErrorCode.protocolError, self.client.receiveData(streamID: streamTwo, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true))
}

func testPushHeadRequest() {
let streamOne = HTTP2StreamID(1)
let streamTwo = HTTP2StreamID(2)

self.exchangePreamble()

let requestHeaders = HPACKHeaders([(":method", "HEAD"), (":authority", "localhost"), (":scheme", "https"), (":path", "/"), ("user-agent", "test")])
let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "25")])

assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true))
assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true))

// Server can push right away
assertSucceeds(self.server.sendPushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: requestHeaders))
assertSucceeds(self.client.receivePushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: requestHeaders))

// The server responds
assertSucceeds(self.server.sendHeaders(streamID: streamTwo, headers: responseHeaders, isEndStreamSet: false))
assertSucceeds(self.client.receiveHeaders(streamID: streamTwo, headers: responseHeaders, isEndStreamSet: false))

// Send in 0 bytes over one frame
assertSucceeds(self.client.receiveData(streamID: streamTwo, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true))
}
}


Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOHTTP2Tests/ContentLengthVerifierTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ extension ContentLengthVerifierTests {
("testMinIntLengthHeaderDoesntPanic", testMinIntLengthHeaderDoesntPanic),
("testMaxIntLengthHeaderDoesntPanic", testMaxIntLengthHeaderDoesntPanic),
("testInvalidLengthHeaderValuesThrow", testInvalidLengthHeaderValuesThrow),
("testContentLengthVerifier_whenResponseStatusIs304", testContentLengthVerifier_whenResponseStatusIs304),
("testContentLengthVerifier_whenRequestMethodIsHead", testContentLengthVerifier_whenRequestMethodIsHead),
]
}
}
Expand Down
Loading

0 comments on commit b408ca7

Please sign in to comment.