diff --git a/internal/remotewrite/logchannel.go b/internal/remotewrite/logchannel.go index 1702a8a4..4d1d04aa 100644 --- a/internal/remotewrite/logchannel.go +++ b/internal/remotewrite/logchannel.go @@ -17,8 +17,6 @@ const ( var ( logMaxCount = int64(10) logInterval = 600 * time.Second - - LogChannels = []chan logMessage{} ) type logMessage struct { @@ -26,6 +24,21 @@ type logMessage struct { keyvals []interface{} } +type logManager struct { + logChannels map[string]chan logMessage +} + +func (lm *logManager) log(forEndpoint, key string, keyvals ...interface{}) { + stream, ok := lm.logChannels[forEndpoint] + if !ok || stream == nil { + return + } + stream <- logMessage{ + messageKey: key, + keyvals: keyvals, + } +} + type logCounter struct { // key for one log event logKey string @@ -46,7 +59,31 @@ func revertCounter(counter *logCounter) { } } -func InitChannels(logger log.Logger, size int) { +// newLogManager creates a new logManager for a list of endpoints +// and calls the custom process function if provided or defaultProcessFunction if nil +// process should start a go routine that reads from the logChannels and logs the messages +func newLogManager(logger log.Logger, forEndpoints []Endpoint, process func(logger log.Logger, messages map[string]chan logMessage)) *logManager { + logChannels := make(map[string]chan logMessage) + for i, ep := range forEndpoints { + if ep.Name == "" { + ep.Name = fmt.Sprintf("endpoint_%d", i) + } + logChannels[ep.Name] = make(chan logMessage) + + } + logChannels[thanosEndpointName] = make(chan logMessage) + + if process == nil { + process = defaultProcessFunction + } + process(logger, logChannels) + + return &logManager{ + logChannels: logChannels, + } +} + +func defaultProcessFunction(logger log.Logger, messages map[string]chan logMessage) { if os.Getenv("LOG_MAX_COUNT") != "" { v, err := strconv.ParseInt(os.Getenv("LOG_MAX_COUNT"), 10, 0) if err != nil { @@ -59,18 +96,17 @@ func InitChannels(logger log.Logger, size int) { logInterval = v } } - for i := 0; i < size; i++ { - LogChannels = append(LogChannels, make(chan logMessage)) - } - for i := 0; i < size; i++ { - j := i - counter := &logCounter{ - LogTimestamps: []time.Time{}, - } + + for _, v := range messages { go func() { + counter := &logCounter{ + LogTimestamps: []time.Time{}, + } + messageStream := v + for { select { - case message := <-LogChannels[j]: + case message := <-messageStream: if message.messageKey == successWrite { revertCounter(counter) } else { diff --git a/internal/remotewrite/proxy.go b/internal/remotewrite/proxy.go index 9fa7945a..5dfdff1b 100644 --- a/internal/remotewrite/proxy.go +++ b/internal/remotewrite/proxy.go @@ -2,7 +2,8 @@ package remotewrite import ( "bytes" - "io/ioutil" + "context" + "io" "net" "net/http" "net/url" @@ -18,13 +19,12 @@ import ( ) const ( - THANOS_ENDPOINT_NAME = "thanos-receiver" + thanosEndpointName = "thanos-receiver" ) type Endpoint struct { - Name string `yaml:"name"` - URL string `yaml:"url"` - // +optional + Name string `yaml:"name"` + URL string `yaml:"url"` ClientConfig *promconfig.HTTPClientConfig `yaml:"http_client_config,omitempty"` } @@ -42,107 +42,79 @@ var ( }, []string{"code", "name"}) ) -func remoteWrite(write *url.URL, endpoints []Endpoint, logger log.Logger) http.Handler { +func (rd *RequestDuplicator) remoteWrite(write *url.URL, endpoints []Endpoint, logger log.Logger, logManager *logManager) http.Handler { + var clientMap = map[string]*http.Client{} + clientMap = make(map[string]*http.Client) + defaultHTTPClient := defaultClient() + writePath := write.Path + writeHost := write.Host + if write.Scheme == "" { + write.Scheme = "http" + } + writeScheme := write.Scheme + + for _, ep := range endpoints { + var client = defaultHTTPClient + if ep.ClientConfig != nil { + epClient, err := promconfig.NewClientFromConfig(*ep.ClientConfig, ep.Name, + promconfig.WithDialContextFunc((&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext)) + if err == nil { + client = epClient + } + } + clientMap[ep.Name] = client + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requests.With(prometheus.Labels{"method": r.Method}).Inc() + rlogger := log.With(logger, "request", middleware.GetReqID(r.Context())) - body, _ := ioutil.ReadAll(r.Body) - _ = r.Body.Close() - r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) - - if write != nil { - remotewriteUrl := url.URL{} - remotewriteUrl.Path = path.Join(write.Path, r.URL.Path) - remotewriteUrl.Host = write.Host - remotewriteUrl.Scheme = write.Scheme - endpoints[len(endpoints)-1].URL = remotewriteUrl.String() + body, err := io.ReadAll(r.Body) + if err != nil { + level.Error(rlogger).Log("msg", "failed to read request body", "err", err) + w.WriteHeader(http.StatusInternalServerError) + return } + headers := r.Header.Clone() - rlogger := log.With(logger, "request", middleware.GetReqID(r.Context())) - for i, endpoint := range endpoints { - var client *http.Client - var err error - if endpoint.ClientConfig == nil { - client = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - IdleConnTimeout: 30 * time.Second, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - }, - } - } else { - client, err = promconfig.NewClientFromConfig(*endpoint.ClientConfig, endpoint.Name, - promconfig.WithDialContextFunc((&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext)) - if err != nil { - //level.Error(rlogger).Log("msg", "failed to create a new HTTP client", "err", err) - LogChannels[i] <- logMessage{ - messageKey: "failed to create a new HTTP client", - keyvals: []interface{}{ - "msg", "failed to create a new HTTP client", "err", err, - }} - } - } + rwReq, err := rebuildProxyRequest(r, body, writePath, writeHost, writeScheme) + if err != nil { + level.Error(rlogger).Log("msg", "failed to rebuild the request", "err", err) + w.WriteHeader(http.StatusInternalServerError) + return + } - req, err := http.NewRequest(http.MethodPost, endpoint.URL, bytes.NewReader(body)) - req.Header = r.Header - if err != nil { - //level.Error(rlogger).Log("msg", "Failed to create the forward request", "err", err, "url", endpoint.URL) - LogChannels[i] <- logMessage{ - messageKey: "failed to create the forward request", - keyvals: []interface{}{ - "msg", "failed to create the forward request", "err", err, - }} - } else { - ep := endpoint - j := i + go func() { + for _, endpoint := range endpoints { go func() { - resp, err := client.Do(req) + req, err := mirrorRequestFromBody(body, headers, endpoint.URL) if err != nil { - remotewriteRequests.With(prometheus.Labels{"code": "", "name": ep.Name}).Inc() - //level.Error(rlogger).Log("msg", "Failed to send request to the server", "err", err) - LogChannels[j] <- logMessage{ - messageKey: "failed to send request to the server", - keyvals: []interface{}{ - "msg", "failed to send request to the server", "err", err, - }} - } else { - defer resp.Body.Close() - remotewriteRequests.With(prometheus.Labels{"code": strconv.Itoa(resp.StatusCode), "name": ep.Name}).Inc() - if resp.StatusCode >= 300 || resp.StatusCode < 200 { - responseBody, err := ioutil.ReadAll(resp.Body) - if err != nil { - //level.Error(rlogger).Log("msg", "Failed to read response of the forward request", "err", err, "return code", resp.Status, "url", ep.URL) - LogChannels[j] <- logMessage{ - messageKey: "failed to forward metrics" + resp.Status, - keyvals: []interface{}{ - "msg", "failed to forward metrics", "return code", resp.Status, "url", ep.URL, - }} - } else { - LogChannels[j] <- logMessage{ - messageKey: "Failed to forward metrics" + resp.Status, - keyvals: []interface{}{ - "msg", "failed to forward metrics", "return code", resp.Status, "response", string(responseBody), "url", ep.URL}} - } - } else { - level.Debug(rlogger).Log("msg", successWrite, "url", ep.URL) - LogChannels[j] <- logMessage{ - messageKey: successWrite, - } - } + level.Error(rlogger).Log("msg", "failed to build the remote write request", "url", endpoint.URL, "err", err) + return } + client := getClientForEndpoint(endpoint.Name, clientMap) + _ = rd.doRemoteWriteRequest(client, req, endpoint.Name, logger) }() + } + }() + + // handle the main remote write endpoint request synchronously + if write != nil { + statusCode := rd.doRemoteWriteRequest(defaultHTTPClient, rwReq, thanosEndpointName, logger) + w.WriteHeader(statusCode) } }) } -func Proxy(write *url.URL, endpoints []Endpoint, logger log.Logger, r *prometheus.Registry) http.Handler { +type RequestDuplicator struct { + logManager *logManager +} + +func (rd *RequestDuplicator) Proxy(write *url.URL, endpoints []Endpoint, logger log.Logger, r *prometheus.Registry) http.Handler { r.MustRegister(requests) r.MustRegister(remotewriteRequests) @@ -151,14 +123,93 @@ func Proxy(write *url.URL, endpoints []Endpoint, logger log.Logger, r *prometheu endpoints = []Endpoint{} } - if write != nil { - endpoints = append(endpoints, Endpoint{ - URL: write.String(), - Name: THANOS_ENDPOINT_NAME, - }) + if rd.logManager == nil { + rd.logManager = newLogManager(logger, endpoints, nil) } - InitChannels(logger, len(endpoints)) + return rd.remoteWrite(write, endpoints, logger, rd.logManager) +} + +func rebuildProxyRequest(r *http.Request, body []byte, reqPath, host, scheme string) (*http.Request, error) { + remotewriteUrl := url.URL{} + remotewriteUrl.Path = path.Join(reqPath, r.URL.Path) + remotewriteUrl.Host = host + remotewriteUrl.Scheme = scheme + + req, err := http.NewRequest(r.Method, remotewriteUrl.String(), bytes.NewReader(body)) + if err != nil { + return nil, err - return remoteWrite(write, endpoints, logger) + } + req.Header = r.Header.Clone() + req.WithContext(r.Context()) + return req, nil +} + +// mirrorRequestFromBody build a remote write request for the upstream remote write endpoint +// we enforce a 5s timeout here to avoid having unbounded goroutines due to slow backends +func mirrorRequestFromBody(body []byte, headers http.Header, endpoint string) (*http.Request, error) { + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header = headers + ctx, _ := context.WithTimeout(context.Background(), 5*time.Second) + req = req.WithContext(ctx) + return req, nil +} + +func defaultClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + IdleConnTimeout: 30 * time.Second, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + }, + } +} + +func getClientForEndpoint(name string, fromPool map[string]*http.Client) *http.Client { + c, ok := fromPool[name] + if !ok { + return defaultClient() + } + return c +} + +func (rd *RequestDuplicator) doRemoteWriteRequest( + client *http.Client, + req *http.Request, + epName string, + logger log.Logger, +) int { + resp, err := client.Do(req) + if err != nil { + remotewriteRequests.With(prometheus.Labels{"code": "", "name": epName}).Inc() + rd.logManager.log(epName, "failed to send request to the server", "msg", "failed to send request to the server", "err", err) + return http.StatusInternalServerError + } + + remotewriteRequests.With(prometheus.Labels{"code": strconv.Itoa(resp.StatusCode), "name": epName}).Inc() + if resp.StatusCode >= 300 || resp.StatusCode < 200 { + responseBody, err := io.ReadAll(resp.Body) + keyVals := []interface{}{ + "msg", "failed to forward metrics", + "endpoint", epName, + "response code", resp.Status, + "response", string(responseBody), + "url", req.URL.String(), + } + + if err != nil { + keyVals = append(keyVals, "err", err) + } + rd.logManager.log(epName, "failed to forward metrics "+resp.Status, keyVals...) + return resp.StatusCode + } + level.Debug(logger).Log("msg", "Successfully forwarded metrics", "url", req.URL.String()) + return resp.StatusCode } diff --git a/internal/remotewrite/proxy_test.go b/internal/remotewrite/proxy_test.go index 82926684..c984eff3 100644 --- a/internal/remotewrite/proxy_test.go +++ b/internal/remotewrite/proxy_test.go @@ -2,61 +2,114 @@ package remotewrite import ( "bytes" + "fmt" + "github.com/go-kit/kit/log" + "github.com/observatorium/observatorium/internal" + "github.com/prometheus/client_golang/prometheus" "net/http" "net/http/httptest" "net/url" + "sync" "testing" - - "github.com/observatorium/observatorium/internal" - "github.com/prometheus/client_golang/prometheus" ) func TestProxy(t *testing.T) { logger := internal.NewLogger("debug", "logfmt", "test") - - // remoteWriteMain is the primary remote write endpoint that always returns 403 Forbidden. - remoteWriteMain := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger.Log("msg", "remote write main") - w.WriteHeader(http.StatusForbidden) - })) - defer remoteWriteMain.Close() - - // remoteWriteMirror is a secondary remote write endpoint that always returns 403 Forbidden. - remoteWriteMirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger.Log("msg", "remote write mirror") - w.WriteHeader(http.StatusForbidden) - })) - defer remoteWriteMirror.Close() - reg := prometheus.NewRegistry() + client := http.DefaultClient - writeURL, err := url.Parse(remoteWriteMain.URL) - if err != nil { - t.Fatal(err) + type parsedLog struct { + Message string + Endpoint string + Code int } - endpoints := []Endpoint{ + testCases := []struct { + name string + mainReturnCode int + mirrorReturnCode int + expectLogLength int + expectLogs map[string]parsedLog + }{ { - Name: "mirror", - URL: remoteWriteMirror.URL, + name: "test", + mainReturnCode: http.StatusForbidden, + mirrorReturnCode: http.StatusForbidden, + expectLogLength: 2, }, } - gateway := httptest.NewServer(Proxy(writeURL, endpoints, logger, reg)) - defer gateway.Close() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + remoteWriteMain := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger.Log("msg", "remote write main") + w.WriteHeader(tc.mainReturnCode) + })) + defer remoteWriteMain.Close() - req, err := http.NewRequest(http.MethodPost, gateway.URL, bytes.NewBufferString("some metrics here :)")) - if err != nil { - t.Fatal(err) - } + // remoteWriteMirror is a secondary remote write endpoint that always returns 403 Forbidden. + remoteWriteMirror := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger.Log("msg", "remote write mirror") + w.WriteHeader(tc.mirrorReturnCode) + })) + defer remoteWriteMirror.Close() - client := http.DefaultClient + writeURL, err := url.Parse(remoteWriteMain.URL) + if err != nil { + t.Fatal(err) + } - res, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - _ = res.Body.Close() - if res.StatusCode != http.StatusForbidden { - t.Fatalf("expected status code %d, got %d", http.StatusForbidden, res.StatusCode) + endpoints := []Endpoint{ + { + Name: "mirror", + URL: remoteWriteMirror.URL, + }, + } + + var expectKeyVals []logMessage + var wg sync.WaitGroup + lm := newLogManager(logger, endpoints, func(logger log.Logger, messages map[string]chan logMessage) { + for _, v := range messages { + wg.Add(1) + messageStream := v + go func() { + for { + select { + case message := <-messageStream: + expectKeyVals = append(expectKeyVals, message) + wg.Done() + } + } + }() + } + }) + rd := &RequestDuplicator{ + logManager: lm, + } + gateway := httptest.NewServer(rd.Proxy(writeURL, endpoints, logger, reg)) + defer gateway.Close() + + req, err := http.NewRequest(http.MethodPost, gateway.URL, bytes.NewBufferString("some metrics here :)")) + if err != nil { + t.Fatal(err) + } + + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + _ = res.Body.Close() + + if res.StatusCode != tc.mainReturnCode { + t.Fatalf("expected status code %d, got %d", tc.mainReturnCode, res.StatusCode) + } + + wg.Wait() + if expectKeyVals == nil || len(expectKeyVals) != 2 { + t.Fatalf("expected 2 log messages, got %d", len(expectKeyVals)) + } + for _, log := range expectKeyVals { + fmt.Println(log) + } + }) } }