diff --git a/src/k8s/cmd/k8s/k8s_get_join_token.go b/src/k8s/cmd/k8s/k8s_get_join_token.go index f4c32655f..2f32253fa 100644 --- a/src/k8s/cmd/k8s/k8s_get_join_token.go +++ b/src/k8s/cmd/k8s/k8s_get_join_token.go @@ -14,9 +14,12 @@ func newGetJoinTokenCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { Use: "get-join-token ", Short: "Create a token for a node to join the cluster", PreRun: chainPreRunHooks(hookRequireRoot(env)), - Args: cmdutil.ExactArgs(env, 1), + Args: cmdutil.MaximumNArgs(env, 1), Run: func(cmd *cobra.Command, args []string) { - name := args[0] + var name string + if len(args) == 1 { + name = args[0] + } client, err := env.Client(cmd.Context()) if err != nil { diff --git a/src/k8s/pkg/k8sd/api/worker.go b/src/k8s/pkg/k8sd/api/worker.go index aa900d7af..ad94f7f41 100644 --- a/src/k8s/pkg/k8sd/api/worker.go +++ b/src/k8s/pkg/k8sd/api/worker.go @@ -87,8 +87,9 @@ func (e *Endpoints) postWorkerInfo(s *state.State, r *http.Request) response.Res return response.InternalError(fmt.Errorf("add worker node transaction failed: %w", err)) } + workerToken := r.Header.Get("worker-token") if err := s.Database.Transaction(s.Context, func(ctx context.Context, tx *sql.Tx) error { - return database.DeleteWorkerNodeToken(ctx, tx, workerName) + return database.DeleteWorkerNodeToken(ctx, tx, workerToken) }); err != nil { return response.InternalError(fmt.Errorf("delete worker node token transaction failed: %w", err)) } diff --git a/src/k8s/pkg/k8sd/database/schema.go b/src/k8s/pkg/k8sd/database/schema.go index 78490de3d..fa7d9d467 100644 --- a/src/k8s/pkg/k8sd/database/schema.go +++ b/src/k8s/pkg/k8sd/database/schema.go @@ -16,6 +16,7 @@ var ( schemaApplyMigration("kubernetes-auth-tokens", "000-create.sql"), schemaApplyMigration("cluster-configs", "000-create.sql"), schemaApplyMigration("worker-nodes", "000-create.sql"), + schemaApplyMigration("worker-tokens", "000-create.sql"), } //go:embed sql/migrations diff --git a/src/k8s/pkg/k8sd/database/sql/migrations/worker-tokens/000-create.sql b/src/k8s/pkg/k8sd/database/sql/migrations/worker-tokens/000-create.sql new file mode 100644 index 000000000..007c316e3 --- /dev/null +++ b/src/k8s/pkg/k8sd/database/sql/migrations/worker-tokens/000-create.sql @@ -0,0 +1,5 @@ +CREATE TABLE worker_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL, + token TEXT NOT NULL +) diff --git a/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/delete-worker-token.sql b/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/delete-worker-token.sql deleted file mode 100644 index e3a609da1..000000000 --- a/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/delete-worker-token.sql +++ /dev/null @@ -1,4 +0,0 @@ -DELETE FROM - cluster_configs AS c -WHERE - c.key = "worker-token::" || ? diff --git a/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/insert-worker-token.sql b/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/insert-worker-token.sql deleted file mode 100644 index 20dcedd4e..000000000 --- a/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/insert-worker-token.sql +++ /dev/null @@ -1,6 +0,0 @@ -INSERT INTO - cluster_configs(key, value) -VALUES - ( "worker-token::" || ?, ? ) -ON CONFLICT(key) DO - UPDATE SET value = EXCLUDED.value; diff --git a/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/select-worker-token.sql b/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/select-worker-token.sql deleted file mode 100644 index 53f017ecf..000000000 --- a/src/k8s/pkg/k8sd/database/sql/queries/cluster-configs/select-worker-token.sql +++ /dev/null @@ -1,6 +0,0 @@ -SELECT - c.value -FROM - cluster_configs AS c -WHERE - ( c.key = "worker-token::" || ? ) diff --git a/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/delete-by-token.sql b/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/delete-by-token.sql new file mode 100644 index 000000000..0332d3afe --- /dev/null +++ b/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/delete-by-token.sql @@ -0,0 +1,4 @@ +DELETE FROM + worker_tokens AS t +WHERE + t.token = ? diff --git a/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/insert.sql b/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/insert.sql new file mode 100644 index 000000000..9b37d5c0c --- /dev/null +++ b/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/insert.sql @@ -0,0 +1,4 @@ +INSERT INTO + worker_tokens(name, token) +VALUES + ( ?, ? ) diff --git a/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/select.sql b/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/select.sql new file mode 100644 index 000000000..3f1ea8d78 --- /dev/null +++ b/src/k8s/pkg/k8sd/database/sql/queries/worker-tokens/select.sql @@ -0,0 +1,7 @@ +SELECT + t.name +FROM + worker_tokens AS t +WHERE + ( t.token = ? ) +LIMIT 1 diff --git a/src/k8s/pkg/k8sd/database/worker.go b/src/k8s/pkg/k8sd/database/worker.go index 030222096..ec22f84f0 100644 --- a/src/k8s/pkg/k8sd/database/worker.go +++ b/src/k8s/pkg/k8sd/database/worker.go @@ -18,9 +18,9 @@ var ( "select-by-name": MustPrepareStatement("worker-nodes", "select-by-name.sql"), "delete-node": MustPrepareStatement("worker-nodes", "delete.sql"), - "insert-token": MustPrepareStatement("cluster-configs", "insert-worker-token.sql"), - "select-token": MustPrepareStatement("cluster-configs", "select-worker-token.sql"), - "delete-token": MustPrepareStatement("cluster-configs", "delete-worker-token.sql"), + "insert-token": MustPrepareStatement("worker-tokens", "insert.sql"), + "select-token": MustPrepareStatement("worker-tokens", "select.sql"), + "delete-token": MustPrepareStatement("worker-tokens", "delete-by-token.sql"), } ) @@ -30,9 +30,9 @@ func CheckWorkerNodeToken(ctx context.Context, tx *sql.Tx, nodeName string, toke if err != nil { return false, fmt.Errorf("failed to prepare select statement: %w", err) } - var realToken string - if selectTxStmt.QueryRowContext(ctx, nodeName).Scan(&realToken) == nil { - return subtle.ConstantTimeCompare([]byte(token), []byte(realToken)) == 1, nil + var tokenNodeName string + if selectTxStmt.QueryRowContext(ctx, token).Scan(&tokenNodeName) == nil { + return tokenNodeName == "" || subtle.ConstantTimeCompare([]byte(nodeName), []byte(tokenNodeName)) == 1, nil } return false, nil } @@ -40,13 +40,9 @@ func CheckWorkerNodeToken(ctx context.Context, tx *sql.Tx, nodeName string, toke // GetOrCreateWorkerNodeToken returns a token that can be used to join a worker node on the cluster. // GetOrCreateWorkerNodeToken will return the existing token, if one already exists for the node. func GetOrCreateWorkerNodeToken(ctx context.Context, tx *sql.Tx, nodeName string) (string, error) { - selectTxStmt, err := cluster.Stmt(tx, workerStmts["select-token"]) + insertTxStmt, err := cluster.Stmt(tx, workerStmts["insert-token"]) if err != nil { - return "", fmt.Errorf("failed to prepare select statement: %w", err) - } - var token string - if selectTxStmt.QueryRowContext(ctx, fmt.Sprintf("worker-token::%s", nodeName)).Scan(&token) == nil { - return token, nil + return "", fmt.Errorf("failed to prepare insert statement: %w", err) } // generate random bytes for the token @@ -54,12 +50,7 @@ func GetOrCreateWorkerNodeToken(ctx context.Context, tx *sql.Tx, nodeName string if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("is the system entropy low? failed to get random bytes: %w", err) } - token = fmt.Sprintf("worker::%s", hex.EncodeToString(b)) - - insertTxStmt, err := cluster.Stmt(tx, workerStmts["insert-token"]) - if err != nil { - return "", fmt.Errorf("failed to prepare insert statement: %w", err) - } + token := fmt.Sprintf("worker::%s", hex.EncodeToString(b)) if _, err := insertTxStmt.ExecContext(ctx, nodeName, token); err != nil { return "", fmt.Errorf("insert token query failed: %w", err) } diff --git a/src/k8s/pkg/k8sd/database/worker_test.go b/src/k8s/pkg/k8sd/database/worker_test.go index 413c2a787..2bc14d684 100644 --- a/src/k8s/pkg/k8sd/database/worker_test.go +++ b/src/k8s/pkg/k8sd/database/worker_test.go @@ -11,47 +11,65 @@ import ( func TestWorkerNodeToken(t *testing.T) { WithDB(t, func(ctx context.Context, db DB) { - g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { - exists, err := database.CheckWorkerNodeToken(ctx, tx, "somenode", "sometoken") - g.Expect(err).To(BeNil()) - g.Expect(exists).To(BeFalse()) - - token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "somenode") - g.Expect(err).To(BeNil()) - g.Expect(token).To(HaveLen(48)) - - othertoken, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "someothernode") - g.Expect(err).To(BeNil()) - g.Expect(othertoken).To(HaveLen(48)) - g.Expect(othertoken).NotTo(Equal(token)) - - valid, err := database.CheckWorkerNodeToken(ctx, tx, "somenode", token) - g.Expect(err).To(BeNil()) - g.Expect(valid).To(BeTrue()) - - valid, err = database.CheckWorkerNodeToken(ctx, tx, "someothernode", token) - g.Expect(err).To(BeNil()) - g.Expect(valid).To(BeFalse()) - - valid, err = database.CheckWorkerNodeToken(ctx, tx, "someothernode", othertoken) - g.Expect(err).To(BeNil()) - g.Expect(valid).To(BeTrue()) - - err = database.DeleteWorkerNodeToken(ctx, tx, "somenode") - g.Expect(err).To(BeNil()) - - valid, err = database.CheckWorkerNodeToken(ctx, tx, "somenode", token) - g.Expect(err).To(BeNil()) - g.Expect(valid).To(BeFalse()) - - newToken, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "somenode") - g.Expect(err).To(BeNil()) - g.Expect(newToken).To(HaveLen(48)) - g.Expect(newToken).ToNot(Equal(token)) + _ = db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + t.Run("Default", func(t *testing.T) { + g := NewWithT(t) + exists, err := database.CheckWorkerNodeToken(ctx, tx, "somenode", "sometoken") + g.Expect(err).To(BeNil()) + g.Expect(exists).To(BeFalse()) + + token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "somenode") + g.Expect(err).To(BeNil()) + g.Expect(token).To(HaveLen(48)) + + othertoken, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "someothernode") + g.Expect(err).To(BeNil()) + g.Expect(othertoken).To(HaveLen(48)) + g.Expect(othertoken).NotTo(Equal(token)) + + valid, err := database.CheckWorkerNodeToken(ctx, tx, "somenode", token) + g.Expect(err).To(BeNil()) + g.Expect(valid).To(BeTrue()) + + valid, err = database.CheckWorkerNodeToken(ctx, tx, "someothernode", token) + g.Expect(err).To(BeNil()) + g.Expect(valid).To(BeFalse()) + + valid, err = database.CheckWorkerNodeToken(ctx, tx, "someothernode", othertoken) + g.Expect(err).To(BeNil()) + g.Expect(valid).To(BeTrue()) + + err = database.DeleteWorkerNodeToken(ctx, tx, token) + g.Expect(err).To(BeNil()) + + valid, err = database.CheckWorkerNodeToken(ctx, tx, "somenode", token) + g.Expect(err).To(BeNil()) + g.Expect(valid).To(BeFalse()) + + newToken, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "somenode") + g.Expect(err).To(BeNil()) + g.Expect(newToken).To(HaveLen(48)) + g.Expect(newToken).ToNot(Equal(token)) + }) + + t.Run("AnyNodeName", func(t *testing.T) { + g := NewWithT(t) + token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "") + g.Expect(err).To(BeNil()) + g.Expect(token).To(HaveLen(48)) + + for _, name := range []string{"", "test", "other"} { + t.Run(name, func(t *testing.T) { + g := NewWithT(t) + + valid, err := database.CheckWorkerNodeToken(ctx, tx, name, token) + g.Expect(err).To(BeNil()) + g.Expect(valid).To(BeTrue()) + }) + } + }) return nil }) - g.Expect(err).To(BeNil()) }) }