Skip to content

Commit

Permalink
Getting rid of config template and moving auth to top-level config
Browse files Browse the repository at this point in the history
  • Loading branch information
arpitbbhayani committed Jan 21, 2025
1 parent bc8e90c commit 7fed270
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 191 deletions.
2 changes: 1 addition & 1 deletion cmd/init_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var initConfigCmd = &cobra.Command{
Short: "creates a config file at dicedb.yaml with default values",
Run: func(cmd *cobra.Command, args []string) {
config.Init(cmd.Flags())
viper.WriteConfigAs("dicedb.yaml")
_ = viper.WriteConfigAs("dicedb.yaml")
fmt.Println("config created at dicedb.yaml")
},
}
Expand Down
97 changes: 6 additions & 91 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,82 +29,6 @@ const (

DefaultKeysLimit int = 200000000
DefaultEvictionRatio float64 = 0.1

defaultConfigTemplate = `# Configuration file for Dicedb
# Version
version = "0.1.0"
# Async Server Configuration
async_server.addr = "0.0.0.0"
async_server.port = 7379
async_server.keepalive = 300
async_server.timeout = 300
async_server.max_conn = 0
# HTTP Configuration
http.enabled = false
http.port = 8082
# WebSocket Configuration
websocket.enabled = false
websocket.port = 8379
websocket.max_write_response_retries = 3
websocket.write_response_timeout = 10s
# Performance Configuration
performance.watch_chan_buf_size = 20000
performance.shard_cron_frequency = 1s
performance.multiplexer_poll_timeout = 100ms
performance.max_clients = 20000
performance.store_map_init_size = 1024000
performance.adhoc_req_chan_buf_size = 20
performance.enable_profiling = false
performance.enable_watch = false
performance.num_shards = -1
# Memory Configuration
memory.max_memory = 0
memory.eviction_policy = "allkeys-lfu"
memory.eviction_ratio = 0.9
memory.keys_limit = 200000000
memory.lfu_log_factor = 10
# Persistence Configuration
persistence.enabled = false
persistence.aof_file = "./dice-master.aof"
persistence.persistence_enabled = true
persistence.write_aof_on_cleanup = false
persistence.wal-dir = "./"
persistence.restore-wal = false
persistence.wal-engine = "aof"
# Logging Configuration
logging.log_level = "info"
logging.log_dir = "/tmp/dicedb"
# Authentication Configuration
auth.username = "dice"
auth.password = ""
# Network Configuration
network.io_buffer_length = 512
network.io_buffer_length_max = 51200
# WAL Configuration
LogDir = "tmp/dicedb-wal"
Enabled = "true"
WalMode = "buffered"
WriteMode = "default"
BufferSizeMB = 1
RotationMode = "segemnt-size"
MaxSegmentSizeMB = 16
MaxSegmentRotationTime = 60s
BufferSyncInterval = 200ms
RetentionMode = "num-segments"
MaxSegmentCount = 10
MaxSegmentRetentionDuration = 600s
RecoveryMode = "strict"`
)

var (
Expand All @@ -115,7 +39,6 @@ var (
type Config struct {
Version string `config:"version" default:"0.1.0"`
InstanceID string `config:"instance_id"`
Auth auth `config:"auth"`
RespServer respServer `config:"async_server"`
HTTP http `config:"http"`
WebSocket websocket `config:"websocket"`
Expand All @@ -127,11 +50,6 @@ type Config struct {
WAL WALConfig `config:"WAL"`
}

type auth struct {
UserName string `config:"username" default:"dice"`
Password string `config:"password"`
}

type respServer struct {
Addr string `config:"addr" default:"0.0.0.0" validate:"ipv4"`
Port int `config:"port" default:"7379" validate:"number,gte=0,lte=65535"`
Expand Down Expand Up @@ -261,10 +179,6 @@ func writeConfigFile(configFilePath string) error {
}
defer file.Close()

if _, err := file.WriteString(defaultConfigTemplate); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -311,8 +225,6 @@ func MergeFlags(flags *Config) {
DiceConfig.Persistence.RestoreFromWAL = flags.Persistence.RestoreFromWAL
case "wal-engine":
DiceConfig.Persistence.WALEngine = flags.Persistence.WALEngine
case "require-pass":
DiceConfig.Auth.Password = flags.Auth.Password
case "keys-limit":
DiceConfig.Memory.KeysLimit = flags.Memory.KeysLimit
case "eviction-ratio":
Expand All @@ -322,9 +234,12 @@ func MergeFlags(flags *Config) {
}

type DiceDBConfig struct {
Host string `mapstructure:"host" description:"the host address to bind to" default:"0.0.0.0"`
Port int `mapstructure:"port" description:"the port to bind to" default:"7379"`
EnableHTTP bool `mapstructure:"enable-http" description:"enable http server" default:"false"`
Host string `mapstructure:"host" default:"0.0.0.0" description:"the host address to bind to"`
Port int `mapstructure:"port" default:"7379" description:"the port to bind to"`
EnableHTTP bool `mapstructure:"enable-http" default:"false" description:"enable http server"`

Username string `mapstructure:"username" default:"dicedb" description:"the username to use for authentication"`
Password string `mapstructure:"password" default:"" description:"the password to use for authentication"`
}

var GlobalDiceDBConfig *DiceDBConfig
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func NewSession() (session *Session) {
}

func (session *Session) IsActive() (isActive bool) {
if config.DiceConfig.Auth.Password == utils.EmptyStr && session.Status != SessionStatusActive {
if config.GlobalDiceDBConfig.Password == utils.EmptyStr && session.Status != SessionStatusActive {
session.Activate(session.User)
}
isActive = session.Status == SessionStatusActive
Expand All @@ -133,7 +133,7 @@ func (session *Session) Validate(username, password string) error {
if user, err = UserStore.Get(username); err != nil {
return err
}
if username == config.DiceConfig.Auth.UserName && len(user.Passwords) == 0 {
if username == config.GlobalDiceDBConfig.Username && len(user.Passwords) == 0 {
session.Activate(user)
return nil
}
Expand Down
96 changes: 5 additions & 91 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ import (
"os"
"path/filepath"
"runtime"
"strings"

"github.com/dicedb/dice/config"
"github.com/dicedb/dice/internal/server/utils"
"github.com/fatih/color"
)

Expand Down Expand Up @@ -72,6 +70,11 @@ func render() {

func Execute() {
flagsConfig := config.Config{}
var tempStr string

flag.StringVar(&tempStr, "username", "dicedb", "deleted")
flag.StringVar(&tempStr, "password", "dicedb", "deleted")

flag.StringVar(&flagsConfig.RespServer.Addr, "host", "0.0.0.0", "host for the DiceDB server")

flag.IntVar(&flagsConfig.RespServer.Port, "port", 7379, "port for the DiceDB server")
Expand All @@ -94,7 +97,6 @@ func Execute() {
flag.BoolVar(&flagsConfig.Persistence.RestoreFromWAL, "restore-wal", false, "restore the database from the WAL files")
flag.StringVar(&flagsConfig.Persistence.WALEngine, "wal-engine", "null", "wal engine to use, values: sqlite, aof")

flag.StringVar(&flagsConfig.Auth.Password, "requirepass", utils.EmptyStr, "enable authentication for the default user")
flag.StringVar(&config.CustomConfigFilePath, "o", config.CustomConfigFilePath, "dir path to create the flagsConfig file")
flag.StringVar(&config.CustomConfigDirPath, "c", config.CustomConfigDirPath, "file path of the config file")

Expand Down Expand Up @@ -140,94 +142,6 @@ func Execute() {
}

flag.Parse()

if len(os.Args) > 2 {
switch os.Args[1] {
case "-v", "--version":
fmt.Println("dicedb version", config.DiceDBVersion)
os.Exit(0)

case "-":
parser := config.NewConfigParser()
if err := parser.ParseFromStdin(); err != nil {
log.Fatal(err)
}
if err := parser.Loadconfig(config.DiceConfig); err != nil {
log.Fatal(err)
}
fmt.Println(config.DiceConfig.Version)
case "-o", "--output":
if len(os.Args) < 3 {
log.Fatal("Output file path not provided")
} else {
dirPath := os.Args[2]
if dirPath == "" {
log.Fatal("Output file path not provided")
}

info, err := os.Stat(dirPath)
switch {
case os.IsNotExist(err):
log.Fatal("Output file path does not exist")
case err != nil:
log.Fatalf("Error checking output file path: %v", err)
case !info.IsDir():
log.Fatal("Output file path is not a directory")
}

filePath := filepath.Join(dirPath, config.DefaultConfigName)
if _, err := os.Stat(filePath); err == nil {
slog.Warn("Config file already exists at the specified path", slog.String("path", filePath), slog.String("action", "skipping file creation"))
return
}
if err := config.CreateConfigFile(filePath); err != nil {
log.Fatal(err)
}

config.MergeFlags(&flagsConfig)
render()
}
case "-c", "--config":
if len(os.Args) >= 3 {
filePath := os.Args[2]
if filePath == "" {
log.Fatal("Error: Config file path not provided")
}

info, err := os.Stat(filePath)
switch {
case os.IsNotExist(err):
log.Fatalf("Config file does not exist: %s", filePath)
case err != nil:
log.Fatalf("Unable to check config file: %v", err)
}

if info.IsDir() {
log.Fatalf("Config file path points to a directory: %s", filePath)
}

if !strings.HasSuffix(filePath, ".conf") {
log.Fatalf("Config file must have a .conf extension: %s", filePath)
}

parser := config.NewConfigParser()
if err := parser.ParseFromFile(filePath); err != nil {
log.Fatal(err)
}
if err := parser.Loadconfig(config.DiceConfig); err != nil {
log.Fatal(err)
}

config.MergeFlags(&flagsConfig)
render()
} else {
log.Fatal("Config file path not provided")
}
default:
defaultConfig(&flagsConfig)
}
}

defaultConfig(&flagsConfig)
}

Expand Down
14 changes: 10 additions & 4 deletions internal/commandhandler/commandhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,14 @@ func (h *BaseCommandHandler) sendResponseToIOThread(resp interface{}, err error)
h.ioThreadWriteChan <- resp
}

func (h *BaseCommandHandler) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error {
if diceDBCmd.Cmd != auth.Cmd && !h.Session.IsActive() {
func (h *BaseCommandHandler) isAuthenticated(c *cmd.DiceDBCmd) error {
// TODO: Revisit the flow and check the need of explicitly whitelisting PING and CLIENT commands here.
// We might not need this special case handling for other commands.
if c.Cmd == "PING" || c.Cmd == "CLIENT" {
return nil
}

if c.Cmd != auth.Cmd && !h.Session.IsActive() {
return errors.New("NOAUTH Authentication required")
}

Expand All @@ -530,11 +536,11 @@ func (h *BaseCommandHandler) RespAuth(args []string) interface{} {
return diceerrors.ErrWrongArgumentCount("AUTH")
}

if config.DiceConfig.Auth.Password == "" {
if config.GlobalDiceDBConfig.Password == "" {
return diceerrors.ErrAuth
}

username := config.DiceConfig.Auth.UserName
username := config.GlobalDiceDBConfig.Username
var password string

if len(args) == 1 {
Expand Down
4 changes: 2 additions & 2 deletions internal/eval/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ func evalECHO(args []string, store *dstore.Store) []byte {
func EvalAUTH(args []string, c *comm.Client) []byte {
var err error

if config.DiceConfig.Auth.Password == "" {
if config.GlobalDiceDBConfig.Password == "" {
return diceerrors.NewErrWithMessage("AUTH <password> called without any password configured for the default user. Are you sure your configuration is correct?")
}

username := config.DiceConfig.Auth.UserName
username := config.GlobalDiceDBConfig.Username
var password string

if len(args) == 1 {
Expand Down
10 changes: 10 additions & 0 deletions server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"syscall"
"time"

"github.com/dicedb/dice/internal/auth"
"github.com/dicedb/dice/internal/server/httpws"

"github.com/dicedb/dice/internal/cli"
Expand All @@ -44,6 +45,13 @@ func Start() {
iid := observability.GetOrCreateInstanceID()
config.DiceConfig.InstanceID = iid

// TODO: Handle the addition of the default user
// and new users in a much better way. Doing this using
// and empty password check is not a good solution.
if config.GlobalDiceDBConfig.Password != "" {
_, _ = auth.UserStore.Add(config.GlobalDiceDBConfig.Username)
}

// This is counter intuitive, but it's the first thing that should be done
// because this function parses the flags and prepares the config,
cli.Execute()
Expand Down Expand Up @@ -71,12 +79,14 @@ func Start() {
if err != nil {
slog.Warn("could not create WAL with", slog.String("wal-engine", config.DiceConfig.Persistence.WALEngine), slog.Any("error", err))
sigs <- syscall.SIGKILL
cancel()
return
}
wl = _wl
} else {
slog.Error("unsupported WAL engine", slog.String("engine", config.DiceConfig.Persistence.WALEngine))
sigs <- syscall.SIGKILL
cancel()
return
}

Expand Down

0 comments on commit 7fed270

Please sign in to comment.