diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/ConnectionStreamsState.swift b/Sources/NIOHTTP2/ConnectionStateMachine/ConnectionStreamsState.swift index 7b0c68b7..0b100d4c 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/ConnectionStreamsState.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/ConnectionStreamsState.swift @@ -77,10 +77,11 @@ struct ConnectionStreamState { /// - parameters: /// - streamID: The ID of the pushed stream. /// - remoteInitialWindowSize: The initial window size of the remote peer. + /// - requestVerb: the HTTP method used on the request /// - throws: If the stream ID is invalid. - mutating func createRemotelyPushedStream(streamID: HTTP2StreamID, remoteInitialWindowSize: UInt32) throws { + mutating func createRemotelyPushedStream(streamID: HTTP2StreamID, remoteInitialWindowSize: UInt32, requestVerb: String?) throws { try self.reserveServerStreamID(streamID) - let streamState = HTTP2StreamStateMachine(receivedPushPromiseCreatingStreamID: streamID, remoteInitialWindowSize: remoteInitialWindowSize) + let streamState = HTTP2StreamStateMachine(receivedPushPromiseCreatingStreamID: streamID, remoteInitialWindowSize: remoteInitialWindowSize, requestVerb: requestVerb) self.activeStreams.insert(streamState) } @@ -93,9 +94,9 @@ struct ConnectionStreamState { /// - streamID: The ID of the pushed stream. /// - localInitialWindowSize: Our initial window size.. /// - throws: If the stream ID is invalid. - mutating func createLocallyPushedStream(streamID: HTTP2StreamID, localInitialWindowSize: UInt32) throws { + mutating func createLocallyPushedStream(streamID: HTTP2StreamID, localInitialWindowSize: UInt32, requestVerb: String?) throws { try self.reserveServerStreamID(streamID) - let streamState = HTTP2StreamStateMachine(sentPushPromiseCreatingStreamID: streamID, localInitialWindowSize: localInitialWindowSize) + let streamState = HTTP2StreamStateMachine(sentPushPromiseCreatingStreamID: streamID, localInitialWindowSize: localInitialWindowSize, requestVerb: requestVerb) self.activeStreams.insert(streamState) } diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/FrameReceivingStates/ReceivingPushPromiseState.swift b/Sources/NIOHTTP2/ConnectionStateMachine/FrameReceivingStates/ReceivingPushPromiseState.swift index 9079486c..9037669d 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/FrameReceivingStates/ReceivingPushPromiseState.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/FrameReceivingStates/ReceivingPushPromiseState.swift @@ -50,10 +50,10 @@ extension ReceivingPushPromiseState { } let validateHeaderBlock = self.headerBlockValidation == .enabled - + let requestVerb = headers.first(name: ":method") do { try self.streamState.createRemotelyPushedStream(streamID: childStreamID, - remoteInitialWindowSize: self.remoteInitialWindowSize) + remoteInitialWindowSize: self.remoteInitialWindowSize, requestVerb: requestVerb) let result = self.streamState.modifyStreamState(streamID: originalStreamID, ignoreRecentlyReset: true) { $0.receivePushPromise(headers: headers, validateHeaderBlock: validateHeaderBlock) diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/FrameSendingStates/SendingPushPromiseState.swift b/Sources/NIOHTTP2/ConnectionStateMachine/FrameSendingStates/SendingPushPromiseState.swift index 6306c7e3..b97f311f 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/FrameSendingStates/SendingPushPromiseState.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/FrameSendingStates/SendingPushPromiseState.swift @@ -59,8 +59,8 @@ extension SendingPushPromiseState { guard case .succeed = result.result else { return result } - - try self.streamState.createLocallyPushedStream(streamID: childStreamID, localInitialWindowSize: self.localInitialWindowSize) + let requestVerb = headers.first(name: ":method") + try self.streamState.createLocallyPushedStream(streamID: childStreamID, localInitialWindowSize: self.localInitialWindowSize, requestVerb: requestVerb) return result } catch { return StateMachineResultWithEffect(result: .connectionError(underlyingError: error, type: .protocolError), effect: nil) diff --git a/Sources/NIOHTTP2/ContentLengthVerifier.swift b/Sources/NIOHTTP2/ContentLengthVerifier.swift index 2f612dc9..fdb89f9c 100644 --- a/Sources/NIOHTTP2/ContentLengthVerifier.swift +++ b/Sources/NIOHTTP2/ContentLengthVerifier.swift @@ -48,7 +48,16 @@ extension ContentLengthVerifier { } extension ContentLengthVerifier { - internal init(_ headers: HPACKHeaders) throws { + internal init(_ headers: HPACKHeaders, requestMethod: String?) throws { + if let requestMethod = requestMethod { + if let status = headers.first(name: ":status"), status == "304" { + self.expectedContentLength = 0 + return + } else if requestMethod == "HEAD" { + self.expectedContentLength = 0 + return + } + } let contentLengths = headers.values(forHeader: "content-length", canonicalForm: true) var iterator = contentLengths.makeIterator() guard let first = iterator.next() else { @@ -82,3 +91,4 @@ extension ContentLengthVerifier: CustomStringConvertible { return "ContentLengthVerifier(length: \(String(describing: self.expectedContentLength)))" } } + diff --git a/Sources/NIOHTTP2/StreamStateMachine.swift b/Sources/NIOHTTP2/StreamStateMachine.swift index 5ab48dc6..696c7d95 100644 --- a/Sources/NIOHTTP2/StreamStateMachine.swift +++ b/Sources/NIOHTTP2/StreamStateMachine.swift @@ -110,12 +110,12 @@ struct HTTP2StreamStateMachine { /// In the reservedRemote state, the stream has been opened by the remote peer emitting a /// PUSH_PROMISE frame. We are expecting to receive a HEADERS frame for the pushed response. In this /// state we are definitionally a client. - case reservedRemote(remoteWindow: HTTP2FlowControlWindow) + case reservedRemote(remoteWindow: HTTP2FlowControlWindow, requestVerb: String?) /// In the reservedLocal state, the stream has been opened by the local user sending a PUSH_PROMISE /// frame. We now need to send a HEADERS frame for the pushed response. In this state we are definitionally /// a server. - case reservedLocal(localWindow: HTTP2FlowControlWindow) + case reservedLocal(localWindow: HTTP2FlowControlWindow, requestVerb: String?) /// This state does not exist on the diagram above. It encodes the notion that this stream has /// been opened by the local user sending a HEADERS frame, but we have not yet received the remote @@ -123,14 +123,14 @@ struct HTTP2StreamStateMachine { /// from the remote peer in this state, however. If we are in this state, we must be a client: servers /// initiating streams put them into reservedLocal, and then sending HEADERS transfers them directly to /// halfClosedRemoteLocalActive. - case halfOpenLocalPeerIdle(localWindow: HTTP2FlowControlWindow, localContentLength: ContentLengthVerifier, remoteWindow: HTTP2FlowControlWindow) + case halfOpenLocalPeerIdle(localWindow: HTTP2FlowControlWindow, localContentLength: ContentLengthVerifier, remoteWindow: HTTP2FlowControlWindow, requestVerb: String?) /// This state does not exist on the diagram above. It encodes the notion that this stream has /// been opened by the remote user sending a HEADERS frame, but we have not yet sent our HEADERS frame /// in response. If we are in this state, we must be a server: clients receiving streams that were opened /// by servers put them into reservedRemote, and then receiving the response HEADERS transitions them directly /// to halfClosedLocalPeerActive. - case halfOpenRemoteLocalIdle(localWindow: HTTP2FlowControlWindow, remoteContentLength: ContentLengthVerifier, remoteWindow: HTTP2FlowControlWindow) + case halfOpenRemoteLocalIdle(localWindow: HTTP2FlowControlWindow, remoteContentLength: ContentLengthVerifier, remoteWindow: HTTP2FlowControlWindow, requestVerb: String?) /// This state is when both peers have sent a HEADERS frame, but neither has sent a frame with END_STREAM /// set. Both peers may exchange data fully. In this state we keep track of whether we are a client or a @@ -146,7 +146,7 @@ struct HTTP2StreamStateMachine { /// END_STREAM before we receive HEADERS. This cannot happen to a server, as we must have initiated /// this stream to have half closed it before we receive HEADERS, and if we had initiated the stream via /// PUSH_PROMISE (as a server must), the stream would be halfClosedRemote, not halfClosedLocal. - case halfClosedLocalPeerIdle(remoteWindow: HTTP2FlowControlWindow) + case halfClosedLocalPeerIdle(remoteWindow: HTTP2FlowControlWindow, requestVerb: String?) /// In the halfClosedLocalPeerActive state, the local user has sent END_STREAM, and the remote peer has /// sent its HEADERS frame. This happens when we send END_STREAM from the fullyOpen state, or when we @@ -171,7 +171,7 @@ struct HTTP2StreamStateMachine { /// END_STREAM before we send HEADERS. This cannot happen to a client, as the remote peer must have initiated /// this stream to have half closed it before we send HEADERS, and that will cause a client to enter halfClosedLocal, /// not halfClosedRemote. - case halfClosedRemoteLocalIdle(localWindow: HTTP2FlowControlWindow) + case halfClosedRemoteLocalIdle(localWindow: HTTP2FlowControlWindow, requestVerb: String?) /// In the halfClosedRemoteLocalActive state, the remote peer has sent END_STREAM, and the local user has /// sent its HEADERS frame. This happens when we receive END_STREAM in the fullyOpen state, or when we @@ -237,16 +237,16 @@ struct HTTP2StreamStateMachine { /// Creates a new HTTP/2 stream for a stream that was created by receiving a PUSH_PROMISE frame /// on another stream. - init(receivedPushPromiseCreatingStreamID streamID: HTTP2StreamID, remoteInitialWindowSize: UInt32) { + init(receivedPushPromiseCreatingStreamID streamID: HTTP2StreamID, remoteInitialWindowSize: UInt32, requestVerb: String?) { self.streamID = streamID - self.state = .reservedRemote(remoteWindow: HTTP2FlowControlWindow(initialValue: remoteInitialWindowSize)) + self.state = .reservedRemote(remoteWindow: HTTP2FlowControlWindow(initialValue: remoteInitialWindowSize), requestVerb: requestVerb) } /// Creates a new HTTP/2 stream for a stream that was created by sending a PUSH_PROMISE frame on /// another stream. - init(sentPushPromiseCreatingStreamID streamID: HTTP2StreamID, localInitialWindowSize: UInt32) { + init(sentPushPromiseCreatingStreamID streamID: HTTP2StreamID, localInitialWindowSize: UInt32, requestVerb: String?) { self.streamID = streamID - self.state = .reservedLocal(localWindow: HTTP2FlowControlWindow(initialValue: localInitialWindowSize)) + self.state = .reservedLocal(localWindow: HTTP2FlowControlWindow(initialValue: localInitialWindowSize), requestVerb: requestVerb) } } @@ -284,13 +284,14 @@ extension HTTP2StreamStateMachine { switch self.state { case .idle(.client, localWindow: let localWindow, remoteWindow: let remoteWindow): let targetState: State - let localContentLength = validateContentLength ? try ContentLengthVerifier(headers) : .disabled + let localContentLength = validateContentLength ? try ContentLengthVerifier(headers, requestMethod: nil) : .disabled + let requestVerb = headers.first(name: ":method") if endStream { try localContentLength.endOfStream() - targetState = .halfClosedLocalPeerIdle(remoteWindow: remoteWindow) + targetState = .halfClosedLocalPeerIdle(remoteWindow: remoteWindow, requestVerb: requestVerb) } else { - targetState = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow) + targetState = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) } let targetEffect: StreamStateChange = .streamCreated(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) @@ -299,9 +300,9 @@ extension HTTP2StreamStateMachine { targetState: targetState, targetEffect: targetEffect) - case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: let remoteContentLength, remoteWindow: let remoteWindow): + case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: let remoteContentLength, remoteWindow: let remoteWindow, requestVerb: let requestVerb): let targetState: State - let localContentLength = validateContentLength ? try ContentLengthVerifier(headers) : .disabled + let localContentLength = validateContentLength ? try ContentLengthVerifier(headers, requestMethod: requestVerb) : .disabled if endStream { try localContentLength.endOfStream() @@ -315,18 +316,18 @@ extension HTTP2StreamStateMachine { targetStateIfFinal: targetState, targetEffectIfFinal: nil) - case .halfOpenLocalPeerIdle(localWindow: _, localContentLength: let localContentLength, remoteWindow: let remoteWindow): + case .halfOpenLocalPeerIdle(localWindow: _, localContentLength: let localContentLength, remoteWindow: let remoteWindow, requestVerb: let requestVerb): try localContentLength.endOfStream() return self.processTrailers(headers, validateHeaderBlock: validateHeaderBlock, isEndStreamSet: endStream, - targetState: .halfClosedLocalPeerIdle(remoteWindow: remoteWindow), + targetState: .halfClosedLocalPeerIdle(remoteWindow: remoteWindow, requestVerb: requestVerb), targetEffect: nil) - case .reservedLocal(let localWindow): + case .reservedLocal(let localWindow, requestVerb: let requestVerb): let targetState: State let targetEffect: StreamStateChange - let localContentLength = validateContentLength ? try ContentLengthVerifier(headers) : .disabled + let localContentLength = validateContentLength ? try ContentLengthVerifier(headers, requestMethod: requestVerb) : .disabled if endStream { try localContentLength.endOfStream() @@ -350,10 +351,10 @@ extension HTTP2StreamStateMachine { targetState: .halfClosedLocalPeerActive(localRole: localRole, initiatedBy: .client, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow), targetEffect: nil) - case .halfClosedRemoteLocalIdle(let localWindow): + case .halfClosedRemoteLocalIdle(let localWindow, requestVerb: let requestVerb): let targetState: State let targetEffect: StreamStateChange? - let localContentLength = validateContentLength ? try ContentLengthVerifier(headers) : .disabled + let localContentLength = validateContentLength ? try ContentLengthVerifier(headers, requestMethod: requestVerb) : .disabled if endStream { try localContentLength.endOfStream() @@ -414,13 +415,14 @@ extension HTTP2StreamStateMachine { switch self.state { case .idle(.server, localWindow: let localWindow, remoteWindow: let remoteWindow): let targetState: State - let remoteContentLength = validateContentLength ? try ContentLengthVerifier(headers) : .disabled + let requestVerb = headers.first(name: ":method") + let remoteContentLength = validateContentLength ? try ContentLengthVerifier(headers, requestMethod: nil) : .disabled if endStream { try remoteContentLength.endOfStream() - targetState = .halfClosedRemoteLocalIdle(localWindow: localWindow) + targetState = .halfClosedRemoteLocalIdle(localWindow: localWindow, requestVerb: requestVerb) } else { - targetState = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow) + targetState = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) } let targetEffect: StreamStateChange = .streamCreated(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) @@ -429,9 +431,9 @@ extension HTTP2StreamStateMachine { targetState: targetState, targetEffect: targetEffect) - case .halfOpenLocalPeerIdle(localWindow: let localWindow, localContentLength: let localContentLength, remoteWindow: let remoteWindow): + case .halfOpenLocalPeerIdle(localWindow: let localWindow, localContentLength: let localContentLength, remoteWindow: let remoteWindow, requestVerb: let requestVerb): let targetState: State - let remoteContentLength = validateContentLength ? try ContentLengthVerifier(headers) : .disabled + let remoteContentLength = validateContentLength ? try ContentLengthVerifier(headers, requestMethod: requestVerb) : .disabled if endStream { try remoteContentLength.endOfStream() @@ -445,18 +447,18 @@ extension HTTP2StreamStateMachine { targetStateIfFinal: targetState, targetEffectIfFinal: nil) - case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: let remoteContentLength, remoteWindow: _): + case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: let remoteContentLength, remoteWindow: _, requestVerb: let requestVerb): try remoteContentLength.endOfStream() return self.processTrailers(headers, validateHeaderBlock: validateHeaderBlock, isEndStreamSet: endStream, - targetState: .halfClosedRemoteLocalIdle(localWindow: localWindow), + targetState: .halfClosedRemoteLocalIdle(localWindow: localWindow, requestVerb: requestVerb), targetEffect: nil) - case .reservedRemote(let remoteWindow): + case .reservedRemote(let remoteWindow, requestVerb: let requestVerb): let targetState: State let targetEffect: StreamStateChange - let remoteContentLength = validateContentLength ? try ContentLengthVerifier(headers) : .disabled + let remoteContentLength = validateContentLength ? try ContentLengthVerifier(headers, requestMethod: requestVerb) : .disabled if endStream { try remoteContentLength.endOfStream() @@ -480,10 +482,10 @@ extension HTTP2StreamStateMachine { targetState: .halfClosedRemoteLocalActive(localRole: localRole, initiatedBy: .client, localContentLength: localContentLength, localWindow: localWindow), targetEffect: nil) - case .halfClosedLocalPeerIdle(let remoteWindow): + case .halfClosedLocalPeerIdle(let remoteWindow, requestVerb: let requestVerb): let targetState: State let targetEffect: StreamStateChange? - let remoteContentLength = validateContentLength ? try ContentLengthVerifier(headers) : .disabled + let remoteContentLength = validateContentLength ? try ContentLengthVerifier(headers, requestMethod: requestVerb) : .disabled if endStream { try remoteContentLength.endOfStream() @@ -537,17 +539,17 @@ extension HTTP2StreamStateMachine { // // Valid data frames always have a stream effect, because they consume flow control windows. switch self.state { - case .halfOpenLocalPeerIdle(localWindow: var localWindow, localContentLength: var localContentLength, remoteWindow: let remoteWindow): + case .halfOpenLocalPeerIdle(localWindow: var localWindow, localContentLength: var localContentLength, remoteWindow: let remoteWindow, requestVerb: let requestVerb): try localWindow.consume(flowControlledBytes: flowControlledBytes) try localContentLength.receivedDataChunk(length: contentLength) let effect: StreamStateChange if endStream { try localContentLength.endOfStream() - self.state = .halfClosedLocalPeerIdle(remoteWindow: remoteWindow) + self.state = .halfClosedLocalPeerIdle(remoteWindow: remoteWindow, requestVerb: requestVerb) effect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: nil, remoteStreamWindowSize: Int(remoteWindow))) } else { - self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) effect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) } @@ -607,17 +609,16 @@ extension HTTP2StreamStateMachine { // - fullyOpen, where we could be either a client or a server using a fully bi-directional stream. // - halfClosedLocalPeerActive, whe have completed our data, but the remote peer has more to send. switch self.state { - case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: var remoteContentLength, remoteWindow: var remoteWindow): + case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: var remoteContentLength, remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.consume(flowControlledBytes: flowControlledBytes) try remoteContentLength.receivedDataChunk(length: contentLength) - let effect: StreamStateChange if endStream { try remoteContentLength.endOfStream() - self.state = .halfClosedRemoteLocalIdle(localWindow: localWindow) + self.state = .halfClosedRemoteLocalIdle(localWindow: localWindow, requestVerb: requestVerb) effect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: nil)) } else { - self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) effect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) } @@ -682,7 +683,7 @@ extension HTTP2StreamStateMachine { // PUSH_PROMISE frames never have stream effects: they cannot create or close streams, or affect flow control state. switch self.state { case .fullyOpen(localRole: .server, localContentLength: _, remoteContentLength: _, localWindow: _, remoteWindow: _), - .halfOpenRemoteLocalIdle(localWindow: _, remoteContentLength: _, remoteWindow: _), + .halfOpenRemoteLocalIdle(localWindow: _, remoteContentLength: _, remoteWindow: _, requestVerb: _), .halfClosedRemoteLocalIdle(localWindow: _), .halfClosedRemoteLocalActive(localRole: .server, initiatedBy: .client, localContentLength: _, localWindow: _): return self.processRequestHeaders(headers, validateHeaderBlock: validateHeaderBlock, targetState: self.state, targetEffect: nil) @@ -709,7 +710,7 @@ extension HTTP2StreamStateMachine { // RFC 7540 ยง 6.6 forbids receiving PUSH_PROMISE frames on remotely-initiated streams. switch self.state { case .fullyOpen(localRole: .client, localContentLength: _, remoteContentLength: _, localWindow: _, remoteWindow: _), - .halfOpenLocalPeerIdle(localWindow: _, localContentLength: _, remoteWindow: _), + .halfOpenLocalPeerIdle(localWindow: _, localContentLength: _, remoteWindow: _, requestVerb: _), .halfClosedLocalPeerIdle(remoteWindow: _), .halfClosedLocalPeerActive(localRole: .client, initiatedBy: .client, remoteContentLength: _, remoteWindow: _): return self.processRequestHeaders(headers, validateHeaderBlock: validateHeaderBlock, targetState: self.state, targetEffect: nil) @@ -739,19 +740,19 @@ extension HTTP2StreamStateMachine { // can send no further data // - closed, because the entire stream is closed now switch self.state { - case .reservedRemote(remoteWindow: var remoteWindow): + case .reservedRemote(remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.windowUpdate(by: windowIncrement) - self.state = .reservedRemote(remoteWindow: remoteWindow) + self.state = .reservedRemote(remoteWindow: remoteWindow, requestVerb: requestVerb) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: nil, remoteStreamWindowSize: Int(remoteWindow))) - case .halfOpenLocalPeerIdle(localWindow: let localWindow, localContentLength: let localContentLength, remoteWindow: var remoteWindow): + case .halfOpenLocalPeerIdle(localWindow: let localWindow, localContentLength: let localContentLength, remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.windowUpdate(by: windowIncrement) - self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) - case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: let remoteContentLength, remoteWindow: var remoteWindow): + case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: let remoteContentLength, remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.windowUpdate(by: windowIncrement) - self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) case .fullyOpen(localRole: let localRole, localContentLength: let localContentLength, remoteContentLength: let remoteContentLength, localWindow: let localWindow, remoteWindow: var remoteWindow): @@ -759,9 +760,9 @@ extension HTTP2StreamStateMachine { self.state = .fullyOpen(localRole: localRole, localContentLength: localContentLength, remoteContentLength: remoteContentLength, localWindow: localWindow, remoteWindow: remoteWindow) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) - case .halfClosedLocalPeerIdle(remoteWindow: var remoteWindow): + case .halfClosedLocalPeerIdle(remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.windowUpdate(by: windowIncrement) - self.state = .halfClosedLocalPeerIdle(remoteWindow: remoteWindow) + self.state = .halfClosedLocalPeerIdle(remoteWindow: remoteWindow, requestVerb: requestVerb) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: nil, remoteStreamWindowSize: Int(remoteWindow))) case .halfClosedLocalPeerActive(localRole: let localRole, initiatedBy: let initiatedBy, remoteContentLength: let remoteContentLength, remoteWindow: var remoteWindow): @@ -799,19 +800,19 @@ extension HTTP2StreamStateMachine { // it is possible that those frames may have been in flight when we were closing the stream, and so we shouldn't cause // the stream to explode simply for that reason. In this case, we just ignore the data. switch self.state { - case .reservedLocal(localWindow: var localWindow): + case .reservedLocal(localWindow: var localWindow, requestVerb: let requestVerb): try localWindow.windowUpdate(by: windowIncrement) - self.state = .reservedLocal(localWindow: localWindow) + self.state = .reservedLocal(localWindow: localWindow, requestVerb: requestVerb) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: nil)) - case .halfOpenLocalPeerIdle(localWindow: var localWindow, localContentLength: let localContentLength, remoteWindow: let remoteWindow): + case .halfOpenLocalPeerIdle(localWindow: var localWindow, localContentLength: let localContentLength, remoteWindow: let remoteWindow, requestVerb: let requestVerb): try localWindow.windowUpdate(by: windowIncrement) - self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) - case .halfOpenRemoteLocalIdle(localWindow: var localWindow, remoteContentLength: let remoteContentLength, remoteWindow: let remoteWindow): + case .halfOpenRemoteLocalIdle(localWindow: var localWindow, remoteContentLength: let remoteContentLength, remoteWindow: let remoteWindow, requestVerb: let requestVerb): try localWindow.windowUpdate(by: windowIncrement) - self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) case .fullyOpen(localRole: let localRole, localContentLength: let localContentLength, remoteContentLength: let remoteContentLength, localWindow: var localWindow, remoteWindow: let remoteWindow): @@ -819,9 +820,9 @@ extension HTTP2StreamStateMachine { self.state = .fullyOpen(localRole: localRole, localContentLength: localContentLength, remoteContentLength: remoteContentLength, localWindow: localWindow, remoteWindow: remoteWindow) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) - case .halfClosedRemoteLocalIdle(localWindow: var localWindow): + case .halfClosedRemoteLocalIdle(localWindow: var localWindow, requestVerb: let requestVerb): try localWindow.windowUpdate(by: windowIncrement) - self.state = .halfClosedRemoteLocalIdle(localWindow: localWindow) + self.state = .halfClosedRemoteLocalIdle(localWindow: localWindow, requestVerb: requestVerb) windowEffect = .windowSizeChange(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: nil)) case .halfClosedRemoteLocalActive(localRole: let localRole, initiatedBy: let initiatedBy, localContentLength: let localContentLength, localWindow: var localWindow): @@ -873,25 +874,25 @@ extension HTTP2StreamStateMachine { try remoteWindow.initialSizeChanged(by: change) self.state = .idle(localRole: role, localWindow: localWindow, remoteWindow: remoteWindow) - case .reservedRemote(remoteWindow: var remoteWindow): + case .reservedRemote(remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.initialSizeChanged(by: change) - self.state = .reservedRemote(remoteWindow: remoteWindow) + self.state = .reservedRemote(remoteWindow: remoteWindow, requestVerb: requestVerb) - case .halfOpenLocalPeerIdle(localWindow: let localWindow, localContentLength: let localContentLength, remoteWindow: var remoteWindow): + case .halfOpenLocalPeerIdle(localWindow: let localWindow, localContentLength: let localContentLength, remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.initialSizeChanged(by: change) - self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) - case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: let remoteContentLength, remoteWindow: var remoteWindow): + case .halfOpenRemoteLocalIdle(localWindow: let localWindow, remoteContentLength: let remoteContentLength, remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.initialSizeChanged(by: change) - self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) case .fullyOpen(localRole: let localRole, localContentLength: let localContentLength, remoteContentLength: let remoteContentLength, localWindow: let localWindow, remoteWindow: var remoteWindow): try remoteWindow.initialSizeChanged(by: change) self.state = .fullyOpen(localRole: localRole, localContentLength: localContentLength, remoteContentLength: remoteContentLength, localWindow: localWindow, remoteWindow: remoteWindow) - case .halfClosedLocalPeerIdle(remoteWindow: var remoteWindow): + case .halfClosedLocalPeerIdle(remoteWindow: var remoteWindow, requestVerb: let requestVerb): try remoteWindow.initialSizeChanged(by: change) - self.state = .halfClosedLocalPeerIdle(remoteWindow: remoteWindow) + self.state = .halfClosedLocalPeerIdle(remoteWindow: remoteWindow, requestVerb: requestVerb) case .halfClosedLocalPeerActive(localRole: let localRole, initiatedBy: let initiatedBy, remoteContentLength: let remoteContentLength, remoteWindow: var remoteWindow): try remoteWindow.initialSizeChanged(by: change) @@ -917,25 +918,25 @@ extension HTTP2StreamStateMachine { try localWindow.initialSizeChanged(by: change) self.state = .idle(localRole: role, localWindow: localWindow, remoteWindow: remoteWindow) - case .reservedLocal(localWindow: var localWindow): + case .reservedLocal(localWindow: var localWindow, requestVerb: let requestVerb): try localWindow.initialSizeChanged(by: change) - self.state = .reservedLocal(localWindow: localWindow) + self.state = .reservedLocal(localWindow: localWindow, requestVerb: requestVerb) - case .halfOpenLocalPeerIdle(localWindow: var localWindow, localContentLength: let localContentLength, remoteWindow: let remoteWindow): + case .halfOpenLocalPeerIdle(localWindow: var localWindow, localContentLength: let localContentLength, remoteWindow: let remoteWindow, requestVerb: let requestVerb): try localWindow.initialSizeChanged(by: change) - self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenLocalPeerIdle(localWindow: localWindow, localContentLength: localContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) - case .halfOpenRemoteLocalIdle(localWindow: var localWindow, remoteContentLength: let remoteContentLength, remoteWindow: let remoteWindow): + case .halfOpenRemoteLocalIdle(localWindow: var localWindow, remoteContentLength: let remoteContentLength, remoteWindow: let remoteWindow, requestVerb: let requestVerb): try localWindow.initialSizeChanged(by: change) - self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow) + self.state = .halfOpenRemoteLocalIdle(localWindow: localWindow, remoteContentLength: remoteContentLength, remoteWindow: remoteWindow, requestVerb: requestVerb) case .fullyOpen(localRole: let localRole, localContentLength: let localContentLength, remoteContentLength: let remoteContentLength, localWindow: var localWindow, remoteWindow: let remoteWindow): try localWindow.initialSizeChanged(by: change) self.state = .fullyOpen(localRole: localRole, localContentLength: localContentLength, remoteContentLength: remoteContentLength, localWindow: localWindow, remoteWindow: remoteWindow) - case .halfClosedRemoteLocalIdle(localWindow: var localWindow): + case .halfClosedRemoteLocalIdle(localWindow: var localWindow, requestVerb: let requestVerb): try localWindow.initialSizeChanged(by: change) - self.state = .halfClosedRemoteLocalIdle(localWindow: localWindow) + self.state = .halfClosedRemoteLocalIdle(localWindow: localWindow, requestVerb: requestVerb) case .halfClosedRemoteLocalActive(localRole: let localRole, initiatedBy: let initiatedBy, localContentLength: let localContentLength, localWindow: var localWindow): try localWindow.initialSizeChanged(by: change) diff --git a/Sources/NIOHTTP2Server/main.swift b/Sources/NIOHTTP2Server/main.swift index f86db3be..908d90fd 100644 --- a/Sources/NIOHTTP2Server/main.swift +++ b/Sources/NIOHTTP2Server/main.swift @@ -25,11 +25,23 @@ final class HTTP1TestServer: ChannelInboundHandler { public typealias InboundIn = HTTPServerRequestPart public typealias OutboundOut = HTTPServerResponsePart + private var head: HTTPRequestHead? = nil + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - guard case .end = self.unwrapInboundIn(data) else { + switch self.unwrapInboundIn(data) { + case .head(let head): + self.head = head + return + case .body: return + case .end: + // Deliberate fallthrough + () } + let requestHead = self.head! + self.head = nil + // Insert an event loop tick here. This more accurately represents real workloads in SwiftNIO, which will not // re-entrantly write their response frames. context.eventLoop.execute { @@ -39,9 +51,12 @@ final class HTTP1TestServer: ChannelInboundHandler { headers.add(name: "x-stream-id", value: String(Int(streamID))) context.channel.write(self.wrapOutboundOut(HTTPServerResponsePart.head(HTTPResponseHead(version: .init(major: 2, minor: 0), status: .ok, headers: headers))), promise: nil) - var buffer = context.channel.allocator.buffer(capacity: 12) - buffer.writeStaticString("hello") - context.channel.write(self.wrapOutboundOut(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil) + if requestHead.method != .HEAD { + var buffer = context.channel.allocator.buffer(capacity: 12) + buffer.writeStaticString("hello") + context.channel.write(self.wrapOutboundOut(HTTPServerResponsePart.body(.byteBuffer(buffer))), promise: nil) + } + return context.channel.writeAndFlush(self.wrapOutboundOut(HTTPServerResponsePart.end(nil))) }.whenComplete { _ in context.close(promise: nil) diff --git a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift index 7777408b..17ca96d2 100644 --- a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift +++ b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift @@ -145,6 +145,12 @@ extension ConnectionStateMachineTests { ("testWeTolerateOneStreamBeingResetTwice", testWeTolerateOneStreamBeingResetTwice), ("testReceivedAltServiceFramesAreIgnored", testReceivedAltServiceFramesAreIgnored), ("testReceivedOriginFramesAreIgnored", testReceivedOriginFramesAreIgnored), + ("testContentLengthForStatus304", testContentLengthForStatus304), + ("testContentLengthForStatus304Failure", testContentLengthForStatus304Failure), + ("testContentLengthForMethodHead", testContentLengthForMethodHead), + ("testContentLengthForHeadFailure", testContentLengthForHeadFailure), + ("testPushHeadRequestFailure", testPushHeadRequestFailure), + ("testPushHeadRequest", testPushHeadRequest), ] } } diff --git a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift index 4f4023aa..138db197 100644 --- a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift +++ b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift @@ -3001,6 +3001,149 @@ class ConnectionStateMachineTests: XCTestCase { assertIgnored(self.client.receiveOrigin(origins: ["one", "two"])) assertIgnored(self.server.receiveOrigin(origins: ["one", "two"])) } + + func testContentLengthForStatus304() { + let streamOne = HTTP2StreamID(1) + + self.server = .init(role: .server) + self.client = .init(role: .client) + + self.exchangePreamble() + + let responseHeaders = HPACKHeaders([(":status", "304"), ("content-length", "25")]) + + // Set up the connection + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + + // The server responds + assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false)) + assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false)) + + // Send in 0 bytes over two sets + assertSucceeds(self.server.sendData(streamID: streamOne, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true)) + assertSucceeds(self.client.receiveData(streamID: streamOne, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true)) + } + + func testContentLengthForStatus304Failure() { + let streamOne = HTTP2StreamID(1) + + self.server = .init(role: .server) + self.client = .init(role: .client) + + self.exchangePreamble() + + let responseHeaders = HPACKHeaders([(":status", "304"), ("content-length", "25")]) + + // Set up the connection + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + + // The server responds + assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false)) + assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false)) + + // Send in 1 byte over one frame + assertStreamError(type: HTTP2ErrorCode.protocolError, self.server.sendData(streamID: streamOne, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true)) + assertStreamError(type: HTTP2ErrorCode.protocolError, self.client.receiveData(streamID: streamOne, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true)) + } + + func testContentLengthForMethodHead() { + let streamOne = HTTP2StreamID(1) + + self.server = .init(role: .server) + self.client = .init(role: .client) + + self.exchangePreamble() + + let requestHeaders = HPACKHeaders([(":method", "HEAD"), (":authority", "localhost"), (":scheme", "https"), (":path", "/"), ("user-agent", "test")]) + let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "25")]) + + // Set up the connection + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: requestHeaders, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: requestHeaders, isEndStreamSet: true)) + + // The server responds + assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false)) + assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false)) + + // Send in 0 bytes over one frame + assertSucceeds(self.server.sendData(streamID: streamOne, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true)) + assertSucceeds(self.client.receiveData(streamID: streamOne, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true)) + } + + func testContentLengthForHeadFailure() { + let streamOne = HTTP2StreamID(1) + + self.server = .init(role: .server) + self.client = .init(role: .client) + + self.exchangePreamble() + + let requestHeaders = HPACKHeaders([(":method", "HEAD"), (":authority", "localhost"), (":scheme", "https"), (":path", "/"), ("user-agent", "test")]) + let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "25")]) + + // Set up the connection + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: requestHeaders, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: requestHeaders, isEndStreamSet: true)) + + // The server responds + assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false)) + assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: responseHeaders, isEndStreamSet: false)) + + // Send in 1 byte over 1 frame + assertStreamError(type: HTTP2ErrorCode.protocolError, self.server.sendData(streamID: streamOne, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true)) + assertStreamError(type: HTTP2ErrorCode.protocolError, self.client.receiveData(streamID: streamOne, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true)) + } + + func testPushHeadRequestFailure() { + let streamOne = HTTP2StreamID(1) + let streamTwo = HTTP2StreamID(2) + + self.exchangePreamble() + + let requestHeaders = HPACKHeaders([(":method", "HEAD"), (":authority", "localhost"), (":scheme", "https"), (":path", "/"), ("user-agent", "test")]) + let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "25")]) + + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + + // Server can push right away + assertSucceeds(self.server.sendPushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: requestHeaders)) + assertSucceeds(self.client.receivePushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: requestHeaders)) + + // The server responds + assertSucceeds(self.server.sendHeaders(streamID: streamTwo, headers: responseHeaders, isEndStreamSet: false)) + assertSucceeds(self.client.receiveHeaders(streamID: streamTwo, headers: responseHeaders, isEndStreamSet: false)) + + // Send in 1 byte over one frame + assertStreamError(type: HTTP2ErrorCode.protocolError, self.server.sendData(streamID: streamTwo, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true)) + assertStreamError(type: HTTP2ErrorCode.protocolError, self.client.receiveData(streamID: streamTwo, contentLength: 1, flowControlledBytes: 1, isEndStreamSet: true)) + } + + func testPushHeadRequest() { + let streamOne = HTTP2StreamID(1) + let streamTwo = HTTP2StreamID(2) + + self.exchangePreamble() + + let requestHeaders = HPACKHeaders([(":method", "HEAD"), (":authority", "localhost"), (":scheme", "https"), (":path", "/"), ("user-agent", "test")]) + let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "25")]) + + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + + // Server can push right away + assertSucceeds(self.server.sendPushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: requestHeaders)) + assertSucceeds(self.client.receivePushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: requestHeaders)) + + // The server responds + assertSucceeds(self.server.sendHeaders(streamID: streamTwo, headers: responseHeaders, isEndStreamSet: false)) + assertSucceeds(self.client.receiveHeaders(streamID: streamTwo, headers: responseHeaders, isEndStreamSet: false)) + + // Send in 0 bytes over one frame + assertSucceeds(self.client.receiveData(streamID: streamTwo, contentLength: 0, flowControlledBytes: 0, isEndStreamSet: true)) + } } diff --git a/Tests/NIOHTTP2Tests/ContentLengthVerifierTests+XCTest.swift b/Tests/NIOHTTP2Tests/ContentLengthVerifierTests+XCTest.swift index 5f334751..c3e6829e 100644 --- a/Tests/NIOHTTP2Tests/ContentLengthVerifierTests+XCTest.swift +++ b/Tests/NIOHTTP2Tests/ContentLengthVerifierTests+XCTest.swift @@ -34,6 +34,8 @@ extension ContentLengthVerifierTests { ("testMinIntLengthHeaderDoesntPanic", testMinIntLengthHeaderDoesntPanic), ("testMaxIntLengthHeaderDoesntPanic", testMaxIntLengthHeaderDoesntPanic), ("testInvalidLengthHeaderValuesThrow", testInvalidLengthHeaderValuesThrow), + ("testContentLengthVerifier_whenResponseStatusIs304", testContentLengthVerifier_whenResponseStatusIs304), + ("testContentLengthVerifier_whenRequestMethodIsHead", testContentLengthVerifier_whenRequestMethodIsHead), ] } } diff --git a/Tests/NIOHTTP2Tests/ContentLengthVerifierTests.swift b/Tests/NIOHTTP2Tests/ContentLengthVerifierTests.swift index 44f600bc..35665382 100644 --- a/Tests/NIOHTTP2Tests/ContentLengthVerifierTests.swift +++ b/Tests/NIOHTTP2Tests/ContentLengthVerifierTests.swift @@ -19,83 +19,96 @@ import NIOHPACK class ContentLengthVerifierTests: XCTestCase { func testDuplicatedLengthHeadersPermitted() throws { var headers = HPACKHeaders([("Host", "apple.com"), ("content-length", "1834"), ("User-Agent", "myCoolClient/1.0")]) - XCTAssertNoThrow(try ContentLengthVerifier(headers)) - var verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers)) + XCTAssertNoThrow(try ContentLengthVerifier(headers, requestMethod: nil)) + var verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: nil)) XCTAssertEqual(1834, verifier.expectedContentLength) headers.add(contentsOf: [("content-length", "1834")]) - verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers)) + verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: nil)) XCTAssertEqual(1834, verifier.expectedContentLength) headers.add(contentsOf: [("content-length", "1834")]) - XCTAssertNoThrow(try ContentLengthVerifier(headers)) - verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers)) + XCTAssertNoThrow(try ContentLengthVerifier(headers, requestMethod: nil)) + verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: nil)) XCTAssertEqual(1834, verifier.expectedContentLength) headers.add(contentsOf: [("Content-Length", "1834")]) - XCTAssertNoThrow(try ContentLengthVerifier(headers)) - verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers)) + XCTAssertNoThrow(try ContentLengthVerifier(headers, requestMethod: nil)) + verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: nil)) XCTAssertEqual(1834, verifier.expectedContentLength) } func testDuplicatedConflictingLengthHeadersThrow() throws { var headers = HPACKHeaders([("Host", "apple.com"), ("content-length", "1834"), ("User-Agent", "myCoolClient/1.0")]) - let verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers)) + let verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: nil)) XCTAssertEqual(1834, verifier.expectedContentLength) headers.add(contentsOf: [("Content-Length", "4381")]) - XCTAssertThrowsError(try ContentLengthVerifier(headers)) { error in + XCTAssertThrowsError(try ContentLengthVerifier(headers, requestMethod: nil)) { error in XCTAssertTrue(error is NIOHTTP2Errors.ContentLengthHeadersMismatch) } } func testNumericallyEquivalentButConflictingLengthHeadersThrow() throws { var headers = HPACKHeaders([("Host", "apple.com"), ("content-length", "1834"), ("User-Agent", "myCoolClient/1.0")]) - let verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers)) + let verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: nil)) XCTAssertEqual(1834, verifier.expectedContentLength) headers.add(contentsOf: [("Content-Length", "01834")]) - XCTAssertThrowsError(try ContentLengthVerifier(headers)) { error in + XCTAssertThrowsError(try ContentLengthVerifier(headers, requestMethod: nil)) { error in XCTAssertTrue(error is NIOHTTP2Errors.ContentLengthHeadersMismatch) } } func testNegativeLengthHeaderThrows() throws { let headers = HPACKHeaders([("Host", "apple.com"), ("content-length", "-1"), ("User-Agent", "myCoolClient/1.0")]) - XCTAssertThrowsError(try ContentLengthVerifier(headers)) { error in + XCTAssertThrowsError(try ContentLengthVerifier(headers, requestMethod: nil)) { error in XCTAssertTrue(error is NIOHTTP2Errors.ContentLengthHeaderNegative) } } func testMinIntLengthHeaderDoesntPanic() throws { let headers = HPACKHeaders([("Host", "apple.com"), ("content-length", String(Int.min)), ("User-Agent", "myCoolClient/1.0")]) - XCTAssertThrowsError(try ContentLengthVerifier(headers)) { error in + XCTAssertThrowsError(try ContentLengthVerifier(headers, requestMethod: nil)) { error in XCTAssertTrue(error is NIOHTTP2Errors.ContentLengthHeaderNegative) } } func testMaxIntLengthHeaderDoesntPanic() throws { let headers = HPACKHeaders([("Host", "apple.com"), ("content-length", String(Int.max)), ("User-Agent", "myCoolClient/1.0")]) - let verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers)) + let verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: nil)) XCTAssertEqual(Int.max, verifier.expectedContentLength) } func testInvalidLengthHeaderValuesThrow() throws { var headers = HPACKHeaders([("Host", "apple.com"), ("content-length", "0xFF"), ("User-Agent", "myCoolClient/1.0")]) - XCTAssertThrowsError(try ContentLengthVerifier(headers)) { error in + XCTAssertThrowsError(try ContentLengthVerifier(headers, requestMethod: nil)) { error in XCTAssertTrue(error is NIOHTTP2Errors.ContentLengthHeaderMalformedValue) } // Int.min - 1 headers = HPACKHeaders([("Host", "apple.com"), ("content-length", "-9223372036854775809"), ("User-Agent", "myCoolClient/1.0")]) - XCTAssertThrowsError(try ContentLengthVerifier(headers)) { error in + XCTAssertThrowsError(try ContentLengthVerifier(headers, requestMethod: nil)) { error in XCTAssertTrue(error is NIOHTTP2Errors.ContentLengthHeaderMalformedValue) } // Int.max + 1 headers = HPACKHeaders([("Host", "apple.com"), ("content-length", "9223372036854775809"), ("User-Agent", "myCoolClient/1.0")]) - XCTAssertThrowsError(try ContentLengthVerifier(headers)) { error in + XCTAssertThrowsError(try ContentLengthVerifier(headers, requestMethod: nil)) { error in XCTAssertTrue(error is NIOHTTP2Errors.ContentLengthHeaderMalformedValue) } } + + func testContentLengthVerifier_whenResponseStatusIs304() throws { + let headers = HPACKHeaders([(":status", "304"), ("Host", "apple.com"), ("content-length", "1834"), ("User-Agent", "myCoolClient/1.0")]) + let verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: "GET")) + XCTAssertEqual(0, verifier.expectedContentLength) + } + + + func testContentLengthVerifier_whenRequestMethodIsHead() throws { + let headers = HPACKHeaders([("Host", "apple.com"), ("content-length", "1834"), ("User-Agent", "myCoolClient/1.0")]) + let verifier = try assertNoThrowWithValue(try ContentLengthVerifier(headers, requestMethod: "HEAD")) + XCTAssertEqual(0, verifier.expectedContentLength) + } }