diff --git a/lib/tbot/service_spiffe_workload_api.go b/lib/tbot/service_spiffe_workload_api.go index 748e2e3cbd13f..b8a5675673657 100644 --- a/lib/tbot/service_spiffe_workload_api.go +++ b/lib/tbot/service_spiffe_workload_api.go @@ -52,6 +52,7 @@ import ( "github.com/gravitational/teleport" machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/observability/metrics" @@ -227,13 +228,27 @@ func (s *SPIFFEWorkloadAPIService) Run(ctx context.Context) error { ) workloadpb.RegisterSpiffeWorkloadAPIServer(srv, s) sdsHandler := &spiffeSDSHandler{ - log: s.log, - cfg: s.cfg, - botCfg: s.botCfg, - - trustBundleCache: s.trustBundleCache, - clientAuthenticator: s.authenticateClient, - svidFetcher: s.fetchX509SVIDs, + log: s.log, + botCfg: s.botCfg, + trustBundleCache: s.trustBundleCache, + clientAuthenticator: func(ctx context.Context) (*slog.Logger, svidFetcher, error) { + log, attrs, err := s.authenticateClient(ctx) + if err != nil { + return log, nil, trace.Wrap(err, "authenticating client") + } + fetchSVIDs := func( + ctx context.Context, + localBundle *spiffebundle.Bundle, + ) ([]*workloadpb.X509SVID, error) { + return s.fetchX509SVIDs( + ctx, + log, + localBundle, + filterSVIDRequests(ctx, log, s.cfg.SVIDs, attrs), + ) + } + return log, fetchSVIDs, nil + }, } secretv3pb.RegisterSecretDiscoveryServiceServer(srv, sdsHandler) @@ -373,7 +388,7 @@ func filterSVIDRequests( ctx context.Context, log *slog.Logger, svidRequests []config.SVIDRequestWithRules, - att workloadattest.Attestation, + att *workloadidentityv1pb.WorkloadAttrs, ) []config.SVIDRequest { var filtered []config.SVIDRequest for _, req := range svidRequests { @@ -413,67 +428,67 @@ func filterSVIDRequests( "Evaluating rule against workload attestation", ) if rule.Unix.UID != nil { - if !att.Unix.Attested { + if !att.GetUnix().GetAttested() { logNotAttested("unix") continue } - if *rule.Unix.UID != att.Unix.UID { - logMismatch("unix.uid", *rule.Unix.UID, att.Unix.UID) + if *rule.Unix.UID != int(att.GetUnix().GetUid()) { + logMismatch("unix.uid", *rule.Unix.UID, att.GetUnix().GetUid()) continue } // Rule field matched! } if rule.Unix.PID != nil { - if !att.Unix.Attested { + if !att.GetUnix().GetAttested() { logNotAttested("unix") continue } - if *rule.Unix.PID != att.Unix.PID { - logMismatch("unix.pid", *rule.Unix.PID, att.Unix.PID) + if *rule.Unix.PID != int(att.GetUnix().GetPid()) { + logMismatch("unix.pid", *rule.Unix.PID, att.GetUnix().GetPid()) continue } // Rule field matched! } if rule.Unix.GID != nil { - if !att.Unix.Attested { + if !att.GetUnix().GetAttested() { logNotAttested("unix") continue } - if *rule.Unix.GID != att.Unix.GID { - logMismatch("unix.gid", *rule.Unix.GID, att.Unix.GID) + if *rule.Unix.GID != int(att.GetUnix().GetGid()) { + logMismatch("unix.gid", *rule.Unix.GID, att.GetUnix().GetGid()) continue } // Rule field matched! } if rule.Kubernetes.Namespace != "" { - if !att.Kubernetes.Attested { + if !att.GetKubernetes().GetAttested() { logNotAttested("kubernetes") continue } - if rule.Kubernetes.Namespace != att.Kubernetes.Namespace { - logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.Kubernetes.Namespace) + if rule.Kubernetes.Namespace != att.GetKubernetes().GetNamespace() { + logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.GetKubernetes().GetNamespace()) continue } // Rule field matched! } if rule.Kubernetes.PodName != "" { - if !att.Kubernetes.Attested { + if !att.GetKubernetes().GetAttested() { logNotAttested("kubernetes") continue } - if rule.Kubernetes.PodName != att.Kubernetes.PodName { - logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.Kubernetes.PodName) + if rule.Kubernetes.PodName != att.GetKubernetes().GetPodName() { + logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.GetKubernetes().GetPodName()) continue } // Rule field matched! } if rule.Kubernetes.ServiceAccount != "" { - if !att.Kubernetes.Attested { + if !att.GetKubernetes().GetAttested() { logNotAttested("kubernetes") continue } - if rule.Kubernetes.ServiceAccount != att.Kubernetes.ServiceAccount { - logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.Kubernetes.ServiceAccount) + if rule.Kubernetes.ServiceAccount != att.GetKubernetes().GetServiceAccount() { + logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.GetKubernetes().GetServiceAccount()) continue } // Rule field matched! @@ -499,10 +514,10 @@ func filterSVIDRequests( func (s *SPIFFEWorkloadAPIService) authenticateClient( ctx context.Context, -) (*slog.Logger, workloadattest.Attestation, error) { +) (*slog.Logger, *workloadidentityv1pb.WorkloadAttrs, error) { p, ok := peer.FromContext(ctx) if !ok { - return nil, workloadattest.Attestation{}, trace.BadParameter("peer not found in context") + return nil, nil, trace.BadParameter("peer not found in context") } log := s.log @@ -516,7 +531,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient( // We expect Creds to be nil/unset if the client is connecting via TCP and // therefore there is no workload attestation that can be completed. if !ok || authInfo.Creds == nil { - return log, workloadattest.Attestation{}, nil + return log, nil, nil } // For a UDS, sometimes we are unable to determine the PID of the calling @@ -528,7 +543,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient( if authInfo.Creds.PID == 0 { log.DebugContext( ctx, "Failed to determine the PID of the calling workload. TBot may be running in a different process namespace to the workload. Workload attestation will not be completed.") - return log, workloadattest.Attestation{}, nil + return log, nil, nil } att, err := s.attestor.Attest(ctx, authInfo.Creds.PID) @@ -541,10 +556,10 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient( "error", err, "pid", authInfo.Creds.PID, ) - return log, workloadattest.Attestation{}, nil + return log, nil, nil } log = log.With( - "workload", slog.LogValuer(att), + "workload", att, ) return log, att, nil diff --git a/lib/tbot/service_spiffe_workload_api_sds.go b/lib/tbot/service_spiffe_workload_api_sds.go index a74379e52383c..23bd84ad512d5 100644 --- a/lib/tbot/service_spiffe_workload_api_sds.go +++ b/lib/tbot/service_spiffe_workload_api_sds.go @@ -40,7 +40,6 @@ import ( "github.com/gravitational/teleport/lib/tbot/config" "github.com/gravitational/teleport/lib/tbot/workloadidentity" - "github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest" "github.com/gravitational/teleport/lib/utils" ) @@ -63,23 +62,18 @@ type bundleSetGetter interface { GetBundleSet(ctx context.Context) (*workloadidentity.BundleSet, error) } +type svidFetcher func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error) + // spiffeSDSHandler implements an Envoy SDS API. // // This effectively replaces the Workload API for Envoy, but functions in a // very similar way. type spiffeSDSHandler struct { log *slog.Logger - cfg *config.SPIFFEWorkloadAPIService botCfg *config.BotConfig trustBundleCache bundleSetGetter - clientAuthenticator func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error) - svidFetcher func( - ctx context.Context, - log *slog.Logger, - localBundle *spiffebundle.Bundle, - svidRequests []config.SVIDRequest, - ) ([]*workloadpb.X509SVID, error) + clientAuthenticator func(ctx context.Context) (*slog.Logger, svidFetcher, error) } // FetchSecrets implements @@ -97,7 +91,7 @@ func (s *spiffeSDSHandler) FetchSecrets( return nil, trace.Wrap(err) } - log, creds, err := s.clientAuthenticator(ctx) + log, fetchSVIDs, err := s.clientAuthenticator(ctx) if err != nil { return nil, trace.Wrap(err, "authenticating client") } @@ -114,11 +108,7 @@ func (s *spiffeSDSHandler) FetchSecrets( return nil, trace.Wrap(err, "getting trust bundle set") } - // Filter SVIDs down to those accessible to this workload - svids, err := s.svidFetcher( - ctx, - log, - bundleSet.Local, filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds)) + svids, err := fetchSVIDs(ctx, bundleSet.Local) if err != nil { return nil, trace.Wrap(err, "fetching X509 SVIDs") } @@ -174,7 +164,7 @@ func (s *spiffeSDSHandler) StreamSecrets( srv secretv3pb.SecretDiscoveryService_StreamSecretsServer, ) error { ctx := srv.Context() - log, creds, err := s.clientAuthenticator(ctx) + log, fetchSVIDs, err := s.clientAuthenticator(ctx) if err != nil { return trace.Wrap(err, "authenticating client") } @@ -216,9 +206,6 @@ func (s *spiffeSDSHandler) StreamSecrets( renewalTimer.Stop() defer renewalTimer.Stop() - // Filter SVIDs down to those accessible to this workload - availableSVIDs := filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds) - // Track the last response and last request to allow us to handle ACK/NACK // and versioning. var ( @@ -311,7 +298,7 @@ func (s *spiffeSDSHandler) StreamSecrets( // Fetch the SVIDs if necessary if svids == nil { - svids, err = s.svidFetcher(ctx, log, bundleSet.Local, availableSVIDs) + svids, err = fetchSVIDs(ctx, bundleSet.Local) if err != nil { return trace.Wrap(err, "fetching X509 SVIDs") } diff --git a/lib/tbot/service_spiffe_workload_api_sds_test.go b/lib/tbot/service_spiffe_workload_api_sds_test.go index 0ed4ad5c7cddf..b8a5304620c57 100644 --- a/lib/tbot/service_spiffe_workload_api_sds_test.go +++ b/lib/tbot/service_spiffe_workload_api_sds_test.go @@ -35,7 +35,6 @@ import ( discoveryv3pb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" secretv3pb "github.com/envoyproxy/go-control-plane/envoy/service/secret/v3" "github.com/google/go-cmp/cmp" - "github.com/gravitational/trace" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" workloadpb "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload" "github.com/spiffe/go-spiffe/v2/spiffeid" @@ -51,7 +50,6 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/tbot/config" "github.com/gravitational/teleport/lib/tbot/workloadidentity" - "github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/golden" "github.com/gravitational/teleport/tool/teleport/testenv" @@ -80,14 +78,22 @@ func TestSDS_FetchSecrets(t *testing.T) { ca, err := x509.ParseCertificate(b.Bytes) require.NoError(t, err) - uid := 100 - notUID := 200 - clientAuthenticator := func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error) { - return log, workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ - Attested: true, - UID: uid, - }, + clientAuthenticator := func(ctx context.Context) (*slog.Logger, svidFetcher, error) { + return log, func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error) { + return []*workloadpb.X509SVID{ + { + SpiffeId: "spiffe://example.com/default", + X509Svid: []byte("CERT-spiffe://example.com/default"), + X509SvidKey: []byte("KEY-spiffe://example.com/default"), + Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()), + }, + { + SpiffeId: "spiffe://example.com/second", + X509Svid: []byte("CERT-spiffe://example.com/second"), + X509SvidKey: []byte("KEY-spiffe://example.com/second"), + Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()), + }, + }, nil }, nil } @@ -105,72 +111,9 @@ func TestSDS_FetchSecrets(t *testing.T) { }, }, } - svidFetcher := func( - ctx context.Context, - log *slog.Logger, - localBundle *spiffebundle.Bundle, - svidRequests []config.SVIDRequest) ([]*workloadpb.X509SVID, error) { - if len(svidRequests) != 2 { - return nil, trace.BadParameter("expected 2 svids requested") - } - return []*workloadpb.X509SVID{ - { - SpiffeId: "spiffe://example.com/default", - X509Svid: []byte("CERT-spiffe://example.com/default"), - X509SvidKey: []byte("KEY-spiffe://example.com/default"), - Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()), - }, - { - SpiffeId: "spiffe://example.com/second", - X509Svid: []byte("CERT-spiffe://example.com/second"), - X509SvidKey: []byte("KEY-spiffe://example.com/second"), - Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()), - }, - }, nil - } botConfig := &config.BotConfig{ RenewalInterval: time.Minute, } - cfg := &config.SPIFFEWorkloadAPIService{ - SVIDs: []config.SVIDRequestWithRules{ - { - SVIDRequest: config.SVIDRequest{ - Path: "/default", - }, - Rules: []config.SVIDRequestRule{ - { - Unix: config.SVIDRequestRuleUnix{ - UID: &uid, - }, - }, - }, - }, - { - SVIDRequest: config.SVIDRequest{ - Path: "/second", - }, - Rules: []config.SVIDRequestRule{ - { - Unix: config.SVIDRequestRuleUnix{ - UID: &uid, - }, - }, - }, - }, - { - SVIDRequest: config.SVIDRequest{ - Path: "/not-matching", - }, - Rules: []config.SVIDRequestRule{ - { - Unix: config.SVIDRequestRuleUnix{ - UID: ¬UID, - }, - }, - }, - }, - }, - } tests := []struct { name string @@ -231,12 +174,10 @@ func TestSDS_FetchSecrets(t *testing.T) { t.Run(tt.name, func(t *testing.T) { sds := &spiffeSDSHandler{ log: log, - cfg: cfg, botCfg: botConfig, trustBundleCache: mockBundleCache, clientAuthenticator: clientAuthenticator, - svidFetcher: svidFetcher, } req := &discoveryv3pb.DiscoveryRequest{ diff --git a/lib/tbot/service_spiffe_workload_api_test.go b/lib/tbot/service_spiffe_workload_api_test.go index 3c4c10927b994..1a2b4227c9572 100644 --- a/lib/tbot/service_spiffe_workload_api_test.go +++ b/lib/tbot/service_spiffe_workload_api_test.go @@ -34,9 +34,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/tbot/config" - "github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/tool/teleport/testenv" ) @@ -52,7 +52,7 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { log := utils.NewSlogLoggerForTests() tests := []struct { name string - att workloadattest.Attestation + att *workloadidentityv1pb.WorkloadAttrs in []config.SVIDRequestWithRules want []config.SVIDRequest }{ @@ -81,12 +81,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { }, { name: "no rules with attestation", - att: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + att: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 1000, - GID: 1001, - PID: 1002, + Uid: 1000, + Gid: 1001, + Pid: 1002, }, }, in: []config.SVIDRequestWithRules{ @@ -112,15 +112,15 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { }, { name: "no rules with attestation", - att: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + att: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ // We don't expect that workloadattest will ever return // Attested: false and include UID/PID/GID but we want to // ensure we handle this by failing regardless. Attested: false, - UID: 1000, - GID: 1001, - PID: 1002, + Uid: 1000, + Gid: 1001, + Pid: 1002, }, }, in: []config.SVIDRequestWithRules{ @@ -141,12 +141,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { }, { name: "no matching rules with attestation", - att: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + att: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 1000, - GID: 1001, - PID: 1002, + Uid: 1000, + Gid: 1001, + Pid: 1002, }, }, in: []config.SVIDRequestWithRules{ @@ -220,12 +220,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { }, { name: "some matching rules with uds", - att: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + att: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 1000, - GID: 1001, - PID: 1002, + Uid: 1000, + Gid: 1001, + Pid: 1002, }, }, in: []config.SVIDRequestWithRules{ @@ -290,8 +290,8 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { log := utils.NewSlogLoggerForTests() tests := []struct { field string - matching workloadattest.Attestation - nonMatching workloadattest.Attestation + matching *workloadidentityv1pb.WorkloadAttrs + nonMatching *workloadidentityv1pb.WorkloadAttrs rule config.SVIDRequestRule }{ { @@ -301,16 +301,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { PID: ptr(1000), }, }, - matching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - PID: 1000, + Pid: 1000, }, }, - nonMatching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - PID: 200, + Pid: 200, }, }, }, @@ -321,16 +321,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { UID: ptr(1000), }, }, - matching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 1000, + Uid: 1000, }, }, - nonMatching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 200, + Uid: 200, }, }, }, @@ -341,16 +341,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { GID: ptr(1000), }, }, - matching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - GID: 1000, + Gid: 1000, }, }, - nonMatching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - GID: 200, + Gid: 200, }, }, }, @@ -361,14 +361,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { Namespace: "foo", }, }, - matching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, Namespace: "foo", }, }, - nonMatching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, Namespace: "bar", }, @@ -381,14 +381,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { ServiceAccount: "foo", }, }, - matching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, ServiceAccount: "foo", }, }, - nonMatching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, ServiceAccount: "bar", }, @@ -401,14 +401,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { PodName: "foo", }, }, - matching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, PodName: "foo", }, }, - nonMatching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, PodName: "bar", },