diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 72cde32d76..41b46fb07e 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -4,10 +4,10 @@ import ( "fmt" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc/status" - "tailscale.com/types/key" ) const ( @@ -79,7 +79,7 @@ var createNodeCmd = &cobra.Command{ ) } - machineKey, err := cmd.Flags().GetString("key") + registrationID, err := cmd.Flags().GetString("key") if err != nil { ErrorOutput( err, @@ -88,8 +88,7 @@ var createNodeCmd = &cobra.Command{ ) } - var mkey key.MachinePublic - err = mkey.UnmarshalText([]byte(machineKey)) + _, err = types.RegistrationIDFromString(registrationID) if err != nil { ErrorOutput( err, @@ -108,7 +107,7 @@ var createNodeCmd = &cobra.Command{ } request := &v1.DebugCreateNodeRequest{ - Key: machineKey, + Key: registrationID, Name: name, User: user, Routes: routes, diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 8ffc85f6ad..d65814135f 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -122,7 +122,7 @@ var registerNodeCmd = &cobra.Command{ defer cancel() defer conn.Close() - machineKey, err := cmd.Flags().GetString("key") + registrationID, err := cmd.Flags().GetString("key") if err != nil { ErrorOutput( err, @@ -132,7 +132,7 @@ var registerNodeCmd = &cobra.Command{ } request := &v1.RegisterNodeRequest{ - Key: machineKey, + Key: registrationID, User: user, } diff --git a/hscontrol/app.go b/hscontrol/app.go index 641f5d421a..263342d769 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -96,7 +96,7 @@ type Headscale struct { mapper *mapper.Mapper nodeNotifier *notifier.Notifier - registrationCache *zcache.Cache[string, types.Node] + registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] authProvider AuthProvider @@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - registrationCache := zcache.New[string, types.Node]( + registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( registerCacheExpiration, registerCacheCleanup, ) @@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) - router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet) + router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet) if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index b4923ccb5c..9e22660d46 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "net/http" + "net/url" + "strings" "time" "github.com/juanfont/headscale/hscontrol/db" @@ -20,16 +22,18 @@ import ( type AuthProvider interface { RegisterHandler(http.ResponseWriter, *http.Request) - AuthURL(key.MachinePublic) string + AuthURL(types.RegistrationID) string } func logAuthFunc( registerRequest tailcfg.RegisterRequest, machineKey key.MachinePublic, + registrationId types.RegistrationID, ) (func(string), func(string), func(error, string)) { return func(msg string) { log.Info(). Caller(). + Str("registration_id", registrationId.String()). Str("machine_key", machineKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()). @@ -41,6 +45,7 @@ func logAuthFunc( func(msg string) { log.Trace(). Caller(). + Str("registration_id", registrationId.String()). Str("machine_key", machineKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()). @@ -52,6 +57,7 @@ func logAuthFunc( func(err error, msg string) { log.Error(). Caller(). + Str("registration_id", registrationId.String()). Str("machine_key", machineKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()). @@ -63,6 +69,40 @@ func logAuthFunc( } } +func (h *Headscale) waitForFollowup( + req *http.Request, + regReq tailcfg.RegisterRequest, + logTrace func(string), +) { + logTrace("register request is a followup") + fu, err := url.Parse(regReq.Followup) + if err != nil { + logTrace("failed to parse followup URL") + return + } + + followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) + if err != nil { + logTrace("followup URL does not contains a valid registration ID") + return + } + + logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg)) + + if reg, ok := h.registrationCache.Get(followupReg); ok { + logTrace("Node is waiting for interactive login") + + select { + case <-req.Context().Done(): + logTrace("node went away before it was registered") + return + case <-reg.Registered: + logTrace("node has successfully registered") + return + } + } +} + // handleRegister is the logic for registering a client. func (h *Headscale) handleRegister( writer http.ResponseWriter, @@ -70,9 +110,23 @@ func (h *Headscale) handleRegister( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) { - logInfo, logTrace, _ := logAuthFunc(regReq, machineKey) + registrationId, err := types.NewRegistrationID() + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to generate registration ID") + http.Error(writer, "Internal server error", http.StatusInternalServerError) + + return + } + + logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId) now := time.Now().UTC() logTrace("handleRegister called, looking up machine in DB") + + // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs + // key refreshes. This will allow us to remove the machineKey from the registration request. node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) logTrace("handleRegister database lookup has returned") if errors.Is(err, gorm.ErrRecordNotFound) { @@ -84,27 +138,9 @@ func (h *Headscale) handleRegister( } // Check if the node is waiting for interactive login. - // - // TODO(juan): We could use this field to improve our protocol implementation, - // and hold the request until the client closes it, or the interactive - // login is completed (i.e., the user registers the node). - // This is not implemented yet, as it is no strictly required. The only side-effect - // is that the client will hammer headscale with requests until it gets a - // successful RegisterResponse. if regReq.Followup != "" { - logTrace("register request is a followup") - if _, ok := h.registrationCache.Get(machineKey.String()); ok { - logTrace("Node is waiting for interactive login") - - select { - case <-req.Context().Done(): - return - case <-time.After(registrationHoldoff): - h.handleNewNode(writer, regReq, machineKey) - - return - } - } + h.waitForFollowup(req, regReq, logTrace) + return } logInfo("Node not found in database, creating new") @@ -113,25 +149,28 @@ func (h *Headscale) handleRegister( // that we rely on a method that calls back some how (OpenID or CLI) // We create the node and then keep it around until a callback // happens - newNode := types.Node{ - MachineKey: machineKey, - Hostname: regReq.Hostinfo.Hostname, - NodeKey: regReq.NodeKey, - LastSeen: &now, - Expiry: &time.Time{}, + newNode := types.RegisterNode{ + Node: types.Node{ + MachineKey: machineKey, + Hostname: regReq.Hostinfo.Hostname, + NodeKey: regReq.NodeKey, + LastSeen: &now, + Expiry: &time.Time{}, + }, + Registered: make(chan struct{}), } if !regReq.Expiry.IsZero() { logTrace("Non-zero expiry time requested") - newNode.Expiry = ®Req.Expiry + newNode.Node.Expiry = ®Req.Expiry } h.registrationCache.Set( - machineKey.String(), + registrationId, newNode, ) - h.handleNewNode(writer, regReq, machineKey) + h.handleNewNode(writer, regReq, registrationId) return } @@ -206,27 +245,28 @@ func (h *Headscale) handleRegister( } if regReq.Followup != "" { - select { - case <-req.Context().Done(): - return - case <-time.After(registrationHoldoff): - } + h.waitForFollowup(req, regReq, logTrace) + return } // The node has expired or it is logged out - h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey) + h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId) // TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use node.Expiry = &time.Time{} + // TODO(kradalby): do we need to rethink this as part of authflow? // If we are here it means the client needs to be reauthorized, // we need to make sure the NodeKey matches the one in the request // TODO(juan): What happens when using fast user switching between two // headscale-managed tailnets? node.NodeKey = regReq.NodeKey h.registrationCache.Set( - machineKey.String(), - *node, + registrationId, + types.RegisterNode{ + Node: *node, + Registered: make(chan struct{}), + }, ) return @@ -296,6 +336,8 @@ func (h *Headscale) handleAuthKey( // The error is not important, because if it does not // exist, then this is a new node and we will move // on to registration. + // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs + // key refreshes. This will allow us to remove the machineKey from the registration request. node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) if node != nil { log.Trace(). @@ -444,16 +486,16 @@ func (h *Headscale) handleAuthKey( func (h *Headscale) handleNewNode( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, - machineKey key.MachinePublic, + registrationId types.RegistrationID, ) { - logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey) + logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId) resp := tailcfg.RegisterResponse{} // The node registration is new, redirect the client to the registration URL - logTrace("The node seems to be new, sending auth url") + logTrace("The node is new, sending auth url") - resp.AuthURL = h.authProvider.AuthURL(machineKey) + resp.AuthURL = h.authProvider.AuthURL(registrationId) respBody, err := json.Marshal(resp) if err != nil { @@ -660,6 +702,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut( regReq tailcfg.RegisterRequest, node types.Node, machineKey key.MachinePublic, + registrationId types.RegistrationID, ) { resp := tailcfg.RegisterResponse{} @@ -673,12 +716,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut( log.Trace(). Caller(). Str("node", node.Hostname). - Str("machine_key", machineKey.ShortString()). + Str("registration_id", registrationId.String()). Str("node_key", regReq.NodeKey.ShortString()). Str("node_key_old", regReq.OldNodeKey.ShortString()). Msg("Node registration has expired or logged out. Sending a auth url to register") - resp.AuthURL = h.authProvider.AuthURL(machineKey) + resp.AuthURL = h.authProvider.AuthURL(registrationId) respBody, err := json.Marshal(resp) if err != nil { @@ -703,7 +746,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut( log.Trace(). Caller(). - Str("machine_key", machineKey.ShortString()). + Str("registration_id", registrationId.String()). Str("node_key", regReq.NodeKey.ShortString()). Str("node_key_old", regReq.OldNodeKey.ShortString()). Str("node", node.Hostname). diff --git a/hscontrol/auth_noise.go b/hscontrol/auth_noise.go deleted file mode 100644 index 6659dfa527..0000000000 --- a/hscontrol/auth_noise.go +++ /dev/null @@ -1,56 +0,0 @@ -package hscontrol - -import ( - "encoding/json" - "io" - "net/http" - - "github.com/rs/zerolog/log" - "tailscale.com/tailcfg" -) - -// // NoiseRegistrationHandler handles the actual registration process of a node. -func (ns *noiseServer) NoiseRegistrationHandler( - writer http.ResponseWriter, - req *http.Request, -) { - log.Trace().Caller().Msgf("Noise registration handler for client %s", req.RemoteAddr) - if req.Method != http.MethodPost { - http.Error(writer, "Wrong method", http.StatusMethodNotAllowed) - - return - } - - log.Trace(). - Any("headers", req.Header). - Caller(). - Msg("Headers") - - body, _ := io.ReadAll(req.Body) - registerRequest := tailcfg.RegisterRequest{} - if err := json.Unmarshal(body, ®isterRequest); err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse RegisterRequest") - http.Error(writer, "Internal error", http.StatusInternalServerError) - - return - } - - // Reject unsupported versions - if registerRequest.Version < MinimumCapVersion { - log.Info(). - Caller(). - Int("min_version", int(MinimumCapVersion)). - Int("client_version", int(registerRequest.Version)). - Msg("unsupported client connected") - http.Error(writer, "Internal error", http.StatusBadRequest) - - return - } - - ns.nodeKey = registerRequest.NodeKey - - ns.headscale.handleRegister(writer, req, registerRequest, ns.conn.Peer()) -} diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 36955e229e..6c3493b8bc 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -41,7 +41,7 @@ type KV struct { type HSDatabase struct { DB *gorm.DB cfg *types.DatabaseConfig - regCache *zcache.Cache[string, types.Node] + regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] baseDomain string } @@ -51,7 +51,7 @@ type HSDatabase struct { func NewHeadscaleDatabase( cfg types.DatabaseConfig, baseDomain string, - regCache *zcache.Cache[string, types.Node], + regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], ) (*HSDatabase, error) { dbConn, err := openDB(cfg) if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 0672c2523e..8ca773033c 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -260,8 +260,8 @@ func testCopyOfDatabase(src string) (string, error) { return dst, err } -func emptyCache() *zcache.Cache[string, types.Node] { - return zcache.New[string, types.Node](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { + return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) } // requireConstraintFailed checks if the error is a constraint failure with diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index ce9c90e916..f722d9ab16 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -158,6 +158,30 @@ func GetNodeByMachineKey( return &mach, nil } +func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByNodeKey(rx, nodeKey) + }) +} + +// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct. +func GetNodeByNodeKey( + tx *gorm.DB, + nodeKey key.NodePublic, +) (*types.Node, error) { + mach := types.Node{} + if result := tx. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil { + return nil, result.Error + } + + return &mach, nil +} + func (hsdb *HSDatabase) GetNodeByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic, @@ -319,60 +343,83 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error } -func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( - mkey key.MachinePublic, +// HandleNodeFromAuthPath is called from the OIDC or CLI auth path +// with a registrationID to register or reauthenticate a node. +// If the node found in the registration cache is not already registered, +// it will be registered with the user and the node will be removed from the cache. +// If the node is already registered, the expiry will be updated. +// The node, and a boolean indicating if it was a new node or not, will be returned. +func (hsdb *HSDatabase) HandleNodeFromAuthPath( + registrationID types.RegistrationID, userID types.UserID, nodeExpiry *time.Time, registrationMethod string, ipv4 *netip.Addr, ipv6 *netip.Addr, -) (*types.Node, error) { - return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { - if node, ok := hsdb.regCache.Get(mkey.String()); ok { - user, err := GetUserByID(tx, userID) - if err != nil { - return nil, fmt.Errorf( - "failed to find user in register node from auth callback, %w", - err, +) (*types.Node, bool, error) { + var newNode bool + node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { + if reg, ok := hsdb.regCache.Get(registrationID); ok { + if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { + user, err := GetUserByID(tx, userID) + if err != nil { + return nil, fmt.Errorf( + "failed to find user in register node from auth callback, %w", + err, + ) + } + + log.Debug(). + Str("registration_id", registrationID.String()). + Str("username", user.Username()). + Str("registrationMethod", registrationMethod). + Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). + Msg("Registering node from API/CLI or auth callback") + + // TODO(kradalby): This looks quite wrong? why ID 0? + // Why not always? + // Registration of expired node with different user + if reg.Node.ID != 0 && + reg.Node.UserID != user.ID { + return nil, ErrDifferentRegisteredUser + } + + reg.Node.UserID = user.ID + reg.Node.User = *user + reg.Node.RegisterMethod = registrationMethod + + if nodeExpiry != nil { + reg.Node.Expiry = nodeExpiry + } + + node, err := RegisterNode( + tx, + reg.Node, + ipv4, ipv6, ) - } - - log.Debug(). - Str("machine_key", mkey.ShortString()). - Str("username", user.Username()). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). - Msg("Registering node from API/CLI or auth callback") - - // Registration of expired node with different user - if node.ID != 0 && - node.UserID != user.ID { - return nil, ErrDifferentRegisteredUser - } - - node.UserID = user.ID - node.User = *user - node.RegisterMethod = registrationMethod - - if nodeExpiry != nil { - node.Expiry = nodeExpiry - } - node, err := RegisterNode( - tx, - node, - ipv4, ipv6, - ) - - if err == nil { - hsdb.regCache.Delete(mkey.String()) + if err == nil { + hsdb.regCache.Delete(registrationID) + } + + // Signal to waiting clients that the machine has been registered. + close(reg.Registered) + newNode = true + return node, err + } else { + // If the node is already registered, this is a refresh. + err := NodeSetExpiry(tx, node.ID, *nodeExpiry) + if err != nil { + return nil, err + } + return node, nil } - - return node, err } return nil, ErrNodeNotFoundRegistrationCache }) + + return node, newNode, err } func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index b7c7e50e8d..7b1c658115 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -227,11 +227,10 @@ func (api headscaleV1APIServer) RegisterNode( ) (*v1.RegisterNodeResponse, error) { log.Trace(). Str("user", request.GetUser()). - Str("machine_key", request.GetKey()). + Str("registration_id", request.GetKey()). Msg("Registering node") - var mkey key.MachinePublic - err := mkey.UnmarshalText([]byte(request.GetKey())) + registrationId, err := types.RegistrationIDFromString(request.GetKey()) if err != nil { return nil, err } @@ -246,8 +245,8 @@ func (api headscaleV1APIServer) RegisterNode( return nil, fmt.Errorf("looking up user: %w", err) } - node, err := api.h.db.RegisterNodeFromAuthCallback( - mkey, + node, _, err := api.h.db.HandleNodeFromAuthPath( + registrationId, types.UserID(user.ID), nil, util.RegisterMethodCLI, @@ -839,36 +838,36 @@ func (api headscaleV1APIServer) DebugCreateNode( Hostname: "DebugTestNode", } - var mkey key.MachinePublic - err = mkey.UnmarshalText([]byte(request.GetKey())) + registrationId, err := types.RegistrationIDFromString(request.GetKey()) if err != nil { return nil, err } - nodeKey := key.NewNode() + newNode := types.RegisterNode{ + Node: types.Node{ + NodeKey: key.NewNode().Public(), + MachineKey: key.NewMachine().Public(), + Hostname: request.GetName(), + User: *user, - newNode := types.Node{ - MachineKey: mkey, - NodeKey: nodeKey.Public(), - Hostname: request.GetName(), - User: *user, + Expiry: &time.Time{}, + LastSeen: &time.Time{}, - Expiry: &time.Time{}, - LastSeen: &time.Time{}, - - Hostinfo: &hostinfo, + Hostinfo: &hostinfo, + }, + Registered: make(chan struct{}), } log.Debug(). - Str("machine_key", mkey.ShortString()). + Str("registration_id", registrationId.String()). Msg("adding debug machine via CLI, appending to registration cache") api.h.registrationCache.Set( - mkey.String(), + registrationId, newNode, ) - return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil + return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil } func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 3858df9339..edebae4a13 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -8,16 +8,13 @@ import ( "net/http" "strconv" "strings" - "time" - "github.com/chasefleming/elem-go" - "github.com/chasefleming/elem-go/attrs" "github.com/chasefleming/elem-go/styles" "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/templates" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" - "tailscale.com/types/key" ) const ( @@ -32,8 +29,6 @@ const ( // See also https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go NoiseCapabilityVersion = 39 - // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. - registrationHoldoff = time.Second * 5 reservedResponseHeaderSize = 4 ) @@ -204,31 +199,6 @@ var codeStyleRegisterWebAPI = styles.Props{ styles.BackgroundColor: "#eee", } -func registerWebHTML(key string) *elem.Element { - return elem.Html(nil, - elem.Head( - nil, - elem.Title(nil, elem.Text("Registration - Headscale")), - elem.Meta(attrs.Props{ - attrs.Name: "viewport", - attrs.Content: "width=device-width, initial-scale=1", - }), - ), - elem.Body(attrs.Props{ - attrs.Style: styles.Props{ - styles.FontFamily: "sans", - }.ToInline(), - }, - elem.H1(nil, elem.Text("headscale")), - elem.H2(nil, elem.Text("Machine registration")), - elem.P(nil, elem.Text("Run the command below in the headscale server to add this machine to your network:")), - elem.Code(attrs.Props{attrs.Style: codeStyleRegisterWebAPI.ToInline()}, - elem.Text(fmt.Sprintf("headscale nodes register --user USERNAME --key %s", key)), - ), - ), - ) -} - type AuthProviderWeb struct { serverURL string } @@ -239,15 +209,15 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb { } } -func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string { +func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - mKey.String()) + registrationId.String()) } // RegisterWebAPI shows a simple message in the browser to point to the CLI -// Listens in /register/:nkey. +// Listens in /register/:registration_id. // // This is not part of the Tailscale control API, as we could send whatever URL // in the RegisterResponse.AuthURL field. @@ -256,39 +226,23 @@ func (a *AuthProviderWeb) RegisterHandler( req *http.Request, ) { vars := mux.Vars(req) - machineKeyStr := vars["mkey"] + registrationIdStr := vars["registration_id"] // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - var machineKey key.MachinePublic - err := machineKey.UnmarshalText( - []byte(machineKeyStr), - ) + registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - log.Warn().Err(err).Msg("Failed to parse incoming machinekey") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Wrong params")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + http.Error(writer, "invalid registration ID", http.StatusBadRequest) return } writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - if _, err := writer.Write([]byte(registerWebHTML(machineKey.String()).Render())); err != nil { - if _, err := writer.Write([]byte(templates.RegisterWeb(machineKey.String()).Render())); err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } + if _, err := writer.Write([]byte(templates.RegisterWeb(registrationId).Render())); err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") } } diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 393b9608f3..d1b0baa5ca 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -3,6 +3,7 @@ package hscontrol import ( "encoding/binary" "encoding/json" + "fmt" "io" "net/http" @@ -115,18 +116,8 @@ func (h *Headscale) NoiseUpgradeHandler( } func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { - log.Trace(). - Caller(). - Int("protocol_version", protocolVersion). - Str("challenge", ns.challenge.Public().String()). - Msg("earlyNoise called") - - if protocolVersion < earlyNoiseCapabilityVersion { - log.Trace(). - Caller(). - Msgf("protocol version %d does not support early noise", protocolVersion) - - return nil + if !isSupportedVersion(tailcfg.CapabilityVersion(protocolVersion)) { + return fmt.Errorf("unsupported client version: %d", protocolVersion) } earlyJSON, err := json.Marshal(&tailcfg.EarlyNoise{ @@ -162,6 +153,26 @@ const ( MinimumCapVersion tailcfg.CapabilityVersion = 82 ) +func isSupportedVersion(version tailcfg.CapabilityVersion) bool { + return version >= MinimumCapVersion +} + +func rejectUnsupported(writer http.ResponseWriter, version tailcfg.CapabilityVersion) bool { + // Reject unsupported versions + if !isSupportedVersion(version) { + log.Info(). + Caller(). + Int("min_version", int(MinimumCapVersion)). + Int("client_version", int(version)). + Msg("unsupported client connected") + http.Error(writer, "unsupported client version", http.StatusBadRequest) + + return true + } + + return false +} + // NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol // // This is the busiest endpoint, as it keeps the HTTP long poll that updates @@ -177,7 +188,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( ) { body, _ := io.ReadAll(req.Body) - mapRequest := tailcfg.MapRequest{} + var mapRequest tailcfg.MapRequest if err := json.Unmarshal(body, &mapRequest); err != nil { log.Error(). Caller(). @@ -197,14 +208,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( Msg("PollNetMapHandler called") // Reject unsupported versions - if mapRequest.Version < MinimumCapVersion { - log.Info(). - Caller(). - Int("min_version", int(MinimumCapVersion)). - Int("client_version", int(mapRequest.Version)). - Msg("unsupported client connected") - http.Error(writer, "Internal error", http.StatusBadRequest) - + if rejectUnsupported(writer, mapRequest.Version) { return } @@ -232,3 +236,42 @@ func (ns *noiseServer) NoisePollNetMapHandler( sess.serveLongPoll() } } + +// NoiseRegistrationHandler handles the actual registration process of a node. +func (ns *noiseServer) NoiseRegistrationHandler( + writer http.ResponseWriter, + req *http.Request, +) { + log.Trace().Caller().Msgf("Noise registration handler for client %s", req.RemoteAddr) + if req.Method != http.MethodPost { + http.Error(writer, "Wrong method", http.StatusMethodNotAllowed) + + return + } + + log.Trace(). + Any("headers", req.Header). + Caller(). + Msg("Headers") + + body, _ := io.ReadAll(req.Body) + var registerRequest tailcfg.RegisterRequest + if err := json.Unmarshal(body, ®isterRequest); err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot parse RegisterRequest") + http.Error(writer, "Internal error", http.StatusInternalServerError) + + return + } + + // Reject unsupported versions + if rejectUnsupported(writer, registerRequest.Version) { + return + } + + ns.nodeKey = registerRequest.NodeKey + + ns.headscale.handleRegister(writer, req, registerRequest, ns.conn.Peer()) +} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 4470ba41b9..5bc548d076 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -21,7 +21,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "tailscale.com/types/key" "zgo.at/zcache/v2" ) @@ -49,8 +48,8 @@ var ( // RegistrationInfo contains both machine key and verifier information for OIDC validation. type RegistrationInfo struct { - MachineKey key.MachinePublic - Verifier *string + RegistrationID types.RegistrationID + Verifier *string } type AuthProviderOIDC struct { @@ -112,11 +111,11 @@ func NewAuthProviderOIDC( }, nil } -func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string { +func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - mKey.String()) + registrationID.String()) } func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { @@ -129,32 +128,29 @@ func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time // RegisterOIDC redirects to the OIDC provider for authentication // Puts NodeKey in cache so the callback can retrieve it using the oidc state param -// Listens in /register/:mKey. +// Listens in /register/:registration_id. func (a *AuthProviderOIDC) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { vars := mux.Vars(req) - machineKeyStr, ok := vars["mkey"] - - log.Debug(). - Caller(). - Str("machine_key", machineKeyStr). - Bool("ok", ok). - Msg("Received oidc register call") + registrationIdStr, ok := vars["registration_id"] // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - var machineKey key.MachinePublic - err := machineKey.UnmarshalText( - []byte(machineKeyStr), - ) + registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - http.Error(writer, err.Error(), http.StatusBadRequest) + http.Error(writer, "invalid registration ID", http.StatusBadRequest) return } + log.Debug(). + Caller(). + Str("registration_id", registrationId.String()). + Bool("ok", ok). + Msg("Received oidc register call") + // Set the state and nonce cookies to protect against CSRF attacks state, err := setCSRFCookie(writer, req, "state") if err != nil { @@ -171,7 +167,7 @@ func (a *AuthProviderOIDC) RegisterHandler( // Initialize registration info with machine key registrationInfo := RegistrationInfo{ - MachineKey: machineKey, + RegistrationID: registrationId, } extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) @@ -290,49 +286,27 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // Retrieve the node and the machine key from the state cache and - // database. + // TODO(kradalby): Is this comment right? // If the node exists, then the node should be reauthenticated, // if the node does not exist, and the machine key exists, then // this is a new node that should be registered. - node, mKey := a.getMachineKeyFromState(state) + registrationId := a.getRegistrationIDFromState(state) - // Reauthenticate the node if it does exists. - if node != nil { - err := a.reauthenticateNode(node, nodeExpiry) + // Register the node if it does not exist. + if registrationId != nil { + verb := "Reauthenticated" + newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry) if err != nil { http.Error(writer, err.Error(), http.StatusInternalServerError) return } - // TODO(kradalby): replace with go-elem - var content bytes.Buffer - if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ - User: user.DisplayNameOrUsername(), - Verb: "Reauthenticated", - }); err != nil { - http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError) - return - } - - writer.Header().Set("Content-Type", "text/html; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(content.Bytes()) - if err != nil { - util.LogErr(err, "Failed to write response") + if newNode { + verb = "Authenticated" } - return - } - - // Register the node if it does not exist. - if mKey != nil { - if err := a.registerNode(user, mKey, nodeExpiry); err != nil { - http.Error(writer, err.Error(), http.StatusInternalServerError) - return - } - - content, err := renderOIDCCallbackTemplate(user) + // TODO(kradalby): replace with go-elem + content, err := renderOIDCCallbackTemplate(user, verb) if err != nil { http.Error(writer, err.Error(), http.StatusInternalServerError) return @@ -456,49 +430,14 @@ func validateOIDCAllowedUsers( return nil } -// getMachineKeyFromState retrieves the machine key from the state -// cache. If the machine key is found, it will try retrieve the -// node information from the database. -func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) { +// getRegistrationIDFromState retrieves the registration ID from the state. +func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID { regInfo, ok := a.registrationCache.Get(state) if !ok { - return nil, nil - } - - // retrieve node information if it exist - // The error is not important, because if it does not - // exist, then this is a new node and we will move - // on to registration. - node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey) - - return node, ®Info.MachineKey -} - -// reauthenticateNode updates the node expiry in the database -// and notifies the node and its peers about the change. -func (a *AuthProviderOIDC) reauthenticateNode( - node *types.Node, - expiry time.Time, -) error { - err := a.db.NodeSetExpiry(node.ID, expiry) - if err != nil { - return err + return nil } - ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) - a.notifier.NotifyByNodeID( - ctx, - types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: []types.NodeID{node.ID}, - }, - node.ID, - ) - - ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) - a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID) - - return nil + return ®Info.RegistrationID } func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( @@ -556,43 +495,63 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return user, nil } -func (a *AuthProviderOIDC) registerNode( +func (a *AuthProviderOIDC) handleRegistrationID( user *types.User, - machineKey *key.MachinePublic, + registrationID types.RegistrationID, expiry time.Time, -) error { +) (bool, error) { ipv4, ipv6, err := a.ipAlloc.Next() if err != nil { - return err + return false, err } - if _, err := a.db.RegisterNodeFromAuthCallback( - *machineKey, + node, newNode, err := a.db.HandleNodeFromAuthPath( + registrationID, types.UserID(user.ID), &expiry, util.RegisterMethodOIDC, ipv4, ipv6, - ); err != nil { - return fmt.Errorf("could not register node: %w", err) - } - - err = nodesChangedHook(a.db, a.polMan, a.notifier) + ) if err != nil { - return fmt.Errorf("updating resources using node: %w", err) + return false, fmt.Errorf("could not register node: %w", err) } - return nil + // Send an update to all nodes if this is a new node that they need to know + // about. + // If this is a refresh, just send new expiry updates. + if newNode { + err = nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return false, fmt.Errorf("updating resources using node: %w", err) + } + } else { + ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) + a.notifier.NotifyByNodeID( + ctx, + types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: []types.NodeID{node.ID}, + }, + node.ID, + ) + + ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) + a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID) + } + + return newNode, nil } // TODO(kradalby): // Rewrite in elem-go. func renderOIDCCallbackTemplate( user *types.User, + verb string, ) (*bytes.Buffer, error) { var content bytes.Buffer if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ User: user.DisplayNameOrUsername(), - Verb: "Authenticated", + Verb: verb, }); err != nil { return nil, fmt.Errorf("rendering OIDC callback template: %w", err) } diff --git a/hscontrol/templates/register_web.go b/hscontrol/templates/register_web.go index 8361048a77..271f4e7d78 100644 --- a/hscontrol/templates/register_web.go +++ b/hscontrol/templates/register_web.go @@ -6,6 +6,7 @@ import ( "github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go/attrs" "github.com/chasefleming/elem-go/styles" + "github.com/juanfont/headscale/hscontrol/types" ) var codeStyleRegisterWebAPI = styles.Props{ @@ -15,7 +16,7 @@ var codeStyleRegisterWebAPI = styles.Props{ styles.BackgroundColor: "#eee", } -func RegisterWeb(key string) *elem.Element { +func RegisterWeb(registrationID types.RegistrationID) *elem.Element { return HtmlStructure( elem.Title(nil, elem.Text("Registration - Headscale")), elem.Body(attrs.Props{ @@ -27,7 +28,7 @@ func RegisterWeb(key string) *elem.Element { elem.H2(nil, elem.Text("Machine registration")), elem.P(nil, elem.Text("Run the command below in the headscale server to add this machine to your network: ")), elem.Code(attrs.Props{attrs.Style: codeStyleRegisterWebAPI.ToInline()}, - elem.Text(fmt.Sprintf("headscale nodes register --user USERNAME --key %s", key)), + elem.Text(fmt.Sprintf("headscale nodes register --user USERNAME --key %s", registrationID.String())), ), ), ) diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 32ad8a67db..3b6c1be11f 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -3,8 +3,10 @@ package types import ( "context" "errors" + "fmt" "time" + "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" "tailscale.com/util/ctxkey" ) @@ -123,3 +125,40 @@ func NotifyCtx(ctx context.Context, origin, hostname string) context.Context { ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname) return ctx2 } + +const RegistrationIDLength = 24 + +type RegistrationID string + +func NewRegistrationID() (RegistrationID, error) { + rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength) + if err != nil { + return "", err + } + + return RegistrationID(rid), nil +} + +func MustRegistrationID() RegistrationID { + rid, err := NewRegistrationID() + if err != nil { + panic(err) + } + return rid +} + +func RegistrationIDFromString(str string) (RegistrationID, error) { + if len(str) != RegistrationIDLength { + return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength) + } + return RegistrationID(str), nil +} + +func (r RegistrationID) String() string { + return string(r) +} + +type RegisterNode struct { + Node Node + Registered chan struct{} +} diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index ce38b82e87..08769060bc 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -32,7 +32,8 @@ func GenerateRandomBytes(n int) ([]byte, error) { func GenerateRandomStringURLSafe(n int) (string, error) { b, err := GenerateRandomBytes(n) - return base64.RawURLEncoding.EncodeToString(b), err + uenc := base64.RawURLEncoding.EncodeToString(b) + return uenc[:n], err } // GenerateRandomStringDNSSafe returns a DNS-safe diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index e74eae56ab..22790f91b5 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -11,8 +11,8 @@ import ( "net" "net/http" "net/http/cookiejar" - "net/http/httptest" "net/netip" + "net/url" "sort" "strconv" "testing" @@ -56,7 +56,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { scenario := AuthOIDCScenario{ Scenario: baseScenario, } - // defer scenario.ShutdownAssertNoPanics(t) + defer scenario.ShutdownAssertNoPanics(t) // Logins to MockOIDC is served by a queue with a strict order, // if we use more than one node per user, the order of the logins @@ -91,7 +91,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { hsic.WithTestName("oidcauthping"), hsic.WithConfigEnv(oidcMap), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), ) assertNoErrHeadscaleEnv(t, err) @@ -206,7 +205,6 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { spec, hsic.WithTestName("oidcexpirenodes"), hsic.WithConfigEnv(oidcMap), - hsic.WithHostnameAsServerURL(), ) assertNoErrHeadscaleEnv(t, err) @@ -497,7 +495,6 @@ func TestOIDC024UserCreation(t *testing.T) { hsic.WithTestName("oidcmigration"), hsic.WithConfigEnv(oidcMap), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), ) assertNoErrHeadscaleEnv(t, err) @@ -576,7 +573,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { hsic.WithTestName("oidcauthpkce"), hsic.WithConfigEnv(oidcMap), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), ) assertNoErrHeadscaleEnv(t, err) @@ -770,11 +766,6 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error func (s *AuthOIDCScenario) runTailscaleUp( userStr, loginServer string, ) error { - headscale, err := s.Headscale() - if err != nil { - return err - } - log.Printf("running tailscale up for user %s", userStr) if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { @@ -785,59 +776,11 @@ func (s *AuthOIDCScenario) runTailscaleUp( log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) } - loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetHostname()) - loginURL.Scheme = "http" - - if len(headscale.GetCert()) > 0 { - loginURL.Scheme = "https" - } - - httptest.NewRecorder() - hc := &http.Client{ - Transport: LoggingRoundTripper{}, - } - hc.Jar, err = cookiejar.New(nil) - if err != nil { - log.Printf("failed to create cookie jar: %s", err) - } - - log.Printf("%s login url: %s\n", tsc.Hostname(), loginURL.String()) - - log.Printf("%s logging in with url", tsc.Hostname()) - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := hc.Do(req) + _, err = doLoginURL(tsc.Hostname(), loginURL) if err != nil { - log.Printf( - "%s failed to login using url %s: %s", - tsc.Hostname(), - loginURL, - err, - ) - return err } - log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) - - if resp.StatusCode != http.StatusOK { - log.Printf("%s response code of oidc login request was %s", tsc.Hostname(), resp.Status) - body, _ := io.ReadAll(resp.Body) - log.Printf("body: %s", body) - - return errStatusCodeNotOK - } - - defer resp.Body.Close() - - _, err = io.ReadAll(resp.Body) - if err != nil { - log.Printf("%s failed to read response body: %s", tsc.Hostname(), err) - - return err - } - - log.Printf("Finished request for %s to join tailnet", tsc.Hostname()) return nil }) @@ -865,6 +808,49 @@ func (s *AuthOIDCScenario) runTailscaleUp( return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) } +// doLoginURL visits the given login URL and returns the body as a +// string. +func doLoginURL(hostname string, loginURL *url.URL) (string, error) { + log.Printf("%s login url: %s\n", hostname, loginURL.String()) + + var err error + hc := &http.Client{ + Transport: LoggingRoundTripper{}, + } + hc.Jar, err = cookiejar.New(nil) + if err != nil { + return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err) + } + + log.Printf("%s logging in with url", hostname) + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + resp, err := hc.Do(req) + if err != nil { + return "", fmt.Errorf("%s failed to send http request: %w", hostname, err) + } + + log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + log.Printf("body: %s", body) + + return "", fmt.Errorf("%s response code of login request was %w", hostname, err) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("%s failed to read response body: %s", hostname, err) + + return "", fmt.Errorf("%s failed to read response body: %w", hostname, err) + } + + return string(body), nil +} + func (s *AuthOIDCScenario) Shutdown() { err := s.pool.Purge(s.mockOIDC) if err != nil { diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 3ef3142245..72703e953e 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -1,13 +1,9 @@ package integration import ( - "context" - "crypto/tls" "errors" "fmt" - "io" "log" - "net/http" "net/netip" "net/url" "strings" @@ -47,7 +43,6 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { hsic.WithTestName("webauthping"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), ) assertNoErrHeadscaleEnv(t, err) @@ -87,7 +82,10 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { "user2": len(MustTestVersions), } - err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("weblogout")) + err = scenario.CreateHeadscaleEnv(spec, + hsic.WithTestName("weblogout"), + hsic.WithTLS(), + ) assertNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() @@ -135,7 +133,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { for userName := range spec { err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) if err != nil { - t.Fatalf("failed to run tailscale up: %s", err) + t.Fatalf("failed to run tailscale up (%q): %s", headscale.GetEndpoint(), err) } } @@ -227,11 +225,12 @@ func (s *AuthWebFlowScenario) CreateHeadscaleEnv( func (s *AuthWebFlowScenario) runTailscaleUp( userStr, loginServer string, ) error { - log.Printf("running tailscale up for user %s", userStr) + log.Printf("running tailscale up for user %q", userStr) if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { c := client user.joinWaitGroup.Go(func() error { + log.Printf("logging %q into %q", c.Hostname(), loginServer) loginURL, err := c.LoginWithURL(loginServer) if err != nil { log.Printf("failed to run tailscale up (%s): %s", c.Hostname(), err) @@ -273,39 +272,11 @@ func (s *AuthWebFlowScenario) runTailscaleUp( } func (s *AuthWebFlowScenario) runHeadscaleRegister(userStr string, loginURL *url.URL) error { - headscale, err := s.Headscale() - if err != nil { - return err - } - - log.Printf("loginURL: %s", loginURL) - loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) - loginURL.Scheme = "http" - - if len(headscale.GetCert()) > 0 { - loginURL.Scheme = "https" - } - - insecureTransport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint - } - httpClient := &http.Client{ - Transport: insecureTransport, - } - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := httpClient.Do(req) - if err != nil { - return err - } - - body, err := io.ReadAll(resp.Body) + body, err := doLoginURL("web-auth-not-set", loginURL) if err != nil { return err } - defer resp.Body.Close() - // see api.go HTML template codeSep := strings.Split(string(body), "") if len(codeSep) != 2 { diff --git a/integration/cli_test.go b/integration/cli_test.go index 08d5937cfc..59d39278fa 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" @@ -544,7 +545,6 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { hsic.WithTestName("clipak"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), ) assertNoErr(t, err) @@ -812,14 +812,14 @@ func TestNodeTagCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - machineKeys := []string{ - "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", + regIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - nodes := make([]*v1.Node, len(machineKeys)) + nodes := make([]*v1.Node, len(regIDs)) assert.Nil(t, err) - for index, machineKey := range machineKeys { + for index, regID := range regIDs { _, err := headscale.Execute( []string{ "headscale", @@ -830,7 +830,7 @@ func TestNodeTagCommand(t *testing.T) { "--user", "user1", "--key", - machineKey, + regID, "--output", "json", }, @@ -847,7 +847,7 @@ func TestNodeTagCommand(t *testing.T) { "user1", "register", "--key", - machineKey, + regID, "--output", "json", }, @@ -857,7 +857,7 @@ func TestNodeTagCommand(t *testing.T) { nodes[index] = &node } - assert.Len(t, nodes, len(machineKeys)) + assert.Len(t, nodes, len(regIDs)) var node v1.Node err = executeAndUnmarshal( @@ -889,7 +889,7 @@ func TestNodeTagCommand(t *testing.T) { assert.ErrorContains(t, err, "tag must start with the string 'tag:'") // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, len(machineKeys)) + resultMachines := make([]*v1.Node, len(regIDs)) err = executeAndUnmarshal( headscale, []string{ @@ -1054,18 +1054,17 @@ func TestNodeCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - // Pregenerated machine keys - machineKeys := []string{ - "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", + regIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - nodes := make([]*v1.Node, len(machineKeys)) + nodes := make([]*v1.Node, len(regIDs)) assert.Nil(t, err) - for index, machineKey := range machineKeys { + for index, regID := range regIDs { _, err := headscale.Execute( []string{ "headscale", @@ -1076,7 +1075,7 @@ func TestNodeCommand(t *testing.T) { "--user", "node-user", "--key", - machineKey, + regID, "--output", "json", }, @@ -1093,7 +1092,7 @@ func TestNodeCommand(t *testing.T) { "node-user", "register", "--key", - machineKey, + regID, "--output", "json", }, @@ -1104,7 +1103,7 @@ func TestNodeCommand(t *testing.T) { nodes[index] = &node } - assert.Len(t, nodes, len(machineKeys)) + assert.Len(t, nodes, len(regIDs)) // Test list all nodes after added seconds var listAll []v1.Node @@ -1135,14 +1134,14 @@ func TestNodeCommand(t *testing.T) { assert.Equal(t, "node-4", listAll[3].GetName()) assert.Equal(t, "node-5", listAll[4].GetName()) - otherUserMachineKeys := []string{ - "mkey:b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e", - "mkey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", + otherUserRegIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - otherUserMachines := make([]*v1.Node, len(otherUserMachineKeys)) + otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) assert.Nil(t, err) - for index, machineKey := range otherUserMachineKeys { + for index, regID := range otherUserRegIDs { _, err := headscale.Execute( []string{ "headscale", @@ -1153,7 +1152,7 @@ func TestNodeCommand(t *testing.T) { "--user", "other-user", "--key", - machineKey, + regID, "--output", "json", }, @@ -1170,7 +1169,7 @@ func TestNodeCommand(t *testing.T) { "other-user", "register", "--key", - machineKey, + regID, "--output", "json", }, @@ -1181,7 +1180,7 @@ func TestNodeCommand(t *testing.T) { otherUserMachines[index] = &node } - assert.Len(t, otherUserMachines, len(otherUserMachineKeys)) + assert.Len(t, otherUserMachines, len(otherUserRegIDs)) // Test list all nodes after added otherUser var listAllWithotherUser []v1.Node @@ -1294,17 +1293,16 @@ func TestNodeExpireCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - // Pregenerated machine keys - machineKeys := []string{ - "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", + regIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - nodes := make([]*v1.Node, len(machineKeys)) + nodes := make([]*v1.Node, len(regIDs)) - for index, machineKey := range machineKeys { + for index, regID := range regIDs { _, err := headscale.Execute( []string{ "headscale", @@ -1315,7 +1313,7 @@ func TestNodeExpireCommand(t *testing.T) { "--user", "node-expire-user", "--key", - machineKey, + regID, "--output", "json", }, @@ -1332,7 +1330,7 @@ func TestNodeExpireCommand(t *testing.T) { "node-expire-user", "register", "--key", - machineKey, + regID, "--output", "json", }, @@ -1343,7 +1341,7 @@ func TestNodeExpireCommand(t *testing.T) { nodes[index] = &node } - assert.Len(t, nodes, len(machineKeys)) + assert.Len(t, nodes, len(regIDs)) var listAll []v1.Node err = executeAndUnmarshal( @@ -1421,18 +1419,17 @@ func TestNodeRenameCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - // Pregenerated machine keys - machineKeys := []string{ - "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", - "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", + regIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - nodes := make([]*v1.Node, len(machineKeys)) + nodes := make([]*v1.Node, len(regIDs)) assert.Nil(t, err) - for index, machineKey := range machineKeys { + for index, regID := range regIDs { _, err := headscale.Execute( []string{ "headscale", @@ -1443,7 +1440,7 @@ func TestNodeRenameCommand(t *testing.T) { "--user", "node-rename-command", "--key", - machineKey, + regID, "--output", "json", }, @@ -1460,7 +1457,7 @@ func TestNodeRenameCommand(t *testing.T) { "node-rename-command", "register", "--key", - machineKey, + regID, "--output", "json", }, @@ -1471,7 +1468,7 @@ func TestNodeRenameCommand(t *testing.T) { nodes[index] = &node } - assert.Len(t, nodes, len(machineKeys)) + assert.Len(t, nodes, len(regIDs)) var listAll []v1.Node err = executeAndUnmarshal( @@ -1589,7 +1586,7 @@ func TestNodeMoveCommand(t *testing.T) { assertNoErr(t, err) // Randomly generated node key - machineKey := "mkey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa" + regID := types.MustRegistrationID() _, err = headscale.Execute( []string{ @@ -1601,7 +1598,7 @@ func TestNodeMoveCommand(t *testing.T) { "--user", "old-user", "--key", - machineKey, + regID.String(), "--output", "json", }, @@ -1618,7 +1615,7 @@ func TestNodeMoveCommand(t *testing.T) { "old-user", "register", "--key", - machineKey, + regID.String(), "--output", "json", }, diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index adad5b6a49..bc7a0a7d15 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -69,7 +69,6 @@ func TestDERPVerifyEndpoint(t *testing.T) { hsic.WithHostname(hostname), hsic.WithPort(headscalePort), hsic.WithCustomTLS(certHeadscale, keyHeadscale), - hsic.WithHostnameAsServerURL(), hsic.WithDERPConfig(derpMap)) assertNoErrHeadscaleEnv(t, err) diff --git a/integration/dns_test.go b/integration/dns_test.go index d1693441b6..05e272f5c7 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -123,7 +123,6 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { hsic.WithFileInContainer(erPath, b), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), ) assertNoErrHeadscaleEnv(t, err) diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index d5fdb1612a..e17bbacbd6 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -105,7 +105,6 @@ func derpServerScenario( hsic.WithEmbeddedDERPServerOnly(), hsic.WithPort(443), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), hsic.WithConfigEnv(map[string]string{ "HEADSCALE_DERP_AUTO_UPDATE_ENABLED": "true", "HEADSCALE_DERP_UPDATE_FREQUENCY": "10s", diff --git a/integration/general_test.go b/integration/general_test.go index 985c952974..eb26cea903 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -44,7 +44,6 @@ func TestPingAllByIP(t *testing.T) { hsic.WithTestName("pingallbyip"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom), ) assertNoErrHeadscaleEnv(t, err) @@ -123,12 +122,9 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) { opts := []hsic.Option{hsic.WithTestName("pingallbyip")} if https { - opts = []hsic.Option{ - hsic.WithTestName("pingallbyip"), - hsic.WithEmbeddedDERPServerOnly(), + opts = append(opts, []hsic.Option{ hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), - } + }...) } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) @@ -172,7 +168,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) { // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 // https://github.com/juanfont/headscale/issues/2164 if !https { - time.Sleep(3 * time.Minute) + time.Sleep(5 * time.Minute) } for userName := range spec { @@ -1050,7 +1046,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) { hsic.WithTestName("pingallbyipmany"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), ) assertNoErrHeadscaleEnv(t, err) @@ -1133,7 +1128,6 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { hsic.WithTestName("deletenocrash"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), ) assertNoErrHeadscaleEnv(t, err) diff --git a/integration/hsic/config.go b/integration/hsic/config.go index 509052a300..76a5176c0a 100644 --- a/integration/hsic/config.go +++ b/integration/hsic/config.go @@ -26,9 +26,7 @@ func DefaultConfigEnv() map[string]string { "HEADSCALE_DNS_NAMESERVERS_GLOBAL": "127.0.0.11 1.1.1.1", "HEADSCALE_PRIVATE_KEY_PATH": "/tmp/private.key", "HEADSCALE_NOISE_PRIVATE_KEY_PATH": "/tmp/noise_private.key", - "HEADSCALE_LISTEN_ADDR": "0.0.0.0:8080", "HEADSCALE_METRICS_LISTEN_ADDR": "0.0.0.0:9090", - "HEADSCALE_SERVER_URL": "http://headscale:8080", "HEADSCALE_DERP_URLS": "https://controlplane.tailscale.com/derpmap/default", "HEADSCALE_DERP_AUTO_UPDATE_ENABLED": "false", "HEADSCALE_DERP_UPDATE_FREQUENCY": "1m", diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 883fc8bc02..e38abd1ce3 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -7,9 +7,7 @@ import ( "fmt" "io" "log" - "net" "net/http" - "net/url" "os" "path" "strconv" @@ -166,17 +164,6 @@ func WithHostname(hostname string) Option { } } -// WithHostnameAsServerURL sets the Headscale ServerURL based on -// the Hostname. -func WithHostnameAsServerURL() Option { - return func(hsic *HeadscaleInContainer) { - hsic.env["HEADSCALE_SERVER_URL"] = fmt.Sprintf("http://%s", - net.JoinHostPort(hsic.GetHostname(), - fmt.Sprintf("%d", hsic.port)), - ) - } -} - // WithFileInContainer adds a file to the container at the given path. func WithFileInContainer(path string, contents []byte) Option { return func(hsic *HeadscaleInContainer) { @@ -297,16 +284,6 @@ func New( portProto := fmt.Sprintf("%d/tcp", hsic.port) - serverURL, err := url.Parse(hsic.env["HEADSCALE_SERVER_URL"]) - if err != nil { - return nil, err - } - - if len(hsic.tlsCert) != 0 && len(hsic.tlsKey) != 0 { - serverURL.Scheme = "https" - hsic.env["HEADSCALE_SERVER_URL"] = serverURL.String() - } - headscaleBuildOptions := &dockertest.BuildOptions{ Dockerfile: IntegrationTestDockerFileName, ContextDir: dockerContextPath, @@ -352,6 +329,12 @@ func New( hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath } + + // Server URL and Listen Addr should not be overridable outside of + // the configuration passed to docker. + hsic.env["HEADSCALE_SERVER_URL"] = hsic.GetEndpoint() + hsic.env["HEADSCALE_LISTEN_ADDR"] = fmt.Sprintf("0.0.0.0:%d", hsic.port) + for key, value := range hsic.env { env = append(env, fmt.Sprintf("%s=%s", key, value)) } @@ -649,7 +632,7 @@ func (t *HeadscaleInContainer) GetHealthEndpoint() string { // GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer. func (t *HeadscaleInContainer) GetEndpoint() string { hostEndpoint := fmt.Sprintf("%s:%d", - t.GetIP(), + t.GetHostname(), t.port) if t.hasTLS() { diff --git a/integration/scenario.go b/integration/scenario.go index 987b8dbeb9..e45446a719 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -347,6 +347,51 @@ func (s *Scenario) CreateUser(user string) error { /// Client related stuff +func (s *Scenario) CreateTailscaleNode( + version string, + opts ...tsic.Option, +) (TailscaleClient, error) { + headscale, err := s.Headscale() + if err != nil { + return nil, fmt.Errorf("failed to create tailscale node (version: %s): %w", version, err) + } + + cert := headscale.GetCert() + hostname := headscale.GetHostname() + + s.mu.Lock() + defer s.mu.Unlock() + opts = append(opts, + tsic.WithCACert(cert), + tsic.WithHeadscaleName(hostname), + ) + + tsClient, err := tsic.New( + s.pool, + version, + s.network, + opts..., + ) + if err != nil { + return nil, fmt.Errorf( + "failed to create tailscale (%s) node: %w", + tsClient.Hostname(), + err, + ) + } + + err = tsClient.WaitForNeedsLogin() + if err != nil { + return nil, fmt.Errorf( + "failed to wait for tailscaled (%s) to need login: %w", + tsClient.Hostname(), + err, + ) + } + + return tsClient, nil +} + // CreateTailscaleNodesInUser creates and adds a new TailscaleClient to a // User in the Scenario. func (s *Scenario) CreateTailscaleNodesInUser( diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index e63a7b6ecf..c2cb8515b0 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -466,7 +466,7 @@ func (t *TailscaleInContainer) Login( // This login mechanism uses web + command line flow for authentication. func (t *TailscaleInContainer) LoginWithURL( loginServer string, -) (*url.URL, error) { +) (loginURL *url.URL, err error) { command := []string{ "tailscale", "up", @@ -475,20 +475,27 @@ func (t *TailscaleInContainer) LoginWithURL( "--accept-routes=false", } - _, stderr, err := t.Execute(command) + stdout, stderr, err := t.Execute(command) if errors.Is(err, errTailscaleNotLoggedIn) { return nil, errTailscaleCannotUpWithoutAuthkey } - urlStr := strings.ReplaceAll(stderr, "\nTo authenticate, visit:\n\n\t", "") + defer func() { + if err != nil { + log.Printf("join command: %q", strings.Join(command, " ")) + } + }() + + urlStr := strings.ReplaceAll(stdout+stderr, "\nTo authenticate, visit:\n\n\t", "") urlStr = strings.TrimSpace(urlStr) + if urlStr == "" { + return nil, fmt.Errorf("failed to get login URL: stdout: %s, stderr: %s", stdout, stderr) + } + // parse URL - loginURL, err := url.Parse(urlStr) + loginURL, err = url.Parse(urlStr) if err != nil { - log.Printf("Could not parse login URL: %s", err) - log.Printf("Original join command result: %s", stderr) - return nil, err } @@ -497,12 +504,17 @@ func (t *TailscaleInContainer) LoginWithURL( // Logout runs the logout routine on the given Tailscale instance. func (t *TailscaleInContainer) Logout() error { - _, _, err := t.Execute([]string{"tailscale", "logout"}) + stdout, stderr, err := t.Execute([]string{"tailscale", "logout"}) if err != nil { return err } - return nil + stdout, stderr, _ = t.Execute([]string{"tailscale", "status"}) + if !strings.Contains(stdout+stderr, "Logged out.") { + return fmt.Errorf("failed to logout, stdout: %s, stderr: %s", stdout, stderr) + } + + return t.waitForBackendState("NeedsLogin") } // Helper that runs `tailscale up` with no arguments. @@ -826,28 +838,16 @@ func (t *TailscaleInContainer) FailingPeersAsString() (string, bool, error) { // WaitForNeedsLogin blocks until the Tailscale (tailscaled) instance has // started and needs to be logged into. func (t *TailscaleInContainer) WaitForNeedsLogin() error { - return t.pool.Retry(func() error { - status, err := t.Status() - if err != nil { - return errTailscaleStatus(t.hostname, err) - } - - // ipnstate.Status.CurrentTailnet was added in Tailscale 1.22.0 - // https://github.com/tailscale/tailscale/pull/3865 - // - // Before that, we can check the BackendState to see if the - // tailscaled daemon is connected to the control system. - if status.BackendState == "NeedsLogin" { - return nil - } - - return errTailscaledNotReadyForLogin - }) + return t.waitForBackendState("NeedsLogin") } // WaitForRunning blocks until the Tailscale (tailscaled) instance is logged in // and ready to be used. func (t *TailscaleInContainer) WaitForRunning() error { + return t.waitForBackendState("Running") +} + +func (t *TailscaleInContainer) waitForBackendState(state string) error { return t.pool.Retry(func() error { status, err := t.Status() if err != nil { @@ -859,7 +859,7 @@ func (t *TailscaleInContainer) WaitForRunning() error { // // Before that, we can check the BackendState to see if the // tailscaled daemon is connected to the control system. - if status.BackendState == "Running" { + if status.BackendState == state { return nil }