diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index 631094a9315..6c69ac4f5a7 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -111,14 +111,26 @@ impl From for error::Error { } #[pin_project] -struct WsTunnel(#[pin] StreamReader); +struct WsTunnel { + #[pin] + inner: StreamReader, + write_state: WriteState, +} + +enum WriteState { + Free, + Writing(usize), +} #[pin_project] struct WsBytesStream(#[pin] WebSocketStream>); impl WsTunnel { fn new(stream: WebSocketStream>) -> Self { - WsTunnel(StreamReader::new(WsBytesStream(stream))) + WsTunnel { + inner: StreamReader::new(WsBytesStream(stream)), + write_state: WriteState::Free, + } } } @@ -130,34 +142,46 @@ impl WsBytesStream { impl AsyncRead for WsTunnel { fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - self.project().0.poll_read(cx, buf) + self.project().inner.poll_read(cx, buf) } } impl AsyncBufRead for WsTunnel { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().0.poll_fill_buf(cx) + self.project().inner.poll_fill_buf(cx) } fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().0.consume(amt) + self.project().inner.consume(amt) } } impl AsyncWrite for WsTunnel { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - let stream = &mut self.get_mut().0.get_mut().0; + let this = self.get_mut(); + let sink = &mut this.inner.get_mut().0; let to_io_err = |err| IoError::new(IoErrorKind::Other, err); - ready!(stream.poll_ready_unpin(cx).map_err(to_io_err))?; - stream - .start_send_unpin(Message::Binary(Bytes::copy_from_slice(buf))) - .map_err(to_io_err)?; - Poll::Ready(Ok(buf.len())) + + match this.write_state { + WriteState::Free => { + ready!(sink.poll_ready_unpin(cx)).map_err(to_io_err)?; + sink.start_send_unpin(Message::Binary(Bytes::copy_from_slice(buf))) + .map_err(to_io_err)?; + this.write_state = WriteState::Writing(buf.len()); + cx.waker().wake_by_ref(); + Poll::Pending + } + WriteState::Writing(len) => { + ready!(sink.poll_flush_unpin(cx)).map_err(to_io_err)?; + this.write_state = WriteState::Free; + Poll::Ready(Ok(len)) + } + } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() - .0 + .inner .get_pin_mut() .get_pin_mut() .poll_flush(cx) @@ -166,7 +190,7 @@ impl AsyncWrite for WsTunnel { fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() - .0 + .inner .get_pin_mut() .get_pin_mut() .poll_close(cx)