diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index a51f784d20..cbdf094113 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -627,6 +627,7 @@ func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, } c := &http.Cookie{ + Path: "/", Name: name, Value: val, MaxAge: int(time.Hour.Seconds()), diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index e8b4999189..e74eae56ab 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -10,6 +10,8 @@ import ( "log" "net" "net/http" + "net/http/cookiejar" + "net/http/httptest" "net/netip" "sort" "strconv" @@ -747,6 +749,24 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc }, nil } +type LoggingRoundTripper struct{} + +func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + noTls := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint + } + resp, err := noTls.RoundTrip(req) + if err != nil { + return nil, err + } + + log.Printf("---") + log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String()) + log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies()) + + return resp, nil +} + func (s *AuthOIDCScenario) runTailscaleUp( userStr, loginServer string, ) error { @@ -758,35 +778,39 @@ func (s *AuthOIDCScenario) runTailscaleUp( log.Printf("running tailscale up for user %s", userStr) if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { - c := client + tsc := client user.joinWaitGroup.Go(func() error { - loginURL, err := c.LoginWithURL(loginServer) + loginURL, err := tsc.LoginWithURL(loginServer) if err != nil { - log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) + log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) } - loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) + loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetHostname()) loginURL.Scheme = "http" if len(headscale.GetCert()) > 0 { loginURL.Scheme = "https" } - insecureTransport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint + httptest.NewRecorder() + hc := &http.Client{ + Transport: LoggingRoundTripper{}, + } + hc.Jar, err = cookiejar.New(nil) + if err != nil { + log.Printf("failed to create cookie jar: %s", err) } - log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) + log.Printf("%s login url: %s\n", tsc.Hostname(), loginURL.String()) - log.Printf("%s logging in with url", c.Hostname()) - httpClient := &http.Client{Transport: insecureTransport} + log.Printf("%s logging in with url", tsc.Hostname()) ctx := context.Background() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := httpClient.Do(req) + resp, err := hc.Do(req) if err != nil { log.Printf( "%s failed to login using url %s: %s", - c.Hostname(), + tsc.Hostname(), loginURL, err, ) @@ -794,8 +818,10 @@ func (s *AuthOIDCScenario) runTailscaleUp( return err } + log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) + if resp.StatusCode != http.StatusOK { - log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + log.Printf("%s response code of oidc login request was %s", tsc.Hostname(), resp.Status) body, _ := io.ReadAll(resp.Body) log.Printf("body: %s", body) @@ -806,12 +832,12 @@ func (s *AuthOIDCScenario) runTailscaleUp( _, err = io.ReadAll(resp.Body) if err != nil { - log.Printf("%s failed to read response body: %s", c.Hostname(), err) + log.Printf("%s failed to read response body: %s", tsc.Hostname(), err) return err } - log.Printf("Finished request for %s to join tailnet", c.Hostname()) + log.Printf("Finished request for %s to join tailnet", tsc.Hostname()) return nil })