Skip to content

Commit

Permalink
Remove ping/pong handling, tungstenite does it automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
aqrln committed Jan 16, 2025
1 parent 3b4aea2 commit a7afff8
Showing 1 changed file with 22 additions and 69 deletions.
91 changes: 22 additions & 69 deletions quaint/src/connector/postgres/native/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
};

use bytes::Bytes;
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use futures::{FutureExt, Sink, SinkExt, Stream};
use pin_project::pin_project;
use postgres_native_tls::TlsConnector;
use prisma_metrics::WithMetricsInstrumentation;
Expand Down Expand Up @@ -114,33 +114,17 @@ impl From<TungsteniteError> for error::Error {
struct WsTunnel(#[pin] StreamReader<WsBytesStream, Bytes>);

#[pin_project]
struct WsBytesStream {
state: WsBytesStreamState,
#[pin]
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
}

enum WsBytesStreamState {
Reading,
SendingPong(Bytes),
}
struct WsBytesStream(#[pin] WebSocketStream<MaybeTlsStream<TcpStream>>);

impl WsTunnel {
fn new(stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
WsTunnel(StreamReader::new(WsBytesStream::new(stream)))
WsTunnel(StreamReader::new(WsBytesStream(stream)))
}
}

impl WsBytesStream {
fn new(inner: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
WsBytesStream {
state: WsBytesStreamState::Reading,
inner,
}
}

fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
self.project().inner
self.project().0
}
}

Expand All @@ -162,7 +146,7 @@ impl AsyncBufRead for WsTunnel {

impl AsyncWrite for WsTunnel {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
let stream = &mut self.get_mut().0.get_mut().inner;
let stream = &mut self.get_mut().0.get_mut().0;
ready!(stream
.poll_ready_unpin(cx)
.map_err(|err| IoError::new(IoErrorKind::Other, err)))?;
Expand Down Expand Up @@ -195,58 +179,27 @@ impl Stream for WsBytesStream {
type Item = Result<Bytes, IoError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

match this.state {
WsBytesStreamState::Reading => match this.inner.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(msg))) => match msg {
Message::Binary(b) => Poll::Ready(Some(Ok(b))),
Message::Close(_) => Poll::Ready(None),
Message::Text(_) => Poll::Ready(Some(Err(IoError::new(
IoErrorKind::Other,
"TCP tunneling requires binary frames, got text",
)))),
Message::Ping(b) => {
this.state = WsBytesStreamState::SendingPong(b);
cx.waker().wake_by_ref();
Poll::Pending
}
Message::Pong(_) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Message::Frame(_) => {
Poll::Ready(Some(Err(IoError::new(IoErrorKind::Other, "unexpected raw frame"))))
}
},
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(IoError::new(IoErrorKind::Other, err)))),
},

WsBytesStreamState::SendingPong(_) => {
if let Err(err) = ready!(this.inner.poll_ready_unpin(cx)) {
return Poll::Ready(Some(Err(IoError::new(IoErrorKind::Other, err))));
}

let WsBytesStreamState::SendingPong(b) =
std::mem::replace(&mut this.state, WsBytesStreamState::Reading)
else {
unreachable!()
};

match this.inner.start_send_unpin(Message::Pong(b)) {
Ok(()) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(err) => Poll::Ready(Some(Err(IoError::new(IoErrorKind::Other, err)))),
match self.get_pin_mut().poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(msg))) => match msg {
Message::Binary(b) => Poll::Ready(Some(Ok(b))),
Message::Close(_) => Poll::Ready(None),
Message::Text(_) => Poll::Ready(Some(Err(IoError::new(
IoErrorKind::Other,
"TCP tunneling requires binary frames, got text",
)))),
Message::Ping(_) | Message::Pong(_) => {
cx.waker().wake_by_ref();
Poll::Pending
}
}
Message::Frame(_) => Poll::Ready(Some(Err(IoError::new(IoErrorKind::Other, "unexpected raw frame")))),
},
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(IoError::new(IoErrorKind::Other, err)))),
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
self.0.size_hint()
}
}

0 comments on commit a7afff8

Please sign in to comment.