Skip to content

Commit

Permalink
llvm: Improve test coverage of compilation helpers (#3029)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvesely authored Aug 7, 2024
2 parents 6b899b4 + db02a8e commit 310afb1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 20 deletions.
29 changes: 13 additions & 16 deletions psyneulink/core/llvm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def convert_type(builder, val, t):
return builder.trunc(val, t)
elif val.type.width < t.width:
# Python integers are signed
return builder.sext(val, t)
return builder.zext(val, t)
else:
assert False, "Unknown integer conversion: {} -> {}".format(val.type, t)

Expand All @@ -319,8 +319,7 @@ def convert_type(builder, val, t):
val = builder.fptrunc(val, ir.FloatType())
return builder.fptrunc(val, t)
else:
assert val.type == t
return val
assert False, "Unknown float conversion: {} -> {}".format(val.type, t)

assert False, "Unknown type conversion: {} -> {}".format(val.type, t)

Expand Down Expand Up @@ -409,16 +408,12 @@ def printf(builder, fmt, *args, override_debug=False):
#FIXME: Fix builtin printf and use that instead of this
libc_name = "msvcrt" if sys.platform == "win32" else "c"
libc = util.find_library(libc_name)
if libc is None:
warnings.warn("Standard libc library not found, 'printf' not available!")
return
assert libc is not None, "Standard libc library not found"

llvm.load_library_permanently(libc)
# Address will be none if the symbol is not found
printf_address = llvm.address_of_symbol("printf")
if printf_address is None:
warnings.warn("'printf' symbol not found in libc, 'printf' not available!")
return
assert printf_address is not None, "'printf' symbol not found in {}".format(libc)

# Direct pointer constants don't work
printf_ty = ir.FunctionType(ir.IntType(32), [ir.IntType(8).as_pointer()], var_arg=True)
Expand Down Expand Up @@ -758,14 +753,16 @@ def generate_sched_condition(self, builder, condition, cond_ptr, node,
node_state = builder.gep(nodes_states, [self.ctx.int32_ty(0), self.ctx.int32_ty(node_idx)])
param_ptr = get_state_ptr(builder, target, node_state, param)

if isinstance(param_ptr.type.pointee, ir.ArrayType):
if indices is None:
indices = [0, 0]
elif isinstance(indices, TimeScale):
indices = [indices.value]
# parameters in state include history of at least one element
# so they are always arrays.
assert isinstance(param_ptr.type.pointee, ir.ArrayType)

if indices is None:
indices = [0, 0]
elif isinstance(indices, TimeScale):
indices = [indices.value]

indices = [self.ctx.int32_ty(x) for x in [0] + list(indices)]
param_ptr = builder.gep(param_ptr, indices)
param_ptr = builder.gep(param_ptr, [self.ctx.int32_ty(x) for x in [0] + list(indices)])

val = builder.load(param_ptr)
val = convert_type(builder, val, ir.DoubleType())
Expand Down
47 changes: 43 additions & 4 deletions tests/llvm/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,10 @@ def test_helper_all_close(mode, var1, var2, atol, rtol):
builder.store(res, out)
builder.ret_void()

bin_f = pnlvm.LLVMBinaryFunction.get(custom_name)
res = bin_f.np_buffer_for_arg(2)

ref = np.allclose(vec1, vec2, **tolerance)
res = np.array(5, dtype=np.uint32)

bin_f = pnlvm.LLVMBinaryFunction.get(custom_name)

if mode == 'CPU':
bin_f(vec1, vec2, res)
Expand Down Expand Up @@ -558,7 +557,7 @@ def test_helper_convert_fp_type(t1, t2, mode, val):
np_dt1, np_dt2 = (np.dtype(bin_f.np_arg_dtypes[i]) for i in (0, 1))

# instantiate value, result and reference
x = np.asfarray(val, dtype=bin_f.np_arg_dtypes[0])
x = np.asfarray(val, dtype=np_dt1)
y = bin_f.np_buffer_for_arg(1)
ref = x.astype(np_dt2)

Expand All @@ -568,3 +567,43 @@ def test_helper_convert_fp_type(t1, t2, mode, val):
bin_f.cuda_wrap_call(x, y)

np.testing.assert_allclose(y, ref, equal_nan=True)


_int_types = [ir.IntType(64), ir.IntType(32), ir.IntType(16), ir.IntType(8)]


@pytest.mark.llvm
@pytest.mark.parametrize('mode', ['CPU', pytest.helpers.cuda_param('PTX')])
@pytest.mark.parametrize('t1', _int_types, ids=str)
@pytest.mark.parametrize('t2', _int_types, ids=str)
@pytest.mark.parametrize('val', [0, 1, -1, 127, -128, 255, -32768, 32767, 65535, np.iinfo(np.int32).min, np.iinfo(np.int32).max])
def test_helper_convert_int_type(t1, t2, mode, val):
with pnlvm.LLVMBuilderContext.get_current() as ctx:
func_ty = ir.FunctionType(ir.VoidType(), [t1.as_pointer(), t2.as_pointer()])
custom_name = ctx.get_unique_name("int_convert")
function = ir.Function(ctx.module, func_ty, name=custom_name)
x, y = function.args
block = function.append_basic_block(name="entry")
builder = ir.IRBuilder(block)

x_val = builder.load(x)
conv_x = pnlvm.helpers.convert_type(builder, x_val, y.type.pointee)
builder.store(conv_x, y)
builder.ret_void()

bin_f = pnlvm.LLVMBinaryFunction.get(custom_name)

# Get the argument numpy dtype
np_dt1, np_dt2 = (np.dtype(bin_f.np_arg_dtypes[i]) for i in (0, 1))

# instantiate value, result and reference
x = np.asarray(val).astype(np_dt1)
y = bin_f.np_buffer_for_arg(1)
ref = x.astype(np_dt2)

if mode == 'CPU':
bin_f(x, y)
else:
bin_f.cuda_wrap_call(x, y)

np.testing.assert_array_equal(y, ref)

0 comments on commit 310afb1

Please sign in to comment.