Skip to content

Commit

Permalink
Actually poll write till completion
Browse files Browse the repository at this point in the history
  • Loading branch information
aqrln committed Jan 16, 2025
1 parent db22e55 commit eeba3e5
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions quaint/src/connector/postgres/native/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,26 @@ impl From<TungsteniteError> for error::Error {
}

#[pin_project]
struct WsTunnel(#[pin] StreamReader<WsBytesStream, Bytes>);
struct WsTunnel {
#[pin]
inner: StreamReader<WsBytesStream, Bytes>,
write_state: WriteState,
}

enum WriteState {
Free,
Writing(usize),
}

#[pin_project]
struct WsBytesStream(#[pin] WebSocketStream<MaybeTlsStream<TcpStream>>);

impl WsTunnel {
fn new(stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
WsTunnel(StreamReader::new(WsBytesStream(stream)))
WsTunnel {
inner: StreamReader::new(WsBytesStream(stream)),
write_state: WriteState::Free,
}
}
}

Expand All @@ -130,34 +142,46 @@ impl WsBytesStream {

impl AsyncRead for WsTunnel {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
self.project().0.poll_read(cx, buf)
self.project().inner.poll_read(cx, buf)
}
}

impl AsyncBufRead for WsTunnel {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
self.project().0.poll_fill_buf(cx)
self.project().inner.poll_fill_buf(cx)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().0.consume(amt)
self.project().inner.consume(amt)
}
}

impl AsyncWrite for WsTunnel {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
let stream = &mut self.get_mut().0.get_mut().0;
let this = self.get_mut();
let sink = &mut this.inner.get_mut().0;
let to_io_err = |err| IoError::new(IoErrorKind::Other, err);
ready!(stream.poll_ready_unpin(cx).map_err(to_io_err))?;
stream
.start_send_unpin(Message::Binary(Bytes::copy_from_slice(buf)))
.map_err(to_io_err)?;
Poll::Ready(Ok(buf.len()))

match this.write_state {
WriteState::Free => {
ready!(sink.poll_ready_unpin(cx)).map_err(to_io_err)?;
sink.start_send_unpin(Message::Binary(Bytes::copy_from_slice(buf)))
.map_err(to_io_err)?;
this.write_state = WriteState::Writing(buf.len());
cx.waker().wake_by_ref();
Poll::Pending
}
WriteState::Writing(len) => {
ready!(sink.poll_flush_unpin(cx)).map_err(to_io_err)?;
this.write_state = WriteState::Free;
Poll::Ready(Ok(len))
}
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project()
.0
.inner
.get_pin_mut()
.get_pin_mut()
.poll_flush(cx)
Expand All @@ -166,7 +190,7 @@ impl AsyncWrite for WsTunnel {

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project()
.0
.inner
.get_pin_mut()
.get_pin_mut()
.poll_close(cx)
Expand Down

0 comments on commit eeba3e5

Please sign in to comment.