-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PostgreSQL REPL implementation (#49598)
* feat(repl): add postgres * refactor(repl): change repl to use a single Run function * test(repl): reduce usage of require.Eventually blocks * refactor(repl): code review suggestions * refactor(repl): code review suggestions * test(repl): increase timeout values * fix(repl): commands formatting * refactor(repl): send close pgconn using a different context * fix(repl): add proper spacing between multi queries * test(repl): add fuzz test for processing commands
- Loading branch information
1 parent
2fe3803
commit 037b9d0
Showing
10 changed files
with
1,246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Teleport | ||
// Copyright (C) 2024 Gravitational, Inc. | ||
// | ||
// This program is free software: you can redistribute it and/or modify | ||
// it under the terms of the GNU Affero General Public License as published by | ||
// the Free Software Foundation, either version 3 of the License, or | ||
// (at your option) any later version. | ||
// | ||
// This program is distributed in the hope that it will be useful, | ||
// but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
// GNU Affero General Public License for more details. | ||
// | ||
// You should have received a copy of the GNU Affero General Public License | ||
// along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
package repl | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/gravitational/teleport" | ||
"github.com/gravitational/teleport/lib/asciitable" | ||
) | ||
|
||
// processCommand receives a command call and return the reply and if the | ||
// command terminates the session. | ||
func (r *REPL) processCommand(line string) (string, bool) { | ||
cmdStr, args, _ := strings.Cut(strings.TrimPrefix(line, commandPrefix), " ") | ||
cmd, ok := r.commands[cmdStr] | ||
if !ok { | ||
return "Unknown command. Try \\? to show the list of supported commands." + lineBreak, false | ||
} | ||
|
||
return cmd.ExecFunc(r, args) | ||
} | ||
|
||
// commandType specify the command category. This is used to organize the | ||
// commands, for example, when showing them in the help command. | ||
type commandType string | ||
|
||
const ( | ||
// commandTypeGeneral represents a general-purpose command type. | ||
commandTypeGeneral commandType = "General" | ||
// commandTypeConnection represents a command type related to connection | ||
// operations. | ||
commandTypeConnection = "Connection" | ||
) | ||
|
||
// command represents a command that can be executed in the REPL. | ||
type command struct { | ||
// Type specifies the type of the command. | ||
Type commandType | ||
// Description provides a user-friendly explanation of what the command | ||
// does. | ||
Description string | ||
// ExecFunc is the function to execute the command. The commands can either | ||
// return a reply (that will be sent back to the client) as a string. It can | ||
// terminate the REPL by returning bool on the second argument. | ||
ExecFunc func(r *REPL, args string) (reply string, exit bool) | ||
} | ||
|
||
func initCommands() map[string]*command { | ||
return map[string]*command{ | ||
"q": { | ||
Type: commandTypeGeneral, | ||
Description: "Terminates the session.", | ||
ExecFunc: func(_ *REPL, _ string) (string, bool) { return "", true }, | ||
}, | ||
"teleport": { | ||
Type: commandTypeGeneral, | ||
Description: "Show Teleport interactive shell information, such as execution limitations.", | ||
ExecFunc: func(_ *REPL, _ string) (string, bool) { | ||
// Formats limitiations in a dash list. Example: | ||
// - hello | ||
// multi line | ||
// - another item | ||
var limitations strings.Builder | ||
for _, l := range descriptiveLimitations { | ||
limitations.WriteString("- " + strings.Join(strings.Split(l, "\n"), "\n ") + lineBreak) | ||
} | ||
|
||
return fmt.Sprintf( | ||
"Teleport PostgreSQL interactive shell (v%s)\n\nLimitations: \n%s", | ||
teleport.Version, | ||
limitations.String(), | ||
), false | ||
}, | ||
}, | ||
"?": { | ||
Type: commandTypeGeneral, | ||
Description: "Show the list of supported commands.", | ||
ExecFunc: func(r *REPL, _ string) (string, bool) { | ||
typesTable := make(map[commandType]*asciitable.Table) | ||
for cmdStr, cmd := range r.commands { | ||
if _, ok := typesTable[cmd.Type]; !ok { | ||
table := asciitable.MakeHeadlessTable(2) | ||
typesTable[cmd.Type] = &table | ||
} | ||
|
||
typesTable[cmd.Type].AddRow([]string{"\\" + cmdStr, cmd.Description}) | ||
} | ||
|
||
var res strings.Builder | ||
for cmdType, output := range typesTable { | ||
res.WriteString(string(cmdType) + lineBreak) | ||
output.AsBuffer().WriteTo(&res) | ||
res.WriteString(lineBreak) | ||
} | ||
|
||
return res.String(), false | ||
}, | ||
}, | ||
"session": { | ||
Type: commandTypeConnection, | ||
Description: "Display information about the current session, like user, and database instance.", | ||
ExecFunc: func(r *REPL, _ string) (string, bool) { | ||
return fmt.Sprintf("Connected to %q instance at %q database as %q user.", r.route.ServiceName, r.route.Database, r.route.Username), false | ||
}, | ||
}, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
// Teleport | ||
// Copyright (C) 2024 Gravitational, Inc. | ||
// | ||
// This program is free software: you can redistribute it and/or modify | ||
// it under the terms of the GNU Affero General Public License as published by | ||
// the Free Software Foundation, either version 3 of the License, or | ||
// (at your option) any later version. | ||
// | ||
// This program is distributed in the hope that it will be useful, | ||
// but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
// GNU Affero General Public License for more details. | ||
// | ||
// You should have received a copy of the GNU Affero General Public License | ||
// along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
package repl | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/gravitational/teleport" | ||
clientproto "github.com/gravitational/teleport/api/client/proto" | ||
) | ||
|
||
func TestCommandExecution(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
for name, tt := range map[string]struct { | ||
line string | ||
commandResult string | ||
expectedArgs string | ||
expectUnknown bool | ||
commandExit bool | ||
}{ | ||
"execute": {line: "\\test", commandResult: "test"}, | ||
"execute with additional arguments": {line: "\\test a b", commandResult: "test", expectedArgs: "a b"}, | ||
"execute with exit": {line: "\\test", commandExit: true}, | ||
"execute with leading and trailing whitespace": {line: " \\test ", commandResult: "test"}, | ||
"unknown command with semicolon": {line: "\\test;", expectUnknown: true}, | ||
"unknown command": {line: "\\wrong", expectUnknown: true}, | ||
"with special characters": {line: "\\special_chars_!@#$%^&*()}", expectUnknown: true}, | ||
"empty command": {line: "\\", expectUnknown: true}, | ||
} { | ||
t.Run(name, func(t *testing.T) { | ||
commandArgsChan := make(chan string, 1) | ||
instance, tc := StartWithServer(t, ctx, WithSkipREPLRun()) | ||
ctx, cancel := context.WithCancel(ctx) | ||
defer cancel() | ||
|
||
runErrChan := make(chan error) | ||
go func() { | ||
runErrChan <- instance.Run(ctx) | ||
}() | ||
|
||
// Consume the REPL banner. | ||
_ = readUntilNextLead(t, tc) | ||
|
||
// Reset available commands and add a test command so we can assert | ||
// the command execution flow without relying in commands | ||
// implementation or test server capabilities. | ||
instance.commands = map[string]*command{ | ||
"test": { | ||
ExecFunc: func(r *REPL, args string) (string, bool) { | ||
commandArgsChan <- args | ||
return tt.commandResult, tt.commandExit | ||
}, | ||
}, | ||
} | ||
|
||
writeLine(t, tc, tt.line) | ||
if tt.expectUnknown { | ||
reply := readUntilNextLead(t, tc) | ||
require.True(t, strings.HasPrefix(strings.ToLower(reply), "unknown command")) | ||
return | ||
} | ||
|
||
select { | ||
case args := <-commandArgsChan: | ||
require.Equal(t, tt.expectedArgs, args) | ||
case <-time.After(time.Second): | ||
require.Fail(t, "expected to command args from test server but got nothing") | ||
} | ||
|
||
// When the command exits, the REPL and the connections will be | ||
// closed. | ||
if tt.commandExit { | ||
require.EventuallyWithT(t, func(t *assert.CollectT) { | ||
var buf []byte | ||
_, err := tc.conn.Read(buf[0:]) | ||
assert.ErrorIs(t, err, io.EOF) | ||
}, 5*time.Second, time.Millisecond) | ||
|
||
select { | ||
case err := <-runErrChan: | ||
require.NoError(t, err, "expected the REPL instance exit gracefully") | ||
case <-time.After(5 * time.Second): | ||
require.Fail(t, "expected REPL run to terminate but got nothing") | ||
} | ||
return | ||
} | ||
|
||
reply := readUntilNextLead(t, tc) | ||
require.Equal(t, tt.commandResult, reply) | ||
|
||
// Terminate the REPL run session and wait for the Run results. | ||
cancel() | ||
select { | ||
case err := <-runErrChan: | ||
require.ErrorIs(t, err, context.Canceled, "expected the REPL instance to finish running with error due to cancelation") | ||
case <-time.After(5 * time.Second): | ||
require.Fail(t, "expected REPL run to terminate but got nothing") | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestCommands(t *testing.T) { | ||
availableCmds := initCommands() | ||
for cmdName, tc := range map[string]struct { | ||
repl *REPL | ||
args string | ||
expectExit bool | ||
assertCommandReply require.ValueAssertionFunc | ||
}{ | ||
"q": {expectExit: true}, | ||
"teleport": { | ||
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { | ||
require.Contains(t, val, teleport.Version, "expected \\teleport command to include current Teleport version") | ||
}, | ||
}, | ||
"?": { | ||
repl: &REPL{commands: availableCmds}, | ||
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { | ||
for cmd := range availableCmds { | ||
require.Contains(t, val, cmd, "expected \\? command to include information about \\%s", cmd) | ||
} | ||
}, | ||
}, | ||
"session": { | ||
repl: &REPL{route: clientproto.RouteToDatabase{ | ||
ServiceName: "service", | ||
Username: "username", | ||
Database: "database", | ||
}}, | ||
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { | ||
require.Contains(t, val, "service", "expected \\session command to contain service name") | ||
require.Contains(t, val, "username", "expected \\session command to contain username") | ||
require.Contains(t, val, "database", "expected \\session command to contain database name") | ||
}, | ||
}, | ||
} { | ||
t.Run(cmdName, func(t *testing.T) { | ||
cmd, ok := availableCmds[cmdName] | ||
require.True(t, ok, "expected command %q to be available at commands", cmdName) | ||
reply, exit := cmd.ExecFunc(tc.repl, tc.args) | ||
if tc.expectExit { | ||
require.True(t, exit, "expected command to exit the REPL") | ||
return | ||
} | ||
tc.assertCommandReply(t, reply) | ||
}) | ||
} | ||
} | ||
|
||
func FuzzCommands(f *testing.F) { | ||
f.Add("q") | ||
f.Add("?") | ||
f.Add("session") | ||
f.Add("teleport") | ||
|
||
repl := &REPL{commands: make(map[string]*command)} | ||
f.Fuzz(func(t *testing.T, line string) { | ||
require.NotPanics(t, func() { | ||
_, _ = repl.processCommand(line) | ||
}) | ||
}) | ||
} |
Oops, something went wrong.