Skip to content

Commit

Permalink
Builder: infer more widths (#1683)
Browse files Browse the repository at this point in the history
  • Loading branch information
anshumanmohan authored Sep 6, 2023
1 parent ede694a commit 865d1a1
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 142 deletions.
231 changes: 146 additions & 85 deletions calyx-py/calyx/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,17 @@ def control(self, builder: Union[ast.Control, ControlBuilder]):
else:
self.component.controls = builder

def get_port_width(self, name: str) -> int:
def port_width(self, port: ExprBuilder) -> int:
"""Get the width of an expression, which may be a port of this component."""
name = ExprBuilder.unwrap(port).item.id.name
for input in self.component.inputs:
if input.id.name == name:
return input.width
for output in self.component.outputs:
if output.id.name == name:
return output.width
raise NotFoundError(
f"couldn't find port {name} on component {self.component.name}"
)
# Give up.
return None

def get_cell(self, name: str) -> CellBuilder:
"""Retrieve a cell builder by name."""
Expand Down Expand Up @@ -318,7 +319,7 @@ def binary(
"""Generate a binary cell of the kind specified in `operation`."""
self.prog.import_("primitives/binary_operators.futil")
name = name or self.generate_name(operation)
assert isinstance(name, str)
assert isinstance(name, str), f"name {name} is not a string"
return self.cell(name, ast.Stdlib.op(operation, size, signed))

def add(self, size: int, name: str = None, signed: bool = False) -> CellBuilder:
Expand Down Expand Up @@ -435,43 +436,64 @@ def binary_use(self, left, right, cell, groupname=None):
cell.right = right
return CellAndGroup(cell, comb_group)

def eq_use(self, left, right, width, signed=False, cellname=None):
def try_infer_width(self, width, left, right):
"""If `width` is None, try to infer it from `left` or `right`.
If that fails, raise an error.
"""
width = width or self.infer_width(left) or self.infer_width(right)
if not width:
raise WidthInferenceError(
"Cannot infer widths from `left` or `right`. "
"Consider providing width as an argument."
)
return width

def eq_use(self, left, right, signed=False, cellname=None, width=None):
"""Inserts wiring into `self` to check if `left` == `right`."""
width = self.try_infer_width(width, left, right)
return self.binary_use(left, right, self.eq(width, cellname, signed))

def neq_use(self, left, right, width, signed=False, cellname=None):
def neq_use(self, left, right, signed=False, cellname=None, width=None):
"""Inserts wiring into `self` to check if `left` != `right`."""
width = self.try_infer_width(width, left, right)
return self.binary_use(left, right, self.neq(width, cellname, signed))

def lt_use(self, left, right, width, signed=False, cellname=None):
def lt_use(self, left, right, signed=False, cellname=None, width=None):
"""Inserts wiring into `self` to check if `left` < `right`."""
width = self.try_infer_width(width, left, right)
return self.binary_use(left, right, self.lt(width, cellname, signed))

def le_use(self, left, right, width, signed=False, cellname=None):
def le_use(self, left, right, signed=False, cellname=None, width=None):
"""Inserts wiring into `self` to check if `left` <= `right`."""
width = self.try_infer_width(width, left, right)
return self.binary_use(left, right, self.le(width, cellname, signed))

def ge_use(self, left, right, width, signed=False, cellname=None):
def ge_use(self, left, right, signed=False, cellname=None, width=None):
"""Inserts wiring into `self` to check if `left` >= `right`."""
width = self.try_infer_width(width, left, right)
return self.binary_use(left, right, self.ge(width, cellname, signed))

def gt_use(self, left, right, width, signed=False, cellname=None):
def gt_use(self, left, right, signed=False, cellname=None, width=None):
"""Inserts wiring into `self` to check if `left` > `right`."""
width = self.try_infer_width(width, left, right)
return self.binary_use(left, right, self.gt(width, cellname, signed))

def add_use(self, left, right, width, signed=False, cellname=None):
def add_use(self, left, right, signed=False, cellname=None, width=None):
"""Inserts wiring into `self` to compute `left` + `right`."""
width = self.try_infer_width(width, left, right)
return self.binary_use(left, right, self.add(width, cellname, signed))

def sub_use(self, left, right, width, signed=False, cellname=None):
def sub_use(self, left, right, signed=False, cellname=None, width=None):
"""Inserts wiring into `self` to compute `left` - `right`."""
width = self.try_infer_width(width, left, right)
return self.binary_use(left, right, self.sub(width, cellname, signed))

def bitwise_flip_reg(self, reg, width, cellname=None):
def bitwise_flip_reg(self, reg, cellname=None):
"""Inserts wiring into `self` to bitwise-flip the contents of `reg`
and put the result back into `reg`.
"""
cellname = cellname or f"{reg.name}_not"
width = reg.infer_width_reg()
not_cell = self.not_(width, cellname)
with self.group(f"{cellname}_group") as not_group:
not_cell.in_ = reg.out
Expand All @@ -480,9 +502,10 @@ def bitwise_flip_reg(self, reg, width, cellname=None):
not_group.done = reg.done
return not_group

def incr(self, reg, width, val=1, signed=False, cellname=None):
def incr(self, reg, val=1, signed=False, cellname=None):
"""Inserts wiring into `self` to perform `reg := reg + val`."""
cellname = cellname or f"{reg.name}_incr"
width = reg.infer_width_reg()
add_cell = self.add(width, cellname, signed)
with self.group(f"{cellname}_group") as incr_group:
add_cell.left = reg.out
Expand All @@ -492,9 +515,10 @@ def incr(self, reg, width, val=1, signed=False, cellname=None):
incr_group.done = reg.done
return incr_group

def decr(self, reg, width, val=1, signed=False, cellname=None):
def decr(self, reg, val=1, signed=False, cellname=None):
"""Inserts wiring into `self` to perform `reg := reg - val`."""
cellname = cellname or f"{reg.name}_decr"
width = reg.infer_width_reg()
sub_cell = self.sub(width, cellname, signed)
with self.group(f"{cellname}_group") as decr_group:
sub_cell.left = reg.out
Expand Down Expand Up @@ -620,19 +644,41 @@ def sub_store_in_reg(
)

def eq_store_in_reg(self, left, right, cellname, width, ans_reg=None, signed=False):
"""Adds wiring into `self to perform `reg := left == right`."""
"""Inserts wiring into `self` to perform `reg := left == right`."""
return self.op_store_in_reg(
self.eq(width, cellname, signed), left, right, cellname, 1, ans_reg
)

def neq_store_in_reg(
self, left, right, cellname, width, ans_reg=None, signed=False
):
"""Adds wiring into `self to perform `reg := left != right`."""
"""Inserts wiring into `self` to perform `reg := left != right`."""
return self.op_store_in_reg(
self.neq(width, cellname, signed), left, right, cellname, 1, ans_reg
)

def infer_width(self, expr) -> int:
"""Infer the width of an expression."""
if isinstance(expr, int): # We can't infer the width of an integer.
return None
if self.port_width(expr): # It's an in/out port of this component!
return self.port_width(expr)
expr = ExprBuilder.unwrap(expr) # We unwrap the expr.
if isinstance(expr, ast.Atom): # Inferring width of Atom.
if isinstance(expr.item, ast.ThisPort): # Atom is a ThisPort.
# If we can infer it from this, great, otherwise give up.
return self.port_width(expr)
# Not a ThisPort, but maybe some `cell.port`?
cell_name = expr.item.id.name
port_name = expr.item.name
cell_builder = self.index[cell_name]
if not isinstance(cell_builder, CellBuilder):
return None # Something is wrong, we should have a CellBuilder
# Okay, we really have a CellBuilder.
# Let's try to infer the width of the port.
# If this fails, give up.
return cell_builder.infer_width(port_name)


@dataclass(frozen=True)
class CellAndGroup:
Expand Down Expand Up @@ -852,6 +898,11 @@ def __ne__(self, other: ExprBuilder):
"""Construct an inequality comparison with ==."""
return ExprBuilder(ast.Neq(self.expr, other.expr))

@property
def name(self):
"""Get the name of the expression."""
return self.expr.name

@classmethod
def unwrap(cls, obj):
"""Unwrap an expression builder, or return the object if it is not one."""
Expand Down Expand Up @@ -940,6 +991,66 @@ def is_seq_mem_d1(self) -> bool:
"""Check if the cell is a SeqMemD1 cell."""
return self.is_primitive("seq_mem_d1")

def infer_width_reg(self) -> int:
"""Infer the width of a register. That is, the width of `reg.in`."""
assert self._cell.comp.id == "std_reg", "Cell is not a register"
return self._cell.comp.args[0]

def infer_width(self, port_name) -> int:
"""Infer the width of a port on the cell."""
inst = self._cell.comp
prim = inst.id
if prim == "std_reg":
if port_name in ("in", "out"):
return inst.args[0]
if port_name == "write_en":
return 1
return None
# XXX(Caleb): add all the primitive names instead of adding whenever I need one
if prim in (
"std_add",
"std_sub",
"std_lt",
"std_le",
"std_ge",
"std_gt",
"std_eq",
"std_neq",
"std_sgt",
"std_slt",
"std_fp_sgt",
"std_fp_slt",
):
if port_name in ("left", "right"):
return inst.args[0]
if prim in ("std_mem_d1", "seq_mem_d1"):
if port_name == "write_en":
return 1
if port_name == "addr0":
return inst.args[2]
if port_name == "in":
return inst.args[0]
if prim == "seq_mem_d1" and port_name == "read_en":
return 1
if prim in (
"std_mult_pipe",
"std_smult_pipe",
"std_mod_pipe",
"std_smod_pipe",
"std_div_pipe",
"std_sdiv_pipe",
"std_fp_smult_pipe",
):
if port_name in ("left", "right"):
return inst.args[0]
if port_name == "go":
return 1
if prim == "std_wire" and port_name == "in":
return inst.args[0]

# Give up.
return None

@property
def name(self) -> str:
"""Get the name of the cell."""
Expand Down Expand Up @@ -1061,6 +1172,20 @@ def __enter__(self):
def __exit__(self, exc, value, tb):
TLS.groups.pop()

def infer_width(self, expr):
"""Try to guess the width of a port expression in this group."""
assert isinstance(expr, ast.Atom)
if isinstance(expr.item, ast.ThisPort):
return self.comp.port_width(expr)
cell_name = expr.item.id.name
port_name = expr.item.name

cell_builder = self.comp.index[cell_name]
if not isinstance(cell_builder, CellBuilder):
return None

return cell_builder.infer_width(port_name)


def const(width: int, value: int) -> ExprBuilder:
"""Build a sized integer constant expression.
Expand All @@ -1077,81 +1202,17 @@ def infer_width(expr):
Return an int, or None if we don't have a guess.
"""
assert TLS.groups, "int width inference only works inside `with group:`"
group_builder: GroupBuilder = TLS.groups[-1]

# Deal with `done` holes.
expr = ExprBuilder.unwrap(expr)
if isinstance(expr, ast.HolePort):
assert expr.name == "done", f"unknown hole {expr.name}"
return 1

# Otherwise, it's a `cell.port` lookup.
assert isinstance(expr, ast.Atom)
if isinstance(expr.item, ast.ThisPort):
name = expr.item.id.name
return group_builder.comp.get_port_width(name)
cell_name = expr.item.id.name
port_name = expr.item.name

# Look up the component for the referenced cell.
cell_builder = group_builder.comp.index[cell_name]
if isinstance(cell_builder, CellBuilder):
inst = cell_builder._cell.comp
else:
return None

# Extract widths from stdlib components we know.
prim = inst.id
if prim == "std_reg":
if port_name == "in":
return inst.args[0]
elif port_name == "write_en":
return 1
# XXX(Caleb): add all the primitive names instead of adding whenever I need one
elif prim in (
"std_add",
"std_lt",
"std_le",
"std_ge",
"std_gt",
"std_eq",
"std_sgt",
"std_slt",
"std_fp_sgt",
"std_fp_slt",
):
if port_name == "left" or port_name == "right":
return inst.args[0]
elif prim == "std_mem_d1" or prim == "seq_mem_d1":
if port_name == "write_en":
return 1
elif port_name == "addr0":
return inst.args[2]
elif port_name == "in":
return inst.args[0]
if prim == "seq_mem_d1":
if port_name == "read_en":
return 1
elif prim in (
"std_mult_pipe",
"std_smult_pipe",
"std_mod_pipe",
"std_smod_pipe",
"std_div_pipe",
"std_sdiv_pipe",
"std_fp_smult_pipe",
):
if port_name == "left" or port_name == "right":
return inst.args[0]
elif port_name == "go":
return 1
elif prim == "std_wire":
if port_name == "in":
return inst.args[0]
assert TLS.groups, "int width inference only works inside `with group:`"
group_builder: GroupBuilder = TLS.groups[-1]

# Give up.
return None
return group_builder.infer_width(expr)


def ctx_asgn(lhs: ExprBuilder, rhs: Union[ExprBuilder, CondExprBuilder]):
Expand Down
Loading

0 comments on commit 865d1a1

Please sign in to comment.