Skip to content

Commit

Permalink
Merge pull request #745 from luraproject/propagate_seq_merger_params
Browse files Browse the repository at this point in the history
Global param propagation for sequential merger
  • Loading branch information
kpacha authored Jan 17, 2025
2 parents fff63b0 + fb3d36f commit 9b8a7c3
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 89 deletions.
214 changes: 145 additions & 69 deletions proxy/merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package proxy
import (
"context"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
Expand All @@ -16,7 +17,7 @@ import (
)

// NewMergeDataMiddleware creates proxy middleware for merging responses from several backends
func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.EndpointConfig) Middleware {
func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.EndpointConfig) Middleware { // skipcq: GO-R1005
totalBackends := len(endpointConfig.Backend)
if totalBackends == 0 {
logger.Fatal("all endpoints must have at least one backend: NewMergeDataMiddleware")
Expand All @@ -27,7 +28,7 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
}
serviceTimeout := time.Duration(85*endpointConfig.Timeout.Nanoseconds()/100) * time.Nanosecond
combiner := getResponseCombiner(endpointConfig.ExtraConfig)
isSequential := shouldRunSequentialMerger(endpointConfig)
isSequential, propagatedParams := sequentialMergerConfig(endpointConfig)

logger.Debug(
fmt.Sprintf(
Expand Down Expand Up @@ -57,24 +58,86 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
return parallelMerge(reqClone, serviceTimeout, combiner, next...)
}

patterns := make([]string, len(endpointConfig.Backend))
sequentialReplacements := make([][]sequentialBackendReplacement, totalBackends)

var rePropagatedParams = regexp.MustCompile(`[Rr]esp(\d+)_?([\w-.]+)?`)
var reUrlPatterns = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-.]+)\}\}`)
destKeyGenerator := func(i string, t string) string {
key := "Resp" + i
if t != "" {
key += "_" + t
}
return key
}

for i, b := range endpointConfig.Backend {
patterns[i] = b.URLPattern
for _, match := range reUrlPatterns.FindAllStringSubmatch(b.URLPattern, -1) {
if len(match) > 1 {
backendIndex, err := strconv.Atoi(match[1])
if err != nil {
continue
}

sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
backendIndex: backendIndex,
destination: destKeyGenerator(match[1], match[2]),
source: strings.Split(match[2], "."),
fullResponse: match[2] == "",
})
}
}

if i > 0 {
for _, p := range propagatedParams {
for _, match := range rePropagatedParams.FindAllStringSubmatch(p, -1) {
if len(match) > 1 {
backendIndex, err := strconv.Atoi(match[1])
if err != nil || backendIndex >= totalBackends {
continue
}

sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
backendIndex: backendIndex,
destination: destKeyGenerator(match[1], match[2]),
source: strings.Split(match[2], "."),
fullResponse: match[2] == "",
})
}
}
}
}
}
return sequentialMerge(reqClone, patterns, serviceTimeout, combiner, next...)

return sequentialMerge(reqClone, sequentialReplacements, serviceTimeout, combiner, next...)
}
}

func shouldRunSequentialMerger(cfg *config.EndpointConfig) bool {
type sequentialBackendReplacement struct {
backendIndex int
destination string
source []string
fullResponse bool
}

func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, []string) {
enabled := false
var propagatedParams []string
if v, ok := cfg.ExtraConfig[Namespace]; ok {
if e, ok := v.(map[string]interface{}); ok {
if v, ok := e[isSequentialKey]; ok {
c, ok := v.(bool)
return ok && c
enabled = ok && c
}
if v, ok := e[sequentialPropagateKey]; ok {
if a, ok := v.([]interface{}); ok {
for _, p := range a {
propagatedParams = append(propagatedParams, p.(string))
}
}
}
}
}
return false
return enabled, propagatedParams
}

func hasUnsafeBackends(cfg *config.EndpointConfig) bool {
Expand Down Expand Up @@ -118,75 +181,92 @@ func parallelMerge(reqCloner func(*Request) *Request, timeout time.Duration, rc
}
}

var reMergeKey = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-\.]+)\}\}`)

func sequentialMerge(reqCloner func(*Request) *Request, patterns []string, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy {
func sequentialMerge(reqCloner func(*Request) *Request, sequentialReplacements [][]sequentialBackendReplacement, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy { // skipcq: GO-R1005
return func(ctx context.Context, request *Request) (*Response, error) {
localCtx, cancel := context.WithTimeout(ctx, timeout)

parts := make([]*Response, len(next))
out := make(chan *Response, 1)
errCh := make(chan error, 1)
sequentialMergeRegistry := map[string]string{}

acc := newIncrementalMergeAccumulator(len(next), rc)
TxLoop:
for i, n := range next {
if i > 0 {
for _, match := range reMergeKey.FindAllStringSubmatch(patterns[i], -1) {
if len(match) > 1 {
rNum, err := strconv.Atoi(match[1])
if err != nil || rNum >= i || parts[rNum] == nil {
continue
}
key := "Resp" + match[1] + "_" + match[2]

var v interface{}
var ok bool

data := parts[rNum].Data
keys := strings.Split(match[2], ".")
if len(keys) > 1 {
for _, k := range keys[:len(keys)-1] {
v, ok = data[k]
if !ok {
break
}
clean, ok := v.(map[string]interface{})
if !ok {
break
}
data = clean
for _, r := range sequentialReplacements[i] {
if r.backendIndex >= i || parts[r.backendIndex] == nil {
continue
}

var v interface{}
var ok bool

data := parts[r.backendIndex].Data
if len(r.source) > 1 {
for _, k := range r.source[:len(r.source)-1] {
v, ok = data[k]
if !ok {
break
}
clean, ok := v.(map[string]interface{})
if !ok {
break
}
data = clean
}
}

v, ok = data[keys[len(keys)-1]]
if !ok {
if found := sequentialMergeRegistry[r.destination]; found != "" {
request.Params[r.destination] = found
continue
}

if r.fullResponse {
if parts[r.backendIndex].Io == nil {
continue
}
switch clean := v.(type) {
case []interface{}:
if len(clean) == 0 {
request.Params[key] = ""
continue
}
var b strings.Builder
for i := 0; i < len(clean)-1; i++ {
fmt.Fprintf(&b, "%v,", clean[i])
}
fmt.Fprintf(&b, "%v", clean[len(clean)-1])
request.Params[key] = b.String()
case string:
request.Params[key] = clean
case int:
request.Params[key] = strconv.Itoa(clean)
case float64:
request.Params[key] = strconv.FormatFloat(clean, 'E', -1, 32)
case bool:
request.Params[key] = strconv.FormatBool(clean)
default:
request.Params[key] = fmt.Sprintf("%v", v)
buf, err := io.ReadAll(parts[r.backendIndex].Io)

if err == nil {
request.Params[r.destination] = string(buf)
sequentialMergeRegistry[r.destination] = string(buf)
}
continue
}

v, ok = data[r.source[len(r.source)-1]]
if !ok {
continue
}

var param string

switch clean := v.(type) {
case []interface{}:
if len(clean) == 0 {
request.Params[r.destination] = ""
break
}
var b strings.Builder
for i := 0; i < len(clean)-1; i++ {
fmt.Fprintf(&b, "%v,", clean[i])
}
fmt.Fprintf(&b, "%v", clean[len(clean)-1])
param = b.String()
case string:
param = clean
case int:
param = strconv.Itoa(clean)
case float64:
param = strconv.FormatFloat(clean, 'E', -1, 32)
case bool:
param = strconv.FormatBool(clean)
default:
param = fmt.Sprintf("%v", v)
}
request.Params[r.destination] = param
sequentialMergeRegistry[r.destination] = param
}
}

Expand Down Expand Up @@ -284,30 +364,25 @@ func requestPart(ctx context.Context, next Proxy, request *Request, out chan<- *
}

func sequentialRequestPart(ctx context.Context, next Proxy, request *Request, out chan<- *Response, failed chan<- error) {
localCtx, cancel := context.WithCancel(ctx)

copyRequest := CloneRequest(request)

in, err := next(localCtx, request)
in, err := next(ctx, request)

*request = *copyRequest

if err != nil {
failed <- err
cancel()
return
}
if in == nil {
failed <- errNullResult
cancel()
return
}
select {
case out <- in:
case <-ctx.Done():
failed <- ctx.Err()
}
cancel()
}

func newMergeError(errs []error) error {
Expand Down Expand Up @@ -342,9 +417,10 @@ func RegisterResponseCombiner(name string, f ResponseCombiner) {
}

const (
mergeKey = "combiner"
isSequentialKey = "sequential"
defaultCombinerName = "default"
mergeKey = "combiner"
isSequentialKey = "sequential"
sequentialPropagateKey = "sequential_propagated_params"
defaultCombinerName = "default"
)

var responseCombiners = initResponseCombiners()
Expand Down Expand Up @@ -382,7 +458,7 @@ func combineData(total int, parts []*Response) *Response {
}
isComplete = isComplete && part.IsComplete
if retResponse == nil {
retResponse = part
retResponse = &Response{Data: part.Data, IsComplete: isComplete}
continue
}
for k, v := range part.Data {
Expand Down
Loading

0 comments on commit 9b8a7c3

Please sign in to comment.