Skip to content

Commit

Permalink
Merge pull request #29685 from teskje/clusterd-hostname-check
Browse files Browse the repository at this point in the history
  • Loading branch information
teskje authored Sep 25, 2024
2 parents c80b754 + 53db8a9 commit 316fc3e
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 128 deletions.
1 change: 1 addition & 0 deletions misc/python/materialize/mzcompose/services/clusterd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
) -> None:
environment = [
"CLUSTERD_LOG_FILTER",
f"CLUSTERD_GRPC_HOST={name}",
"MZ_SOFT_ASSERTIONS=1",
*environment_extra,
]
Expand Down
7 changes: 7 additions & 0 deletions src/clusterd/ci/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,11 @@ export CLUSTERD_INTERNAL_HTTP_LISTEN_ADDR=${CLUSTERD_INTERNAL_HTTP_LISTEN_ADDR:-
export CLUSTERD_SECRETS_READER=${CLUSTERD_SECRETS_READER:-local-file}
export CLUSTERD_SECRETS_READER_LOCAL_FILE_DIR=${CLUSTERD_SECRETS_READER_LOCAL_DIR:-/mzdata/secrets}

# Pass the host's FQDN as the host to be used for GRPC request validation only
# when running in Kubernetes. In other contexts (like when running locally, or
# in Docker), this is likely not desirable.
if [[ "${KUBERNETES_SERVICE_HOST:-}" ]]; then
export CLUSTERD_GRPC_HOST=${CLUSTERD_GRPC_HOST:-$(hostname --fqdn)}
fi

exec clusterd "$@"
9 changes: 9 additions & 0 deletions src/clusterd/src/bin/clusterd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ struct Args {
default_value = "127.0.0.1:6878"
)]
internal_http_listen_addr: SocketAddr,
/// The FQDN of this process, for GRPC request validation.
///
/// Not providing this value or setting it to the empty string disables host validation for
/// GRPC requests.
#[clap(long, env = "GRPC_HOST", value_name = "NAME")]
grpc_host: Option<String>,

// === Storage options. ===
/// The URL for the Persist PubSub service.
Expand Down Expand Up @@ -288,6 +294,7 @@ async fn run(args: Args) -> Result<(), anyhow::Error> {
None,
);

let grpc_host = args.grpc_host.and_then(|h| (!h.is_empty()).then_some(h));
let grpc_server_metrics = GrpcServerMetrics::register_with(&metrics_registry);

// Start storage server.
Expand All @@ -312,6 +319,7 @@ async fn run(args: Args) -> Result<(), anyhow::Error> {
&grpc_server_metrics,
args.storage_controller_listen_addr,
BUILD_INFO.semver_version(),
grpc_host.clone(),
storage_client,
|svc| ProtoStorageServer::new(svc).max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE),
),
Expand Down Expand Up @@ -341,6 +349,7 @@ async fn run(args: Args) -> Result<(), anyhow::Error> {
&grpc_server_metrics,
args.compute_controller_listen_addr,
BUILD_INFO.semver_version(),
grpc_host,
compute_client,
|svc| ProtoComputeServer::new(svc).max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE),
),
Expand Down
116 changes: 82 additions & 34 deletions src/service/src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
use async_stream::stream;
use async_trait::async_trait;
use futures::future;
use futures::future::{self, BoxFuture};
use futures::stream::{Stream, StreamExt, TryStreamExt};
use http::uri::PathAndQuery;
use hyper_util::rt::TokioIo;
Expand All @@ -21,11 +21,11 @@ use mz_ore::netio::{Listener, SocketAddr, SocketAddrType};
use mz_proto::{ProtoType, RustType};
use prometheus::core::AtomicU64;
use semver::Version;
use std::error::Error;
use std::fmt::{self, Debug};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::LazyLock;
use std::time::UNIX_EPOCH;
use tokio::net::UnixStream;
use tokio::select;
Expand All @@ -34,13 +34,13 @@ use tokio::sync::{oneshot, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::body::BoxBody;
use tonic::codegen::InterceptedService;
use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue};
use tonic::metadata::AsciiMetadataValue;
use tonic::server::NamedService;
use tonic::service::Interceptor;
use tonic::transport::{Channel, Endpoint, Server};
use tonic::{IntoStreamingRequest, Request, Response, Status, Streaming};
use tower::Service;
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};

use crate::client::{GenericClient, Partitionable, Partitioned};
use crate::codec::{StatCodec, StatsCollector};
Expand Down Expand Up @@ -268,6 +268,7 @@ where
metrics: &GrpcServerMetrics,
listen_addr: SocketAddr,
version: Version,
host: Option<String>,
client_builder: F,
service_builder: Fs,
) -> impl Future<Output = Result<(), anyhow::Error>>
Expand All @@ -292,17 +293,20 @@ where
let server = Self {
state: Arc::new(state),
};
let service = InterceptedService::new(
service_builder(server),
VersionCheckExactInterceptor::new(version),
);
let service = service_builder(server);

if host.is_none() {
warn!("no host provided; request destination host checking is disabled");
}
let validation = RequestValidationLayer { version, host };

info!("Starting to listen on {}", listen_addr);

async {
let listener = Listener::bind(listen_addr).await?;

Server::builder()
.layer(validation)
.add_service(service)
.serve_with_incoming(listener)
.await?;
Expand Down Expand Up @@ -453,8 +457,7 @@ struct PerGrpcServerMetrics {
last_command_received: DeleteOnDropGauge<'static, AtomicU64, Vec<&'static str>>,
}

static VERSION_METADATA_KEY: LazyLock<AsciiMetadataKey> =
LazyLock::new(|| AsciiMetadataKey::from_static("x-mz-version"));
const VERSION_HEADER_KEY: &str = "x-mz-version";

/// A gRPC interceptor that attaches a version as metadata to each request.
#[derive(Debug, Clone)]
Expand All @@ -477,40 +480,85 @@ impl Interceptor for VersionAttachInterceptor {
fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
request
.metadata_mut()
.insert(VERSION_METADATA_KEY.clone(), self.version.clone());
.insert(VERSION_HEADER_KEY, self.version.clone());
Ok(request)
}
}

/// A gRPC interceptor that ensures the version attached to the request by the
/// `VersionAttachInterceptor` exactly matches the expected version.
#[derive(Debug, Clone)]
struct VersionCheckExactInterceptor {
version: AsciiMetadataValue,
/// A `tower` layer that validates requests for compatibility with the server.
#[derive(Clone)]
struct RequestValidationLayer {
version: Version,
host: Option<String>,
}

impl VersionCheckExactInterceptor {
fn new(version: Version) -> VersionCheckExactInterceptor {
VersionCheckExactInterceptor {
version: version
.to_string()
.try_into()
.expect("semver versions are valid metadata values"),
impl<S> tower::Layer<S> for RequestValidationLayer {
type Service = RequestValidation<S>;

fn layer(&self, inner: S) -> Self::Service {
let version = self
.version
.to_string()
.try_into()
.expect("version is a valid header value");
RequestValidation {
inner,
version,
host: self.host.clone(),
}
}
}

impl Interceptor for VersionCheckExactInterceptor {
fn call(&mut self, request: Request<()>) -> Result<Request<()>, Status> {
match request.metadata().get(&*VERSION_METADATA_KEY) {
None => Err(Status::permission_denied(
"request missing version metadata",
)),
Some(version) if version == self.version => Ok(request),
Some(version) => Err(Status::permission_denied(format!(
"request version {:?} but {:?} required",
version, self.version
))),
/// A `tower` middleware that validates requests for compatibility with the server.
#[derive(Clone)]
struct RequestValidation<S> {
inner: S,
version: http::HeaderValue,
host: Option<String>,
}

impl<S, B> Service<http::Request<B>> for RequestValidation<S>
where
S: Service<http::Request<B>, Error = Box<dyn Error + Send + Sync + 'static>>,
S::Response: Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<S::Response, S::Error>>;

fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: http::Request<B>) -> Self::Future {
let error = |msg| {
let error: S::Error = Box::new(Status::permission_denied(msg));
Box::pin(future::ready(Err(error)))
};

let Some(req_version) = req.headers().get(VERSION_HEADER_KEY) else {
return error("request missing version header".into());
};
if req_version != self.version {
return error(format!(
"request has version {req_version:?} but {:?} required",
self.version
));
}

let req_host = req.uri().host();
if let (Some(req_host), Some(host)) = (req_host, &self.host) {
if req_host != host {
return error(format!(
"request has host {req_host:?} but {host:?} required"
));
}
}

Box::pin(self.inner.call(req))
}
}
Loading

0 comments on commit 316fc3e

Please sign in to comment.