Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize != string allocation return results #32

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions engine_stringmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ func (n *stringLookup) Match(ctx context.Context, input map[string]any) ([]*Stor

pool := newErrPool(errPoolOpts{concurrency: n.concurrency})

neqOptimized := false

// First, handle equality matching.
for item := range n.vars {
path := item
Expand All @@ -83,15 +85,22 @@ func (n *stringLookup) Match(ctx context.Context, input map[string]any) ([]*Stor
}
}

m := n.equalitySearch(ctx, path, str)
m, opt := n.equalitySearch(ctx, path, str)

l.Lock()
matched = append(matched, m...)
if opt {
neqOptimized = true
}
l.Unlock()
return nil
})
}
if err := pool.Wait(); err != nil {
return nil, err
}

pool = newErrPool(errPoolOpts{concurrency: n.concurrency})
// Then, iterate through the inequality matches.
for item := range n.inequality {
path := item
Expand All @@ -109,7 +118,7 @@ func (n *stringLookup) Match(ctx context.Context, input map[string]any) ([]*Stor
}
}

m := n.inequalitySearch(ctx, path, str)
m := n.inequalitySearch(ctx, path, str, neqOptimized, matched)

l.Lock()
matched = append(matched, m...)
Expand All @@ -131,11 +140,11 @@ func (n *stringLookup) Search(ctx context.Context, variable string, input any) (
return nil
}

return n.equalitySearch(ctx, variable, str)

matched, _ = n.equalitySearch(ctx, variable, str)
return matched
}

func (n *stringLookup) equalitySearch(ctx context.Context, variable string, input string) (matched []*StoredExpressionPart) {
func (n *stringLookup) equalitySearch(ctx context.Context, variable string, input string) (matched []*StoredExpressionPart, neqOptimized bool) {
n.lock.RLock()
defer n.lock.RUnlock()

Expand All @@ -150,27 +159,67 @@ func (n *stringLookup) equalitySearch(ctx context.Context, variable string, inpu
// The variables don't match.
continue
}

if part.GroupID.Flag() != OptimizeNone {
neqOptimized = true
}

filtered[i] = part
i++
}
filtered = filtered[0:i]

return filtered
return filtered, neqOptimized
}

func (n *stringLookup) inequalitySearch(ctx context.Context, variable string, input string) (matched []*StoredExpressionPart) {
// inequalitySearch performs lookups for != matches.
func (n *stringLookup) inequalitySearch(ctx context.Context, variable string, input string, neqOptimized bool, currentMatches []*StoredExpressionPart) (matched []*StoredExpressionPart) {
if len(n.inequality[variable]) == 0 {
return nil
}

n.lock.RLock()
defer n.lock.RUnlock()

hashedInput := n.hash(input)

var found map[groupID]int8

if neqOptimized {
// If we're optimizing the "neq" value, we have a compound group which has both an == and != joined:
// `a == a && b != c`.
//
// In these cases, we'd naively return every StoredExpressionPart in the filter, as b != c - disregarding
// the `a == a` match.
//
// With optimizations, we check that there's the right number of string `==` matches in the group before
// evaluating !=, ensuring we keep allocations to a minimum.
found = map[groupID]int8{}
for _, match := range currentMatches {
found[match.GroupID]++
}
}

results := []*StoredExpressionPart{}
for value, exprs := range n.inequality[variable] {
if value == hashedInput {
continue
}
results = append(results, exprs...)

if !neqOptimized {
results = append(results, exprs...)
continue
}

for _, expr := range exprs {
res, ok := found[expr.GroupID]
if !ok || res < int8(expr.GroupID.Flag()) {
continue
}
results = append(results, expr)
}
}

return results
}

Expand Down
38 changes: 36 additions & 2 deletions engine_stringmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,40 @@ import (
"testing"

"github.com/google/cel-go/common/operators"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

func TestEngineStringmap(t *testing.T) {
ctx := context.Background()
s := newStringEqualityMatcher(testConcurrency).(*stringLookup)

gid := newGroupID(4, 2) // optimized to 2 == matches.
exp := &ParsedExpression{
EvaluableID: uuid.NewSHA1(uuid.NameSpaceURL, []byte("eq-neq")),
}

a := ExpressionPart{
Parsed: exp,
GroupID: gid,
Predicate: &Predicate{
Ident: "async.data.id",
Literal: "123",
Operator: operators.Equals,
},
}
b := ExpressionPart{
Parsed: &ParsedExpression{EvaluableID: uuid.NewSHA1(uuid.NameSpaceURL, []byte("eq-single"))},
GroupID: newGroupID(1, 0), // This belongs to a "different" expression, but is the same pred.
Predicate: &Predicate{
Ident: "async.data.id",
Literal: "123",
Operator: operators.Equals,
},
}
c := ExpressionPart{
Parsed: exp,
GroupID: gid,
Predicate: &Predicate{
Ident: "async.data.another",
Literal: "456",
Expand All @@ -36,13 +48,17 @@ func TestEngineStringmap(t *testing.T) {

// Test inequality
d := ExpressionPart{
Parsed: exp,
GroupID: gid,
Predicate: &Predicate{
Ident: "async.data.neq",
Literal: "neq-1",
Operator: operators.NotEquals,
},
}
e := ExpressionPart{
Parsed: &ParsedExpression{EvaluableID: uuid.NewSHA1(uuid.NameSpaceURL, []byte("neq-single"))},
GroupID: newGroupID(1, 0), // This belongs to a "different" expression, but is the same pred.
Predicate: &Predicate{
Ident: "async.data.neq",
Literal: "neq-2",
Expand Down Expand Up @@ -134,7 +150,11 @@ func TestEngineStringmap(t *testing.T) {
},
})
require.NoError(t, err)
require.Equal(t, 4, len(found)) // matching plus inequality

// This should match "neq-single" and eq-single only. It shouldn't
// match the eq-neq expression, as the "async.data.nother" part wasn't matched
// and there's expression optimization to test this.
require.Equal(t, 2, len(found))
})

t.Run("It matches data with null neq", func(t *testing.T) {
Expand All @@ -147,9 +167,23 @@ func TestEngineStringmap(t *testing.T) {
},
})
require.NoError(t, err)
require.Equal(t, 4, len(found)) // matching plus inequality
require.Equal(t, 2, len(found)) // matching plus inequality
})

t.Run("It matches data with expression optimizations in group ID", func(t *testing.T) {
found, err := s.Match(ctx, map[string]any{
"async": map[string]any{
"data": map[string]any{
"id": "123",
"another": "456",
"neq": "lol",
},
},
})
require.NoError(t, err)

require.Equal(t, 4, len(found))
})
}

func TestEngineStringmap_DuplicateValues(t *testing.T) {
Expand Down
49 changes: 45 additions & 4 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"fmt"
"math/rand"
"runtime"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -120,7 +121,7 @@ func TestEvaluate_Strings(t *testing.T) {
ctx := context.Background()
parser := NewTreeParser(NewCachingCompiler(newEnv(), nil))

expected := tex(`event.data.account_id == "yes" && event.data.match == "true"`)
expected := tex(`event.data.account_id == "yes" && event.data.another == "ok" && event.data.match == "true"`)
loader := newEvalLoader()
loader.AddEval(expected)

Expand All @@ -145,6 +146,7 @@ func TestEvaluate_Strings(t *testing.T) {
"event": map[string]any{
"data": map[string]any{
"account_id": "yes",
"another": "ok",
"match": "true",
},
},
Expand All @@ -166,6 +168,7 @@ func TestEvaluate_Strings(t *testing.T) {
"event": map[string]any{
"data": map[string]any{
"account_id": "yes",
"another": "ok",
"match": "no",
},
},
Expand All @@ -181,10 +184,11 @@ func TestEvaluate_Strings(t *testing.T) {
}

func TestEvaluate_Strings_Inequality(t *testing.T) {

ctx := context.Background()
parser := NewTreeParser(NewCachingCompiler(newEnv(), nil))

expected := tex(`event.data.account_id == "yes" && event.data.neq != "neq"`)
expected := tex(`event.data.account_id == "yes" && event.data.another == "ok" && event.data.neq != "neq"`)
loader := newEvalLoader()
loader.AddEval(expected)

Expand All @@ -194,17 +198,19 @@ func TestEvaluate_Strings_Inequality(t *testing.T) {
require.NoError(t, err)

n := 100_000

addOtherExpressions(n, e, loader)

require.EqualValues(t, n+1, e.Len())

mem := getMem()
printMem(mem, "no matches")

t.Run("It matches items", func(t *testing.T) {
pre := time.Now()
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"account_id": "yes",
"another": "ok",
"match": "true",
"neq": "nah",
},
Expand All @@ -222,12 +228,15 @@ func TestEvaluate_Strings_Inequality(t *testing.T) {
require.GreaterOrEqual(t, matched, int32(1))
})

printMem(getMem(), "first match")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?


t.Run("It handles non-matching data", func(t *testing.T) {
pre := time.Now()
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"account_id": "yes",
"another": "ok",
"match": "no",
"neq": "nah",
},
Expand All @@ -241,6 +250,8 @@ func TestEvaluate_Strings_Inequality(t *testing.T) {
require.EqualValues(t, 1, len(evals))
require.EqualValues(t, 1, matched)
})

printMem(getMem(), "second match")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

}

func TestEvaluate_Numbers(t *testing.T) {
Expand Down Expand Up @@ -1189,3 +1200,33 @@ func addOtherExpressions(n int, e AggregateEvaluator, loader *evalLoader) {
}
wg.Wait()
}

func getMem() runtime.MemStats {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return m
}

//nolint:all
func deltaMem(prev runtime.MemStats) runtime.MemStats {
next := getMem()

return runtime.MemStats{
HeapAlloc: next.HeapAlloc - prev.HeapAlloc,
Alloc: next.Alloc - prev.Alloc,
TotalAlloc: next.TotalAlloc - prev.TotalAlloc,
}
}

func printMem(m runtime.MemStats, label ...string) {
if len(label) > 0 {
fmt.Printf("\t%s\n", label[0])
}

fmt.Printf("\tAlloc: %d MiB\n", bToMb(m.Alloc))
fmt.Printf("\tTotalAlloc: %d MiB\n", bToMb(m.TotalAlloc))
}

func bToMb(b uint64) uint64 {
return b / 1024 / 1024
}
Loading