Skip to content

Commit

Permalink
Merge pull request #35 from GlowingScrewdriver/c-lisp
Browse files Browse the repository at this point in the history
More features for C-Lisp
  • Loading branch information
chsasank authored Jun 24, 2024
2 parents bc64548 + 86e2bd6 commit 373d080
Show file tree
Hide file tree
Showing 18 changed files with 587 additions and 61 deletions.
9 changes: 5 additions & 4 deletions src/backend/brilisp.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,16 @@ def gen_store_instr(instr):


def brilisp(expr):
for x in expr:
assert expr[0] == "brilisp"
body = expr[1:]
for x in body:
assert is_function(x), f"{x} is not a function"
return {"functions": [gen_function(x) for x in expr]}
return {"functions": [gen_function(x) for x in body]}


def main():
expr = json.load(sys.stdin)
assert expr[0] == "brilisp"
print(json.dumps(brilisp(expr[1:])))
print(json.dumps(brilisp(expr)))


if __name__ == "__main__":
Expand Down
227 changes: 182 additions & 45 deletions src/backend/c-lisp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,40 @@ class BrilispCodeGenerator:
def __init__(self):
# Type tracking
self.symbol_types = {} # Variable name -> type
self.scopes = [] # Stack of scope tags
self.function_types = {} # Function name > (ret-type, (arg-types...))
self.pointer_types = {} # For internal use, e.g. temporary pointer variables

self.binary_op_types = {
self.fixed_op_types = {
# <opcode>: <result type>
# Integer arithmetic
"add": "int",
"sub": "int",
"mul": "int",
"div": "int",
"add": ("int", 2),
"sub": ("int", 2),
"mul": ("int", 2),
"div": ("int", 2),
# Integer comparison
"eq": "bool",
"ne": "bool",
"lt": "bool",
"gt": "bool",
"le": "bool",
"ge": "bool",
"eq": ("bool", 2),
"ne": ("bool", 2),
"lt": ("bool", 2),
"gt": ("bool", 2),
"le": ("bool", 2),
"ge": ("bool", 2),
# Floating-point arithmetic
"fadd": "float",
"fsub": "float",
"fmul": "float",
"fdiv": "float",
"fadd": ("float", 2),
"fsub": ("float", 2),
"fmul": ("float", 2),
"fdiv": ("float", 2),
# Floating-point comparison
"feq": "bool",
"fne": "bool",
"flt": "bool",
"fgt": "bool",
"fle": "bool",
"fge": "bool",
"feq": ("bool", 2),
"fne": ("bool", 2),
"flt": ("bool", 2),
"fgt": ("bool", 2),
"fle": ("bool", 2),
"fge": ("bool", 2),
# Boolean logic
"and": ("bool", 2),
"or": ("bool", 2),
"not": ("bool", 1),
}

def c_lisp(self, prog):
Expand All @@ -51,6 +57,18 @@ def c_lisp(self, prog):

return ["brilisp"] + [self.gen_function(fn) for fn in prog[1:]]

def construct_scoped_name(self, name, scopes):
return ".".join([name] + scopes)

def scoped_lookup(self, name):
"""Look up the name in reverse order of current scope stack"""
# be as specific as possible first
for s in range(len(self.scopes), -1, -1):
scoped_name = self.construct_scoped_name(name, self.scopes[:s])
if scoped_name in self.symbol_types:
return scoped_name
raise CodegenError(f"Undeclared symbol: {name}")

def gen_function(self, func):
if not func[0] == "define":
raise CodegenError(f"Not a function: {func}")
Expand All @@ -59,7 +77,10 @@ def gen_function(self, func):
if not len(elem) == 2:
raise CodegenError(f"Bad function prototype: {func[1]}")

self.symbol_types = {} # Clear the symbol table
# Clear the symbol table and scope stack
self.symbol_types = {}
self.scopes = []

name, ret_type = func[1][0]
parm_types = []
for parm in func[1][1:]:
Expand All @@ -69,7 +90,7 @@ def gen_function(self, func):
return [
"define",
func[1],
*self.gen_stmt(func[2:]),
*self.gen_compound_stmt(func[2:], new_scope=False),
]

def gen_stmt(self, stmt):
Expand All @@ -87,6 +108,8 @@ def gen_stmt(self, stmt):
return self.gen_if_stmt(stmt)
elif self.is_for_stmt(stmt):
return self.gen_for_stmt(stmt)
elif self.is_while_stmt(stmt):
return self.gen_while_stmt(stmt)
else:
return self.gen_expr(stmt)
else:
Expand All @@ -95,6 +118,35 @@ def gen_stmt(self, stmt):
print(f"Error in statement: {stmt}")
raise e

def is_while_stmt(self, stmt):
return stmt[0] == "while"

def gen_while_stmt(self, stmt):
if len(stmt) < 3:
raise CodegenError(f"Bad while statement: {stmt}")

cond_sym, loop_lbl, cont_lbl, break_lbl = [
random_label(CLISP_PREFIX, [extra])
for extra in (
"cond",
"loop",
"cont",
"break",
)
]
cond_expr_instr = self.gen_expr(stmt[1], res_sym=cond_sym)
loop_stmt_instr = self.gen_stmt(stmt[2:])

return [
["label", loop_lbl],
*cond_expr_instr,
["br", cond_sym, cont_lbl, break_lbl],
["label", cont_lbl],
*loop_stmt_instr,
["jmp", loop_lbl],
["label", break_lbl],
]

def is_for_stmt(self, stmt):
return stmt[0] == "for"

Expand Down Expand Up @@ -170,9 +222,10 @@ def gen_decl_stmt(self, stmt):
raise CodegenError(f"bad declare statement: {stmt}")

name, typ = stmt[1]
if name in self.symbol_types:
scoped_name = self.construct_scoped_name(name, self.scopes)
if scoped_name in self.symbol_types:
raise CodegenError(f"Re-declaration of variable {name}")
self.symbol_types[name] = typ
self.symbol_types[scoped_name] = typ
return []

def is_ret_stmt(self, stmt):
Expand All @@ -194,22 +247,30 @@ def gen_ret_stmt(self, stmt):
def is_compound_stmt(self, stmt):
return isinstance(stmt, list) and isinstance(stmt[0], list)

def gen_compound_stmt(self, stmt):
def gen_compound_stmt(self, stmt, new_scope=True):
if new_scope:
scope = random_label()
self.scopes.append(scope)
instr_list = []
for s in stmt:
instr_list += self.gen_stmt(s)
if new_scope:
self.scopes.pop()
return instr_list

def is_set_expr(self, expr):
return expr[0] == "set"

def gen_set_expr(self, expr, res_sym):
name = expr[1]
if not name in self.symbol_types:
scoped_name = self.scoped_lookup(name)
if not scoped_name in self.symbol_types:
raise CodegenError(f"Cannot set undeclared variable: {name}")

instr_list = self.gen_expr(expr[2], res_sym=res_sym)
instr_list.append(["set", [name, self.symbol_types[name]], ["id", res_sym]])
instr_list.append(
["set", [scoped_name, self.symbol_types[scoped_name]], ["id", res_sym]]
)
return instr_list

def get_literal_type(self, expr):
Expand Down Expand Up @@ -248,28 +309,96 @@ def is_var_expr(self, expr):
return isinstance(expr, str)

def gen_var_expr(self, expr, res_sym):
if expr in self.symbol_types:
return [["set", [res_sym, self.symbol_types[expr]], ["id", expr]]]
scoped_name = self.scoped_lookup(expr)
if scoped_name in self.symbol_types:
typ = self.symbol_types[scoped_name]
instr_list = [["set", [res_sym, typ], ["id", scoped_name]]]
if typ[0] == "ptr":
self.pointer_types[res_sym] = typ
return instr_list
else:
raise CodegenError(f"Reference to undeclared variable: {expr}")

def is_binary_expr(self, expr):
return expr[0] in self.binary_op_types

def gen_binary_expr(self, expr, res_sym):
if not len(expr) == 3:
raise CodegenError(f"Binary operation takes only 2 operands: {expr}")
def is_fixed_type_expr(self, expr):
return expr[0] in self.fixed_op_types

def gen_fixed_type_expr(self, expr, res_sym):
instr_list = []
in1_sym, in2_sym = [
random_label(CLISP_PREFIX, [extra]) for extra in ("in1", "in2")
]
opcode = expr[0]
typ = self.binary_op_types[opcode]
typ, n_ops = self.fixed_op_types[opcode]
if not (len(expr) == n_ops + 1):
raise CodegenError(f"`{opcode}` takes only 2 operands: {expr}")
in_syms = [
random_label(CLISP_PREFIX, [f"inp_{n}"]) for n in range(n_ops)
]
input_instrs = []
for n in range(n_ops):
input_instrs += [*self.gen_expr(expr[n + 1], in_syms[n])]
return [
*input_instrs,
["set", [res_sym, typ], [opcode, *in_syms]],
]

def is_ptradd_expr(self, expr):
return expr[0] == "ptradd"

def gen_ptradd_expr(self, expr, res_sym):
if len(expr) != 3:
raise CodegenError(f"Bad ptradd expression: {expr}")

offset_sym = random_label(CLISP_PREFIX)
ptr_name = self.scoped_lookup(expr[1])
ptr_type = self.symbol_types[ptr_name]
self.pointer_types[res_sym] = ptr_type
return [
*self.gen_expr(expr[2], res_sym=offset_sym),
["set", [res_sym, ptr_type], ["ptradd", ptr_name, offset_sym]],
]

def is_load_expr(self, expr):
return expr[0] == "load"

def gen_load_expr(self, expr, res_sym):
if len(expr) != 2:
raise CodegenError(f"Bad load expression: {expr}")

ptr_sym = random_label(CLISP_PREFIX)
return [
*self.gen_expr(expr[1], res_sym=ptr_sym),
["set", [res_sym, self.pointer_types[ptr_sym][1]], ["load", ptr_sym]],
]

def is_store_expr(self, expr):
return expr[0] == "store"

def gen_store_expr(self, expr, res_sym):
if len(expr) != 3:
raise CodegenError(f"Bad store expression: {expr}")

val_sym, ptr_sym = [
random_label(CLISP_PREFIX, [extra])
for extra in ("val", "ptr")
]
return [
*self.gen_expr(expr[1], res_sym=ptr_sym),
*self.gen_expr(expr[2], res_sym=val_sym),
["store", ptr_sym, val_sym],
["set", [res_sym, self.pointer_types[ptr_sym][1]], ["id", val_sym]],
]

def is_alloc_expr(self, expr):
return expr[0] == "alloc"

def gen_alloc_expr(self, expr, res_sym):
if len(expr) != 3:
raise CodegenError(f"Bad alloc expression: {expr}")

ptr_type = ["ptr", expr[1]]
self.pointer_types[res_sym] = ptr_type
size_sym = random_label(CLISP_PREFIX)
return [
*self.gen_expr(expr[1], in1_sym),
*self.gen_expr(expr[2], in2_sym),
["set", [res_sym, typ], [opcode, in1_sym, in2_sym]],
*self.gen_expr(expr[2], res_sym=size_sym),
["set", [res_sym, ptr_type], ["alloc", size_sym]],
]

def gen_expr(self, expr, res_sym=None):
Expand All @@ -282,8 +411,16 @@ def gen_expr(self, expr, res_sym=None):
return self.gen_call_expr(expr, res_sym)
elif self.is_var_expr(expr):
return self.gen_var_expr(expr, res_sym)
elif self.is_binary_expr(expr):
return self.gen_binary_expr(expr, res_sym)
elif self.is_fixed_type_expr(expr):
return self.gen_fixed_type_expr(expr, res_sym)
elif self.is_ptradd_expr(expr):
return self.gen_ptradd_expr(expr, res_sym)
elif self.is_load_expr(expr):
return self.gen_load_expr(expr, res_sym)
elif self.is_store_expr(expr):
return self.gen_store_expr(expr, res_sym)
elif self.is_alloc_expr(expr):
return self.gen_alloc_expr(expr, res_sym)
else:
raise CodegenError(f"Bad expression: {expr}")

Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/brilisp/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ tmp_out=$(mktemp --suffix '.out')

cp /dev/stdin $tmp_in
clang $tmp_in runtime.c -o $tmp_out -Wno-override-module -O2
$tmp_out
$tmp_out $@

rm $tmp_in $tmp_out
2 changes: 2 additions & 0 deletions src/backend/tests/c-lisp/array-sum.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
45
285
34 changes: 34 additions & 0 deletions src/backend/tests/c-lisp/array-sum.sexp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
(c-lisp
(define ((print int) (n int)))

(define ((arr_sum int) (a (ptr int)) (n int))
(declare (i int))
(declare (sum int))
(set sum 0)

(for ((set i 0)
(lt i n)
(set i (add i 1)))
(set sum (add sum (load (ptradd a i)))))
(ret sum))

(define ((main void))
(declare (arr1 (ptr int)))
(declare (arr2 (ptr int)))
(declare (i int))

(set arr1 (alloc int 10))
(set arr2 (alloc int 10))

(for ((set i 0)
(lt i 10)
(set i (add i 1)))
(declare (arr_i (ptr int)))
(set arr_i (ptradd arr1 i))
(store arr_i i)
(set arr_i (ptradd arr2 i))
(store arr_i (mul i i)))

(call print (call arr_sum arr1 10))
(call print (call arr_sum arr2 10))
(ret)))
Loading

0 comments on commit 373d080

Please sign in to comment.