Skip to content

Commit

Permalink
Add WebSocket handler for WebUI database sessions (#49749)
Browse files Browse the repository at this point in the history
* feat(web): add websocket handler for database webui sessions

* refactor: move common structs into a separate package

* refactor(web): use ALPN local proxy to dial databases

* feat(repl): add default registry

* refactor(web): code review suggestions

* refactor: update repl config parameters

* refactor: move default getter implementation

* feat(web): add supports_interactive field on dbs

* refactor: code review suggestions

* refactor: update database REPL interfaces

* chore(web): remove debug print

* feat: register postgres repl

* refactor(web): update MakeDatabase to receive access checker and interactive

* chore(web): remove unused function
  • Loading branch information
gabrielcorado committed Dec 16, 2024
1 parent 037b9d0 commit 4c0a9d7
Show file tree
Hide file tree
Showing 18 changed files with 793 additions and 88 deletions.
27 changes: 23 additions & 4 deletions lib/client/alpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,33 @@ type ALPNAuthTunnelConfig struct {
// RouteToDatabase contains the destination server that must receive the connection.
// Specific for database proxying.
RouteToDatabase proto.RouteToDatabase

// TLSCert specifies the TLS certificate used on the proxy connection.
TLSCert *tls.Certificate
}

func (c *ALPNAuthTunnelConfig) CheckAndSetDefaults(ctx context.Context) error {
if c.AuthClient == nil {
return trace.BadParameter("missing auth client")
}

if c.TLSCert == nil {
tlsCert, err := getUserCerts(ctx, c.AuthClient, c.MFAResponse, c.Expires, c.RouteToDatabase, c.ConnectionDiagnosticID)
if err != nil {
return trace.BadParameter("failed to parse private key: %v", err)
}

c.TLSCert = &tlsCert
}

return nil
}

// RunALPNAuthTunnel runs a local authenticated ALPN proxy to another service.
// At least one Route (which defines the service) must be defined
func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error {
tlsCert, err := getUserCerts(ctx, cfg.AuthClient, cfg.MFAResponse, cfg.Expires, cfg.RouteToDatabase, cfg.ConnectionDiagnosticID)
if err != nil {
return trace.BadParameter("failed to parse private key: %v", err)
if err := cfg.CheckAndSetDefaults(ctx); err != nil {
return trace.Wrap(err)
}

lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{
Expand All @@ -101,7 +120,7 @@ func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error {
Protocols: []alpn.Protocol{cfg.Protocol},
Listener: cfg.Listener,
ParentContext: ctx,
Cert: tlsCert,
Cert: *cfg.TLSCert,
}, alpnproxy.WithALPNConnUpgradeTest(ctx, getClusterCACertPool(cfg.AuthClient)))
if err != nil {
return trace.Wrap(err)
Expand Down
17 changes: 9 additions & 8 deletions lib/client/db/postgres/repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/gravitational/teleport"
clientproto "github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/lib/asciitable"
dbrepl "github.com/gravitational/teleport/lib/client/db/repl"
"github.com/gravitational/teleport/lib/defaults"
)

Expand All @@ -44,13 +45,13 @@ type REPL struct {
commands map[string]*command
}

func New(client io.ReadWriteCloser, serverConn net.Conn, route clientproto.RouteToDatabase) (*REPL, error) {
func New(_ context.Context, cfg *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, error) {
config, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%s", hostnamePlaceholder))
if err != nil {
return nil, trace.Wrap(err)
}
config.User = route.Username
config.Database = route.Database
config.User = cfg.Route.Username
config.Database = cfg.Route.Database
config.ConnectTimeout = defaults.DatabaseConnectTimeout
config.RuntimeParams = map[string]string{
applicationNameParamName: applicationNameParamValue,
Expand All @@ -63,15 +64,15 @@ func New(client io.ReadWriteCloser, serverConn net.Conn, route clientproto.Route
return []string{hostnamePlaceholder}, nil
}
config.DialFunc = func(_ context.Context, _, _ string) (net.Conn, error) {
return serverConn, nil
return cfg.ServerConn, nil
}

return &REPL{
connConfig: config,
client: client,
serverConn: serverConn,
route: route,
term: term.NewTerminal(client, ""),
client: cfg.Client,
serverConn: cfg.ServerConn,
route: cfg.Route,
term: term.NewTerminal(cfg.Client, ""),
commands: initCommands(),
}, nil
}
Expand Down
6 changes: 4 additions & 2 deletions lib/client/db/postgres/repl/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (

clientproto "github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/lib/client/db/postgres/repl/testdata"
dbrepl "github.com/gravitational/teleport/lib/client/db/repl"
"github.com/gravitational/teleport/lib/utils/golden"
)

Expand Down Expand Up @@ -338,7 +339,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) (
}
}(tc)

r, err := New(tc.clientConn, tc.serverConn, tc.route)
instance, err := New(ctx, &dbrepl.NewREPLConfig{Client: tc.clientConn, ServerConn: tc.serverConn, Route: tc.route})
require.NoError(t, err)

if !cfg.skipREPLRun {
Expand All @@ -347,7 +348,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) (
runCtx, cancelRun := context.WithCancel(ctx)
runErrChan := make(chan error, 1)
go func() {
runErrChan <- r.Run(runCtx)
runErrChan <- instance.Run(runCtx)
}()
t.Cleanup(func() {
cancelRun()
Expand All @@ -363,6 +364,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) (
})
}

r, _ := instance.(*REPL)
return r, tc
}

Expand Down
80 changes: 80 additions & 0 deletions lib/client/db/repl/repl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package repl

import (
"context"
"io"
"net"

"github.com/gravitational/trace"

clientproto "github.com/gravitational/teleport/api/client/proto"
)

// NewREPLConfig represents the database REPL constructor config.
type NewREPLConfig struct {
// Client is the user terminal client.
Client io.ReadWriteCloser
// ServerConn is the database server connection.
ServerConn net.Conn
// Route is the session routing information.
Route clientproto.RouteToDatabase
}

// REPLNewFunc defines the constructor function for database REPL
// sessions.
type REPLNewFunc func(context.Context, *NewREPLConfig) (REPLInstance, error)

// REPLInstance represents a REPL instance.
type REPLInstance interface {
// Run executes the REPL. This is a blocking operation.
Run(context.Context) error
}

// REPLRegistry is an interface for initializing REPL instances and checking
// if the database protocol is supported.
type REPLRegistry interface {
// IsSupported returns if a database protocol is supported by any REPL.
IsSupported(protocol string) bool
// NewInstance initializes a new REPL instance given the configuration.
NewInstance(context.Context, *NewREPLConfig) (REPLInstance, error)
}

// NewREPLGetter creates a new REPL getter given the list of supported REPLs.
func NewREPLGetter(replNewFuncs map[string]REPLNewFunc) REPLRegistry {
return &replRegistry{m: replNewFuncs}
}

type replRegistry struct {
m map[string]REPLNewFunc
}

// IsSupported implements REPLGetter.
func (r *replRegistry) IsSupported(protocol string) bool {
_, supported := r.m[protocol]
return supported
}

// NewInstance implements REPLGetter.
func (r *replRegistry) NewInstance(ctx context.Context, cfg *NewREPLConfig) (REPLInstance, error) {
if newFunc, ok := r.m[cfg.Route.Protocol]; ok {
return newFunc(ctx, cfg)
}

return nil, trace.NotImplemented("REPL not supported for protocol %q", cfg.Route.Protocol)
}
4 changes: 4 additions & 0 deletions lib/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,10 @@ const (

// WebsocketKubeExec provides latency information for a session.
WebsocketKubeExec = "k"

// WebsocketDatabaseSessionRequest is received when a new database session
// is requested.
WebsocketDatabaseSessionRequest = "d"
)

// The following are cryptographic primitives Teleport does not support in
Expand Down
9 changes: 9 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ import (
_ "github.com/gravitational/teleport/lib/backend/pgbk"
"github.com/gravitational/teleport/lib/bpf"
"github.com/gravitational/teleport/lib/cache"
pgrepl "github.com/gravitational/teleport/lib/client/db/postgres/repl"
dbrepl "github.com/gravitational/teleport/lib/client/db/repl"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/gcp"
"github.com/gravitational/teleport/lib/cloud/imds"
Expand Down Expand Up @@ -1084,6 +1086,12 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) {
cfg.PluginRegistry = plugin.NewRegistry()
}

if cfg.DatabaseREPLRegistry == nil {
cfg.DatabaseREPLRegistry = dbrepl.NewREPLGetter(map[string]dbrepl.REPLNewFunc{
defaults.ProtocolPostgres: pgrepl.New,
})
}

var cloudLabels labels.Importer

// Check if we're on a cloud instance, and if we should override the node's hostname.
Expand Down Expand Up @@ -4652,6 +4660,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
AutomaticUpgradesChannels: cfg.Proxy.AutomaticUpgradesChannels,
IntegrationAppHandler: connectionsHandler,
FeatureWatchInterval: retryutils.HalfJitter(web.DefaultFeatureWatchInterval * 2),
DatabaseREPLRegistry: cfg.DatabaseREPLRegistry,
}
webHandler, err := web.NewHandler(webConfig)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions lib/service/servicecfg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/lite"
dbrepl "github.com/gravitational/teleport/lib/client/db/repl"
"github.com/gravitational/teleport/lib/cloud/imds"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -265,6 +266,10 @@ type Config struct {
// AccessGraph represents AccessGraph server config
AccessGraph AccessGraphConfig

// DatabaseREPLRegistry is used to retrieve datatabase REPL given the
// protocol.
DatabaseREPLRegistry dbrepl.REPLRegistry

// token is either the token needed to join the auth server, or a path pointing to a file
// that contains the token
//
Expand Down
19 changes: 8 additions & 11 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ import (
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/automaticupgrades"
"github.com/gravitational/teleport/lib/client"
dbrepl "github.com/gravitational/teleport/lib/client/db/repl"
"github.com/gravitational/teleport/lib/client/sso"
"github.com/gravitational/teleport/lib/defaults"
dtconfig "github.com/gravitational/teleport/lib/devicetrust/config"
Expand Down Expand Up @@ -332,6 +333,9 @@ type Config struct {
// FeatureWatchInterval is the interval between pings to the auth server
// to fetch new cluster features
FeatureWatchInterval time.Duration

// DatabaseREPLRegistry is used for retrieving database REPL.
DatabaseREPLRegistry dbrepl.REPLRegistry
}

// SetDefaults ensures proper default values are set if
Expand Down Expand Up @@ -837,6 +841,7 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions

h.GET("/webapi/sites/:site/kube/exec/ws", h.WithClusterAuthWebSocket(h.podConnect)) // connect to a pod with exec (via websocket, with auth over websocket)
h.GET("/webapi/sites/:site/db/exec/ws", h.WithClusterAuthWebSocket(h.dbConnect))

// Audit events handlers.
h.GET("/webapi/sites/:site/events/search", h.WithClusterAuth(h.clusterSearchEvents)) // search site events
Expand Down Expand Up @@ -3055,9 +3060,6 @@ func (h *Handler) clusterUnifiedResourcesGet(w http.ResponseWriter, request *htt

getUserGroupLookup := h.getUserGroupLookup(request.Context(), clt)

var dbNames, dbUsers []string
hasFetchedDBUsersAndNames := false

unifiedResources := make([]any, 0, len(page))
for _, enriched := range page {
switch r := enriched.ResourceWithLabels.(type) {
Expand All @@ -3069,14 +3071,7 @@ func (h *Handler) clusterUnifiedResourcesGet(w http.ResponseWriter, request *htt

unifiedResources = append(unifiedResources, ui.MakeServer(site.GetName(), r, logins, enriched.RequiresRequest))
case types.DatabaseServer:
if !hasFetchedDBUsersAndNames {
dbNames, dbUsers, err = getDatabaseUsersAndNames(accessChecker)
if err != nil {
return nil, trace.Wrap(err)
}
hasFetchedDBUsersAndNames = true
}
db := ui.MakeDatabase(r.GetDatabase(), dbUsers, dbNames, enriched.RequiresRequest)
db := ui.MakeDatabase(r.GetDatabase(), accessChecker, h.cfg.DatabaseREPLRegistry, enriched.RequiresRequest)
unifiedResources = append(unifiedResources, db)
case types.AppServer:
allowedAWSRoles, err := calculateAppLogins(accessChecker, r, enriched.Logins)
Expand Down Expand Up @@ -3570,6 +3565,7 @@ func (h *Handler) siteNodeConnect(
}

term, err := NewTerminal(ctx, TerminalHandlerConfig{
Logger: h.logger,
Term: req.Term,
SessionCtx: sessionCtx,
UserAuthClient: clt,
Expand Down Expand Up @@ -3722,6 +3718,7 @@ func (h *Handler) podConnect(
ws: ws,
keepAliveInterval: keepAliveInterval,
log: h.log.WithField(teleport.ComponentKey, "pod"),
logger: h.logger.With(teleport.ComponentKey, "pod"),
userClient: clt,
localCA: hostCA,
configServerAddr: serverAddr,
Expand Down
10 changes: 9 additions & 1 deletion lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ import (
"github.com/gravitational/teleport/lib/bpf"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/conntest"
dbrepl "github.com/gravitational/teleport/lib/client/db/repl"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -210,6 +211,9 @@ type webSuiteConfig struct {

// clock to use for all server components
clock clockwork.FakeClock

// databaseREPLGetter allows setting custom database REPLs.
databaseREPLGetter dbrepl.REPLRegistry
}

func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite {
Expand Down Expand Up @@ -509,6 +513,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite {
return &proxyClientCert, nil
},
IntegrationAppHandler: &mockIntegrationAppHandler{},
DatabaseREPLRegistry: cfg.databaseREPLGetter,
}

if handlerConfig.HealthCheckAppServer == nil {
Expand Down Expand Up @@ -7437,6 +7442,7 @@ func TestOverwriteDatabase(t *testing.T) {
env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "user", nil /* roles */)
accessChecker := services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, env.server.ClusterName(), nil)

initDb, err := types.NewDatabaseV3(types.Metadata{
Name: "postgres",
Expand Down Expand Up @@ -7477,7 +7483,8 @@ func TestOverwriteDatabase(t *testing.T) {

backendDb, err := env.server.Auth().GetDatabase(context.Background(), req.Name)
require.NoError(t, err)
require.Equal(t, webui.MakeDatabase(backendDb, nil, nil, false), gotDb)

require.Equal(t, webui.MakeDatabase(backendDb, accessChecker, proxy.handler.handler.cfg.DatabaseREPLRegistry, false), gotDb)
},
},
{
Expand Down Expand Up @@ -8390,6 +8397,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula
return &proxyClientCert, nil
},
IntegrationAppHandler: &mockIntegrationAppHandler{},
DatabaseREPLRegistry: &mockDatabaseREPLRegistry{repl: map[string]dbrepl.REPLNewFunc{}},
}, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(clock))
require.NoError(t, err)

Expand Down
Loading

0 comments on commit 4c0a9d7

Please sign in to comment.