Skip to content

Commit

Permalink
implement followup
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Jan 13, 2025
1 parent 3f85dfa commit 6525e49
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 48 deletions.
4 changes: 2 additions & 2 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type Headscale struct {
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier

registrationCache *zcache.Cache[types.RegistrationID, types.Node]
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]

authProvider AuthProvider

Expand All @@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}

registrationCache := zcache.New[types.RegistrationID, types.Node](
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration,
registerCacheCleanup,
)
Expand Down
45 changes: 19 additions & 26 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,6 @@ func (h *Headscale) handleRegister(
}

// Check if the node is waiting for interactive login.
//
// TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the node).
// This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if regReq.Followup != "" {
logTrace("register request is a followup")
fu, err := url.Parse(regReq.Followup)
Expand All @@ -124,18 +117,15 @@ func (h *Headscale) handleRegister(

logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg))

// log.Debug().Interface("regcache", h.registrationCache.Items()).Msg("followup regcache")
log.Debug().Interface("regcache", h.registrationCache.Keys()).Msg("followup regcache")

if _, ok := h.registrationCache.Get(followupReg); ok {
if reg, ok := h.registrationCache.Get(followupReg); ok {
logTrace("Node is waiting for interactive login")

select {
case <-req.Context().Done():
logTrace("node went away before it was successfully registered")
return
case <-time.After(registrationHoldoff):
h.handleNewNode(writer, regReq, followupReg)

case <-reg.Registered:
logTrace("node has successfully registered")
return
}
}
Expand All @@ -147,17 +137,20 @@ func (h *Headscale) handleRegister(
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the node and then keep it around until a callback
// happens
newNode := types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
newNode := types.RegisterNode{
Node: types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
},
Registered: make(chan struct{}),
}

if !regReq.Expiry.IsZero() {
logTrace("Non-zero expiry time requested")
newNode.Expiry = &regReq.Expiry
newNode.Node.Expiry = &regReq.Expiry
}

h.registrationCache.Set(
Expand Down Expand Up @@ -257,11 +250,11 @@ func (h *Headscale) handleRegister(
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
node.NodeKey = regReq.NodeKey
h.registrationCache.Set(
registrationId,
*node,
)
// node.NodeKey = regReq.NodeKey
// h.registrationCache.Set(
// registrationId,
// *node,
// )

return
}
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type KV struct {
type HSDatabase struct {
DB *gorm.DB
cfg *types.DatabaseConfig
regCache *zcache.Cache[types.RegistrationID, types.Node]
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]

baseDomain string
}
Expand All @@ -51,7 +51,7 @@ type HSDatabase struct {
func NewHeadscaleDatabase(
cfg types.DatabaseConfig,
baseDomain string,
regCache *zcache.Cache[types.RegistrationID, types.Node],
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
) (*HSDatabase, error) {
dbConn, err := openDB(cfg)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ func testCopyOfDatabase(src string) (string, error) {
return dst, err
}

func emptyCache() *zcache.Cache[types.RegistrationID, types.Node] {
return zcache.New[types.RegistrationID, types.Node](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
}

// requireConstraintFailed checks if the error is a constraint failure with
Expand Down
18 changes: 10 additions & 8 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
ipv6 *netip.Addr,
) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if node, ok := hsdb.regCache.Get(registrationID); ok {
if reg, ok := hsdb.regCache.Get(registrationID); ok {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
Expand All @@ -347,29 +347,31 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user
if node.ID != 0 &&
node.UserID != user.ID {
if reg.Node.ID != 0 &&
reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}

node.UserID = user.ID
node.User = *user
node.RegisterMethod = registrationMethod
reg.Node.UserID = user.ID
reg.Node.User = *user
reg.Node.RegisterMethod = registrationMethod

if nodeExpiry != nil {
node.Expiry = nodeExpiry
reg.Node.Expiry = nodeExpiry
}

node, err := RegisterNode(
tx,
node,
reg.Node,
ipv4, ipv6,
)

if err == nil {
hsdb.regCache.Delete(registrationID)
}

// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
return node, err
}

Expand Down
19 changes: 11 additions & 8 deletions hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,15 +845,18 @@ func (api headscaleV1APIServer) DebugCreateNode(

nodeKey := key.NewNode()

newNode := types.Node{
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
User: *user,
newNode := types.RegisterNode{
Node: types.Node{
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
User: *user,

Expiry: &time.Time{},
LastSeen: &time.Time{},
Expiry: &time.Time{},
LastSeen: &time.Time{},

Hostinfo: &hostinfo,
Hostinfo: &hostinfo,
},
Registered: make(chan struct{}),
}

log.Debug().
Expand All @@ -865,7 +868,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
newNode,
)

return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
}

func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}
5 changes: 5 additions & 0 deletions hscontrol/types/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,8 @@ func RegistrationIDFromString(str string) (RegistrationID, error) {
func (r RegistrationID) String() string {
return string(r)
}

type RegisterNode struct {
Node Node
Registered chan struct{}
}

0 comments on commit 6525e49

Please sign in to comment.