Skip to content

Commit

Permalink
[v16] kube: properly return the reason for connection disruption (#51455
Browse files Browse the repository at this point in the history
)

* kube: properly return the reason for connection disruption (#51398)

* kube: properly return the reason for connection disruption

There are several cases where connection monitor can terminate an
ongoing connection. Iddle timeout, certificate expiring among others are
some reasons for the connection to be terminated.

For Kubernetes access, the underlying error is never propagated back to
the client so they don't receive the reason for the exec session being
terminated.

This PR fixes that by adding an hook to write the client error response
into the connection error channel for clients to be aware.

Part of #18496

* handle review comments

* handle review comments

* fix slog ref
  • Loading branch information
tigrato authored Jan 24, 2025
1 parent 938750d commit 65feb3f
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 16 deletions.
10 changes: 9 additions & 1 deletion integration/kube_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) {
ClientIdleTimeout: types.NewDuration(500 * time.Millisecond),
},
disconnectTimeout: 2 * time.Second,
verifyError: errorContains("Client exceeded idle timeout of"),
},
{
name: "expired cert",
Expand All @@ -1158,6 +1159,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) {
MaxSessionTTL: types.NewDuration(3 * time.Second),
},
disconnectTimeout: 6 * time.Second,
verifyError: errorContains("client certificate expire"),
},
}

Expand Down Expand Up @@ -1252,9 +1254,15 @@ func runKubeDisconnectTest(t *testing.T, suite *KubeSuite, tc disconnectTestCase
tty: true,
stdin: term,
})
require.NoError(t, err)
require.NoError(t, tc.verifyError(err))
}()

require.Eventually(t, func() bool {
// wait for the shell prompt
return strings.Contains(term.AllOutput(), "#")
}, 5*time.Second, 10*time.Millisecond, "Failed to get shell prompt. "+
"If this fails, the exec command is likely hanging and never reaching the kind cluster")

// lets type something followed by "enter" and then hang the session
require.NoError(t, enterInput(sessionCtx, term, "echo boring platypus\r\n", ".*boring platypus.*"))
time.Sleep(tc.disconnectTimeout)
Expand Down
48 changes: 41 additions & 7 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,9 @@ type authContext struct {
recordingConfig types.SessionRecordingConfig
// clientIdleTimeout sets information on client idle timeout
clientIdleTimeout time.Duration
// clientIdleTimeoutMessage is the message to be displayed to the user
// when the client idle timeout is reached
clientIdleTimeoutMessage string
// disconnectExpiredCert if set, controls the time when the connection
// should be disconnected because the client cert expires
disconnectExpiredCert time.Time
Expand Down Expand Up @@ -805,13 +808,14 @@ func (f *Forwarder) setupContext(
}

return &authContext{
clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()),
sessionTTL: sessionTTL,
Context: authCtx,
recordingConfig: recordingConfig,
kubeClusterName: kubeCluster,
certExpires: identity.Expires,
disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref),
clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()),
clientIdleTimeoutMessage: netConfig.GetClientIdleTimeoutMessage(),
sessionTTL: sessionTTL,
Context: authCtx,
recordingConfig: recordingConfig,
kubeClusterName: kubeCluster,
certExpires: identity.Expires,
disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref),
teleportCluster: teleportClusterClient{
name: teleportClusterName,
remoteAddr: utils.NetAddr{AddrNetwork: "tcp", Addr: req.RemoteAddr},
Expand Down Expand Up @@ -1666,6 +1670,8 @@ func (f *Forwarder) exec(authCtx *authContext, w http.ResponseWriter, req *http.

return upgradeRequestToRemoteCommandProxy(request,
func(proxy *remoteCommandProxy) error {
sess.sendErrStatus = proxy.writeStatus

if !sess.isLocalKubernetesCluster {
// We're forwarding this to another kubernetes_service instance, let it handle multiplexing.
return f.remoteExec(authCtx, w, req, p, sess, request, proxy)
Expand Down Expand Up @@ -2286,6 +2292,8 @@ type clusterSession struct {
connCtx context.Context
// connMonitorCancel is the conn monitor connMonitorCancel function.
connMonitorCancel context.CancelCauseFunc
// sendErrStatus is a function that sends an error status to the client.
sendErrStatus func(status *kubeerrors.StatusError) error
}

// close cancels the connection monitor context if available.
Expand Down Expand Up @@ -2324,6 +2332,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n
LockTargets: lockTargets,
DisconnectExpiredCert: s.disconnectExpiredCert,
ClientIdleTimeout: s.clientIdleTimeout,
IdleTimeoutMessage: s.clientIdleTimeoutMessage,
Clock: s.parent.cfg.Clock,
Tracker: tc,
Conn: tc,
Expand All @@ -2333,6 +2342,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n
Entry: s.parent.log,
Emitter: s.parent.cfg.AuthClient,
EmitterContext: s.parent.ctx,
MessageWriter: formatForwardResponseError(s.sendErrStatus),
})
if err != nil {
tc.CloseWithCause(err)
Expand Down Expand Up @@ -2694,3 +2704,27 @@ func errorToKubeStatusReason(err error, code int) metav1.StatusReason {
return metav1.StatusReasonUnknown
}
}

// formatForwardResponseError formats the error response from the connection
// monitor to a Kubernetes API error response.
type formatForwardResponseError func(status *kubeerrors.StatusError) error

func (f formatForwardResponseError) WriteString(s string) (int, error) {
if f == nil {
return len(s), nil
}
err := f(
&kubeerrors.StatusError{
ErrStatus: metav1.Status{
Status: metav1.StatusFailure,
Code: http.StatusInternalServerError,
Reason: metav1.StatusReasonInternalError,
Message: s,
},
},
)
if err != nil {
return 0, trace.Wrap(err)
}
return len(s), nil
}
2 changes: 1 addition & 1 deletion lib/kube/proxy/portforward_spdy.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func runPortForwardingHTTPStreams(req portForwardRequest) error {
defer h.Close()

h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout)
conn.SetIdleTimeout(req.idleTimeout)
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

h.run()
return nil
Expand Down
4 changes: 2 additions & 2 deletions lib/kube/proxy/portforward_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func runPortForwardingWebSocket(req portForwardRequest) error {
},
})

conn.SetIdleTimeout(req.idleTimeout)
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

// Upgrade the request and create the virtual streams.
_, streams, err := conn.Open(
Expand Down Expand Up @@ -355,7 +355,7 @@ func runPortForwardingTunneledHTTPStreams(req portForwardRequest) error {
defer h.Close()

h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout)
spdyConn.SetIdleTimeout(req.idleTimeout)
spdyConn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

h.run()
return nil
Expand Down
21 changes: 17 additions & 4 deletions lib/kube/proxy/remotecommand.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -157,7 +158,7 @@ func createSPDYStreams(req remoteCommandRequest) (*remoteCommandProxy, error) {
return nil, trace.ConnectionProblem(trace.BadParameter("missing connection"), "missing connection")
}

conn.SetIdleTimeout(req.idleTimeout)
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

var handler protocolHandler
switch protocol {
Expand Down Expand Up @@ -445,23 +446,35 @@ func waitStreamReply(ctx context.Context, replySent <-chan struct{}, notify chan
// v4WriteStatusFunc returns a WriteStatusFunc that marshals a given api Status
// as json in the error channel.
func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
return func(status *apierrors.StatusError) error {
return writeStatusOnceFunc(func(status *apierrors.StatusError) error {
st := status.Status()
data, err := runtime.Encode(globalKubeCodecs.LegacyCodec(), &st)
if err != nil {
return trace.Wrap(err)
}
_, err = stream.Write(data)
return err
}
})
}

func v1WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
return func(status *apierrors.StatusError) error {
return writeStatusOnceFunc(func(status *apierrors.StatusError) error {
if status.Status().Status == metav1.StatusSuccess {
return nil // send error messages
}
_, err := stream.Write([]byte(status.Error()))
return err
})
}

// writeStatusOnceFunc returns a function that only calls f once, and returns the result of the first call.
func writeStatusOnceFunc(f func(status *apierrors.StatusError) error) func(status *apierrors.StatusError) error {
var once sync.Once
var err error
return func(status *apierrors.StatusError) error {
once.Do(func() {
err = f(status)
})
return trace.Wrap(err)
}
}
20 changes: 19 additions & 1 deletion lib/kube/proxy/remotecommand_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License.
package proxy

import (
"time"

"github.com/go-logr/logr"
"github.com/gravitational/trace"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
Expand Down Expand Up @@ -110,7 +112,7 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro
},
})

conn.SetIdleTimeout(req.idleTimeout)
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

negotiatedProtocol, streams, err := conn.Open(
responsewriter.GetOriginal(req.httpResponseWriter),
Expand Down Expand Up @@ -163,3 +165,19 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro

return proxy, nil
}

// adjustIdleTimeoutForConn adjusts the idle timeout for the connection
// to be 5 seconds longer than the requested idle timeout.
// This is done to prevent the connection from being closed by the server
// before the connection monitor has a chance to close it and write the
// status code.
// If the idle timeout is 0, this function returns 0 because it means the
// connection will never be closed by the server due to idleness.
func adjustIdleTimeoutForConn(idleTimeout time.Duration) time.Duration {
// If the idle timeout is 0, we don't need to adjust it because it
// means the connection will never be closed by the server due to idleness.
if idleTimeout != 0 {
idleTimeout += 5 * time.Second
}
return idleTimeout
}
11 changes: 11 additions & 0 deletions lib/srv/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ type MonitorConfig struct {
Entry log.FieldLogger
// IdleTimeoutMessage is sent to the client when the idle timeout expires.
IdleTimeoutMessage string
// CertificateExpiredMessage is sent to the client when the certificate expires.
CertificateExpiredMessage string
// MessageWriter wraps a channel to send text messages to the client. Use
// for disconnection messages, etc.
MessageWriter io.StringWriter
Expand Down Expand Up @@ -417,6 +419,15 @@ func (w *Monitor) start(lockWatch types.Watcher) {

func (w *Monitor) disconnectClientOnExpiredCert() {
reason := fmt.Sprintf("client certificate expired at %v", w.Clock.Now().UTC())
if w.MessageWriter != nil {
msg := w.CertificateExpiredMessage
if msg == "" {
msg = reason
}
if _, err := w.MessageWriter.WriteString(msg); err != nil {
w.Entry.WithError(err).Warn("Failed to send certificate expiration message")
}
}
w.disconnectClient(reason)
}

Expand Down

0 comments on commit 65feb3f

Please sign in to comment.