Skip to content

Commit

Permalink
Merge pull request #1024 from stgraber/cli
Browse files Browse the repository at this point in the history
Properly handle request retries on OIDC
  • Loading branch information
stgraber authored Jul 20, 2024
2 parents 0b0aa95 + 173e64c commit 65ac8ef
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 6 deletions.
16 changes: 15 additions & 1 deletion client/incus.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,21 @@ func (r *ProtocolIncus) DoHTTP(req *http.Request) (*http.Response, error) {
return r.oidcClient.do(req)
}

return r.http.Do(req)
resp, err := r.http.Do(req)
if resp != nil && resp.StatusCode == http.StatusUseProxy && req.GetBody != nil {
// Reset the request body.
body, err := req.GetBody()
if err != nil {
return nil, err
}

req.Body = body

// Retry the request.
return r.http.Do(req)
}

return resp, err
}

// DoWebsocket performs a websocket connection, using OIDC authentication if set.
Expand Down
2 changes: 2 additions & 0 deletions client/incus_images.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,8 @@ func (r *ProtocolIncus) CreateImage(image api.ImagesPost, args *ImageCreateArgs)
return nil, err
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(body), nil }

// Setup the headers
req.Header.Set("Content-Type", contentType)
if image.Public {
Expand Down
4 changes: 4 additions & 0 deletions client/incus_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ func (r *ProtocolIncus) CreateInstanceFromBackup(args InstanceBackupArgs) (Opera
return nil, err
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(args.BackupFile), nil }
req.Header.Set("Content-Type", "application/octet-stream")

if args.PoolName != "" {
Expand Down Expand Up @@ -1472,6 +1473,8 @@ func (r *ProtocolIncus) CreateInstanceFile(instanceName string, filePath string,
return err
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(args.Content), nil }

// Set the various headers
if args.UID > -1 {
req.Header.Set("X-Incus-uid", fmt.Sprintf("%d", args.UID))
Expand Down Expand Up @@ -2403,6 +2406,7 @@ func (r *ProtocolIncus) CreateInstanceTemplateFile(instanceName string, template
return err
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(content), nil }
req.Header.Set("Content-Type", "application/octet-stream")

// Send the request
Expand Down
9 changes: 9 additions & 0 deletions client/incus_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import (
"golang.org/x/oauth2"
)

// ErrOIDCExpired is returned when the token is expired and we can't retry the request ourselves.
var ErrOIDCExpired = fmt.Errorf("OIDC token expired, please re-try the request")

// setupOIDCClient initializes the OIDC (OpenID Connect) client with given tokens if it hasn't been set up already.
// It also assigns the protocol's http client to the oidcClient's httpClient.
func (r *ProtocolIncus) setupOIDCClient(token *oidc.Tokens[*oidc.IDTokenClaims]) {
Expand Down Expand Up @@ -119,6 +122,7 @@ func (o *oidcClient) do(req *http.Request) (*http.Response, error) {
return resp, nil
}

// Refresh the token.
err = o.refresh(issuer, clientID)
if err != nil {
err = o.authenticate(issuer, clientID, audience)
Expand All @@ -127,6 +131,11 @@ func (o *oidcClient) do(req *http.Request) (*http.Response, error) {
}
}

// If not dealing with something we can retry, return a clear error.
if req.Method != "GET" && req.GetBody == nil {
return resp, ErrOIDCExpired
}

// Set the new access token in the header.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.tokens.AccessToken))

Expand Down
1 change: 1 addition & 0 deletions client/incus_storage_buckets.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ func (r *ProtocolIncus) CreateStoragePoolBucketFromBackup(pool string, args Stor
return nil, err
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(args.BackupFile), nil }
req.Header.Set("Content-Type", "application/octet-stream")

if args.Name != "" {
Expand Down
10 changes: 6 additions & 4 deletions client/incus_storage_volumes.go
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,10 @@ func (r *ProtocolIncus) CreateStoragePoolVolumeFromISO(pool string, args Storage
return nil, err
}

if args.Name == "" {
return nil, fmt.Errorf("Missing volume name")
}

path := fmt.Sprintf("/storage-pools/%s/volumes/custom", url.PathEscape(pool))

// Prepare the HTTP request.
Expand All @@ -996,10 +1000,7 @@ func (r *ProtocolIncus) CreateStoragePoolVolumeFromISO(pool string, args Storage
return nil, err
}

if args.Name == "" {
return nil, fmt.Errorf("Missing volume name")
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(args.BackupFile), nil }
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("X-Incus-name", args.Name)
req.Header.Set("X-Incus-type", "iso")
Expand Down Expand Up @@ -1057,6 +1058,7 @@ func (r *ProtocolIncus) CreateStoragePoolVolumeFromBackup(pool string, args Stor
return nil, err
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(args.BackupFile), nil }
req.Header.Set("Content-Type", "application/octet-stream")

if args.Name != "" {
Expand Down
11 changes: 10 additions & 1 deletion cmd/incus/remote_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,16 @@ func (t remoteProxyTransport) RoundTrip(r *http.Request) (*http.Response, error)
r.URL.Host = t.baseURL.Host
r.RequestURI = ""

return t.s.DoHTTP(r)
resp, err := t.s.DoHTTP(r)
if err == incus.ErrOIDCExpired {
// Override the response so the client knows to retry the request.
resp.StatusCode = http.StatusUseProxy
resp.Status = "Retry the request for OIDC refresh"

return resp, nil
}

return resp, err
}

type remoteProxyHandler struct {
Expand Down

0 comments on commit 65ac8ef

Please sign in to comment.