Skip to content

Commit

Permalink
add ssh identity object (#50787)
Browse files Browse the repository at this point in the history
  • Loading branch information
fspmarshall committed Jan 10, 2025
1 parent 65663eb commit d86055a
Show file tree
Hide file tree
Showing 14 changed files with 725 additions and 370 deletions.
66 changes: 34 additions & 32 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,7 @@ func (a *Server) AugmentContextUserCertificates(

// submitCertificateIssuedEvent submits a certificate issued usage event to the
// usage reporting service.
func (a *Server) submitCertificateIssuedEvent(req *certRequest, params services.UserCertParams) {
func (a *Server) submitCertificateIssuedEvent(req *certRequest, params sshca.UserCertificateRequest) {
var database, app, kubernetes, desktop bool

if req.dbService != "" {
Expand Down Expand Up @@ -2699,7 +2699,7 @@ func (a *Server) submitCertificateIssuedEvent(req *certRequest, params services.
UsageApp: app,
UsageKubernetes: kubernetes,
UsageDesktop: desktop,
PrivateKeyPolicy: string(params.PrivateKeyPolicy),
PrivateKeyPolicy: string(params.Identity.PrivateKeyPolicy),
})
}

Expand Down Expand Up @@ -2902,36 +2902,38 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
return nil, trace.Wrap(err)
}

params := services.UserCertParams{
CASigner: sshSigner,
PublicUserKey: req.publicKey,
Username: req.user.GetName(),
Impersonator: req.impersonator,
AllowedLogins: allowedLogins,
TTL: sessionTTL,
Roles: req.checker.RoleNames(),
CertificateFormat: certificateFormat,
PermitPortForwarding: req.checker.CanPortForward(),
PermitAgentForwarding: req.checker.CanForwardAgents(),
PermitX11Forwarding: req.checker.PermitX11Forwarding(),
RouteToCluster: req.routeToCluster,
Traits: req.traits,
ActiveRequests: req.activeRequests,
MFAVerified: req.mfaVerified,
PreviousIdentityExpires: req.previousIdentityExpires,
LoginIP: req.loginIP,
PinnedIP: pinnedIP,
DisallowReissue: req.disallowReissue,
Renewable: req.renewable,
Generation: req.generation,
BotName: req.botName,
CertificateExtensions: req.checker.CertificateExtensions(),
AllowedResourceIDs: requestedResourcesStr,
ConnectionDiagnosticID: req.connectionDiagnosticID,
PrivateKeyPolicy: attestedKeyPolicy,
DeviceID: req.deviceExtensions.DeviceID,
DeviceAssetTag: req.deviceExtensions.AssetTag,
DeviceCredentialID: req.deviceExtensions.CredentialID,
params := sshca.UserCertificateRequest{
CASigner: sshSigner,
PublicUserKey: req.publicKey,
TTL: sessionTTL,
CertificateFormat: certificateFormat,
Identity: sshca.Identity{
Username: req.user.GetName(),
Impersonator: req.impersonator,
AllowedLogins: allowedLogins,
Roles: req.checker.RoleNames(),
PermitPortForwarding: req.checker.CanPortForward(),
PermitAgentForwarding: req.checker.CanForwardAgents(),
PermitX11Forwarding: req.checker.PermitX11Forwarding(),
RouteToCluster: req.routeToCluster,
Traits: req.traits,
ActiveRequests: req.activeRequests,
MFAVerified: req.mfaVerified,
PreviousIdentityExpires: req.previousIdentityExpires,
LoginIP: req.loginIP,
PinnedIP: pinnedIP,
DisallowReissue: req.disallowReissue,
Renewable: req.renewable,
Generation: req.generation,
BotName: req.botName,
CertificateExtensions: req.checker.CertificateExtensions(),
AllowedResourceIDs: requestedResourcesStr,
ConnectionDiagnosticID: req.connectionDiagnosticID,
PrivateKeyPolicy: attestedKeyPolicy,
DeviceID: req.deviceExtensions.DeviceID,
DeviceAssetTag: req.deviceExtensions.AssetTag,
DeviceCredentialID: req.deviceExtensions.CredentialID,
},
}
signedSSHCert, err := a.GenerateUserCert(params)
if err != nil {
Expand Down
162 changes: 38 additions & 124 deletions lib/auth/keygen/keygen.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"context"
"crypto/rand"
"fmt"
"strings"
"time"

"github.com/gravitational/trace"
Expand All @@ -31,13 +30,12 @@ import (
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/wrappers"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -144,149 +142,65 @@ func (k *Keygen) GenerateHostCertWithoutValidation(c services.HostCertParams) ([

// GenerateUserCert generates a user ssh certificate with the passed in parameters.
// The private key of the CA to sign the certificate must be provided.
func (k *Keygen) GenerateUserCert(c services.UserCertParams) ([]byte, error) {
if err := c.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err, "error validating UserCertParams")
func (k *Keygen) GenerateUserCert(req sshca.UserCertificateRequest) ([]byte, error) {
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err, "error validating user certificate request")
}
return k.GenerateUserCertWithoutValidation(c)
return k.GenerateUserCertWithoutValidation(req)
}

// GenerateUserCertWithoutValidation generates a user ssh certificate with the
// passed in parameters without validating them.
func (k *Keygen) GenerateUserCertWithoutValidation(c services.UserCertParams) ([]byte, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(c.PublicUserKey)
func (k *Keygen) GenerateUserCertWithoutValidation(req sshca.UserCertificateRequest) ([]byte, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(req.PublicUserKey)
if err != nil {
return nil, trace.Wrap(err)
}
validBefore := uint64(ssh.CertTimeInfinity)
if c.TTL != 0 {
b := k.clock.Now().UTC().Add(c.TTL)
validBefore = uint64(b.Unix())
log.Debugf("generated user key for %v with expiry on (%v) %v", c.AllowedLogins, validBefore, b)
}
cert := &ssh.Certificate{
// we have to use key id to identify teleport user
KeyId: c.Username,
ValidPrincipals: c.AllowedLogins,
Key: pubKey,
ValidAfter: uint64(k.clock.Now().UTC().Add(-1 * time.Minute).Unix()),
ValidBefore: validBefore,
CertType: ssh.UserCert,
}
cert.Permissions.Extensions = map[string]string{
teleport.CertExtensionPermitPTY: "",
}
if c.PermitX11Forwarding {
cert.Permissions.Extensions[teleport.CertExtensionPermitX11Forwarding] = ""
}
if c.PermitAgentForwarding {
cert.Permissions.Extensions[teleport.CertExtensionPermitAgentForwarding] = ""
}
if c.PermitPortForwarding {
cert.Permissions.Extensions[teleport.CertExtensionPermitPortForwarding] = ""
}
if c.MFAVerified != "" {
cert.Permissions.Extensions[teleport.CertExtensionMFAVerified] = c.MFAVerified
}
if !c.PreviousIdentityExpires.IsZero() {
cert.Permissions.Extensions[teleport.CertExtensionPreviousIdentityExpires] = c.PreviousIdentityExpires.Format(time.RFC3339)
}
if c.LoginIP != "" {
cert.Permissions.Extensions[teleport.CertExtensionLoginIP] = c.LoginIP
}
if c.Impersonator != "" {
cert.Permissions.Extensions[teleport.CertExtensionImpersonator] = c.Impersonator
}
if c.DisallowReissue {
cert.Permissions.Extensions[teleport.CertExtensionDisallowReissue] = ""
}
if c.Renewable {
cert.Permissions.Extensions[teleport.CertExtensionRenewable] = ""
}
if c.Generation > 0 {
cert.Permissions.Extensions[teleport.CertExtensionGeneration] = fmt.Sprint(c.Generation)
}
if c.BotName != "" {
cert.Permissions.Extensions[teleport.CertExtensionBotName] = c.BotName
}
if c.AllowedResourceIDs != "" {
cert.Permissions.Extensions[teleport.CertExtensionAllowedResources] = c.AllowedResourceIDs
}
if c.ConnectionDiagnosticID != "" {
cert.Permissions.Extensions[teleport.CertExtensionConnectionDiagnosticID] = c.ConnectionDiagnosticID
}
if c.PrivateKeyPolicy != "" {
cert.Permissions.Extensions[teleport.CertExtensionPrivateKeyPolicy] = string(c.PrivateKeyPolicy)
}
if devID := c.DeviceID; devID != "" {
cert.Permissions.Extensions[teleport.CertExtensionDeviceID] = devID

// create shallow copy of identity since we want to make some local changes
ident := req.Identity

// since this method ignores the supplied values for ValidBefore/ValidAfter, avoid confusing by
// rejecting identities where they are set.
if ident.ValidBefore != 0 {
return nil, trace.BadParameter("ValidBefore should not be set in calls to GenerateUserCert")
}
if assetTag := c.DeviceAssetTag; assetTag != "" {
cert.Permissions.Extensions[teleport.CertExtensionDeviceAssetTag] = assetTag

if ident.ValidAfter != 0 {
return nil, trace.BadParameter("ValidAfter should not be set in calls to GenerateUserCert")
}
if credID := c.DeviceCredentialID; credID != "" {
cert.Permissions.Extensions[teleport.CertExtensionDeviceCredentialID] = credID

// calculate ValidBefore based on the outer request TTL
ident.ValidBefore = uint64(ssh.CertTimeInfinity)
if req.TTL != 0 {
b := k.clock.Now().UTC().Add(req.TTL)
ident.ValidBefore = uint64(b.Unix())
log.Debugf("generated user key for %v with expiry on (%v) %v", ident.AllowedLogins, ident.ValidBefore, b)
}

if c.PinnedIP != "" {
// set ValidAfter to be 1 minute in the past
ident.ValidAfter = uint64(k.clock.Now().UTC().Add(-1 * time.Minute).Unix())

// if the provided identity is attempting to perform IP pinning, make sure modules are enforced
if ident.PinnedIP != "" {
if modules.GetModules().BuildType() != modules.BuildEnterprise {
return nil, trace.AccessDenied("source IP pinning is only supported in Teleport Enterprise")
}
if cert.CriticalOptions == nil {
cert.CriticalOptions = make(map[string]string)
}
//IPv4, all bits matter
ip := c.PinnedIP + "/32"
if strings.Contains(c.PinnedIP, ":") {
//IPv6
ip = c.PinnedIP + "/128"
}
cert.CriticalOptions[teleport.CertCriticalOptionSourceAddress] = ip
}

for _, extension := range c.CertificateExtensions {
// TODO(lxea): update behavior when non ssh, non extensions are supported.
if extension.Mode != types.CertExtensionMode_EXTENSION ||
extension.Type != types.CertExtensionType_SSH {
continue
}
cert.Extensions[extension.Name] = extension.Value
// encode the identity into a certificate
cert, err := ident.Encode(req.CertificateFormat)
if err != nil {
return nil, trace.Wrap(err)
}

// Add roles, traits, and route to cluster in the certificate extensions if
// the standard format was requested. Certificate extensions are not included
// legacy SSH certificates due to a bug in OpenSSH <= OpenSSH 7.1:
// https://bugzilla.mindrot.org/show_bug.cgi?id=2387
if c.CertificateFormat == constants.CertificateFormatStandard {
traits, err := wrappers.MarshalTraits(&c.Traits)
if err != nil {
return nil, trace.Wrap(err)
}
if len(traits) > 0 {
cert.Permissions.Extensions[teleport.CertExtensionTeleportTraits] = string(traits)
}
if len(c.Roles) != 0 {
roles, err := services.MarshalCertRoles(c.Roles)
if err != nil {
return nil, trace.Wrap(err)
}
cert.Permissions.Extensions[teleport.CertExtensionTeleportRoles] = roles
}
if c.RouteToCluster != "" {
cert.Permissions.Extensions[teleport.CertExtensionTeleportRouteToCluster] = c.RouteToCluster
}
if !c.ActiveRequests.IsEmpty() {
requests, err := c.ActiveRequests.Marshal()
if err != nil {
return nil, trace.Wrap(err)
}
cert.Permissions.Extensions[teleport.CertExtensionTeleportActiveRequests] = string(requests)
}
}
// set the public key of the certificate
cert.Key = pubKey

if err := cert.SignCert(rand.Reader, c.CASigner); err != nil {
if err := cert.SignCert(rand.Reader, req.CASigner); err != nil {
return nil, trace.Wrap(err)
}

return ssh.MarshalAuthorizedKey(cert), nil
}

Expand Down
34 changes: 17 additions & 17 deletions lib/auth/keygen/keygen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/auth/test"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshca"
)

type nativeContext struct {
Expand Down Expand Up @@ -191,7 +192,7 @@ func TestUserCertCompatibility(t *testing.T) {

tt := setupNativeContext(context.Background(), t)

priv, pub, err := native.GenerateKeyPair()
priv, _, err := native.GenerateKeyPair()
require.NoError(t, err)

caSigner, err := ssh.ParsePrivateKey(priv)
Expand All @@ -217,23 +218,22 @@ func TestUserCertCompatibility(t *testing.T) {
for i, tc := range tests {
comment := fmt.Sprintf("Test %v", i)

userCertificateBytes, err := tt.suite.A.GenerateUserCert(services.UserCertParams{
CASigner: caSigner,
PublicUserKey: pub,
Username: "user",
AllowedLogins: []string{"centos", "root"},
TTL: time.Hour,
Roles: []string{"foo"},
CertificateExtensions: []*types.CertExtension{{
Type: types.CertExtensionType_SSH,
Mode: types.CertExtensionMode_EXTENSION,
Name: "[email protected]",
Value: "hello",
userCertificateBytes, err := tt.suite.A.GenerateUserCert(sshca.UserCertificateRequest{
CASigner: caSigner,
PublicUserKey: ssh.MarshalAuthorizedKey(caSigner.PublicKey()),
TTL: time.Hour,
CertificateFormat: tc.inCompatibility,
Identity: sshca.Identity{
Username: "user",
AllowedLogins: []string{"centos", "root"},
Roles: []string{"foo"},
CertificateExtensions: []*types.CertExtension{{
Type: types.CertExtensionType_SSH,
Mode: types.CertExtensionMode_EXTENSION,
Name: "[email protected]",
Value: "hello",
}},
},
},
CertificateFormat: tc.inCompatibility,
PermitAgentForwarding: true,
PermitPortForwarding: true,
})
require.NoError(t, err, comment)

Expand Down
Loading

0 comments on commit d86055a

Please sign in to comment.