Skip to content

Commit

Permalink
start migrating away from "AnyKey" funcs
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Jan 13, 2025
1 parent d85c3ae commit ff572f1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
8 changes: 6 additions & 2 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ func (h *Headscale) handleRegister(
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, err := h.db.GetNodeByNodeKey(regReq.NodeKey)
logTrace("handleRegister database lookup has returned")
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the node has AuthKey set, handle registration via PreAuthKeys
Expand Down Expand Up @@ -329,7 +331,9 @@ func (h *Headscale) handleAuthKey(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, _ := h.db.GetNodeByNodeKey(registerRequest.NodeKey)
if node != nil {
log.Trace().
Caller().
Expand Down
24 changes: 24 additions & 0 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,30 @@ func GetNodeByMachineKey(
return &mach, nil
}

func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByNodeKey(rx, nodeKey)
})
}

// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
func GetNodeByNodeKey(
tx *gorm.DB,
nodeKey key.NodePublic,
) (*types.Node, error) {
mach := types.Node{}
if result := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil {
return nil, result.Error
}

return &mach, nil
}

func (hsdb *HSDatabase) GetNodeByAnyKey(
machineKey key.MachinePublic,
nodeKey key.NodePublic,
Expand Down

0 comments on commit ff572f1

Please sign in to comment.