Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize != string allocation return results #32

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions engine_stringmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@

pool := newErrPool(errPoolOpts{concurrency: n.concurrency})

neqOptimized := false

// First, handle equality matching.
for item := range n.vars {
path := item
Expand All @@ -83,15 +85,23 @@
}
}

m := n.equalitySearch(ctx, path, str)
m, opt := n.equalitySearch(ctx, path, str)

l.Lock()
matched = append(matched, m...)
if opt {
neqOptimized = true
}
l.Unlock()
return nil
})
}

// Wait for equality matching to optimize inequality matching
if err := eg.Wait(); err != nil {

Check failure on line 101 in engine_stringmap.go

View workflow job for this annotation

GitHub Actions / lint

undefined: eg (typecheck)

Check failure on line 101 in engine_stringmap.go

View workflow job for this annotation

GitHub Actions / test-linux-race

undefined: eg
return nil, err
}

// Then, iterate through the inequality matches.
for item := range n.inequality {
path := item
Expand All @@ -109,7 +119,7 @@
}
}

m := n.inequalitySearch(ctx, path, str)
m := n.inequalitySearch(ctx, path, str, neqOptimized, matched)

l.Lock()
matched = append(matched, m...)
Expand All @@ -131,11 +141,11 @@
return nil
}

return n.equalitySearch(ctx, variable, str)

matched, _ = n.equalitySearch(ctx, variable, str)
return matched
}

func (n *stringLookup) equalitySearch(ctx context.Context, variable string, input string) (matched []*StoredExpressionPart) {
func (n *stringLookup) equalitySearch(ctx context.Context, variable string, input string) (matched []*StoredExpressionPart, neqOptimized bool) {
n.lock.RLock()
defer n.lock.RUnlock()

Expand All @@ -150,27 +160,67 @@
// The variables don't match.
continue
}

if part.GroupID.Flag() != OptimizeNone {
neqOptimized = true
}

filtered[i] = part
i++
}
filtered = filtered[0:i]

return filtered
return filtered, neqOptimized
}

func (n *stringLookup) inequalitySearch(ctx context.Context, variable string, input string) (matched []*StoredExpressionPart) {
// inequalitySearch performs lookups for != matches.
func (n *stringLookup) inequalitySearch(ctx context.Context, variable string, input string, neqOptimized bool, currentMatches []*StoredExpressionPart) (matched []*StoredExpressionPart) {
if len(n.inequality[variable]) == 0 {
return nil
}

n.lock.RLock()
defer n.lock.RUnlock()

hashedInput := n.hash(input)

var found map[groupID]int8

if neqOptimized {
// If we're optimizing the "neq" value, we have a compound group which has both an == and != joined:
// `a == a && b != c`.
//
// In these cases, we'd naively return every StoredExpressionPart in the filter, as b != c - disregarding
// the `a == a` match.
//
// With optimizations, we check that there's the right number of string `==` matches in the group before
// evaluating !=, ensuring we keep allocations to a minimum.
found = map[groupID]int8{}
for _, match := range currentMatches {
found[match.GroupID]++
}
}

results := []*StoredExpressionPart{}
for value, exprs := range n.inequality[variable] {
if value == hashedInput {
continue
}
results = append(results, exprs...)

if !neqOptimized {
results = append(results, exprs...)
continue
}

for _, expr := range exprs {
res, ok := found[expr.GroupID]
if !ok || res < int8(expr.GroupID.Flag()) {
continue
}
results = append(results, expr)
}
}

return results
}

Expand Down
41 changes: 39 additions & 2 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"fmt"
"math/rand"
"runtime"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -181,6 +182,7 @@ func TestEvaluate_Strings(t *testing.T) {
}

func TestEvaluate_Strings_Inequality(t *testing.T) {

ctx := context.Background()
parser := NewTreeParser(NewCachingCompiler(newEnv(), nil))

Expand All @@ -194,11 +196,12 @@ func TestEvaluate_Strings_Inequality(t *testing.T) {
require.NoError(t, err)

n := 100_000

addOtherExpressions(n, e, loader)

require.EqualValues(t, n+1, e.Len())

//mem := getMem()
//printMem(mem, "no matches")

t.Run("It matches items", func(t *testing.T) {
pre := time.Now()
evals, matched, err := e.Evaluate(ctx, map[string]any{
Expand All @@ -222,6 +225,8 @@ func TestEvaluate_Strings_Inequality(t *testing.T) {
require.GreaterOrEqual(t, matched, int32(1))
})

printMem(getMem(), "first match")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?


t.Run("It handles non-matching data", func(t *testing.T) {
pre := time.Now()
evals, matched, err := e.Evaluate(ctx, map[string]any{
Expand All @@ -241,6 +246,8 @@ func TestEvaluate_Strings_Inequality(t *testing.T) {
require.EqualValues(t, 1, len(evals))
require.EqualValues(t, 1, matched)
})

printMem(getMem(), "second match")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

}

func TestEvaluate_Numbers(t *testing.T) {
Expand Down Expand Up @@ -1189,3 +1196,33 @@ func addOtherExpressions(n int, e AggregateEvaluator, loader *evalLoader) {
}
wg.Wait()
}

func getMem() runtime.MemStats {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return m
}

//nolint:all
func deltaMem(prev runtime.MemStats) runtime.MemStats {
next := getMem()

return runtime.MemStats{
HeapAlloc: next.HeapAlloc - prev.HeapAlloc,
Alloc: next.Alloc - prev.Alloc,
TotalAlloc: next.TotalAlloc - prev.TotalAlloc,
}
}

func printMem(m runtime.MemStats, label ...string) {
if len(label) > 0 {
fmt.Printf("\t%s\n", label[0])
}

fmt.Printf("\tAlloc: %d MiB\n", bToMb(m.Alloc))
fmt.Printf("\tTotalAlloc: %d MiB\n", bToMb(m.TotalAlloc))
}

func bToMb(b uint64) uint64 {
return b / 1024 / 1024
}
44 changes: 36 additions & 8 deletions groupid.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,34 @@ import (
"encoding/hex"
)

// groupID represents a group ID. The first 2 byets are an int16 size of the expression group,
// representing the number of predicates within the expression. The last 6 bytes are a random
// ID for the predicate group.
// groupID represents a group ID. Layout, in bytes:
// - 2: an int16 size of the expression group,
// - 1: optimization flag, for optimizing "!=" in string matching
// - 5: random ID for group
type groupID [8]byte

// type internedGroupID unique.Handle[groupID]
//
// func (i internedGroupID) Value() groupID {
// return unique.Handle[groupID](i).Value()
// }
//
// func (i internedGroupID) Size() uint16 {
// // Uses unsafe pointers to access the underlying groupID
// // to return the size without a copy.
// handlePtr := unsafe.Pointer(&i)
// unsafe.Slice(
// // return (*groupID)(unsafe.Pointer(unsafe.SliceData(([8]byte)(handlePtr)))).Size()
// }

var rander = rand.Read

type RandomReader func(p []byte) (n int, err error)

const (
OptimizeNone = 0x0
)

func (g groupID) String() string {
return hex.EncodeToString(g[:])
}
Expand All @@ -23,13 +42,22 @@ func (g groupID) Size() uint16 {
return binary.NativeEndian.Uint16(g[0:2])
}

func newGroupID(size uint16) groupID {
return newGroupIDWithReader(size, rander)
func (g groupID) Flag() byte {
return g[2]
}

func newGroupIDWithReader(size uint16, rander RandomReader) groupID {
func newGroupID(size uint16, optimizeFlag byte) groupID {
return newGroupIDWithReader(size, optimizeFlag, rander)
}

func newGroupIDWithReader(size uint16, optimizeFlag byte, rander RandomReader) groupID {
id := make([]byte, 8)
binary.NativeEndian.PutUint16(id, size)
_, _ = rander(id[2:])
return [8]byte(id[0:8])
// Set the optimize byte.
id[2] = optimizeFlag
_, _ = rander(id[3:])

gid := groupID([8]byte(id[0:8]))
// interned := internedGroupID(unique.Make(gid))
return gid
}
9 changes: 7 additions & 2 deletions groupid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ import (

func TestGroupID(t *testing.T) {
for i := uint16(0); i < 128; i++ {
gid := newGroupID(i)
require.Equal(t, i, gid.Size())
gid := newGroupID(i, 0x0)

require.NotEmpty(t, gid[2:])
require.Equal(t, i, gid.Size())

// check unsafe size method works
// gid := internedGID.Value()
// require.EqualValues(t, int(i), int(internedGID.Size()))
}
}
33 changes: 32 additions & 1 deletion parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,37 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]
total += 1
}

// For each AND, check to see if we have more than one string part, and check to see
// whether we have a "!=" and an "==" chained together. If so, this lets us optimize
// != checks so that we only return the aggregate match if the other "==" also matches.
//
// This is necessary: != returns basically every expression part, which is hugely costly
// in terms of allocation. We want to avoid that if poss.
var (
stringEq uint8
hasStringNeq bool
)
for _, item := range parent.Ands {
if item.Predicate == nil {
continue
}
if _, ok := item.Predicate.Literal.(string); !ok {
continue
}
if item.Predicate.Operator == operators.Equals {
stringEq++
}
if item.Predicate.Operator == operators.NotEquals {
hasStringNeq = true
}
}

flag := byte(OptimizeNone)
if stringEq > 0 && hasStringNeq {
// The flag is the number of string equality checks in the == group.
flag = byte(stringEq)
}

// Create a new group ID which tracks the number of expressions that must match
// within this group in order for the group to pass.
//
Expand All @@ -500,7 +531,7 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]
// When checking an incoming event, we match the event against each node's
// ident/variable. Using the group ID, we can see if we've matched N necessary
// items from the same identifier. If so, the evaluation is true.
parent.GroupID = newGroupIDWithReader(uint16(total), rand)
parent.GroupID = newGroupIDWithReader(uint16(total), flag, rand)

// For each sub-group, add the same group IDs to children if there's no nesting.
//
Expand Down
Loading
Loading