Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert lib/srv to use slog #49913

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 38 additions & 44 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -265,7 +264,7 @@ type ServerContext struct {
// ConnectionContext is the parent context which manages connection-level
// resources.
*sshutils.ConnectionContext
*log.Entry
Logger *slog.Logger

mu sync.RWMutex

Expand Down Expand Up @@ -434,17 +433,14 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
ServerSubKind: srv.TargetMetadata().ServerSubKind,
}

fields := log.Fields{
"local": child.ServerConn.LocalAddr(),
"remote": child.ServerConn.RemoteAddr(),
"login": child.Identity.Login,
"teleportUser": child.Identity.TeleportUser,
"id": child.id,
}
child.Entry = log.WithFields(log.Fields{
teleport.ComponentKey: child.srv.Component(),
teleport.ComponentFields: fields,
})
child.Logger = slog.With(
teleport.ComponentKey, srv.Component(),
"local_addr", child.ServerConn.LocalAddr(),
"remote_addr", child.ServerConn.RemoteAddr(),
"login", child.Identity.Login,
"teleport_user", child.Identity.TeleportUser,
"id", child.id,
)

if identityContext.Login == teleport.SSHSessionJoinPrincipal {
child.JoinOnly = true
Expand All @@ -462,15 +458,11 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s

// Update log entry fields.
if !child.disconnectExpiredCert.IsZero() {
fields["cert"] = child.disconnectExpiredCert
child.Logger = child.Logger.With("cert", child.disconnectExpiredCert)
}
if child.clientIdleTimeout != 0 {
fields["idle"] = child.clientIdleTimeout
child.Logger = child.Logger.With("idle", child.clientIdleTimeout)
}
child.Entry = log.WithFields(log.Fields{
teleport.ComponentKey: srv.Component(),
teleport.ComponentFields: fields,
})

clusterName, err := srv.GetAccessPoint().GetClusterName()
if err != nil {
Expand All @@ -491,11 +483,9 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
TeleportUser: child.Identity.TeleportUser,
Login: child.Identity.Login,
ServerID: child.srv.ID(),
// TODO(tross) update this to use the child logger
// once ServerContext is converted to use a slog.Logger
Logger: slog.Default(),
Emitter: child.srv,
EmitterContext: ctx,
Logger: child.Logger,
Emitter: child.srv,
EmitterContext: ctx,
}
for _, opt := range monitorOpts {
opt(&monitorConfig)
Expand Down Expand Up @@ -573,15 +563,15 @@ func (c *ServerContext) GetServer() Server {

// CreateOrJoinSession will look in the SessionRegistry for the session ID. If
// no session is found, a new one is created. If one is found, it is returned.
func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
func (c *ServerContext) CreateOrJoinSession(ctx context.Context, reg *SessionRegistry) error {
c.mu.Lock()
defer c.mu.Unlock()
// As SSH conversation progresses, at some point a session will be created and
// its ID will be added to the environment
ssid, found := c.getEnvLocked(sshutils.SessionEnvVar)
if !found {
c.sessionID = rsession.NewID()
c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr())
c.Logger.DebugContext(ctx, "Will create new session for SSH connection")
return nil
}

Expand All @@ -595,7 +585,7 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
if sess, found := reg.findSession(*id); found {
c.sessionID = *id
c.session = sess
c.Logger.Debugf("Will join session %v for SSH connection %v.", c.session.id, c.ServerConn.RemoteAddr())
c.Logger.DebugContext(ctx, "Joining active SSH session", "session_id", c.session.id)
} else {
// TODO(capnspacehook): DELETE IN 17.0.0 - by then all supported
// clients should only set TELEPORT_SESSION when they want to
Expand All @@ -605,7 +595,7 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
// to prevent the user from controlling the session ID, generate
// a new one
c.sessionID = rsession.NewID()
c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr())
c.Logger.DebugContext(ctx, "Creating new SSH session")
}

return nil
Expand Down Expand Up @@ -676,18 +666,18 @@ func (c *ServerContext) getEnvLocked(key string) (string, bool) {
}

// setSession sets the context's session
func (c *ServerContext) setSession(sess *session, ch ssh.Channel) {
func (c *ServerContext) setSession(ctx context.Context, sess *session, ch ssh.Channel) {
c.mu.Lock()
defer c.mu.Unlock()
c.session = sess

// inform the client of the session ID that is being used in a new
// goroutine to reduce latency
go func() {
c.Logger.Debug("Sending current session ID.")
c.Logger.DebugContext(ctx, "Sending current session ID")
_, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sess.ID()))
if err != nil {
c.Logger.WithError(err).Debug("Failed to send the current session ID.")
c.Logger.DebugContext(ctx, "Failed to send the current session ID", "error", err)
}
}()
}
Expand Down Expand Up @@ -754,7 +744,7 @@ func (c *ServerContext) CheckSFTPAllowed(registry *SessionRegistry) error {
}

// OpenXServerListener opens a new XServer unix listener.
func (c *ServerContext) HandleX11Listener(l net.Listener, singleConnection bool) error {
func (c *ServerContext) HandleX11Listener(ctx context.Context, l net.Listener, singleConnection bool) error {
display, err := x11.ParseDisplayFromUnixSocket(l.Addr().String())
if err != nil {
return trace.Wrap(err)
Expand All @@ -780,7 +770,7 @@ func (c *ServerContext) HandleX11Listener(l net.Listener, singleConnection bool)
xconn, err := l.Accept()
if err != nil {
if !utils.IsOKNetworkError(err) {
c.Logger.WithError(err).Debug("Encountered error accepting XServer connection")
c.Logger.DebugContext(ctx, "Encountered error accepting XServer connection", "error", err)
}
return
}
Expand All @@ -790,7 +780,7 @@ func (c *ServerContext) HandleX11Listener(l net.Listener, singleConnection bool)

xchan, sin, err := c.ServerConn.OpenChannel(x11.ChannelRequest, x11ChannelReqPayload)
if err != nil {
c.Logger.WithError(err).Debug("Failed to open a new X11 channel")
c.Logger.DebugContext(ctx, "Failed to open a new X11 channel", "error", err)
return
}
defer xchan.Close()
Expand All @@ -802,12 +792,12 @@ func (c *ServerContext) HandleX11Listener(l net.Listener, singleConnection bool)
go func() {
err := sshutils.ForwardRequests(ctx, sin, c.RemoteSession)
if err != nil {
c.Logger.WithError(err).Debug("Failed to forward ssh request from server during X11 forwarding")
c.Logger.DebugContext(ctx, "Failed to forward ssh request from server during X11 forwarding", "error", err)
}
}()

if err := utils.ProxyConn(ctx, xconn, xchan); err != nil {
c.Logger.WithError(err).Debug("Encountered error during X11 forwarding")
c.Logger.DebugContext(ctx, "Encountered error during X11 forwarding", "error", err)
}
}()

Expand Down Expand Up @@ -884,7 +874,7 @@ func (c *ServerContext) reportStats(conn utils.Stater) {
sessionDataEvent.ConnectionMetadata.LocalAddr = c.ServerConn.LocalAddr().String()
}
if err := c.GetServer().EmitAuditEvent(c.GetServer().Context(), sessionDataEvent); err != nil {
c.WithError(err).Warn("Failed to emit session data event.")
c.Logger.WarnContext(c.GetServer().Context(), "Failed to emit session data event", "error", err)
}

// Emit TX and RX bytes to their respective Prometheus counters.
Expand Down Expand Up @@ -926,21 +916,21 @@ func (c *ServerContext) CancelFunc() context.CancelFunc {

// SendExecResult sends the result of execution of the "exec" command over the
// ExecResultCh.
func (c *ServerContext) SendExecResult(r ExecResult) {
func (c *ServerContext) SendExecResult(ctx context.Context, r ExecResult) {
select {
case c.ExecResultCh <- r:
default:
c.Infof("Blocked on sending exec result %v.", r)
c.Logger.InfoContext(ctx, "Blocked on sending exec result", "code", r.Code, "command", r.Command)
}
}

// SendSubsystemResult sends the result of running the subsystem over the
// SubsystemResultCh.
func (c *ServerContext) SendSubsystemResult(r SubsystemResult) {
func (c *ServerContext) SendSubsystemResult(ctx context.Context, r SubsystemResult) {
select {
case c.SubsystemResultCh <- r:
default:
c.Info("Blocked on sending subsystem result.")
c.Logger.InfoContext(ctx, "Blocked on sending subsystem result")
}
}

Expand Down Expand Up @@ -1005,7 +995,11 @@ func getPAMConfig(c *ServerContext) (*PAMConfig, error) {
// If the trait isn't passed by the IdP due to misconfiguration
// we fallback to setting a value which will indicate this.
if trace.IsNotFound(err) {
c.Logger.WithError(err).Warnf("Attempted to interpolate custom PAM environment with external trait but received SAML response does not contain claim")
c.Logger.WarnContext(
c.CancelContext(),
"Attempted to interpolate custom PAM environment with external trait but received SAML response does not contain claim",
"error", err,
)
continue
}

Expand Down Expand Up @@ -1120,11 +1114,11 @@ func buildEnvironment(ctx *ServerContext) []string {
// SSH_CONNECTION environment variables.
remoteHost, remotePort, err := net.SplitHostPort(ctx.ServerConn.RemoteAddr().String())
if err != nil {
ctx.Logger.Debugf("Failed to split remote address: %v.", err)
ctx.Logger.DebugContext(ctx.CancelContext(), "Failed to split remote address", "error", err)
} else {
localHost, localPort, err := net.SplitHostPort(ctx.ServerConn.LocalAddr().String())
if err != nil {
ctx.Logger.Debugf("Failed to split local address: %v.", err)
ctx.Logger.DebugContext(ctx.CancelContext(), "Failed to split local address", "error", err)
} else {
env.AddTrusted("SSH_CLIENT", fmt.Sprintf("%s %s %s", remoteHost, remotePort, localPort))
env.AddTrusted("SSH_CONNECTION", fmt.Sprintf("%s %s %s %s", remoteHost, remotePort, localHost, localPort))
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func TestCreateOrJoinSession(t *testing.T) {
ctx.SetEnv(sshutils.SessionEnvVar, tt.sessionID)
}

err = ctx.CreateOrJoinSession(registry)
err = ctx.CreateOrJoinSession(context.Background(), registry)
require.NoError(t, err)
require.False(t, ctx.sessionID.IsZero())
if tt.wantSameSessionID {
Expand Down
Loading
Loading