From f97473349ffc5f52640609c6cddbb61bfa96921f Mon Sep 17 00:00:00 2001 From: Jacob LeGrone Date: Thu, 29 Sep 2022 13:52:02 -0400 Subject: [PATCH] Test creating default db directory (#144) --- cmd/temporalite/main.go | 32 ++++--- cmd/temporalite/main_test.go | 156 ++++++++++++++++++++++++++++++++ cmd/temporalite/mtls_test.go | 34 ++++--- internal/liteconfig/config.go | 18 ++-- internal/liteconfig/freeport.go | 16 ++-- temporaltest/server.go | 2 +- temporaltest/server_test.go | 12 +-- 7 files changed, 225 insertions(+), 45 deletions(-) diff --git a/cmd/temporalite/main.go b/cmd/temporalite/main.go index 5eac1fa8..40b517bc 100644 --- a/cmd/temporalite/main.go +++ b/cmd/temporalite/main.go @@ -33,10 +33,6 @@ import ( // as a dependency when building with the `headless` tag enabled. const uiServerModule = "github.com/temporalio/ui-server/v2" -var ( - defaultCfg *liteconfig.Config -) - const ( ephemeralFlag = "ephemeral" dbPathFlag = "filename" @@ -54,10 +50,6 @@ const ( dynamicConfigValueFlag = "dynamic-config-value" ) -func init() { - defaultCfg, _ = liteconfig.NewDefaultConfig() -} - func main() { if err := buildCLI().Run(os.Args); err != nil { goLog.Fatal(err) @@ -68,6 +60,8 @@ func main() { var version string func buildCLI() *cli.App { + defaultCfg, _ := liteconfig.NewDefaultConfig() + if version == "" { version = "(devel)" } @@ -177,7 +171,7 @@ func buildCLI() *cli.App { } switch c.String(logFormatFlag) { - case "json", "pretty": + case "json", "pretty", "noop": default: return cli.Exit(fmt.Sprintf("bad value %q passed for flag %q", c.String(logFormatFlag), logFormatFlag), 1) } @@ -237,6 +231,17 @@ func buildCLI() *cli.App { } } + interruptChan := make(chan interface{}, 1) + go func() { + if doneChan := c.Done(); doneChan != nil { + s := <-doneChan + interruptChan <- s + } else { + s := <-temporal.InterruptCh() + interruptChan <- s + } + }() + opts := []temporalite.ServerOption{ temporalite.WithDynamicPorts(), temporalite.WithFrontendPort(serverPort), @@ -246,7 +251,7 @@ func buildCLI() *cli.App { temporalite.WithNamespaces(c.StringSlice(namespaceFlag)...), temporalite.WithSQLitePragmas(pragmas), temporalite.WithUpstreamOptions( - temporal.InterruptOn(temporal.InterruptCh()), + temporal.InterruptOn(interruptChan), ), temporalite.WithBaseConfig(baseConfig), } @@ -265,7 +270,8 @@ func buildCLI() *cli.App { } var logger log.Logger - if c.String(logFormatFlag) == "pretty" { + switch c.String(logFormatFlag) { + case "pretty": lcfg := zap.NewDevelopmentConfig() switch c.String(logLevelFlag) { case "debug": @@ -288,7 +294,9 @@ func buildCLI() *cli.App { return err } logger = log.NewZapLogger(l) - } else { + case "noop": + logger = log.NewNoopLogger() + default: logger = log.NewZapLogger(log.BuildZapLogger(log.Config{ Stdout: true, Level: c.String(logLevelFlag), diff --git a/cmd/temporalite/main_test.go b/cmd/temporalite/main_test.go index 3e3998d0..1edb312e 100644 --- a/cmd/temporalite/main_test.go +++ b/cmd/temporalite/main_test.go @@ -25,8 +25,22 @@ package main import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" "reflect" + "strconv" + "strings" "testing" + "time" + + "github.com/urfave/cli/v2" + "go.temporal.io/api/enums/v1" + "go.temporal.io/sdk/client" + + "github.com/temporalio/temporalite/internal/liteconfig" ) func TestGetDynamicConfigValues(t *testing.T) { @@ -63,3 +77,145 @@ func TestGetDynamicConfigValues(t *testing.T) { "foo=123", `bar="baz"`, "qux=true", `foo=["123", false]`, ) } + +func newServerAndClientOpts(port int, customArgs ...string) ([]string, client.Options) { + args := []string{ + "temporalite", + "start", + "--namespace", "default", + // Use noop logger to avoid fatal logs failing tests on shutdown signal. + "--log-format", "noop", + "--headless", + "--port", strconv.Itoa(port), + } + + return append(args, customArgs...), client.Options{ + HostPort: fmt.Sprintf("localhost:%d", port), + Namespace: "temporal-system", + } +} + +func assertServerHealth(t *testing.T, ctx context.Context, opts client.Options) { + var ( + c client.Client + clientErr error + ) + for i := 0; i < 50; i++ { + if c, clientErr = client.Dial(opts); clientErr == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + if clientErr != nil { + t.Error(clientErr) + } + + if _, err := c.CheckHealth(ctx, nil); err != nil { + t.Error(err) + } + + // Check for pollers on a system task queue to ensure that the worker service is running. + for { + if ctx.Err() != nil { + t.Error(ctx.Err()) + break + } + resp, err := c.DescribeTaskQueue(ctx, "temporal-sys-tq-scanner-taskqueue-0", enums.TASK_QUEUE_TYPE_WORKFLOW) + if err != nil { + t.Error(err) + } + if len(resp.GetPollers()) > 0 { + break + } + time.Sleep(time.Millisecond * 100) + } +} + +func TestCreateDataDirectory(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + testUserHome := filepath.Join(os.TempDir(), "temporalite_test", t.Name()) + t.Cleanup(func() { + if err := os.RemoveAll(testUserHome); err != nil { + fmt.Println("error cleaning up temp dir:", err) + } + }) + // Set user home for all supported operating systems + t.Setenv("AppData", testUserHome) // Windows + t.Setenv("HOME", testUserHome) // macOS + t.Setenv("XDG_CONFIG_HOME", testUserHome) // linux + // Verify that worked + configDir, _ := os.UserConfigDir() + if !strings.HasPrefix(configDir, testUserHome) { + t.Fatalf("expected config dir %q to be inside user home directory %q", configDir, testUserHome) + } + + temporaliteCLI := buildCLI() + // Don't call os.Exit + temporaliteCLI.ExitErrHandler = func(_ *cli.Context, _ error) {} + + portProvider := liteconfig.NewPortProvider() + var ( + port1 = portProvider.MustGetFreePort() + port2 = portProvider.MustGetFreePort() + port3 = portProvider.MustGetFreePort() + ) + portProvider.Close() + + t.Run("default db path", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + args, clientOpts := newServerAndClientOpts(port1) + + go func() { + if err := temporaliteCLI.RunContext(ctx, args); err != nil { + fmt.Println("Server closed with error:", err) + } + }() + + assertServerHealth(t, ctx, clientOpts) + + // If the rest of this test case passes but this assertion fails, + // there may have been a breaking change in the liteconfig package + // related to how the default db file path is calculated. + if _, err := os.Stat(filepath.Join(configDir, "temporalite", "db", "default.db")); err != nil { + t.Errorf("error checking for default db file: %s", err) + } + }) + + t.Run("custom db path -- missing directory", func(t *testing.T) { + customDBPath := filepath.Join(testUserHome, "foo", "bar", "baz.db") + args, _ := newServerAndClientOpts( + port2, "-f", customDBPath, + ) + if err := temporaliteCLI.RunContext(ctx, args); err != nil { + if !errors.Is(err, os.ErrNotExist) { + t.Errorf("expected error %q, got %q", os.ErrNotExist, err) + } + if !strings.Contains(err.Error(), filepath.Dir(customDBPath)) { + t.Errorf("expected error %q to contain string %q", err, filepath.Dir(customDBPath)) + } + } else { + t.Error("no error when directory missing") + } + }) + + t.Run("custom db path -- existing directory", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + args, clientOpts := newServerAndClientOpts( + port3, "-f", filepath.Join(testUserHome, "foo.db"), + ) + + go func() { + if err := temporaliteCLI.RunContext(ctx, args); err != nil { + fmt.Println("Server closed with error:", err) + } + }() + + assertServerHealth(t, ctx, clientOpts) + }) +} diff --git a/cmd/temporalite/mtls_test.go b/cmd/temporalite/mtls_test.go index 1e11f949..816e9ec5 100644 --- a/cmd/temporalite/mtls_test.go +++ b/cmd/temporalite/mtls_test.go @@ -29,6 +29,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "fmt" "io" "net/http" "os" @@ -40,9 +41,12 @@ import ( "text/template" "time" + "github.com/urfave/cli/v2" "go.temporal.io/api/enums/v1" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/client" + + "github.com/temporalio/temporalite/internal/liteconfig" ) func TestMTLSConfig(t *testing.T) { @@ -52,11 +56,7 @@ func TestMTLSConfig(t *testing.T) { mtlsDir := filepath.Join(thisFile, "../../../internal/examples/mtls") // Create temp config dir - confDir, err := os.MkdirTemp("", "temporalite-conf-") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(confDir) + confDir := t.TempDir() // Run templated config and put in temp dir var buf bytes.Buffer @@ -82,6 +82,13 @@ func TestMTLSConfig(t *testing.T) { t.Fatal(err) } + portProvider := liteconfig.NewPortProvider() + var ( + frontendPort = portProvider.MustGetFreePort() + webUIPort = portProvider.MustGetFreePort() + ) + portProvider.Close() + // Run ephemerally using temp config args := []string{ "temporalite", @@ -89,12 +96,17 @@ func TestMTLSConfig(t *testing.T) { "--ephemeral", "--config", confDir, "--namespace", "default", - "--log-format", "pretty", - "--port", "10233", + "--log-format", "noop", + "--port", strconv.Itoa(frontendPort), + "--ui-port", strconv.Itoa(webUIPort), } go func() { - if err := buildCLI().RunContext(ctx, args); err != nil { - t.Logf("CLI failed: %v", err) + temporaliteCLI := buildCLI() + // Don't call os.Exit + temporaliteCLI.ExitErrHandler = func(_ *cli.Context, _ error) {} + + if err := temporaliteCLI.RunContext(ctx, args); err != nil { + fmt.Printf("CLI failed: %s\n", err) } }() @@ -116,7 +128,7 @@ func TestMTLSConfig(t *testing.T) { // Build client options and try to connect client every 100ms for 5s options := client.Options{ - HostPort: "localhost:10233", + HostPort: fmt.Sprintf("localhost:%d", frontendPort), ConnectionOptions: client.ConnectionOptions{ TLS: &tls.Config{ Certificates: []tls.Certificate{clientCert}, @@ -151,7 +163,7 @@ func TestMTLSConfig(t *testing.T) { } // Pretend to be a browser to invoke the UI API - res, err := http.Get("http://localhost:11233/api/v1/namespaces?") + res, err := http.Get(fmt.Sprintf("http://localhost:%d/api/v1/namespaces?", webUIPort)) if err != nil { t.Fatal(err) } diff --git a/internal/liteconfig/config.go b/internal/liteconfig/config.go index 20534ba0..9b466a17 100644 --- a/internal/liteconfig/config.go +++ b/internal/liteconfig/config.go @@ -56,7 +56,7 @@ type Config struct { SQLitePragmas map[string]string Logger log.Logger UpstreamOptions []temporal.ServerOption - portProvider *portProvider + portProvider *PortProvider FrontendIP string UIServer UIServer BaseConfig *config.Config @@ -85,7 +85,7 @@ func NewDefaultConfig() (*Config, error) { return &Config{ Ephemeral: false, - DatabaseFilePath: filepath.Join(userConfigDir, "temporalite/db/default.db"), + DatabaseFilePath: filepath.Join(userConfigDir, "temporalite", "db", "default.db"), FrontendPort: 0, MetricsPort: 0, UIServer: noopUIServer{}, @@ -97,7 +97,7 @@ func NewDefaultConfig() (*Config, error) { Level: "info", OutputFile: "", })), - portProvider: &portProvider{}, + portProvider: NewPortProvider(), FrontendIP: "", BaseConfig: &config.Config{}, }, nil @@ -105,7 +105,7 @@ func NewDefaultConfig() (*Config, error) { func Convert(cfg *Config) *config.Config { defer func() { - if err := cfg.portProvider.close(); err != nil { + if err := cfg.portProvider.Close(); err != nil { panic(err) } }() @@ -130,12 +130,12 @@ func Convert(cfg *Config) *config.Config { var pprofPort int if cfg.DynamicPorts { if cfg.FrontendPort == 0 { - cfg.FrontendPort = cfg.portProvider.mustGetFreePort() + cfg.FrontendPort = cfg.portProvider.MustGetFreePort() } if cfg.MetricsPort == 0 { - cfg.MetricsPort = cfg.portProvider.mustGetFreePort() + cfg.MetricsPort = cfg.portProvider.MustGetFreePort() } - pprofPort = cfg.portProvider.mustGetFreePort() + pprofPort = cfg.portProvider.MustGetFreePort() } else { if cfg.FrontendPort == 0 { cfg.FrontendPort = DefaultFrontendPort @@ -229,9 +229,9 @@ func (cfg *Config) mustGetService(frontendPortOffset int) config.Service { // Assign any open port when configured to use dynamic ports if cfg.DynamicPorts { if frontendPortOffset != 0 { - svc.RPC.GRPCPort = cfg.portProvider.mustGetFreePort() + svc.RPC.GRPCPort = cfg.portProvider.MustGetFreePort() } - svc.RPC.MembershipPort = cfg.portProvider.mustGetFreePort() + svc.RPC.MembershipPort = cfg.portProvider.MustGetFreePort() } // Optionally bind frontend to IPv4 address diff --git a/internal/liteconfig/freeport.go b/internal/liteconfig/freeport.go index 88db990f..f3f077de 100644 --- a/internal/liteconfig/freeport.go +++ b/internal/liteconfig/freeport.go @@ -11,12 +11,16 @@ import ( // Modified from https://github.com/phayes/freeport/blob/95f893ade6f232a5f1511d61735d89b1ae2df543/freeport.go -type portProvider struct { +func NewPortProvider() *PortProvider { + return &PortProvider{} +} + +type PortProvider struct { listeners []*net.TCPListener } -// getFreePort asks the kernel for a free open port that is ready to use. -func (p *portProvider) getFreePort() (int, error) { +// GetFreePort asks the kernel for a free open port that is ready to use. +func (p *PortProvider) GetFreePort() (int, error) { addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") if err != nil { if addr, err = net.ResolveTCPAddr("tcp6", "[::1]:0"); err != nil { @@ -34,15 +38,15 @@ func (p *portProvider) getFreePort() (int, error) { return l.Addr().(*net.TCPAddr).Port, nil } -func (p *portProvider) mustGetFreePort() int { - port, err := p.getFreePort() +func (p *PortProvider) MustGetFreePort() int { + port, err := p.GetFreePort() if err != nil { panic(err) } return port } -func (p *portProvider) close() error { +func (p *PortProvider) Close() error { for _, l := range p.listeners { if err := l.Close(); err != nil { return err diff --git a/temporaltest/server.go b/temporaltest/server.go index 9cc2836f..c84186af 100644 --- a/temporaltest/server.go +++ b/temporaltest/server.go @@ -92,7 +92,7 @@ func (ts *TestServer) NewClientWithOptions(opts client.Options) client.Client { opts.Logger = &testLogger{ts.t} } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() c, err := ts.server.NewClientWithOptions(ctx, opts) diff --git a/temporaltest/server_test.go b/temporaltest/server_test.go index c0daf3c9..1ea07f4c 100644 --- a/temporaltest/server_test.go +++ b/temporaltest/server_test.go @@ -61,7 +61,7 @@ func TestNewServer(t *testing.T) { helloworld.RegisterWorkflowsAndActivities(registry) }) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() wfr, err := ts.DefaultClient().ExecuteWorkflow( @@ -98,7 +98,7 @@ func TestNewWorkerWithOptions(t *testing.T) { }, ) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() wfr, err := ts.DefaultClient().ExecuteWorkflow( @@ -136,7 +136,7 @@ func TestDefaultWorkerOptions(t *testing.T) { ts.NewWorker("hello_world", func(registry worker.Registry) { helloworld.RegisterWorkflowsAndActivities(registry) }) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() wfr, err := ts.DefaultClient().ExecuteWorkflow( @@ -174,7 +174,7 @@ func TestClientWithDefaultInterceptor(t *testing.T) { }, ) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() wfr, err := ts.DefaultClient().ExecuteWorkflow( @@ -198,7 +198,7 @@ func TestClientWithDefaultInterceptor(t *testing.T) { } func TestSearchAttributeCacheDisabled(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() ts := temporaltest.NewServer(temporaltest.WithT(t)) @@ -233,7 +233,7 @@ func BenchmarkRunWorkflow(b *testing.B) { for i := 0; i < b.N; i++ { func(b *testing.B) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() wfr, err := c.ExecuteWorkflow(