diff --git a/cmd/init_config.go b/cmd/init_config.go index 17e180adc..4a433c5f0 100644 --- a/cmd/init_config.go +++ b/cmd/init_config.go @@ -16,7 +16,7 @@ var initConfigCmd = &cobra.Command{ Short: "creates a config file at dicedb.yaml with default values", Run: func(cmd *cobra.Command, args []string) { config.Init(cmd.Flags()) - viper.WriteConfigAs("dicedb.yaml") + _ = viper.WriteConfigAs("dicedb.yaml") fmt.Println("config created at dicedb.yaml") }, } diff --git a/config/config.go b/config/config.go index b808808cf..5f517f2d0 100644 --- a/config/config.go +++ b/config/config.go @@ -29,82 +29,6 @@ const ( DefaultKeysLimit int = 200000000 DefaultEvictionRatio float64 = 0.1 - - defaultConfigTemplate = `# Configuration file for Dicedb - -# Version -version = "0.1.0" - -# Async Server Configuration -async_server.addr = "0.0.0.0" -async_server.port = 7379 -async_server.keepalive = 300 -async_server.timeout = 300 -async_server.max_conn = 0 - -# HTTP Configuration -http.enabled = false -http.port = 8082 - -# WebSocket Configuration -websocket.enabled = false -websocket.port = 8379 -websocket.max_write_response_retries = 3 -websocket.write_response_timeout = 10s - -# Performance Configuration -performance.watch_chan_buf_size = 20000 -performance.shard_cron_frequency = 1s -performance.multiplexer_poll_timeout = 100ms -performance.max_clients = 20000 -performance.store_map_init_size = 1024000 -performance.adhoc_req_chan_buf_size = 20 -performance.enable_profiling = false -performance.enable_watch = false -performance.num_shards = -1 - -# Memory Configuration -memory.max_memory = 0 -memory.eviction_policy = "allkeys-lfu" -memory.eviction_ratio = 0.9 -memory.keys_limit = 200000000 -memory.lfu_log_factor = 10 - -# Persistence Configuration -persistence.enabled = false -persistence.aof_file = "./dice-master.aof" -persistence.persistence_enabled = true -persistence.write_aof_on_cleanup = false -persistence.wal-dir = "./" -persistence.restore-wal = false -persistence.wal-engine = "aof" - -# Logging Configuration -logging.log_level = "info" -logging.log_dir = "/tmp/dicedb" - -# Authentication Configuration -auth.username = "dice" -auth.password = "" - -# Network Configuration -network.io_buffer_length = 512 -network.io_buffer_length_max = 51200 - -# WAL Configuration -LogDir = "tmp/dicedb-wal" -Enabled = "true" -WalMode = "buffered" -WriteMode = "default" -BufferSizeMB = 1 -RotationMode = "segemnt-size" -MaxSegmentSizeMB = 16 -MaxSegmentRotationTime = 60s -BufferSyncInterval = 200ms -RetentionMode = "num-segments" -MaxSegmentCount = 10 -MaxSegmentRetentionDuration = 600s -RecoveryMode = "strict"` ) var ( @@ -115,7 +39,6 @@ var ( type Config struct { Version string `config:"version" default:"0.1.0"` InstanceID string `config:"instance_id"` - Auth auth `config:"auth"` RespServer respServer `config:"async_server"` HTTP http `config:"http"` WebSocket websocket `config:"websocket"` @@ -127,11 +50,6 @@ type Config struct { WAL WALConfig `config:"WAL"` } -type auth struct { - UserName string `config:"username" default:"dice"` - Password string `config:"password"` -} - type respServer struct { Addr string `config:"addr" default:"0.0.0.0" validate:"ipv4"` Port int `config:"port" default:"7379" validate:"number,gte=0,lte=65535"` @@ -261,10 +179,6 @@ func writeConfigFile(configFilePath string) error { } defer file.Close() - if _, err := file.WriteString(defaultConfigTemplate); err != nil { - return err - } - return nil } @@ -311,8 +225,6 @@ func MergeFlags(flags *Config) { DiceConfig.Persistence.RestoreFromWAL = flags.Persistence.RestoreFromWAL case "wal-engine": DiceConfig.Persistence.WALEngine = flags.Persistence.WALEngine - case "require-pass": - DiceConfig.Auth.Password = flags.Auth.Password case "keys-limit": DiceConfig.Memory.KeysLimit = flags.Memory.KeysLimit case "eviction-ratio": @@ -322,9 +234,12 @@ func MergeFlags(flags *Config) { } type DiceDBConfig struct { - Host string `mapstructure:"host" description:"the host address to bind to" default:"0.0.0.0"` - Port int `mapstructure:"port" description:"the port to bind to" default:"7379"` - EnableHTTP bool `mapstructure:"enable-http" description:"enable http server" default:"false"` + Host string `mapstructure:"host" default:"0.0.0.0" description:"the host address to bind to"` + Port int `mapstructure:"port" default:"7379" description:"the port to bind to"` + EnableHTTP bool `mapstructure:"enable-http" default:"false" description:"enable http server"` + + Username string `mapstructure:"username" default:"dicedb" description:"the username to use for authentication"` + Password string `mapstructure:"password" default:"" description:"the password to use for authentication"` } var GlobalDiceDBConfig *DiceDBConfig diff --git a/internal/auth/session.go b/internal/auth/session.go index a49051381..491a8f807 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -108,7 +108,7 @@ func NewSession() (session *Session) { } func (session *Session) IsActive() (isActive bool) { - if config.DiceConfig.Auth.Password == utils.EmptyStr && session.Status != SessionStatusActive { + if config.GlobalDiceDBConfig.Password == utils.EmptyStr && session.Status != SessionStatusActive { session.Activate(session.User) } isActive = session.Status == SessionStatusActive @@ -133,7 +133,7 @@ func (session *Session) Validate(username, password string) error { if user, err = UserStore.Get(username); err != nil { return err } - if username == config.DiceConfig.Auth.UserName && len(user.Passwords) == 0 { + if username == config.GlobalDiceDBConfig.Username && len(user.Passwords) == 0 { session.Activate(user) return nil } diff --git a/internal/cli/cli.go b/internal/cli/cli.go index a48fb527a..1e5e627c6 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -11,10 +11,8 @@ import ( "os" "path/filepath" "runtime" - "strings" "github.com/dicedb/dice/config" - "github.com/dicedb/dice/internal/server/utils" "github.com/fatih/color" ) @@ -72,6 +70,11 @@ func render() { func Execute() { flagsConfig := config.Config{} + var tempStr string + + flag.StringVar(&tempStr, "username", "dicedb", "deleted") + flag.StringVar(&tempStr, "password", "dicedb", "deleted") + flag.StringVar(&flagsConfig.RespServer.Addr, "host", "0.0.0.0", "host for the DiceDB server") flag.IntVar(&flagsConfig.RespServer.Port, "port", 7379, "port for the DiceDB server") @@ -94,7 +97,6 @@ func Execute() { flag.BoolVar(&flagsConfig.Persistence.RestoreFromWAL, "restore-wal", false, "restore the database from the WAL files") flag.StringVar(&flagsConfig.Persistence.WALEngine, "wal-engine", "null", "wal engine to use, values: sqlite, aof") - flag.StringVar(&flagsConfig.Auth.Password, "requirepass", utils.EmptyStr, "enable authentication for the default user") flag.StringVar(&config.CustomConfigFilePath, "o", config.CustomConfigFilePath, "dir path to create the flagsConfig file") flag.StringVar(&config.CustomConfigDirPath, "c", config.CustomConfigDirPath, "file path of the config file") @@ -140,94 +142,6 @@ func Execute() { } flag.Parse() - - if len(os.Args) > 2 { - switch os.Args[1] { - case "-v", "--version": - fmt.Println("dicedb version", config.DiceDBVersion) - os.Exit(0) - - case "-": - parser := config.NewConfigParser() - if err := parser.ParseFromStdin(); err != nil { - log.Fatal(err) - } - if err := parser.Loadconfig(config.DiceConfig); err != nil { - log.Fatal(err) - } - fmt.Println(config.DiceConfig.Version) - case "-o", "--output": - if len(os.Args) < 3 { - log.Fatal("Output file path not provided") - } else { - dirPath := os.Args[2] - if dirPath == "" { - log.Fatal("Output file path not provided") - } - - info, err := os.Stat(dirPath) - switch { - case os.IsNotExist(err): - log.Fatal("Output file path does not exist") - case err != nil: - log.Fatalf("Error checking output file path: %v", err) - case !info.IsDir(): - log.Fatal("Output file path is not a directory") - } - - filePath := filepath.Join(dirPath, config.DefaultConfigName) - if _, err := os.Stat(filePath); err == nil { - slog.Warn("Config file already exists at the specified path", slog.String("path", filePath), slog.String("action", "skipping file creation")) - return - } - if err := config.CreateConfigFile(filePath); err != nil { - log.Fatal(err) - } - - config.MergeFlags(&flagsConfig) - render() - } - case "-c", "--config": - if len(os.Args) >= 3 { - filePath := os.Args[2] - if filePath == "" { - log.Fatal("Error: Config file path not provided") - } - - info, err := os.Stat(filePath) - switch { - case os.IsNotExist(err): - log.Fatalf("Config file does not exist: %s", filePath) - case err != nil: - log.Fatalf("Unable to check config file: %v", err) - } - - if info.IsDir() { - log.Fatalf("Config file path points to a directory: %s", filePath) - } - - if !strings.HasSuffix(filePath, ".conf") { - log.Fatalf("Config file must have a .conf extension: %s", filePath) - } - - parser := config.NewConfigParser() - if err := parser.ParseFromFile(filePath); err != nil { - log.Fatal(err) - } - if err := parser.Loadconfig(config.DiceConfig); err != nil { - log.Fatal(err) - } - - config.MergeFlags(&flagsConfig) - render() - } else { - log.Fatal("Config file path not provided") - } - default: - defaultConfig(&flagsConfig) - } - } - defaultConfig(&flagsConfig) } diff --git a/internal/commandhandler/commandhandler.go b/internal/commandhandler/commandhandler.go index 569c9b8a1..3c7821848 100644 --- a/internal/commandhandler/commandhandler.go +++ b/internal/commandhandler/commandhandler.go @@ -504,8 +504,14 @@ func (h *BaseCommandHandler) sendResponseToIOThread(resp interface{}, err error) h.ioThreadWriteChan <- resp } -func (h *BaseCommandHandler) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error { - if diceDBCmd.Cmd != auth.Cmd && !h.Session.IsActive() { +func (h *BaseCommandHandler) isAuthenticated(c *cmd.DiceDBCmd) error { + // TODO: Revisit the flow and check the need of explicitly whitelisting PING and CLIENT commands here. + // We might not need this special case handling for other commands. + if c.Cmd == "PING" || c.Cmd == "CLIENT" { + return nil + } + + if c.Cmd != auth.Cmd && !h.Session.IsActive() { return errors.New("NOAUTH Authentication required") } @@ -530,11 +536,11 @@ func (h *BaseCommandHandler) RespAuth(args []string) interface{} { return diceerrors.ErrWrongArgumentCount("AUTH") } - if config.DiceConfig.Auth.Password == "" { + if config.GlobalDiceDBConfig.Password == "" { return diceerrors.ErrAuth } - username := config.DiceConfig.Auth.UserName + username := config.GlobalDiceDBConfig.Username var password string if len(args) == 1 { diff --git a/internal/eval/eval.go b/internal/eval/eval.go index 5a77cee0f..e2e9898bc 100644 --- a/internal/eval/eval.go +++ b/internal/eval/eval.go @@ -117,11 +117,11 @@ func evalECHO(args []string, store *dstore.Store) []byte { func EvalAUTH(args []string, c *comm.Client) []byte { var err error - if config.DiceConfig.Auth.Password == "" { + if config.GlobalDiceDBConfig.Password == "" { return diceerrors.NewErrWithMessage("AUTH called without any password configured for the default user. Are you sure your configuration is correct?") } - username := config.DiceConfig.Auth.UserName + username := config.GlobalDiceDBConfig.Username var password string if len(args) == 1 { diff --git a/server/main.go b/server/main.go index 971060f04..1bcb336a6 100644 --- a/server/main.go +++ b/server/main.go @@ -18,6 +18,7 @@ import ( "syscall" "time" + "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/server/httpws" "github.com/dicedb/dice/internal/cli" @@ -44,6 +45,13 @@ func Start() { iid := observability.GetOrCreateInstanceID() config.DiceConfig.InstanceID = iid + // TODO: Handle the addition of the default user + // and new users in a much better way. Doing this using + // and empty password check is not a good solution. + if config.GlobalDiceDBConfig.Password != "" { + _, _ = auth.UserStore.Add(config.GlobalDiceDBConfig.Username) + } + // This is counter intuitive, but it's the first thing that should be done // because this function parses the flags and prepares the config, cli.Execute() @@ -71,12 +79,14 @@ func Start() { if err != nil { slog.Warn("could not create WAL with", slog.String("wal-engine", config.DiceConfig.Persistence.WALEngine), slog.Any("error", err)) sigs <- syscall.SIGKILL + cancel() return } wl = _wl } else { slog.Error("unsupported WAL engine", slog.String("engine", config.DiceConfig.Persistence.WALEngine)) sigs <- syscall.SIGKILL + cancel() return }