Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: proxy and add tests #2843

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/cmd/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ var agentCmd = &cobra.Command{
Short: lang.CmdInternalAgentShort,
Long: lang.CmdInternalAgentLong,
RunE: func(cmd *cobra.Command, _ []string) error {
return agent.StartWebhook(cmd.Context())
cluster, err := cluster.NewCluster()
if err != nil {
return err
}
return agent.StartWebhook(cmd.Context(), cluster)
},
}

Expand All @@ -52,7 +56,11 @@ var httpProxyCmd = &cobra.Command{
Short: lang.CmdInternalProxyShort,
Long: lang.CmdInternalProxyLong,
RunE: func(cmd *cobra.Command, _ []string) error {
return agent.StartHTTPProxy(cmd.Context())
cluster, err := cluster.NewCluster()
phillebaba marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}
return agent.StartHTTPProxy(cmd.Context(), cluster)
},
}

Expand Down
1 change: 0 additions & 1 deletion src/config/lang/english.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ const (
AgentErrMarshallJSONPatch = "unable to marshall the json patch"
AgentErrMarshalResponse = "unable to marshal the response"
AgentErrNilReq = "malformed admission review: request is nil"
AgentErrUnableTransform = "unable to transform the provided request; see zarf http proxy logs for more details"
)

// Package create
Expand Down
97 changes: 28 additions & 69 deletions src/internal/agent/http/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package http

import (
"context"
"crypto/tls"
"fmt"
"io"
Expand All @@ -14,51 +13,43 @@ import (
"net/url"
"strings"

"github.com/zarf-dev/zarf/src/config/lang"
"github.com/zarf-dev/zarf/src/pkg/cluster"
"github.com/zarf-dev/zarf/src/pkg/message"
"github.com/zarf-dev/zarf/src/pkg/transform"
"github.com/zarf-dev/zarf/src/types"
)

// ProxyHandler constructs a new httputil.ReverseProxy and returns an http handler.
func ProxyHandler() http.HandlerFunc {
func ProxyHandler(cluster *cluster.Cluster) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := proxyRequestTransform(r)
state, err := cluster.LoadZarfState(r.Context())
if err != nil {
message.Debugf("%#v", err)
w.WriteHeader(http.StatusInternalServerError)
//nolint: errcheck // ignore
w.Write([]byte(lang.AgentErrUnableTransform))
w.Write([]byte("unable to load Zarf state, see the Zarf HTTP proxy logs for more details"))
return
}
err = proxyRequestTransform(r, state)
if err != nil {
message.Debugf("%#v", err)
w.WriteHeader(http.StatusInternalServerError)
//nolint: errcheck // ignore
w.Write([]byte("unable to transform the provided request, see the Zarf HTTP proxy logs for more details"))
return
}

proxy := &httputil.ReverseProxy{Director: func(_ *http.Request) {}, ModifyResponse: proxyResponseTransform}
proxy.ServeHTTP(w, r)
}
}

func proxyRequestTransform(r *http.Request) error {
message.Debugf("Before Req %#v", r)
message.Debugf("Before Req URL %#v", r.URL)

func proxyRequestTransform(r *http.Request, state *types.ZarfState) error {
// We add this so that we can use it to rewrite urls in the response if needed
r.Header.Add("X-Forwarded-Host", r.Host)

// We remove this so that go will encode and decode on our behalf (see https://pkg.go.dev/net/http#Transport DisableCompression)
r.Header.Del("Accept-Encoding")

c, err := cluster.NewCluster()
if err != nil {
return err
}
ctx := context.Background()
state, err := c.LoadZarfState(ctx)
if err != nil {
return err
}

var targetURL *url.URL

// Setup authentication for each type of service based on User Agent
switch {
case isGitUserAgent(r.UserAgent()):
Expand All @@ -70,6 +61,8 @@ func proxyRequestTransform(r *http.Request) error {
}

// Transform the URL; if we see the NoTransform prefix, strip it; otherwise, transform the URL based on User Agent
var err error
var targetURL *url.URL
if strings.HasPrefix(r.URL.Path, transform.NoTransform) {
switch {
case isGitUserAgent(r.UserAgent()):
Expand All @@ -89,7 +82,6 @@ func proxyRequestTransform(r *http.Request) error {
targetURL, err = transform.GenTransformURL(state.ArtifactServer.Address, getTLSScheme(r.TLS)+r.Host+r.URL.String())
}
}

if err != nil {
return err
}
Expand All @@ -98,19 +90,12 @@ func proxyRequestTransform(r *http.Request) error {
r.URL = targetURL
r.RequestURI = getRequestURI(targetURL.Path, targetURL.RawQuery, targetURL.Fragment)

message.Debugf("After Req %#v", r)
message.Debugf("After Req URL%#v", r.URL)

return nil
}

func proxyResponseTransform(resp *http.Response) error {
message.Debugf("Before Resp %#v", resp)

// Handle redirection codes (3xx) by adding a marker to let Zarf know this has been redirected
if resp.StatusCode/100 == 3 {
message.Debugf("Before Resp Location %#v", resp.Header.Get("Location"))

locationURL, err := url.Parse(resp.Header.Get("Location"))
if err != nil {
return err
Expand All @@ -119,72 +104,46 @@ func proxyResponseTransform(resp *http.Response) error {
locationURL.Host = resp.Request.Header.Get("X-Forwarded-Host")

resp.Header.Set("Location", locationURL.String())

message.Debugf("After Resp Location %#v", resp.Header.Get("Location"))
}

contentType := resp.Header.Get("Content-Type")

// Handle text content returns that may contain links
contentType := resp.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "text") || strings.HasPrefix(contentType, "application/json") || strings.HasPrefix(contentType, "application/xml") {
err := replaceBodyLinks(resp)

forwardedPrefix := fmt.Sprintf("%s%s%s", getTLSScheme(resp.Request.TLS), resp.Request.Header.Get("X-Forwarded-Host"), transform.NoTransform)
targetPrefix := fmt.Sprintf("%s%s", getTLSScheme(resp.TLS), resp.Request.Host)
b, err := io.ReadAll(resp.Body)
if err != nil {
message.Debugf("%#v", err)
return err
}
}

message.Debugf("After Resp %#v", resp)

return nil
}

func replaceBodyLinks(resp *http.Response) error {
message.Debugf("Resp Request: %#v", resp.Request)

// Create the forwarded (online) and target (offline) URL prefixes to replace
forwardedPrefix := fmt.Sprintf("%s%s%s", getTLSScheme(resp.Request.TLS), resp.Request.Header.Get("X-Forwarded-Host"), transform.NoTransform)
targetPrefix := fmt.Sprintf("%s%s", getTLSScheme(resp.TLS), resp.Request.Host)
err = resp.Body.Close()
if err != nil {
return err
}
bodyString := strings.ReplaceAll(string(b), targetPrefix, forwardedPrefix)

b, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = resp.Body.Close()
if err != nil {
return err
resp.Body = io.NopCloser(strings.NewReader(bodyString))
resp.ContentLength = int64(len(bodyString))
resp.Header.Set("Content-Length", fmt.Sprint(int64(len(bodyString))))
}
bodyString := strings.ReplaceAll(string(b), targetPrefix, forwardedPrefix)

// Setup the new reader, and correct the content length
resp.Body = io.NopCloser(strings.NewReader(bodyString))
resp.ContentLength = int64(len(bodyString))
resp.Header.Set("Content-Length", fmt.Sprint(int64(len(bodyString))))

return nil
}

func getTLSScheme(tls *tls.ConnectionState) string {
scheme := "https://"

if tls == nil {
scheme = "http://"
}

return scheme
}

func getRequestURI(path, query, fragment string) string {
uri := path

if query != "" {
uri += "?" + query
}

if fragment != "" {
uri += "#" + fragment
}

return uri
}

Expand Down
Loading