Skip to content

Commit

Permalink
Add internal/grpc/serverauth
Browse files Browse the repository at this point in the history
This is the server-side equivalent of internal/grpc/clientauth

PiperOrigin-RevId: 716465764
Change-Id: Ib9d0b4bbfa3341e8fbbe4aaab37ac851e1500298
  • Loading branch information
laramiel authored and copybara-github committed Jan 17, 2025
1 parent 5f59990 commit df80595
Show file tree
Hide file tree
Showing 18 changed files with 286 additions and 63 deletions.
2 changes: 2 additions & 0 deletions tensorstore/internal/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ tensorstore_cc_library(
hdrs = ["server_credentials.h"],
deps = [
"//tensorstore:context",
"//tensorstore/internal/grpc/serverauth:default_strategy",
"//tensorstore/internal/grpc/serverauth:strategy",
"//tensorstore/internal/json_binding",
"//tensorstore/internal/json_binding:bindable",
"//tensorstore/util:result",
Expand Down
32 changes: 22 additions & 10 deletions tensorstore/internal/grpc/server_credentials.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
#include "absl/base/attributes.h"
#include "absl/base/const_init.h"
#include "absl/synchronization/mutex.h"
#include "grpcpp/security/server_credentials.h" // third_party
#include "tensorstore/context.h"
#include "tensorstore/context_resource_provider.h"
#include "tensorstore/internal/grpc/serverauth/default_strategy.h"
#include "tensorstore/internal/grpc/serverauth/strategy.h"
#include "tensorstore/util/result.h"

namespace tensorstore {
Expand All @@ -37,23 +40,32 @@ const internal::ContextResourceRegistration<GrpcServerCredentials>
// of grpc credentials. See grpcpp/security/credentials.h for options, such as:
// ::grpc::experimental::LocalServerCredentials(LOCAL_TCP);

std::shared_ptr<internal_grpc::ServerAuthenticationStrategy>
GrpcServerCredentials::Resource::GetAuthenticationStrategy() {
absl::MutexLock l(&credentials_mu);
if (strategy_) return strategy_;
return internal_grpc::CreateInsecureServerAuthenticationStrategy();
}

/* static */
bool GrpcServerCredentials::Use(
tensorstore::Context context,
std::shared_ptr<::grpc::ServerCredentials> credentials) {
auto resource = context.GetResource<GrpcServerCredentials>().value();
// NOTE: We really want std::atomic<std::shared_ptr<>>.
absl::MutexLock l(&credentials_mu);
bool result = (resource->credentials_ == nullptr);
resource->credentials_ = std::move(credentials);
return result;
return Use(
context,
std::make_shared<internal_grpc::DefaultServerAuthenticationStrategy>(
std::move(credentials)));
}

std::shared_ptr<::grpc::ServerCredentials>
GrpcServerCredentials::Resource::GetCredentials() {
/* static */
bool GrpcServerCredentials::Use(
tensorstore::Context context,
std::shared_ptr<internal_grpc::ServerAuthenticationStrategy> credentials) {
auto resource = context.GetResource<GrpcServerCredentials>().value();
absl::MutexLock l(&credentials_mu);
if (credentials_) return credentials_;
return grpc::InsecureServerCredentials();
bool result = (resource->strategy_ == nullptr);
resource->strategy_ = std::move(credentials);
return result;
}

} // namespace tensorstore
9 changes: 7 additions & 2 deletions tensorstore/internal/grpc/server_credentials.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "grpcpp/security/server_credentials.h" // third_party
#include "tensorstore/context.h"
#include "tensorstore/context_resource_provider.h"
#include "tensorstore/internal/grpc/serverauth/strategy.h"
#include "tensorstore/internal/json_binding/bindable.h"
#include "tensorstore/internal/json_binding/json_binding.h"
#include "tensorstore/util/result.h"
Expand Down Expand Up @@ -48,11 +49,12 @@ struct GrpcServerCredentials final

struct Resource {
// Returns either the owned credentials or a new default credential.
std::shared_ptr<::grpc::ServerCredentials> GetCredentials();
std::shared_ptr<internal_grpc::ServerAuthenticationStrategy>
GetAuthenticationStrategy();

private:
friend struct GrpcServerCredentials;
std::shared_ptr<::grpc::ServerCredentials> credentials_;
std::shared_ptr<internal_grpc::ServerAuthenticationStrategy> strategy_;
};

static constexpr Spec Default() { return {}; }
Expand All @@ -72,6 +74,9 @@ struct GrpcServerCredentials final
/// Returns true when prior credentials were nullptr.
static bool Use(tensorstore::Context context,
std::shared_ptr<::grpc::ServerCredentials> credentials);
static bool Use(
tensorstore::Context context,
std::shared_ptr<internal_grpc::ServerAuthenticationStrategy> credentials);
};

} // namespace tensorstore
Expand Down
19 changes: 14 additions & 5 deletions tensorstore/internal/grpc/server_credentials_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "tensorstore/internal/grpc/server_credentials.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "grpcpp/security/server_credentials.h" // third_party
#include "tensorstore/context.h"
Expand All @@ -22,21 +23,29 @@
namespace {

using ::tensorstore::GrpcServerCredentials;
using ::testing::Eq;
using ::testing::Ne;

TEST(GrpcServerCredentials, Use) {
auto use = grpc::experimental::LocalServerCredentials(LOCAL_TCP);
auto ctx = tensorstore::Context::Default();

EXPECT_TRUE(GrpcServerCredentials::Use(ctx, use));
auto a = ctx.GetResource<GrpcServerCredentials>().value()->GetCredentials();
EXPECT_EQ(a.get(), use.get());
auto a = ctx.GetResource<GrpcServerCredentials>()
.value()
->GetAuthenticationStrategy();
EXPECT_THAT(a->GetServerCredentials().get(), Eq(use.get()));
}

TEST(GrpcServerCredentials, Default) {
auto ctx = tensorstore::Context::Default();
auto a = ctx.GetResource<GrpcServerCredentials>().value()->GetCredentials();
auto b = ctx.GetResource<GrpcServerCredentials>().value()->GetCredentials();
EXPECT_NE(a.get(), b.get());
auto a = ctx.GetResource<GrpcServerCredentials>()
.value()
->GetAuthenticationStrategy();
auto b = ctx.GetResource<GrpcServerCredentials>()
.value()
->GetAuthenticationStrategy();
EXPECT_THAT(a.get(), Ne(b.get()));
}

} // namespace
19 changes: 19 additions & 0 deletions tensorstore/internal/grpc/serverauth/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
load("//bazel:tensorstore.bzl", "tensorstore_cc_library")

package(default_visibility = ["//tensorstore:internal_packages"])

tensorstore_cc_library(
name = "strategy",
hdrs = ["strategy.h"],
deps = ["@com_github_grpc_grpc//:grpc++"],
)

tensorstore_cc_library(
name = "default_strategy",
srcs = ["default_strategy.cc"],
hdrs = ["default_strategy.h"],
deps = [
":strategy",
"@com_github_grpc_grpc//:grpc++",
],
)
32 changes: 32 additions & 0 deletions tensorstore/internal/grpc/serverauth/default_strategy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2025 The TensorStore Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "tensorstore/internal/grpc/serverauth/default_strategy.h"

#include <memory>

#include "grpcpp/security/server_credentials.h" // third_party
#include "tensorstore/internal/grpc/serverauth/strategy.h"

namespace tensorstore {
namespace internal_grpc {

std::shared_ptr<ServerAuthenticationStrategy>
CreateInsecureServerAuthenticationStrategy() {
return std::make_shared<DefaultServerAuthenticationStrategy>(
grpc::InsecureServerCredentials());
}

} // namespace internal_grpc
} // namespace tensorstore
55 changes: 55 additions & 0 deletions tensorstore/internal/grpc/serverauth/default_strategy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2025 The TensorStore Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef TENSORSTORE_INTERNAL_GRPC_SERVERAUTH_DEFAULT_STRATEGY_H_
#define TENSORSTORE_INTERNAL_GRPC_SERVERAUTH_DEFAULT_STRATEGY_H_

#include <memory>
#include <utility>
#include <vector>

#include "grpcpp/security/server_credentials.h" // third_party
#include "grpcpp/server_builder.h" // third_party
#include "tensorstore/internal/grpc/serverauth/strategy.h"

namespace tensorstore {
namespace internal_grpc {

class DefaultServerAuthenticationStrategy
: public ServerAuthenticationStrategy {
public:
DefaultServerAuthenticationStrategy(
std::shared_ptr<grpc::ServerCredentials> credentials)
: credentials_(std::move(credentials)) {}

~DefaultServerAuthenticationStrategy() override = default;

std::shared_ptr<grpc::ServerCredentials> GetServerCredentials()
const override {
return credentials_;
}

void AddBuilderParameters(grpc::ServerBuilder& builder) const override {}

std::shared_ptr<grpc::ServerCredentials> credentials_;
};

/// Creates an "insecure" server authentication strategy.
std::shared_ptr<ServerAuthenticationStrategy>
CreateInsecureServerAuthenticationStrategy();

} // namespace internal_grpc
} // namespace tensorstore

#endif // TENSORSTORE_INTERNAL_GRPC_SERVERAUTH_DEFAULT_STRATEGY_H_
50 changes: 50 additions & 0 deletions tensorstore/internal/grpc/serverauth/strategy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2025 The TensorStore Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef TENSORSTORE_INTERNAL_GRPC_SERVERAUTH_STRATEGY_H_
#define TENSORSTORE_INTERNAL_GRPC_SERVERAUTH_STRATEGY_H_

#include <memory>

#include "grpcpp/security/server_credentials.h" // third_party
#include "grpcpp/server_builder.h" // third_party

namespace tensorstore {
namespace internal_grpc {

/// Installs gRPC Server authentication strategies.
///
/// Usage:
/// auto strategy = ...;
/// grpc::ServerBuilder builder;
/// builder.RegisterService(...);
/// strategy->AddBuilderParameters(builder);
/// builder.AddListeningPort(bind_addresses,
/// strategy->GetServerCredentials(),
/// &bound_port);
/// auto server = builder.BuildAndStart();
class ServerAuthenticationStrategy {
public:
virtual ~ServerAuthenticationStrategy() = default;

virtual std::shared_ptr<grpc::ServerCredentials> GetServerCredentials()
const = 0;

virtual void AddBuilderParameters(grpc::ServerBuilder& builder) const = 0;
};

} // namespace internal_grpc
} // namespace tensorstore

#endif // TENSORSTORE_INTERNAL_GRPC_SERVERAUTH_STRATEGY_H_
6 changes: 5 additions & 1 deletion tensorstore/kvstore/ocdbt/distributed/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ tensorstore_cc_library(
"//tensorstore/internal/container:heterogeneous_container",
"//tensorstore/internal/container:intrusive_red_black_tree",
"//tensorstore/internal/grpc:peer_address",
"//tensorstore/internal/grpc:utils",
"//tensorstore/internal/grpc/serverauth:default_strategy",
"//tensorstore/internal/grpc/serverauth:strategy",
"//tensorstore/internal/json_binding",
"//tensorstore/internal/json_binding:bindable",
"//tensorstore/internal/log:verbose_flag",
Expand Down Expand Up @@ -208,6 +209,7 @@ tensorstore_cc_library(
"//tensorstore/kvstore/ocdbt/non_distributed:create_new_manifest",
"//tensorstore/kvstore/ocdbt/non_distributed:storage_generation",
"//tensorstore/kvstore/ocdbt/non_distributed:write_nodes",
"//tensorstore/util:bit_span",
"//tensorstore/util:bit_vec",
"//tensorstore/util:division",
"//tensorstore/util:executor",
Expand Down Expand Up @@ -338,6 +340,8 @@ tensorstore_cc_library(
"//tensorstore/internal/cache_key",
"//tensorstore/internal/grpc/clientauth:authentication_strategy",
"//tensorstore/internal/grpc/clientauth:channel_authentication",
"//tensorstore/internal/grpc/serverauth:default_strategy",
"//tensorstore/internal/grpc/serverauth:strategy",
"//tensorstore/internal/json_binding",
"//tensorstore/internal/json_binding:bindable",
"@com_github_grpc_grpc//:grpc++",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,7 @@ grpc::ServerUnaryReactor* Cooperator::GetOrCreateManifest(
const grpc_gen::GetOrCreateManifestRequest* request,
grpc_gen::GetOrCreateManifestResponse* response) {
auto* reactor = context->DefaultReactor();
if (auto status = security_->ValidateServerRequest(context); !status.ok()) {
reactor->Finish(internal::AbslStatusToGrpcStatus(status));
return reactor;
}

if (!internal::IncrementReferenceCountIfNonZero(*this)) {
// Shutting down
reactor->Finish(
Expand Down
3 changes: 3 additions & 0 deletions tensorstore/kvstore/ocdbt/distributed/cooperator_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#ifndef TENSORSTORE_KVSTORE_OCDBT_DISTRIBUTED_COOPERATOR_IMPL_H_
#define TENSORSTORE_KVSTORE_OCDBT_DISTRIBUTED_COOPERATOR_IMPL_H_

#include <stddef.h>

#include <atomic>
#include <memory>
#include <string>
Expand All @@ -39,6 +41,7 @@
#include "tensorstore/kvstore/ocdbt/distributed/cooperator.h"
#include "tensorstore/kvstore/ocdbt/distributed/cooperator.pb.h"
#include "tensorstore/kvstore/ocdbt/distributed/lease_cache_for_cooperator.h"
#include "tensorstore/kvstore/ocdbt/distributed/rpc_security.h"
#include "tensorstore/kvstore/ocdbt/format/btree.h"
#include "tensorstore/kvstore/ocdbt/format/version_tree.h"
#include "tensorstore/kvstore/ocdbt/io_handle.h"
Expand Down
9 changes: 6 additions & 3 deletions tensorstore/kvstore/ocdbt/distributed/cooperator_start.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ Result<CooperatorPtr> Start(Options&& options) {
impl->clock_ = [] { return absl::Now(); };
}
impl->io_handle_ = std::move(options.io_handle);
impl->security_ = options.security;
impl->storage_identifier_ = std::move(options.storage_identifier);

auto server_auth = impl->security_->GetServerAuthenticationStrategy();

grpc::ServerBuilder builder;
builder.RegisterService(impl.get());
auto creds = options.security->GetServerCredentials();
server_auth->AddBuilderParameters(builder);
auto creds = server_auth->GetServerCredentials();
const auto add_listening_port = [&](const std::string& address) {
builder.AddListeningPort(address, creds, &impl->listening_port_);
};
Expand All @@ -56,9 +61,7 @@ Result<CooperatorPtr> Start(Options&& options) {
add_listening_port(bind_address);
}
}
impl->security_ = options.security;
impl->server_ = builder.BuildAndStart();
impl->storage_identifier_ = std::move(options.storage_identifier);

auto auth_strategy = impl->security_->GetClientAuthenticationStrategy();

Expand Down
Loading

0 comments on commit df80595

Please sign in to comment.