Skip to content

Commit

Permalink
Merge pull request #34 from reizio/search-parent-type
Browse files Browse the repository at this point in the history
reiz.reizql: META() matcher for searching on parent types
  • Loading branch information
isidentical authored Mar 6, 2021
2 parents bc8edbb + 5b858b5 commit 8064f82
Show file tree
Hide file tree
Showing 23 changed files with 332 additions and 63 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
static/Python-reiz-fielddb.json linguist-generated=true
static/Python-reiz.json linguist-generated=true
1 change: 1 addition & 0 deletions .github/coverage.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
comments: false
5 changes: 1 addition & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,8 @@ jobs:
env:
EDGEDB_SERVER_BIN: edgedb-server
run: |
# These tests doesn't work well with EdgeDB alpha 7, which is what
# we have in the CI. So until the newer version is released, we are
# just going to skip them
coverage run test_cases/runner.py --change-db-schema \
--start-edgedb-server --run-benchmarks --do-not-fail
--start-edgedb-server --run-benchmarks
coverage xml
- name: Upload coverage to Codecov
Expand Down
6 changes: 6 additions & 0 deletions reiz/ir/backends/edgeql.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,17 @@ def attribute(self, base, attr):
else:
return Attribute(base, attr)

def optional(self, node):
if isinstance(node, self.subscript):
node = self.call("array_get", [node.item, node.value])
return node

set = Set
loop = For
name = Name
call = Call
cast = Cast
tuple = Tuple
union = Union
assign = Assign
exists = Exists
Expand Down
10 changes: 10 additions & 0 deletions reiz/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def combine_filters(self, left, right, operator="AND"):
else:
return self.filter(left, right, operator)

def combine_multi_filters(self, left, right_filters, operator="AND"):
for right_filter in right_filters:
left = self.combine_filters(left, right_filter, operator=operator)
return left

def l_combine_multi_filters(self, left_filters, right, operator="AND"):
for left_filter in reversed(left_filters):
right = self.combine_filters(left_filter, right, operator=operator)
return right

def merge(self, expressions):
union = next(expressions)
for expression in expressions:
Expand Down
35 changes: 20 additions & 15 deletions reiz/reizql/compiler/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def compile_matcher(node, state):

if state.is_root:
state.scope.exit()
if state.filters:
filters = IR.l_combine_multi_filters(state.filters, filters)
if state.variables:
namespace = IR.namespace(state.variables)
filters = IR.add_namespace(namespace, IR.select(filters))
Expand Down Expand Up @@ -125,7 +127,7 @@ def aggregate_array(state):
)
body = IR.loop(
IR.name(_COMPILER_WORKAROUND_FOR_TARGET),
state.parents[-1].compute_path(),
state.parents[-1].compute_path(allow_missing=True),
IR.select(path, order=IR.property("index")),
)
else:
Expand All @@ -137,24 +139,30 @@ def aggregate_array(state):
@codegen.register(grammar.List)
def compile_sequence(node, state):
total_length = len(node.items)
length_verifier = IR.filter(
IR.call("count", [state.compute_path()]), total_length, "="
)
verify_call = IR.call("count", [state.compute_path()])

if total_length == 0 or all(
item in (grammar.Ignore, grammar.Expand) for item in node.items
):
return length_verifier
length_verifier = IR.filter(verify_call, total_length, "=")

if total := node.items.count(grammar.Expand):
state.ensure(node, total == 1)
length_verifier = IR.filter(
IR.call("count", [state.compute_path()]), total_length - 1, ">="
)
length_verifier = IR.filter(verify_call, total_length - 1, ">=")

state.filters.append(length_verifier)
if total_length == 0 or all(
item in (grammar.Ignore, grammar.Expand) for item in node.items
):
return None

array_ref = IR.new_reference("sequence")
state.variables[array_ref] = aggregate_array(state)

# For length verifier, instead of re-accessing the path
# we'll use the already aggregated array. So switch the
# verifier function from 'count' to 'len' (one is for sets
# and the other one is for arrays.)
verify_call.func = "len"
verify_call.args = [array_ref]

expansion_seen = False
with state.temp_flag("in for loop"), state.temp_property(
"enumeration start depth", state.depth
Expand All @@ -174,10 +182,7 @@ def compile_sequence(node, state):
if item_filters := state.codegen(matcher):
filters = IR.combine_filters(filters, item_filters)

if filters:
return IR.combine_filters(length_verifier, filters)
else:
return length_verifier
return filters


@codegen.register(type(grammar.Cease))
Expand Down
29 changes: 29 additions & 0 deletions reiz/reizql/compiler/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from reiz.ir import IR
from reiz.reizql.parser import grammar
from reiz.serialization.transformers import ast

if TYPE_CHECKING:
BuiltinFunctionType = Callable[
Expand Down Expand Up @@ -92,3 +93,31 @@ def convert_length(node, state, arguments):

assert filters is not None
return filters


def metadata_parent(parent_node, state):
state.ensure(parent_node, len(parent_node.filters) == 1)

parent_field, filter_value = parent_node.filters.popitem()
state.ensure(parent_node, filter_value is grammar.Ignore)

with state.temp_pointer("parent_types"):
return IR.filter(
IR.tuple(
[parent_node.bound_node.type_id, IR.literal(parent_field)]
),
state.compute_path(),
"IN",
)


@Signature.register("META", ["parent"], {"parent": None})
def convert_meta(node, state, arguments):
state.ensure(node, state.pointer == "__metadata__")

filters = None
if arguments.parent:
filters = IR.combine_filters(
filters, metadata_parent(arguments.parent, state)
)
return filters
40 changes: 22 additions & 18 deletions reiz/reizql/compiler/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class CompilerState:

pointer_stack: List[str] = field(default_factory=list)
scope: Scope = field(default_factory=Scope)
filters: List[IR.expression] = field(default_factory=list)
variables: Dict[IR.name, IR.expression] = field(default_factory=dict)
properties: Dict[str, Any] = field(default_factory=dict)
parents: List[CompilerState] = field(default_factory=list, repr=False)
Expand All @@ -31,6 +32,7 @@ def from_parent(cls, name, parent):
depth=parent.depth + 1,
scope=parent.scope,
parents=parent.parents + [parent],
filters=parent.filters,
variables=parent.variables,
properties=parent.properties,
)
Expand Down Expand Up @@ -84,18 +86,23 @@ def compile(self, key, value):
with self.temp_pointer(key):
return self.codegen(value)

def compute_path(self):
base = None
for parent in self.get_ordered_parents():
if base is None:
if self.is_flag_set("in for loop"):
base = parent.pointer
else:
base = IR.attribute(None, parent.pointer)
else:
base = IR.attribute(
IR.typed(base, parent.match), parent.pointer
)
def compute_path(self, allow_missing=False):
parent, *parents = self.get_ordered_parents()

def get_pointer(state, allow_missing):
pointer = state.pointer
if allow_missing:
pointer = IR.optional(pointer)
return pointer

base = get_pointer(parent, allow_missing)
if not parent.is_flag_set("in for loop"):
base = IR.attribute(None, base)

for parent in parents:
base = IR.typed(base, parent.match)
base = IR.attribute(base, get_pointer(parent, allow_missing))

return base

def get_ordered_parents(self):
Expand All @@ -117,21 +124,18 @@ def ensure(self, node, condition):
if not condition:
raise ReizQLSyntaxError(f"compiler check failed for: {node!r}")

def as_unique_ref(self, prefix):
return
def is_special(self, name):
return name.startswith("__") and name.endswith("__")

@property
def is_root(self):
return self.depth == 0

@property
def can_raw_name_access(self):
return

@property
def pointer(self):
return IR.wrap(self.pointer_stack[-1], with_prefix=False)

@property
def field_info(self):
assert not self.is_special(self.match)
return FIELD_DB[self.match][self.pointer_stack[0]]
3 changes: 3 additions & 0 deletions reiz/reizql/parser/grammar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import typing
from dataclasses import dataclass, field
from enum import auto
Expand Down Expand Up @@ -44,6 +45,8 @@ class LogicOperator(Unit, ReizEnum):
@dataclass
class Match(Expression):
name: str
bound_node: ast.AST = field(repr=False)

filters: typing.Dict[str, Expression] = field(default_factory=dict)
positional: bool = False

Expand Down
6 changes: 3 additions & 3 deletions reiz/reizql/parser/parse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import ast
import functools

from reiz.ir import IR
from reiz.reizql.parser import ReizQLSyntaxError, grammar
from reiz.serialization.transformers import ast

BUILTIN_FUNCTIONS = ("ALL", "ANY", "LEN", "I")
BUILTIN_FUNCTIONS = ("META", "ALL", "ANY", "LEN", "I")
POSITION_ATTRIBUTES = frozenset(
("lineno", "col_offset", "end_lineno", "end_col_offset")
)
Expand Down Expand Up @@ -77,7 +77,7 @@ def parse_call(self, node):
)
query[arg.arg] = self.parse(arg.value)

return grammar.Match(name, query, positional=positional)
return grammar.Match(name, origin, query, positional=positional)

@parse.register(ast.BinOp)
def parse_binop(self, node):
Expand Down
35 changes: 30 additions & 5 deletions reiz/schema/builders/edgeql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from reiz.utilities import ReizEnum

INDENT = " " * 4
CUSTOM_TYPE_BASE = "custom_types"


class ModelConstraint(str, ReizEnum):
Expand Down Expand Up @@ -61,6 +62,15 @@ def is_ordered_sequence(self):
) and not self.is_property


@dataclass
class InlinedModel:
base: str
elements: List[str]

def construct(self):
return f"{self.base}<{', '.join(self.elements)}>"


@dataclass
class Model:
model: str
Expand All @@ -72,11 +82,12 @@ class Model:
def enum(cls, name, members):
# We can't directly use ident-based enums since
# some are the members (like And, Or) are keywords
base = "enum"
base += "<"
base += ", ".join(repr(member) for member in members)
base += ">"
return cls(name, constraint=ModelConstraint.SCALAR, extending=[base])
base = InlinedModel("enum", [repr(member) for member in members])
return cls(
name,
constraint=ModelConstraint.SCALAR,
extending=[base.construct()],
)

def construct(self):
source = []
Expand Down Expand Up @@ -116,6 +127,7 @@ def __init__(self, schema):
self.schema = schema
self.enum_types = schema.setdefault("enum_types", [])
self.module_types = schema.setdefault("module_annotated_types", [])
self.custom_types = {}

def visit_Module(self, node):
yield Model(self.BASE_TYPE, constraint=ModelConstraint.ABSTRACT)
Expand All @@ -130,6 +142,10 @@ def fix_references(self, definitions):
for field in definition.fields:
if field.kind == "Module":
self.module_types.append(definition.model)
elif field.kind in self.custom_types:
field.kind = self.custom_types[field.kind]
field.is_property = True

if field.kind in self.enum_types:
field.is_property = True
if (
Expand All @@ -151,6 +167,9 @@ def visit_Product(self, node, name):
)

def visit_Sum(self, node, name):
if name == CUSTOM_TYPE_BASE:
return self.process_custom_types(node)

if pyasdl.is_simple_sum(node):
self.enum_types.append(name)
yield Model.enum(
Expand All @@ -162,6 +181,12 @@ def visit_Sum(self, node, name):
)
yield from self.visit_all(node.types, base=name)

def process_custom_types(self, node):
for constructor in node.types:
fields = self.visit_all(constructor.fields)
model = InlinedModel("tuple", [field.kind for field in fields])
self.custom_types[constructor.name] = model.construct()

def visit_Constructor(self, node, base):
return Model(
node.name,
Expand Down
5 changes: 5 additions & 0 deletions reiz/serialization/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def serialize_sequence(sequence, context):
return IR.add_namespace(scope, loop)


@serialize.register(tuple)
def serialize_tuple(sequence, context):
return IR.tuple([serialize(value, context) for value in sequence])


@serialize.register(str)
@serialize.register(int)
def serialize_string(value, context):
Expand Down
Loading

0 comments on commit 8064f82

Please sign in to comment.