Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix logging in with the same node with a new user #2337

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- TestOIDCExpireNodesBasedOnTokenExpiry
- TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNode
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin
- TestUserCommand
Expand Down
6 changes: 3 additions & 3 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type Headscale struct {
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier

registrationCache *zcache.Cache[string, types.Node]
registrationCache *zcache.Cache[types.RegistrationID, types.Node]

authProvider AuthProvider

Expand All @@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}

registrationCache := zcache.New[string, types.Node](
registrationCache := zcache.New[types.RegistrationID, types.Node](
registerCacheExpiration,
registerCacheCleanup,
)
Expand Down Expand Up @@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {

router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)

if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
Expand Down
44 changes: 30 additions & 14 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ import (

type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(key.MachinePublic) string
AuthURL(types.RegistrationID) string
}

func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) (func(string), func(string), func(error, string)) {
return func(msg string) {
log.Info().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -41,6 +43,7 @@ func logAuthFunc(
func(msg string) {
log.Trace().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -52,6 +55,7 @@ func logAuthFunc(
func(err error, msg string) {
log.Error().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -70,7 +74,18 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
registrationId, err := types.NewRegistrationID()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to generate registration ID")
http.Error(writer, "Internal server error", http.StatusInternalServerError)

return
}

logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
Expand All @@ -93,14 +108,14 @@ func (h *Headscale) handleRegister(
// successful RegisterResponse.
if regReq.Followup != "" {
logTrace("register request is a followup")
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
if _, ok := h.registrationCache.Get(registrationId); ok {
logTrace("Node is waiting for interactive login")

select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
h.handleNewNode(writer, regReq, machineKey)
h.handleNewNode(writer, regReq, registrationId)

return
}
Expand All @@ -127,11 +142,11 @@ func (h *Headscale) handleRegister(
}

h.registrationCache.Set(
machineKey.String(),
registrationId,
newNode,
)

h.handleNewNode(writer, regReq, machineKey)
h.handleNewNode(writer, regReq, registrationId)

return
}
Expand Down Expand Up @@ -214,7 +229,7 @@ func (h *Headscale) handleRegister(
}

// The node has expired or it is logged out
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey)
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId)

// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
node.Expiry = &time.Time{}
Expand All @@ -225,7 +240,7 @@ func (h *Headscale) handleRegister(
// headscale-managed tailnets?
node.NodeKey = regReq.NodeKey
h.registrationCache.Set(
machineKey.String(),
registrationId,
*node,
)

Expand Down Expand Up @@ -444,16 +459,16 @@ func (h *Headscale) handleAuthKey(
func (h *Headscale) handleNewNode(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId)

resp := tailcfg.RegisterResponse{}

// The node registration is new, redirect the client to the registration URL
logTrace("The node seems to be new, sending auth url")

resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)

respBody, err := json.Marshal(resp)
if err != nil {
Expand Down Expand Up @@ -660,6 +675,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
regReq tailcfg.RegisterRequest,
node types.Node,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
resp := tailcfg.RegisterResponse{}

Expand All @@ -673,12 +689,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
log.Trace().
Caller().
Str("node", node.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Msg("Node registration has expired or logged out. Sending a auth url to register")

resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)

respBody, err := json.Marshal(resp)
if err != nil {
Expand All @@ -703,7 +719,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(

log.Trace().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Str("node", node.Hostname).
Expand Down
56 changes: 0 additions & 56 deletions hscontrol/auth_noise.go

This file was deleted.

4 changes: 2 additions & 2 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type KV struct {
type HSDatabase struct {
DB *gorm.DB
cfg *types.DatabaseConfig
regCache *zcache.Cache[string, types.Node]
regCache *zcache.Cache[types.RegistrationID, types.Node]

baseDomain string
}
Expand All @@ -51,7 +51,7 @@ type HSDatabase struct {
func NewHeadscaleDatabase(
cfg types.DatabaseConfig,
baseDomain string,
regCache *zcache.Cache[string, types.Node],
regCache *zcache.Cache[types.RegistrationID, types.Node],
) (*HSDatabase, error) {
dbConn, err := openDB(cfg)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ func testCopyOfDatabase(src string) (string, error) {
return dst, err
}

func emptyCache() *zcache.Cache[string, types.Node] {
return zcache.New[string, types.Node](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.RegistrationID, types.Node] {
return zcache.New[types.RegistrationID, types.Node](time.Minute, time.Hour)
}

// requireConstraintFailed checks if the error is a constraint failure with
Expand Down
8 changes: 4 additions & 4 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,15 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
}

func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
mkey key.MachinePublic,
registrationID types.RegistrationID,
userID types.UserID,
nodeExpiry *time.Time,
registrationMethod string,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if node, ok := hsdb.regCache.Get(mkey.String()); ok {
if node, ok := hsdb.regCache.Get(registrationID); ok {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
Expand All @@ -338,7 +338,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
}

log.Debug().
Str("machine_key", mkey.ShortString()).
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Expand All @@ -365,7 +365,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
)

if err == nil {
hsdb.regCache.Delete(mkey.String())
hsdb.regCache.Delete(registrationID)
}

return node, err
Expand Down
21 changes: 9 additions & 12 deletions hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,10 @@ func (api headscaleV1APIServer) RegisterNode(
) (*v1.RegisterNodeResponse, error) {
log.Trace().
Str("user", request.GetUser()).
Str("machine_key", request.GetKey()).
Str("registration_id", request.GetKey()).
Msg("Registering node")

var mkey key.MachinePublic
err := mkey.UnmarshalText([]byte(request.GetKey()))
registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil {
return nil, err
}
Expand All @@ -247,7 +246,7 @@ func (api headscaleV1APIServer) RegisterNode(
}

node, err := api.h.db.RegisterNodeFromAuthCallback(
mkey,
registrationId,
types.UserID(user.ID),
nil,
util.RegisterMethodCLI,
Expand Down Expand Up @@ -839,19 +838,17 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: "DebugTestNode",
}

var mkey key.MachinePublic
err = mkey.UnmarshalText([]byte(request.GetKey()))
registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil {
return nil, err
}

nodeKey := key.NewNode()

newNode := types.Node{
MachineKey: mkey,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
User: *user,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
User: *user,

Expiry: &time.Time{},
LastSeen: &time.Time{},
Expand All @@ -860,11 +857,11 @@ func (api headscaleV1APIServer) DebugCreateNode(
}

log.Debug().
Str("machine_key", mkey.ShortString()).
Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache")

api.h.registrationCache.Set(
mkey.String(),
registrationId,
newNode,
)

Expand Down
5 changes: 3 additions & 2 deletions hscontrol/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/chasefleming/elem-go/styles"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
Expand Down Expand Up @@ -239,11 +240,11 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
}
}

func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string {
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
registrationId.String())
}

// RegisterWebAPI shows a simple message in the browser to point to the CLI
Expand Down
Loading
Loading