Skip to content

Commit

Permalink
Merge pull request #22331 from MaterializeInc/roshan/internal-api-ws
Browse files Browse the repository at this point in the history
Add a websocket route to the Internal HTTP API
  • Loading branch information
rjobanp authored Oct 12, 2023
2 parents cf61951 + 6da8c8b commit 100eacd
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 49 deletions.
114 changes: 80 additions & 34 deletions src/environmentd/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ pub enum TlsMode {
#[derive(Clone)]
pub struct WsState {
frontegg: Arc<Option<FronteggAuthentication>>,
adapter_client: mz_adapter::Client,
adapter_client_rx: Delayed<mz_adapter::Client>,
active_connection_count: SharedConnectionCounter,
}

Expand Down Expand Up @@ -135,13 +135,13 @@ impl HttpServer {
adapter_client_tx
.send(adapter_client.clone())
.expect("rx known to be live");

let adapter_client_rx = adapter_client_rx.shared();
let base_router = base_router(BaseRouterConfig { profiling: false })
.layer(middleware::from_fn(move |req, next| {
let base_frontegg = Arc::clone(&base_frontegg);
async move { http_auth(req, next, tls_mode, &base_frontegg).await }
async move { http_auth(req, next, tls_mode, base_frontegg.as_ref().as_ref()).await }
}))
.layer(Extension(adapter_client_rx.shared()))
.layer(Extension(adapter_client_rx.clone()))
.layer(Extension(Arc::clone(&active_connection_count)))
.layer(
CorsLayer::new()
Expand All @@ -156,12 +156,11 @@ impl HttpServer {
.expose_headers(Any)
.max_age(Duration::from_secs(60) * 60),
);

let ws_router = Router::new()
.route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
.with_state(WsState {
frontegg,
adapter_client: adapter_client.clone(),
adapter_client_rx,
active_connection_count,
});

Expand Down Expand Up @@ -375,6 +374,7 @@ impl InternalHttpServer {
"/internal-console".to_string(),
));

let adapter_client_rx = adapter_client_rx.shared();
let router = base_router(BaseRouterConfig { profiling: true })
.route(
"/metrics",
Expand Down Expand Up @@ -439,9 +439,22 @@ impl InternalHttpServer {
routing::get(console::handle_internal_console),
)
.layer(middleware::from_fn(internal_http_auth))
.layer(Extension(adapter_client_rx.shared()))
.layer(Extension(adapter_client_rx.clone()))
.layer(Extension(console_config))
.layer(Extension(active_connection_count));
.layer(Extension(Arc::clone(&active_connection_count)));

let ws_router = Router::new()
.route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
// This middleware extracts the MZ user from the x-materialize-user http header.
// Normally, browser-initiated websocket requests do not support headers, however for the
// Internal HTTP Server the browser would be connecting through teleport, which should
// attach the x-materialize-user header to all requests it proxies to this api.
.layer(middleware::from_fn(internal_http_auth))
.with_state(WsState {
frontegg: Arc::new(None),
adapter_client_rx,
active_connection_count,
});

let leader_router = Router::new()
.route("/api/leader/status", routing::get(handle_leader_status))
Expand All @@ -451,7 +464,10 @@ impl InternalHttpServer {
ready_to_promote,
})));

let router = router.merge(leader_router).apply_default_layers(metrics);
let router = router
.merge(ws_router)
.merge(leader_router)
.apply_default_layers(metrics);

InternalHttpServer { router }
}
Expand Down Expand Up @@ -484,7 +500,10 @@ impl Server for InternalHttpServer {
let router = self.router.clone();
Box::pin(async {
let http = hyper::server::conn::Http::new();
http.serve_connection(conn, router).err_into().await
http.serve_connection(conn, router)
.with_upgrades()
.err_into()
.await
})
}
}
Expand All @@ -500,7 +519,7 @@ enum ConnProtocol {
}

#[derive(Clone, Debug)]
struct AuthedUser(User);
pub struct AuthedUser(User);

pub struct AuthedClient {
pub client: SessionClient,
Expand Down Expand Up @@ -642,7 +661,7 @@ async fn http_auth<B>(
mut req: Request<B>,
next: Next<B>,
tls_mode: TlsMode,
frontegg: &Option<FronteggAuthentication>,
frontegg: Option<&FronteggAuthentication>,
) -> impl IntoResponse {
// First, extract the username from the certificate, validating that the
// connection matches the TLS configuration along the way.
Expand Down Expand Up @@ -685,9 +704,10 @@ async fn http_auth<B>(
async fn init_ws(
WsState {
frontegg,
adapter_client,
adapter_client_rx,
active_connection_count,
}: &WsState,
existing_user: Option<AuthedUser>,
ws: &mut WebSocket,
) -> Result<AuthedClient, anyhow::Error> {
// TODO: Add a timeout here to prevent resource leaks by clients that
Expand All @@ -709,33 +729,59 @@ async fn init_ws(
}
}
};
let (creds, options) = if frontegg.is_some() {
match ws_auth {
let (user, options) = match (frontegg.as_ref(), existing_user, ws_auth) {
(Some(frontegg), None, ws_auth) => {
let (creds, options) = match ws_auth {
WebSocketAuth::Basic {
user,
password,
options,
} => {
let creds = Credentials::Password {
username: user,
password,
};
(creds, options)
}
WebSocketAuth::Bearer { token, options } => {
let creds = Credentials::Token { token };
(creds, options)
}
WebSocketAuth::OptionsOnly { options: _ } => {
anyhow::bail!("expected auth information");
}
};
(auth(Some(frontegg), creds).await?, options)
}
(
None,
None,
WebSocketAuth::Basic {
user,
password,
password: _,
options,
} => {
let creds = Credentials::Password {
username: user,
password,
};
(creds, options)
}
WebSocketAuth::Bearer { token, options } => {
let creds = Credentials::Token { token };
(creds, options)
}
},
) => (auth(None, Credentials::User(user)).await?, options),
// No frontegg, specified existing user, we only accept options only.
(None, Some(existing_user), WebSocketAuth::OptionsOnly { options }) => {
(existing_user, options)
}
// No frontegg, specified existing user, we do not expect basic or bearer auth.
(None, Some(_), WebSocketAuth::Basic { .. } | WebSocketAuth::Bearer { .. }) => {
warn!("Unexpected bearer or basic auth provided when using user header");
anyhow::bail!("unexpected")
}
// Specifying both frontegg and an existing user should not be possible.
(Some(_), Some(_), _) => anyhow::bail!("unexpected"),
// No frontegg, no existing user, and no passed username.
(None, None, WebSocketAuth::Bearer { .. } | WebSocketAuth::OptionsOnly { .. }) => {
warn!("Unexpected auth type when not using frontegg or user header");
anyhow::bail!("unexpected")
}
} else if let WebSocketAuth::Basic { user, options, .. } = ws_auth {
(Credentials::User(user), options)
} else {
anyhow::bail!("unexpected")
};
let user = auth(frontegg, creds).await?;

let client = AuthedClient::new(
adapter_client,
&adapter_client_rx.clone().await?,
user,
Arc::clone(active_connection_count),
options,
Expand All @@ -753,7 +799,7 @@ enum Credentials {
}

async fn auth(
frontegg: &Option<FronteggAuthentication>,
frontegg: Option<&FronteggAuthentication>,
creds: Credentials,
) -> Result<AuthedUser, AuthError> {
// There are three places a username may be specified:
Expand Down
17 changes: 12 additions & 5 deletions src/environmentd/src/http/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use async_trait::async_trait;
use axum::extract::ws::{CloseFrame, Message, WebSocket};
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use axum::Json;
use axum::{Extension, Json};
use futures::future::BoxFuture;
use futures::Future;
use http::StatusCode;
Expand All @@ -43,7 +43,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::debug;
use tungstenite::protocol::frame::coding::CloseCode;

use crate::http::{init_ws, AuthedClient, WsState, MAX_REQUEST_SIZE};
use crate::http::{init_ws, AuthedClient, AuthedUser, WsState, MAX_REQUEST_SIZE};

pub async fn handle_sql(
mut client: AuthedClient,
Expand All @@ -67,10 +67,13 @@ struct ErrorResponse {

pub async fn handle_sql_ws(
State(state): State<WsState>,
existing_user: Option<Extension<AuthedUser>>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
// An upstream middleware may have already provided the user for us
let user = existing_user.and_then(|Extension(user)| Some(user));
ws.max_message_size(MAX_REQUEST_SIZE)
.on_upgrade(|ws| async move { run_ws(&state, ws).await })
.on_upgrade(|ws| async move { run_ws(&state, user, ws).await })
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
Expand All @@ -87,10 +90,14 @@ pub enum WebSocketAuth {
#[serde(default)]
options: BTreeMap<String, String>,
},
OptionsOnly {
#[serde(default)]
options: BTreeMap<String, String>,
},
}

async fn run_ws(state: &WsState, mut ws: WebSocket) {
let mut client = match init_ws(state, &mut ws).await {
async fn run_ws(state: &WsState, user: Option<AuthedUser>, mut ws: WebSocket) {
let mut client = match init_ws(state, user, &mut ws).await {
Ok(client) => client,
Err(e) => {
// We omit most detail from the error message we send to the client, to
Expand Down
78 changes: 76 additions & 2 deletions src/environmentd/tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ use std::{iter, thread};
use anyhow::bail;
use chrono::{DateTime, Utc};
use futures::FutureExt;
use http::StatusCode;
use http::{Request, StatusCode};
use itertools::Itertools;
use mz_environmentd::http::{
BecomeLeaderResponse, BecomeLeaderResult, LeaderStatus, LeaderStatusResponse,
};
use mz_environmentd::WebSocketResponse;
use mz_environmentd::{WebSocketAuth, WebSocketResponse};
use mz_ore::cast::CastFrom;
use mz_ore::cast::CastLossy;
use mz_ore::cast::TryCastFrom;
Expand Down Expand Up @@ -2110,6 +2110,80 @@ fn test_internal_http_auth() {
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "{:?}", res.text());
}

#[mz_ore::test]
#[cfg_attr(miri, ignore)] // too slow
fn test_internal_ws_auth() {
let server = util::start_server(util::Config::default()).unwrap();

// Create our WebSocket.
let ws_url = server.internal_ws_addr();
let make_req = || {
Request::builder()
.uri(ws_url.as_str())
.method("GET")
.header("Host", ws_url.host_str().unwrap())
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", "foobar")
// Set our user to the mz_support user
.header("x-materialize-user", "mz_support")
.body(())
.unwrap()
};

let (mut ws, _resp) = tungstenite::connect(make_req()).unwrap();
let options = BTreeMap::from([(
"application_name".to_string(),
"billion_dollar_idea".to_string(),
)]);
// We should receive error if sending the standard bearer auth, since that is unexpected
// for the Internal HTTP API
assert_eq!(util::auth_with_ws(&mut ws, options.clone()).is_err(), true);

// Recreate the websocket
let (mut ws, _resp) = tungstenite::connect(make_req()).unwrap();
// Auth with OptionsOnly
util::auth_with_ws_impl(
&mut ws,
Message::Text(serde_json::to_string(&WebSocketAuth::OptionsOnly { options }).unwrap()),
)
.unwrap();

// Query to make sure we get back the correct user, which should be
// set from the headers passed with the websocket request.
let json = "{\"query\":\"SELECT current_user;\"}";
let json: serde_json::Value = serde_json::from_str(json).unwrap();
ws.send(Message::Text(json.to_string())).unwrap();

let mut read_msg = || -> WebSocketResponse {
let msg = ws.read().unwrap();
let msg = msg.into_text().expect("response should be text");
serde_json::from_str(&msg).unwrap()
};
let starting = read_msg();
let columns = read_msg();
let row_val = read_msg();

if !matches!(starting, WebSocketResponse::CommandStarting(_)) {
panic!("wrong message!, {starting:?}");
};

if let WebSocketResponse::Rows(rows) = columns {
let names: Vec<&str> = rows.columns.iter().map(|c| c.name.as_str()).collect();
assert_eq!(names, ["current_user"]);
} else {
panic!("wrong message!, {columns:?}");
};

if let WebSocketResponse::Row(row) = row_val {
let expected = serde_json::Value::String("mz_support".to_string());
assert_eq!(row, [expected]);
} else {
panic!("wrong message!, {row_val:?}");
}
}

#[mz_ore::test]
#[cfg_attr(miri, ignore)] // too slow
fn test_leader_promotion() {
Expand Down
Loading

0 comments on commit 100eacd

Please sign in to comment.