From ff572f1c9b2aa58318ebe45a52b64c2862c64756 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 13 Jan 2025 11:42:20 +0100 Subject: [PATCH] start migrating away from "AnyKey" funcs Signed-off-by: Kristoffer Dalby --- hscontrol/auth.go | 8 ++++++-- hscontrol/db/node.go | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index b0d632b9e1..a0bf3dfb4b 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -124,7 +124,9 @@ func (h *Headscale) handleRegister( logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId) now := time.Now().UTC() logTrace("handleRegister called, looking up machine in DB") - node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) + // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs + // key refreshes. This will allow us to remove the machineKey from the registration request. + node, err := h.db.GetNodeByNodeKey(regReq.NodeKey) logTrace("handleRegister database lookup has returned") if errors.Is(err, gorm.ErrRecordNotFound) { // If the node has AuthKey set, handle registration via PreAuthKeys @@ -329,7 +331,9 @@ func (h *Headscale) handleAuthKey( // The error is not important, because if it does not // exist, then this is a new node and we will move // on to registration. - node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) + // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs + // key refreshes. This will allow us to remove the machineKey from the registration request. + node, _ := h.db.GetNodeByNodeKey(registerRequest.NodeKey) if node != nil { log.Trace(). Caller(). diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index b718b9c62c..d7b0864f23 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -158,6 +158,30 @@ func GetNodeByMachineKey( return &mach, nil } +func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByNodeKey(rx, nodeKey) + }) +} + +// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct. +func GetNodeByNodeKey( + tx *gorm.DB, + nodeKey key.NodePublic, +) (*types.Node, error) { + mach := types.Node{} + if result := tx. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil { + return nil, result.Error + } + + return &mach, nil +} + func (hsdb *HSDatabase) GetNodeByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic,