Skip to content

Commit

Permalink
fix(dgw): Allow video-streamer to properly send code 1000 on stream
Browse files Browse the repository at this point in the history
finishes
  • Loading branch information
irvingoujAtDevolution committed Jan 11, 2025
1 parent f891342 commit dfe4cd9
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 17 deletions.
14 changes: 14 additions & 0 deletions crates/transport/src/ws.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -129,3 +130,16 @@ fn to_io_result<E: std::error::Error + Send + Sync + 'static>(res: Result<(), E>
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
}
}

pub trait CloseStream {
fn close_stream(&mut self) -> impl std::future::Future<Output = ()> + Send;
}

impl<S> CloseStream for WsStream<S>
where
S: CloseStream + Send,
{
async fn close_stream(&mut self) {
self.inner.close_stream().await
}
}
2 changes: 1 addition & 1 deletion crates/video-streamer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ webm-iterable = { version = "0.6", features = ["futures"] }
cadeau = { version = "0.5", features = ["dlopen"] }
thiserror = "2"
num_cpus = "1.16"
transport = { path = "../transport" }

[dev-dependencies]
tracing-subscriber = "0.3"
Expand All @@ -35,7 +36,6 @@ tokio = { version = "1.42", features = [
axum = { version = "0.7", features = ["ws"] }
futures = "0.3"
tokio-tungstenite = "0.24"
transport = { path = "../transport" }

[lints]
workspace = true
9 changes: 6 additions & 3 deletions crates/video-streamer/src/streamer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tag_writers::{EncodeWriterConfig, HeaderWriter, WriterResult};
use tokio::sync::{mpsc, oneshot::error::RecvError, Mutex, Notify};
use tokio_util::codec::Framed;
use tracing::Instrument;
use transport::CloseStream;
use webm_iterable::{
errors::{TagIteratorError, TagWriterError},
matroska_spec::{Master, MatroskaSpec},
Expand All @@ -26,8 +27,8 @@ use crate::{reopenable::Reopenable, StreamingConfig};

#[instrument(skip_all)]
pub fn webm_stream(
output_stream: impl tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin + Send + 'static, // A websocket usually
input_stream: impl std::io::Read + Reopenable, // A file usually
output_stream: impl tokio::io::AsyncWrite + tokio::io::AsyncRead + CloseStream + Unpin + Send + 'static, // A websocket usually
input_stream: impl std::io::Read + Reopenable, // A file usually
shutdown_signal: Arc<Notify>,
config: StreamingConfig,
when_new_chunk_appended: impl Fn() -> tokio::sync::oneshot::Receiver<()>,
Expand Down Expand Up @@ -181,7 +182,7 @@ fn spawn_sending_task<W>(
mut error_receiver: mpsc::Receiver<UserFriendlyError>,
stop_notifier: Arc<Notify>,
) where
W: tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin + Send + 'static,
W: tokio::io::AsyncWrite + CloseStream + tokio::io::AsyncRead + Unpin + Send + 'static,
{
use futures_util::stream::StreamExt;
let ws_frame = Arc::new(Mutex::new(ws_frame));
Expand Down Expand Up @@ -226,6 +227,7 @@ fn spawn_sending_task<W>(
},
}
}
ws_frame.lock().await.get_mut().close_stream().await;
Ok::<_, anyhow::Error>(())
});

Expand All @@ -249,6 +251,7 @@ fn spawn_sending_task<W>(
}
}
info!("Stopping streaming task");
ws_frame_clone.lock().await.get_mut().close_stream().await;
handle.abort();
stop_notifier.notify_waiters();
Ok::<_, anyhow::Error>(())
Expand Down
99 changes: 86 additions & 13 deletions devolutions-gateway/src/ws.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,91 @@
use std::{borrow::Cow, pin::Pin};

use axum::extract::ws::{self, WebSocket};
use futures::{SinkExt as _, StreamExt as _};
use futures::{sink::With, Sink, SinkExt as _, Stream, StreamExt as _};
use tokio::io::{AsyncRead, AsyncWrite};
use transport::CloseStream;

pub fn websocket_compat(ws: WebSocket) -> impl AsyncRead + AsyncWrite + Unpin + Send + 'static {
pub fn websocket_compat(ws: WebSocket) -> impl AsyncRead + AsyncWrite + CloseStream + Unpin + Send + 'static {
let ws_compat = ws
.map(|item| {
item.map(|msg| match msg {
ws::Message::Text(s) => transport::WsMessage::Payload(s.into_bytes()),
ws::Message::Binary(data) => transport::WsMessage::Payload(data),
ws::Message::Ping(_) | ws::Message::Pong(_) => transport::WsMessage::Ignored,
ws::Message::Close(_) => transport::WsMessage::Close,
})
})
.with(|item| futures::future::ready(Ok::<_, axum::Error>(ws::Message::Binary(item))));

transport::WsStream::new(ws_compat)
.map(map_ws_message as fn(Result<ws::Message, axum::Error>) -> Result<transport::WsMessage, axum::Error>)
.with(with_binary as fn(Vec<u8>) -> futures::future::Ready<Result<ws::Message, axum::Error>>);

let res = transport::WsStream::new(CloseableWebSocket(ws_compat));

res
}

impl Stream for CloseableWebSocket {
type Item = Result<transport::WsMessage, axum::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
self.0.poll_next_unpin(cx)
}
}

impl Sink<Vec<u8>> for CloseableWebSocket {
type Error = axum::Error;

fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.0.poll_ready_unpin(cx)
}

fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
self.0.start_send_unpin(item)
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.0.poll_flush_unpin(cx)
}

fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.0.poll_close_unpin(cx)
}
}

pub struct CloseableWebSocket(WithExplicit);

impl CloseStream for CloseableWebSocket {
async fn close_stream(&mut self) {
warn!("Closing WebSocket stream");
let res = self
.0
.get_mut()
.send(ws::Message::Close(Some(ws::CloseFrame {
code: 1000,
reason: Cow::Borrowed("EOF"),
})))
.await;
warn!(?res, "WebSocket stream closed");
}
}

type WithExplicit = With<
futures::stream::Map<WebSocket, fn(Result<ws::Message, axum::Error>) -> Result<transport::WsMessage, axum::Error>>,
ws::Message,
Vec<u8>,
futures::future::Ready<Result<ws::Message, axum::Error>>,
fn(Vec<u8>) -> futures::future::Ready<Result<ws::Message, axum::Error>>,
>;

fn map_ws_message(item: Result<ws::Message, axum::Error>) -> Result<transport::WsMessage, axum::Error> {
item.map(|msg| match msg {
ws::Message::Text(s) => transport::WsMessage::Payload(s.into_bytes()),
ws::Message::Binary(data) => transport::WsMessage::Payload(data),
ws::Message::Ping(_) | ws::Message::Pong(_) => transport::WsMessage::Ignored,
ws::Message::Close(_) => transport::WsMessage::Close,
})
}

fn with_binary(item: Vec<u8>) -> futures::future::Ready<Result<ws::Message, axum::Error>> {
futures::future::ready(Ok::<_, axum::Error>(ws::Message::Binary(item)))
}

0 comments on commit dfe4cd9

Please sign in to comment.