Skip to content

Commit

Permalink
Refactor existing Workload API and SDS services to use new attr protos
Browse files Browse the repository at this point in the history
  • Loading branch information
strideynet committed Jan 8, 2025
1 parent ad5dbf3 commit 85accc5
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 181 deletions.
79 changes: 47 additions & 32 deletions lib/tbot/service_spiffe_workload_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (

"github.com/gravitational/teleport"
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/observability/metrics"
Expand Down Expand Up @@ -227,13 +228,27 @@ func (s *SPIFFEWorkloadAPIService) Run(ctx context.Context) error {
)
workloadpb.RegisterSpiffeWorkloadAPIServer(srv, s)
sdsHandler := &spiffeSDSHandler{
log: s.log,
cfg: s.cfg,
botCfg: s.botCfg,

trustBundleCache: s.trustBundleCache,
clientAuthenticator: s.authenticateClient,
svidFetcher: s.fetchX509SVIDs,
log: s.log,
botCfg: s.botCfg,
trustBundleCache: s.trustBundleCache,
clientAuthenticator: func(ctx context.Context) (*slog.Logger, svidFetcher, error) {
log, attrs, err := s.authenticateClient(ctx)
if err != nil {
return log, nil, trace.Wrap(err, "authenticating client")
}
fetchSVIDs := func(
ctx context.Context,
localBundle *spiffebundle.Bundle,
) ([]*workloadpb.X509SVID, error) {
return s.fetchX509SVIDs(
ctx,
log,
localBundle,
filterSVIDRequests(ctx, log, s.cfg.SVIDs, attrs),
)
}
return log, fetchSVIDs, nil
},
}
secretv3pb.RegisterSecretDiscoveryServiceServer(srv, sdsHandler)

Expand Down Expand Up @@ -373,7 +388,7 @@ func filterSVIDRequests(
ctx context.Context,
log *slog.Logger,
svidRequests []config.SVIDRequestWithRules,
att workloadattest.Attestation,
att *workloadidentityv1pb.WorkloadAttrs,
) []config.SVIDRequest {
var filtered []config.SVIDRequest
for _, req := range svidRequests {
Expand Down Expand Up @@ -413,67 +428,67 @@ func filterSVIDRequests(
"Evaluating rule against workload attestation",
)
if rule.Unix.UID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.UID != att.Unix.UID {
logMismatch("unix.uid", *rule.Unix.UID, att.Unix.UID)
if *rule.Unix.UID != int(att.GetUnix().GetUid()) {
logMismatch("unix.uid", *rule.Unix.UID, att.GetUnix().GetUid())
continue
}
// Rule field matched!
}
if rule.Unix.PID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.PID != att.Unix.PID {
logMismatch("unix.pid", *rule.Unix.PID, att.Unix.PID)
if *rule.Unix.PID != int(att.GetUnix().GetPid()) {
logMismatch("unix.pid", *rule.Unix.PID, att.GetUnix().GetPid())
continue
}
// Rule field matched!
}
if rule.Unix.GID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.GID != att.Unix.GID {
logMismatch("unix.gid", *rule.Unix.GID, att.Unix.GID)
if *rule.Unix.GID != int(att.GetUnix().GetGid()) {
logMismatch("unix.gid", *rule.Unix.GID, att.GetUnix().GetGid())
continue
}
// Rule field matched!
}
if rule.Kubernetes.Namespace != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.Namespace != att.Kubernetes.Namespace {
logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.Kubernetes.Namespace)
if rule.Kubernetes.Namespace != att.GetKubernetes().GetNamespace() {
logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.GetKubernetes().GetNamespace())
continue
}
// Rule field matched!
}
if rule.Kubernetes.PodName != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.PodName != att.Kubernetes.PodName {
logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.Kubernetes.PodName)
if rule.Kubernetes.PodName != att.GetKubernetes().GetPodName() {
logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.GetKubernetes().GetPodName())
continue
}
// Rule field matched!
}
if rule.Kubernetes.ServiceAccount != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.ServiceAccount != att.Kubernetes.ServiceAccount {
logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.Kubernetes.ServiceAccount)
if rule.Kubernetes.ServiceAccount != att.GetKubernetes().GetServiceAccount() {
logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.GetKubernetes().GetServiceAccount())
continue
}
// Rule field matched!
Expand All @@ -499,10 +514,10 @@ func filterSVIDRequests(

func (s *SPIFFEWorkloadAPIService) authenticateClient(
ctx context.Context,
) (*slog.Logger, workloadattest.Attestation, error) {
) (*slog.Logger, *workloadidentityv1pb.WorkloadAttrs, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, workloadattest.Attestation{}, trace.BadParameter("peer not found in context")
return nil, nil, trace.BadParameter("peer not found in context")
}
log := s.log

Expand All @@ -516,7 +531,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
// We expect Creds to be nil/unset if the client is connecting via TCP and
// therefore there is no workload attestation that can be completed.
if !ok || authInfo.Creds == nil {
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}

// For a UDS, sometimes we are unable to determine the PID of the calling
Expand All @@ -528,7 +543,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
if authInfo.Creds.PID == 0 {
log.DebugContext(
ctx, "Failed to determine the PID of the calling workload. TBot may be running in a different process namespace to the workload. Workload attestation will not be completed.")
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}

att, err := s.attestor.Attest(ctx, authInfo.Creds.PID)
Expand All @@ -541,10 +556,10 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
"error", err,
"pid", authInfo.Creds.PID,
)
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}
log = log.With(
"workload", slog.LogValuer(att),
"workload", att,
)

return log, att, nil
Expand Down
27 changes: 7 additions & 20 deletions lib/tbot/service_spiffe_workload_api_sds.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import (

"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
)

Expand All @@ -63,23 +62,18 @@ type bundleSetGetter interface {
GetBundleSet(ctx context.Context) (*workloadidentity.BundleSet, error)
}

type svidFetcher func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error)

// spiffeSDSHandler implements an Envoy SDS API.
//
// This effectively replaces the Workload API for Envoy, but functions in a
// very similar way.
type spiffeSDSHandler struct {
log *slog.Logger
cfg *config.SPIFFEWorkloadAPIService
botCfg *config.BotConfig
trustBundleCache bundleSetGetter

clientAuthenticator func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error)
svidFetcher func(
ctx context.Context,
log *slog.Logger,
localBundle *spiffebundle.Bundle,
svidRequests []config.SVIDRequest,
) ([]*workloadpb.X509SVID, error)
clientAuthenticator func(ctx context.Context) (*slog.Logger, svidFetcher, error)
}

// FetchSecrets implements
Expand All @@ -97,7 +91,7 @@ func (s *spiffeSDSHandler) FetchSecrets(
return nil, trace.Wrap(err)
}

log, creds, err := s.clientAuthenticator(ctx)
log, fetchSVIDs, err := s.clientAuthenticator(ctx)
if err != nil {
return nil, trace.Wrap(err, "authenticating client")
}
Expand All @@ -114,11 +108,7 @@ func (s *spiffeSDSHandler) FetchSecrets(
return nil, trace.Wrap(err, "getting trust bundle set")
}

// Filter SVIDs down to those accessible to this workload
svids, err := s.svidFetcher(
ctx,
log,
bundleSet.Local, filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds))
svids, err := fetchSVIDs(ctx, bundleSet.Local)
if err != nil {
return nil, trace.Wrap(err, "fetching X509 SVIDs")
}
Expand Down Expand Up @@ -174,7 +164,7 @@ func (s *spiffeSDSHandler) StreamSecrets(
srv secretv3pb.SecretDiscoveryService_StreamSecretsServer,
) error {
ctx := srv.Context()
log, creds, err := s.clientAuthenticator(ctx)
log, fetchSVIDs, err := s.clientAuthenticator(ctx)
if err != nil {
return trace.Wrap(err, "authenticating client")
}
Expand Down Expand Up @@ -216,9 +206,6 @@ func (s *spiffeSDSHandler) StreamSecrets(
renewalTimer.Stop()
defer renewalTimer.Stop()

// Filter SVIDs down to those accessible to this workload
availableSVIDs := filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds)

// Track the last response and last request to allow us to handle ACK/NACK
// and versioning.
var (
Expand Down Expand Up @@ -311,7 +298,7 @@ func (s *spiffeSDSHandler) StreamSecrets(

// Fetch the SVIDs if necessary
if svids == nil {
svids, err = s.svidFetcher(ctx, log, bundleSet.Local, availableSVIDs)
svids, err = fetchSVIDs(ctx, bundleSet.Local)
if err != nil {
return trace.Wrap(err, "fetching X509 SVIDs")
}
Expand Down
91 changes: 16 additions & 75 deletions lib/tbot/service_spiffe_workload_api_sds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import (
discoveryv3pb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
secretv3pb "github.com/envoyproxy/go-control-plane/envoy/service/secret/v3"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/spiffe/go-spiffe/v2/bundle/spiffebundle"
workloadpb "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload"
"github.com/spiffe/go-spiffe/v2/spiffeid"
Expand All @@ -51,7 +50,6 @@ import (
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/golden"
"github.com/gravitational/teleport/tool/teleport/testenv"
Expand Down Expand Up @@ -80,14 +78,22 @@ func TestSDS_FetchSecrets(t *testing.T) {
ca, err := x509.ParseCertificate(b.Bytes)
require.NoError(t, err)

uid := 100
notUID := 200
clientAuthenticator := func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error) {
return log, workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
Attested: true,
UID: uid,
},
clientAuthenticator := func(ctx context.Context) (*slog.Logger, svidFetcher, error) {
return log, func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error) {
return []*workloadpb.X509SVID{
{
SpiffeId: "spiffe://example.com/default",
X509Svid: []byte("CERT-spiffe://example.com/default"),
X509SvidKey: []byte("KEY-spiffe://example.com/default"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
{
SpiffeId: "spiffe://example.com/second",
X509Svid: []byte("CERT-spiffe://example.com/second"),
X509SvidKey: []byte("KEY-spiffe://example.com/second"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
}, nil
}, nil
}

Expand All @@ -105,72 +111,9 @@ func TestSDS_FetchSecrets(t *testing.T) {
},
},
}
svidFetcher := func(
ctx context.Context,
log *slog.Logger,
localBundle *spiffebundle.Bundle,
svidRequests []config.SVIDRequest) ([]*workloadpb.X509SVID, error) {
if len(svidRequests) != 2 {
return nil, trace.BadParameter("expected 2 svids requested")
}
return []*workloadpb.X509SVID{
{
SpiffeId: "spiffe://example.com/default",
X509Svid: []byte("CERT-spiffe://example.com/default"),
X509SvidKey: []byte("KEY-spiffe://example.com/default"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
{
SpiffeId: "spiffe://example.com/second",
X509Svid: []byte("CERT-spiffe://example.com/second"),
X509SvidKey: []byte("KEY-spiffe://example.com/second"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
}, nil
}
botConfig := &config.BotConfig{
RenewalInterval: time.Minute,
}
cfg := &config.SPIFFEWorkloadAPIService{
SVIDs: []config.SVIDRequestWithRules{
{
SVIDRequest: config.SVIDRequest{
Path: "/default",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &uid,
},
},
},
},
{
SVIDRequest: config.SVIDRequest{
Path: "/second",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &uid,
},
},
},
},
{
SVIDRequest: config.SVIDRequest{
Path: "/not-matching",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &notUID,
},
},
},
},
},
}

tests := []struct {
name string
Expand Down Expand Up @@ -231,12 +174,10 @@ func TestSDS_FetchSecrets(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
sds := &spiffeSDSHandler{
log: log,
cfg: cfg,
botCfg: botConfig,

trustBundleCache: mockBundleCache,
clientAuthenticator: clientAuthenticator,
svidFetcher: svidFetcher,
}

req := &discoveryv3pb.DiscoveryRequest{
Expand Down
Loading

0 comments on commit 85accc5

Please sign in to comment.