diff --git a/language/src/common/data.t b/language/src/common/data.t index 2b3abf9421..ffed1a77c2 100644 --- a/language/src/common/data.t +++ b/language/src/common/data.t @@ -315,6 +315,7 @@ data.vector.__index = data.vector function data.vector.__eq(a, b) assert(data.is_vector(a) and data.is_vector(b)) + -- FIXME: Looks like a bug, as it ought to make sure they are the same length. for i, v in ipairs(a) do if v ~= b[i] then return false @@ -428,6 +429,22 @@ function data.map:__newindex(k, v) self:put(k, v) end +function data.map:__eq(x) + for k, v in x:items() do + if not self:has(k) then + return false + end + end + + for k, v in self:items() do + if not (x:has(k) and x[k] == v) then + return false + end + end + + return true +end + function data.map:has(k) return self.__values_by_hash[data.hash(k)] end @@ -473,12 +490,21 @@ function data.map:copy() return self:map(function(k, v) return v end) end +function data.map:copy_recursive(n) + if n == 0 then + return self:copy() + end + + return self:map(function(k, v) return v:copy_recursive(n - 1) end) +end + function data.map:map(fn) local result = data.newmap() for k, v in self:items() do result:put(k, fn(k, v)) end - return result + rawset(result, "__default", self.__default) + return setmetatable(result, getmetatable(self)) end function data.map:map_list(fn) @@ -505,6 +531,7 @@ data.default_map = setmetatable( -- So, apparently for this to work you must re-list any metamethods. __tostring = data.map.__tostring, __newindex = data.map.__newindex, + __eq = data.map.__eq, }, { __index = data.map, }) diff --git a/language/src/regent/codegen.t b/language/src/regent/codegen.t index 657e76c9e3..37ec9616f9 100644 --- a/language/src/regent/codegen.t +++ b/language/src/regent/codegen.t @@ -97,8 +97,6 @@ function context:new_local_scope(divergence, must_epoch, must_epoch_point, break return setmetatable({ variant = self.variant, expected_return_type = self.expected_return_type, - privileges = self.privileges, - constraints = self.constraints, orderings = self.orderings, task = self.task, task_meta = self.task_meta, @@ -116,13 +114,11 @@ function context:new_local_scope(divergence, must_epoch, must_epoch_point, break }, context) end -function context:new_task_scope(expected_return_type, constraints, orderings, leaf, task_meta, task, ctx, runtime) +function context:new_task_scope(expected_return_type, orderings, leaf, task_meta, task, ctx, runtime) assert(expected_return_type and task and ctx and runtime) return setmetatable({ variant = self.variant, expected_return_type = expected_return_type, - privileges = data.newmap(), - constraints = constraints, orderings = orderings, task = task, task_meta = task_meta, @@ -763,7 +759,7 @@ local function unpack_region(cx, region_expr, region_type, static_region_type) assert(not cx:has_region(region_type)) local r = terralib.newsymbol(region_type, "r") - local lr = terralib.newsymbol(c.legion_logical_region_t, "lr") + local lr = terralib.newsymbol(c.legion_logical_region_t, "lr") local is = terralib.newsymbol(c.legion_index_space_t, "is") local it = false if cache_index_iterator then @@ -785,11 +781,18 @@ local function unpack_region(cx, region_expr, region_type, static_region_type) end end - local parent_region_type = std.search_constraint_predicate( - cx, region_type, {}, - function(cx, region) - return cx:has_region(region) - end) + local parent_region_type = region_type + if not cx:has_region(region_type) then + local region_cx = cx.task_meta:get_region_context() + if region_cx:has_region(region_type) then + for parent in region_cx.superregions[region_type]:items() do + if cx:has_region(parent) then + parent_region_type = parent + break + end + end + end + end if not parent_region_type then error("failed to find appropriate for region " .. tostring(region_type) .. " in unpack", 2) end @@ -2002,7 +2005,7 @@ function rawref:read(cx) return self:__ref(cx) end -function rawref:write(cx, value) +function rawref:write(cx, value, expr_type) local value_expr = value:read(cx) local ref_expr = self:__ref(cx) local cleanup = make_cleanup_item(cx, ref_expr.value, self.value_type.type) @@ -3493,7 +3496,7 @@ local lift_cast_to_futures = terralib.memoize( leaf = true, inner = false, idempotent = true, - replicable = false, + replicable = false, }, region_divergence = false, metadata = false, @@ -3509,7 +3512,7 @@ local lift_cast_to_futures = terralib.memoize( task:set_privileges(node.privileges) task:set_conditions({}) task:set_param_constraints(node.constraints) - task:set_constraints({}) + task:set_region_context(std.region_context()) task:set_region_universe(data.newmap()) return codegen.entry(node) end) @@ -4751,7 +4754,7 @@ function codegen.expr_cross_product_array(cx, node) var color_space = c.legion_index_partition_get_color_space( [cx.runtime], [lhs.value].impl.index_partition) - var color_domain = + var color_domain = c.legion_index_space_get_domain([cx.runtime], color_space) std.assert(color_domain.dim == 1, "color domain should be 1D") var start_color = color_domain.rect_data[0] @@ -7071,7 +7074,7 @@ local lift_unary_op_to_futures = terralib.memoize( task:set_privileges(node.privileges) task:set_conditions({}) task:set_param_constraints(node.constraints) - task:set_constraints({}) + task:set_region_context(std.region_context()) task:set_region_universe(data.newmap()) return codegen.entry(node) end) @@ -7170,7 +7173,7 @@ local lift_binary_op_to_futures = terralib.memoize( task:set_privileges(node.privileges) task:set_conditions({}) task:set_param_constraints(node.constraints) - task:set_constraints({}) + task:set_region_context(std.region_context()) task:set_region_universe(data.newmap()) return codegen.entry(node) end) @@ -8473,7 +8476,7 @@ function codegen.stat_for_list(cx, node) [rect_it_step]([rect_it]) end ::[break_label]:: - [rect_it_destroy]([rect_it]) + [rect_it_destroy]([rect_it]) [postamble] end @@ -9238,6 +9241,104 @@ function codegen.stat_var(cx, node) else decls:insert(quote var [lhs] end) end + + if not node.value and std.is_region(node.symbol:gettype()) and + not cx:has_ispace(node.symbol:gettype():ispace()) then + local region_type = node.symbol:gettype() + local ispace_type = region_type:ispace() + local index_type = ispace_type.index_type + + local field_paths, field_types = std.flatten_struct_fields(region_type:fspace()) + local privileges, privilege_field_paths, privilege_field_types = + cx.task_meta:get_region_context():find_codegen_privileges(region_type) + + if #privileges == 0 then + -- If there are no privileges then the region cannot be used anyway. + return quote [actions]; [decls] end + end + + local privileges_by_field_path = std.group_task_privileges_by_field_path( + privileges, privilege_field_paths) + + local field_id_array_buffer = terralib.newsymbol(&c.legion_field_id_t[#field_paths], "field_ids") + local field_id_array = `(@[field_id_array_buffer]) + local field_ids_by_field_path = data.dict( + data.zip(field_paths:map(data.hash), data.mapi(function(field_i, _) return `([field_id_array][field_i - 1]) end, field_paths))) + + decls:insert(quote + var [field_id_array_buffer] = [&c.legion_field_id_t[#field_paths]](c.malloc([#field_paths] * [terralib.sizeof(c.legion_field_id_t)])) + end) + + local physical_region_i = 0 + local physical_regions_by_field_path = {} + local base_pointers_by_field_path = {} + local strides_by_field_path = {} + for i, field_paths in ipairs(privilege_field_paths) do + local privilege = privileges[i] + local field_types = privilege_field_types[i] + local physical_region = terralib.newsymbol( + c.legion_physical_region_t, + "pr_" .. tostring(physical_region_i)) + physical_region_i = physical_region_i + 1 + decls:insert(quote var [physical_region] end) + + assert(#field_paths == 1) + local field_path = field_paths[1] + local field_type = field_types[1] + + physical_regions_by_field_path[field_path:hash()] = physical_region + + local dim = data.max(index_type.dim, 1) + local strides = terralib.newlist() + for j = 1, dim do + strides:insert(terralib.newsymbol(c.size_t, "stride" .. tostring(j))) + decls:insert(quote var [ strides[j] ] end) + end + + local base_pointer + if std.is_regent_array(field_type) then + base_pointer = terralib.newsymbol((&elem_type)[field_type.N], "base_pointer") + else + base_pointer = terralib.newsymbol(&field_type, "base_pointer") + end + decls:insert(quote var [base_pointer] end) + + base_pointers_by_field_path[field_path:hash()] = base_pointer + strides_by_field_path[field_path:hash()] = strides + + assert(privileges_by_field_path[field_path:hash()] ~= "none") + end + + local is = terralib.newsymbol(c.legion_index_space_t, "is") + local it = false + if cache_index_iterator then + it = terralib.newsymbol(c.legion_terra_cached_index_iterator_t, "it") + decls:insert(quote var [it] end) + end + local _, domain, bounds = index_space_bounds(cx, is, ispace_type) + if bounds then + decls:insert(quote var [bounds] end) + end + decls:insert(quote + var [domain] + var [is] + end) + + cx:add_ispace_root(ispace_type, is, it, domain, bounds) + + cx:add_region_root(region_type, node.symbol:getsymbol(), + field_paths, + privilege_field_paths, + privileges_by_field_path, + data.dict(data.zip(field_paths:map(data.hash), field_types)), + field_ids_by_field_path, + field_id_array, + data.dict(data.zip(field_paths:map(data.hash), field_types:map(function(_) return false end))), + physical_regions_by_field_path, + base_pointers_by_field_path, + strides_by_field_path) + end + return quote [actions]; [decls] end end @@ -9339,6 +9440,62 @@ function codegen.stat_assignment(cx, node) expr.just(quote end, rhs_expr.value), std.as_read(node.rhs.expr_type)) + local lhs_type = std.as_read(node.lhs.expr_type) + local rhs_type = std.as_read(node.rhs.expr_type) + if std.is_region(lhs_type) and std.is_region(rhs_type) and lhs_type ~= rhs_type then + -- Move associated region information as well + + local lhs_region = cx:region(lhs_type) + local rhs_region = cx:region(rhs_type) + local lhs_ispace = cx:ispace(lhs_type:ispace()) + local rhs_ispace = cx:ispace(rhs_type:ispace()) + + for i, field in ipairs(lhs_region.field_paths) do + actions:insert(quote + [lhs_region.field_id_array][ [i - 1] ] = [rhs_region.field_ids[data.hash(field)]] + end) + end + + for i, field_paths in ipairs(lhs_region.privilege_field_paths) do + assert(#field_paths == 1) + local field_path = field_paths[1] + local field_path_hash = field_path:hash() + + actions:insert(quote + [lhs_region.physical_regions[field_path_hash]] = [rhs_region.physical_regions[field_path_hash]] + [lhs_region.base_pointers[field_path_hash]] = [rhs_region.base_pointers[field_path_hash]] + end) + + local dim = data.max(lhs_type:ispace().dim, 1) + for j = 1, dim do + actions:insert(quote + [ lhs_region.strides[field_path_hash][j] ] = [ rhs_region.strides[field_path_hash][j] ] + end) + end + end + + actions:insert(quote + [lhs_region.logical_region].impl = [rhs_region.logical_region].impl + [lhs_ispace.index_space] = [rhs_ispace.index_space] + [lhs_ispace.domain] = [rhs_ispace.domain] + end) + + if cache_index_iterator then + actions:insert(quote + [lhs_ispace.index_iterator] = [rhs_ispace.index_iterator] + end) + end + + if lhs_ispace.bounds then + assert(rhs_ispace.bounds) + actions:insert(quote + [lhs_ispace.bounds] = [rhs_ispace.bounds] + end) + end + + return quote [actions] end + end + actions:insert(lhs:write(cx, rhs, node.lhs.expr_type).actions) return quote [actions] end @@ -9533,7 +9690,7 @@ local make_dummy_task = terralib.memoize( task:set_privileges(node.privileges) task:set_conditions({}) task:set_param_constraints(node.constraints) - task:set_constraints({}) + task:set_region_context(std.region_context()) task:set_region_universe(data.newmap()) return codegen.entry(node) end) @@ -10166,6 +10323,24 @@ local function setup_regent_calling_convention_metadata(node, task) task:set_field_id_param_labels(param_field_id_labels) end +local function find_privileged_parent(task, region_type) + local privileges = task:get_privileges() + local region_cx = task:get_region_context() + if not region_cx:has_region(region_type) then + return + end + + for _, privilege_list in ipairs(privileges) do + for _, privilege in ipairs(privilege_list) do + local parent = privilege.region:gettype() + if region_cx:has_region(parent) and parent ~= region_type and + region_cx.superregions[region_type][parent] then + return parent + end + end + end +end + function codegen.top_task(cx, node) log_codegen:info("%s", "Starting codegen for task " .. tostring(node.name)) @@ -10253,23 +10428,10 @@ function codegen.top_task(cx, node) end local cx = cx:new_task_scope(return_type, - task:get_constraints(), orderings, variant:get_config_options().leaf, task, c_task, c_context, c_runtime) - -- FIXME: This code should be deduplicated with type_check, no - -- reason to do it twice.... - for _, privilege_list in ipairs(task.privileges) do - for _, privilege in ipairs(privilege_list) do - local privilege_type = privilege.privilege - local region = privilege.region - local field_path = privilege.field_path - assert(std.type_supports_privileges(region:gettype())) - std.add_privilege(cx, privilege_type, region:gettype(), field_path) - end - end - -- Unpack the by-value parameters to the task. local task_setup = terralib.newlist() -- FIXME: This is an obnoxious hack to avoid inline mappings in shard tasks. @@ -10373,10 +10535,8 @@ function codegen.top_task(cx, node) local physical_region_actions = terralib.newlist() local base_pointers = terralib.newlist() local base_pointers_by_field_path = {} - local strides = terralib.newlist() local strides_by_field_path = {} for i, field_paths in ipairs(privilege_field_paths) do - local privilege = privileges[i] local field_types = privilege_field_types[i] local flag = flags[i] local physical_region = terralib.newsymbol( @@ -10438,7 +10598,6 @@ function codegen.top_task(cx, node) for i, field_paths in ipairs(privilege_field_paths) do local field_types = privilege_field_types[i] - local privilege = privileges[i] local physical_region = physical_regions[i] local physical_region_index = physical_regions_index[i] @@ -10473,14 +10632,7 @@ function codegen.top_task(cx, node) local parent if not has_privileges then - local parent_has_privileges = false - for _, field_path in ipairs(field_paths) do - for i = #field_path, 0, -1 do - parent = std.search_any_privilege(cx, region_type, field_path:slice(1, i), {}) - if parent then break end - end - if parent then break end - end + parent = find_privileged_parent(task, region_type) end if parent and cx:has_region(parent) then @@ -10695,7 +10847,7 @@ function codegen.top(cx, node) end, node) end - if node.annotations.cuda:is(ast.annotation.Demand) then + if node.annotations.cuda:is(ast.annotation.Demand) then if not cudahelper.check_cuda_available() then report.warn(node, "ignoring demand pragma at " .. node.span.source .. diff --git a/language/src/regent/dataflow.t b/language/src/regent/dataflow.t new file mode 100644 index 0000000000..9928015a89 --- /dev/null +++ b/language/src/regent/dataflow.t @@ -0,0 +1,302 @@ +-- Copyright 2019 Stanford University, NVIDIA Corporation +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- Data-flow analysis framework + +local ast = require("regent/ast") +local std = require("regent/std") + +local dataflow = {} + +local forward_metatable = {} +local forward = setmetatable({}, forward_metatable) +forward.__index = forward +dataflow.forward = forward + +-- TODO: Maybe make meet be in place + +-- Forward data-flow analysis. The following methods are needed for the +-- semilattice: +-- +, meet operation +-- ==, check for equality of states (for determining convergence) +-- :copy, duplicate state +-- +-- Additionally, it needs the following four methods to implement the transfer +-- function. +-- +-- :enter_block +-- :exit_block +-- :statement (does not need to handle control flow) +-- +-- After constructing an instance, which should only be used inside a single +-- function, use :block to run the analysis on a block of code. + +function forward_metatable:__call(exit_state) + local out = { + block_entries = {}, + exit_stack = { exit_state, len = 1 }, + } + + return setmetatable(out, forward) +end + +function forward:exit_state() + return self.exit_stack[1] +end + +function forward:set_exit_state(value) + self.exit_stack[1] = value +end + +local function push(stack, state) + stack.len = stack.len + 1 + stack[stack.len] = state +end + +local function pop(stack) + local top = stack[stack.len] + stack[stack.len] = nil + stack.len = stack.len - 1 + return top +end + +function forward:block(block, entry, dont_exit) + if not entry then + return nil + end + + if self.block_entries[block] == entry then + -- No need to redo this execution path. + return nil + end + self.block_entries[block] = entry + + entry:enter_block(block) + for _, stat in ipairs(block.stats) do + entry = self:stat(stat, entry) + if not entry then + return nil + end + end + if not dont_exit then + entry:exit_block(block) + end + + return entry +end + +function forward:stat_if(node, entry) + local cond = self:unused_expr(node.cond, entry) + local exit = self:block(node.then_block, dataflow.copy(cond)) + + for _, else_if in ipairs(node.elseif_blocks) do + cond = self:unused_expr(else_if.cond, cond) + local block = self:block(else_if.block, dataflow.copy(cond)) + exit = dataflow.meet(exit, block) + end + + local else_block = self:block(node.else_block, cond) + exit = dataflow.meet(exit, else_block) + return exit +end + +function forward:stat_while(node, entry) + local exit + repeat + local cond = self:unused_expr(node.cond, dataflow.copy(entry)) + exit = dataflow.meet_copy(exit, cond) + push(self.exit_stack, exit) + local block = self:block(node.block, cond) + exit = pop(self.exit_stack) + + local updated + entry, updated = dataflow.meet_updated(entry, block) + until not updated + + return exit +end + +function forward:stat_repeat(node, entry) + local exit + repeat + push(self.exit_stack, exit) + local block = self:block(node.block, dataflow.copy(entry), true) + local cond = self:unused_expr(node.until_cond, block) + if cond then + cond:exit_block(node.block) + end + exit = dataflow.meet(pop(self.exit_stack), cond) + + local updated + entry, updated = dataflow.meet_updated(entry, cond) + until not updated + + return exit +end + +function forward:stat_for_num(node, entry) + for _, value in ipairs(node.values) do + entry = self:unused_expr(value, entry) + end + + local exit + repeat + exit = dataflow.meet(exit, entry) + push(self.exit_stack, exit) + local block = self:block(node.block, dataflow.copy(entry)) + exit = pop(self.exit_stack) + + local updated + entry, updated = dataflow.meet_updated(entry, block) + until not updated + + return exit +end + +function forward:stat_for_list(node, entry) + entry = self:unused_expr(node.value, entry) + + local exit + repeat + exit = dataflow.meet(exit, entry) + push(self.exit_stack, exit) + local block = self:block(node.block, dataflow.copy(entry)) + exit = pop(self.exit_stack) + + local updated + entry, updated = dataflow.meet_updated(entry, block) + until not updated + + return exit +end + +function forward:stat_block(node, entry) + return self:block(node.block, entry) +end + +function forward:stat_break(node, entry) + entry:statement(node) + push(self.exit_stack, dataflow.meet_copy(pop(self.exit_stack), entry)) + + -- nil represents unreachable code. + return nil +end + +function forward:stat_return(node, entry) + entry:statement(node) + self.exit_stack[1] = dataflow.meet_copy(self.exit_stack[1], entry) + + return nil +end + +function forward:stat_other(node, entry) + entry:statement(node) + return entry +end + +local function unreachable() + assert(false, "unreachable") +end + +local forward_stat_node = { + [ast.typed.stat.If] = forward.stat_if, + [ast.typed.stat.While] = forward.stat_while, + [ast.typed.stat.ForNum] = forward.stat_for_num, + [ast.typed.stat.ForNumVectorized] = forward.stat_for_num, + [ast.typed.stat.IndexLaunchNum] = forward.stat_for_num, + [ast.typed.stat.ForList] = forward.stat_for_list, + [ast.typed.stat.ForListVectorized] = forward.stat_for_list, + [ast.typed.stat.IndexLaunchList] = forward.stat_for_list, + [ast.typed.stat.Repeat] = forward.stat_repeat, + [ast.typed.stat.Block] = forward.stat_block, + [ast.typed.stat.MustEpoch] = forward.stat_block, + [ast.typed.stat.ParallelizeWith] = forward.stat_block, + [ast.typed.stat.Break] = forward.stat_break, + [ast.typed.stat.Return] = forward.stat_return, + + [ast.typed.stat.Var] = forward.stat_other, + [ast.typed.stat.VarUnpack] = forward.stat_other, + [ast.typed.stat.Expr] = forward.stat_other, + [ast.typed.stat.Assignment] = forward.stat_other, + [ast.typed.stat.Reduce] = forward.stat_other, + [ast.typed.stat.RawDelete] = forward.stat_other, + [ast.typed.stat.Fence] = forward.stat_other, + [ast.typed.stat.ParallelPrefix] = forward.stat_other, + [ast.typed.stat.BeginTrace] = forward.stat_other, + [ast.typed.stat.EndTrace] = forward.stat_other, + [ast.typed.stat.MapRegions] = forward.stat_other, + [ast.typed.stat.UnmapRegions] = forward.stat_other, + + [ast.typed.stat.Elseif] = unreachable, + [ast.typed.stat.Internal] = unreachable, +} + +local forward_stat = ast.make_single_dispatch(forward_stat_node, {ast.typed.stat}) +function forward:stat(...) + return forward_stat(self)(...) +end + +function forward:unused_expr(node, entry) + if not entry then + return nil + end + + entry:statement( + ast.typed.stat.Expr { + expr = node, + span = node.span, + annotations = node.annotations, + }) + return entry +end + +function dataflow.meet(x, y) + if not y then + return x + elseif not x then + return y + else + return x + y + end +end + +function dataflow.meet_copy(x, y) + if not y then + return x + elseif not x then + return y:copy() + else + return x + y + end +end + +function dataflow.meet_updated(prev, x) + if not prev then + return x, prev ~= x + elseif not x then + return prev, false + else + local new = prev + x + return new, prev ~= new + end +end + +function dataflow.copy(x) + if not x then + return nil + end + return x:copy() +end + +return dataflow diff --git a/language/src/regent/desugar.t b/language/src/regent/desugar.t index 5f97e9d4e7..84e71a08ae 100644 --- a/language/src/regent/desugar.t +++ b/language/src/regent/desugar.t @@ -58,7 +58,7 @@ local function desugar_image_by_task(cx, node) std.newsymbol(partition_type:colors().index_type(colors_symbol)) local colors_expr = ast_util.mk_expr_colors_access(partition) local subregion_type = partition_type:subregion_dynamic() - std.add_constraint(cx, subregion_type, partition_type, std.subregion, false) + cx.constraints:add_constraint(subregion_type, partition_type, std.subregion) local subregion_expr = ast_util.mk_expr_index_access(partition, @@ -85,7 +85,6 @@ local function desugar_image_by_task(cx, node) ast_util.mk_expr_partition(image_partition_type, ast_util.mk_expr_id(colors_symbol), coloring_expr))) - std.add_constraint(cx, image_partition_type, parent_type, std.subregion, false) stats:insert( ast_util.mk_stat_expr( @@ -175,7 +174,7 @@ function desugar.block(cx, block) end function desugar.top_task(node) - local cx = { constraints = node.prototype:get_constraints() } + local cx = { constraints = node.prototype:get_region_context() } local body = node.body and desugar.block(cx, node.body) or false return node { body = body } diff --git a/language/src/regent/optimize_index_launches.t b/language/src/regent/optimize_index_launches.t index 6138b4ef6d..d5e5394eaa 100644 --- a/language/src/regent/optimize_index_launches.t +++ b/language/src/regent/optimize_index_launches.t @@ -39,16 +39,16 @@ end function context:new_local_scope() local cx = { - constraints = self.constraints, + region_cx = self.region_cx, loop_index = false, loop_variables = {}, } return setmetatable(cx, context) end -function context:new_task_scope(constraints) +function context:new_task_scope(region_cx) local cx = { - constraints = constraints, + region_cx = region_cx, } return setmetatable(cx, context) end @@ -149,11 +149,12 @@ local function analyze_noninterference_previous( region_type, other_region_type, std.disjointness) - local exclude_variables = { [cx.loop_index] = true } if not ( not std.type_maybe_eq(region_type.fspace_type, other_region_type.fspace_type) or - std.check_constraint(cx, constraint, exclude_variables) or + -- TODO Need to detect that the constraint holds in every iteration, not + -- just within one. + -- cx.region_cx:check_constraint(constraint) or check_privilege_noninterference(cx, task, arg, other_arg, mapping)) -- Index non-interference is handled at the type checker level -- and is captured in the constraints. @@ -911,7 +912,7 @@ end function optimize_index_launches.top_task(cx, node) if not node.body then return node end - local cx = cx:new_task_scope(node.prototype:get_constraints()) + local cx = cx:new_task_scope(node.prototype:get_region_context()) local body = optimize_index_launches.block(cx, node.body) return node { body = body } diff --git a/language/src/regent/optimize_mapping.t b/language/src/regent/optimize_mapping.t index 59e951beeb..c433463a77 100644 --- a/language/src/regent/optimize_mapping.t +++ b/language/src/regent/optimize_mapping.t @@ -34,10 +34,10 @@ function context:__newindex (field, value) error ("context has no field '" .. field .. "' (in assignment)", 2) end -function context:new_task_scope(constraints, region_universe) - assert(constraints and region_universe) +function context:new_task_scope(region_cx, region_universe) + assert(region_cx and region_universe) local cx = { - constraints = constraints, + region_cx = region_cx, region_universe = region_universe, } return setmetatable(cx, context) @@ -71,7 +71,7 @@ local function uses(cx, region_type, polarity) other_region_type, std.disjointness) if std.type_maybe_eq(region_type:fspace(), other_region_type:fspace()) and - not std.check_constraint(cx, constraint) + not cx.region_cx:check_constraint(constraint) then usage[other_region_type] = polarity end @@ -612,7 +612,7 @@ function optimize_mapping.top_task(cx, node) if not node.body then return node end local cx = cx:new_task_scope( - node.prototype:get_constraints(), + node.prototype:get_region_context(), node.prototype:get_region_universe()) local initial_usage = task_initial_usage(cx, node.privileges) local annotated_body = optimize_mapping.block(cx, node.body) diff --git a/language/src/regent/parallelize_tasks.t b/language/src/regent/parallelize_tasks.t index 8ca4516902..68af8051de 100644 --- a/language/src/regent/parallelize_tasks.t +++ b/language/src/regent/parallelize_tasks.t @@ -906,7 +906,7 @@ local SUBSET = setmetatable({}, { __tostring = function(self) return "SUBSET" en caller_context.__index = caller_context -function caller_context.new(constraints) +function caller_context.new(region_cx) local param = parallel_param.new({ dop = std.config["parallelize-dop"] }) local cx = { __param_stack = terralib.newlist { param }, @@ -932,7 +932,7 @@ function caller_context.new(constraints) -- keep parent-child relationship in region tree __parent_region = {}, -- the constraint graph to update for later stages - constraints = constraints, + region_cx = region_cx, -- symbols for caching region metadata __region_metadata_symbols = {}, -- parallelization parameter -> max dimension @@ -1234,8 +1234,8 @@ function caller_context:update_constraint(expr) local partition = value_type:partition() local parent = value_type:parent_region() local subregion = expr.expr_type - std.add_constraint(self, partition, parent, std.subregion, false) - std.add_constraint(self, subregion, partition, std.subregion, false) + self.region_cx:add_constraint(partition, parent, std.subregion) + self.region_cx:add_constraint(subregion, partition, std.subregion) end -- ##################################### @@ -2263,7 +2263,7 @@ function parallelize_task_calls.top_task(global_cx, node) end -- Add declartions for the variables that contain region metadata (e.g. bounds) - local caller_cx = caller_context.new(node.prototype:get_constraints()) + local caller_cx = caller_context.new(node.prototype:get_region_context()) local body = ast.flatmap_node_continuation( add_metadata_declarations(caller_cx), node.body) @@ -3308,7 +3308,7 @@ function parallelize_tasks.top_task(global_cx, node) task:set_flags(node.flags) task:set_conditions(node.conditions) task:set_param_constraints(node.prototype:get_param_constraints()) - task:set_constraints(node.prototype:get_constraints()) + task:set_region_context(node.prototype:get_region_context()) task:set_region_universe(region_universe) local parallelized = parallelize_tasks.top_task_body(task_cx, normalized) diff --git a/language/src/regent/std.t b/language/src/regent/std.t index 629377c036..b4555d6098 100644 --- a/language/src/regent/std.t +++ b/language/src/regent/std.t @@ -162,129 +162,96 @@ end -- ## Privilege and Constraint Helpers -- ################# -function std.add_privilege(cx, privilege, region, field_path) - assert(privilege:is(ast.privilege_kind)) - assert(std.type_supports_privileges(region)) - assert(data.is_tuple(field_path)) - if not cx.privileges[privilege] then - cx.privileges[privilege] = data.newmap() - end - if not cx.privileges[privilege][region] then - cx.privileges[privilege][region] = data.newmap() - end - cx.privileges[privilege][region][field_path] = true +function std.add_privilege(cx, node, privilege, region, field_path) + cx.dataflow_actions[node]:insert(function(region_cx) + region_cx:add_privilege(privilege, region, field_path) + end) end -function std.copy_privileges(cx, from_region, to_region) +function std.copy_privileges(cx, node, from_region, to_region) assert(std.type_supports_privileges(from_region)) assert(std.type_supports_privileges(to_region)) - local privileges_to_copy = terralib.newlist() - for privilege, privilege_regions in cx.privileges:items() do - local privilege_fields = privilege_regions[from_region] - if privilege_fields then - for _, field_path in privilege_fields:keys() do - privileges_to_copy:insert({privilege, to_region, field_path}) + cx.dataflow_actions[node]:insert(function(region_cx) + local privileges_to_copy = terralib.newlist() + for privilege, privileged_regions in region_cx.privileges:items() do + if privileged_regions:has(from_region) then + for _, field_path in privileged_regions[from_region]:keys() do + privileges_to_copy:insert({privilege, to_region, field_path}) + end end end - end - for _, privilege in ipairs(privileges_to_copy) do - std.add_privilege(cx, unpack(privilege)) - end -end -function std.add_constraint(cx, lhs, rhs, op, symmetric) - if std.is_cross_product(lhs) then lhs = lhs:partition() end - if std.is_cross_product(rhs) then rhs = rhs:partition() end - assert(std.type_supports_constraints(lhs)) - assert(std.type_supports_constraints(rhs)) - cx.constraints[op][lhs][rhs] = true - if symmetric then - std.add_constraint(cx, rhs, lhs, op, false) - end + for _, privilege in ipairs(privileges_to_copy) do + region_cx:add_privilege(unpack(privilege)) + end + end) end -function std.add_constraints(cx, constraints) - for _, constraint in ipairs(constraints) do - local lhs, rhs, op = constraint.lhs, constraint.rhs, constraint.op - local symmetric = op == std.disjointness - std.add_constraint(cx, lhs:gettype(), rhs:gettype(), op, symmetric) - end +function std.add_constraint(cx, node, lhs, rhs, op) + cx.dataflow_actions[node]:insert(function(region_cx) + region_cx:add_constraint(lhs, rhs, op) + end) end -function std.search_constraint_predicate(cx, region, visited, predicate) - if predicate(cx, region) then - return region - end +function std.add_constraints(cx, node, constraints) + cx.dataflow_actions[node]:insert(function(region_cx) + for _, constraint in ipairs(constraints) do + local lhs, rhs, op = constraint.lhs, constraint.rhs, constraint.op + region_cx:add_constraint(lhs, rhs, op) + end + end) +end - if visited[region] then - return nil +function std.require_privilege(cx, node, msg, privilege, region, field_path, region_symbol, action_list) + if not action_list then + action_list = cx.region_checks end - visited[region] = true - if cx.constraints:has(std.subregion) and cx.constraints[std.subregion]:has(region) then - for subregion, _ in cx.constraints[std.subregion][region]:items() do - local result = std.search_constraint_predicate( - cx, subregion, visited, predicate) - if result then return result end + action_list:insert(function(region_cx) + if not region_cx:check_privilege(privilege, region, field_path) then + if not region_symbol then + region_symbol = std.newsymbol(region) + end + report.error( + node, "invalid privileges in " .. msg .. ": " .. tostring(privilege) .. + "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") end - end - return nil + end) end -function std.search_privilege(cx, privilege, region, field_path, visited) - assert(privilege:is(ast.privilege_kind)) - assert(std.type_supports_privileges(region)) - assert(data.is_tuple(field_path)) - return std.search_constraint_predicate( - cx, region, visited, - function(cx, region) - return cx.privileges[privilege] and - cx.privileges[privilege][region] and - cx.privileges[privilege][region][field_path] - end) -end +function std.require_constraint(cx, node, msg, constraint, action_list) + if not action_list then + action_list = cx.region_checks + end -function std.check_privilege(cx, privilege, region, field_path) - assert(privilege:is(ast.privilege_kind)) - assert(std.type_supports_privileges(region)) - assert(data.is_tuple(field_path)) - for i = #field_path, 0, -1 do - if std.search_privilege(cx, privilege, region, field_path:slice(1, i), {}) then - return true + cx.region_checks:insert(function(region_cx) + if not region_cx:check_constraint(constraint) then + report.error(node, "invalid " .. msg .. " missing constraint " .. + tostring(constraint.lhs) .. " " .. tostring(constraint.op) .. + " " .. tostring(constraint.rhs)) end - if std.is_reduce(privilege) then - if std.search_privilege(cx, std.reads, region, field_path:slice(1, i), {}) and - std.search_privilege(cx, std.writes, region, field_path:slice(1, i), {}) - then - return true - end - end - end - return false + end) end -function std.search_any_privilege(cx, region, field_path, visited) - assert(std.is_region(region) and data.is_tuple(field_path)) - return std.search_constraint_predicate( - cx, region, visited, - function(cx, region) - for _, regions in cx.privileges:items() do - if regions[region] and regions[region][field_path] then - return region - end - end - return false - end) -end +function std.require_constraints(cx, node, msg, constraints, mapping) + if not mapping then + mapping = {} + end -function std.check_any_privilege(cx, region, field_path) - assert(std.is_region(region) and data.is_tuple(field_path)) - for i = #field_path, 0, -1 do - if std.search_any_privilege(cx, region, field_path:slice(1, i), {}) then - return true + cx.region_checks:insert(function(region_cx) + for _, constraint in ipairs(constraints) do + local constraint = { + lhs = mapping[constraint.lhs] or constraint.lhs, + rhs = mapping[constraint.rhs] or constraint.rhs, + op = constraint.op, + } + if not region_cx:check_constraint(constraint) then + report.error(node, "invalid " .. msg .. " missing constraint " .. + tostring(constraint.lhs) .. " " .. tostring(constraint.op) .. + " " .. tostring(constraint.rhs)) + end end - end - return false + end) end local function analyze_uses_variables_node(variables) @@ -380,43 +347,324 @@ local function uses_variables(region, variables) return false end -function std.search_constraint(cx, region, constraint, exclude_variables, - visited, reflexive, symmetric) - return std.search_constraint_predicate( - cx, region, visited, - function(cx, region) - if reflexive and region == constraint.rhs then - return true +-- This dataflow keeps track of constraints and privileges for all of the +-- regions in scope at a point in the code. It enforces a set of rules on these +-- relations: +-- +-- 1. Subregion reflexivity: r <= r +-- 2. Subregion transitivity: r <= s and s <= t implies r <= t +-- 3. Disjointness symmetry: r * s iff s * r +-- 4. Disjointness of subregions: r * s and x <= r and y <= s implies x * y +-- 5. Privileges of subregions: priv(r.field) and s <= r implies priv(s.field) +-- (note that field could be the empty string) +-- 6. Reads and writes implies reduces. +-- 7. Privileges of subfields: priv(r.field) implies priv(r.field.subfield) +-- 8. If a privilege holds on all subfields, the field has it as well. +-- +-- Rules 1-5 are applied immediately, while the others are applied on lookup. +-- +-- Meet operations output all constraints and privileges that apply in both +-- cases. +-- +-- The constraints and privileges are stored in the following fields. +-- privileges[priv][region][field]: priv(region.field) +-- disjoint[lhs][rhs]: lhs * rhs +-- superregions[child][parent]: child <= parent +-- subregions[parent][child]: child <= parent +-- +-- Rules +-- (?) Distinct inputs to function with writes permission are disjoint +-- Return value is always disjoint from everything disjoint from all its params. +-- Temporary: return value must be either a subregion is disjoint. +-- If a permission applies to a region then it applies to a field. +-- (?) If a permission applies to all fields of a region then it applies to a +-- region. Not currently true. +-- Which partition a region is associated with is kept track of separately, as +-- partitions cannot be in the constraints and it wouldn't make sense for them +-- to be anyway. + +local region_context_meta = {} +local region_context = setmetatable({}, region_context_meta) +region_context.__index = region_context +std.region_context = region_context + +function region_context_meta:__call() + local out = { + privileges = data.new_recursive_map(2), + disjoint = data.new_recursive_map(1), + superregions = data.new_recursive_map(1), + subregions = data.new_recursive_map(1), + } + + return setmetatable(out, region_context) +end + +function region_context:copy() + local out = { + privileges = self.privileges:copy_recursive(2), + disjoint = self.disjoint:copy_recursive(1), + superregions = self.superregions:copy_recursive(1), + subregions = self.subregions:copy_recursive(1), + } + + return setmetatable(out, region_context) +end + +function region_context:__eq(x) + return self.privileges == x.privileges and + self.disjoint == x.disjoint and + self.superregions == x.superregions and + self.subregions == x.subregions +end + +function region_context:add_region(region) + assert(std.type_supports_constraints(region)) + + -- 1. Subregion reflexivity + self.subregions[region][region] = true + self.superregions[region][region] = true +end + +function region_context:has_region(region) + return self.subregions:has(region) +end + +function region_context:remove_region(region) + if not self:has_region(region) then return end + + for privilege, privileged_regions in self.privileges:items() do + privileged_regions[region] = nil + end + + if self.disjoint:has(region) then + for rhs in self.disjoint[region]:items() do + self.disjoint[rhs][region] = nil + end + self.disjoint[region] = nil + end + + for parent in self.superregions[region]:items() do + self.subregions[parent][region] = nil + end + for child in self.subregions[region]:items() do + self.superregions[child][region] = nil + end + self.subregions[region] = nil + self.superregions[region] = nil +end + +-- Same as +-- self:add_region(new) +-- self:add_subregion_constraint(new, old) +-- self:add_subregion_constraint(old, new) +-- but more efficient. +function region_context:dup_region(old, new) + if old == new or not self:has_region(old) then + return + end + + assert(not self:has_region(new)) + assert(std.type_supports_constraints(new)) + + for privilege, privileged_regions in self.privileges:items() do + privileged_regions[new] = privileged_regions[old]:copy() + end + + if self.disjoint:has(old) then + self.disjoint[new] = self.disjoint[old]:copy() + for rhs in self.disjoint[old]:items() do + self.disjoint[rhs][new] = self.disjoint[rhs][old] + end + end + + self.subregions[new] = self.subregions[old]:copy() + self.superregions[new] = self.superregions[old]:copy() + for parent in self.superregions[old]:items() do + self.subregions[parent][new] = self.subregions[parent][old] + end + for child in self.subregions[old]:items() do + self.superregions[child][new] = self.superregions[child][old] + end + + self.subregions[new][new] = true + self.superregions[new][new] = true +end + +function region_context:add_privilege(privilege, region, field_path) + assert(privilege:is(ast.privilege_kind)) + assert(data.is_tuple(field_path)) + + if not self:has_region(region) then self:add_region(region) end + + for child in self.subregions[region]:items() do + self.privileges[privilege][region][field_path] = true + end +end + +function region_context:add_subregion_constraint(child, parent) + if not self:has_region(child) then self:add_region(child) end + if not self:has_region(parent) then self:add_region(parent) end + + if self.superregions[child][parent] then + return + end + + -- 2. Subregion transitivity + for child in self.subregions[child]:items() do + for parent in self.superregions[parent]:items() do + self.subregions[parent][child] = true + self.superregions[child][parent] = true + + -- 4. Disjointness of subregions + if self.disjoint:has(parent) then + for s in self.disjoint[parent]:items() do + self.disjoint[child][s] = true + self.disjoint[s][child] = true -- 3. Disjointness symmetry + end end - if cx.constraints:has(constraint.op) and - cx.constraints[constraint.op]:has(region) and - cx.constraints[constraint.op][region][constraint.rhs] and - not (exclude_variables and - uses_variables(region, exclude_variables) and - uses_variables(constraint.rhs, exclude_variables)) - then - return true + -- 5. Privileges of subregions + for privilege, privileged_regions in self.privileges:items() do + if privileged_regions:has(parent) then + for field in privileged_regions[parent]:items() do + privileged_regions[child][field] = true + end + end end + end + end +end - if symmetric then - local constraint = { - lhs = constraint.rhs, - rhs = region, - op = constraint.op, - } - if std.search_constraint(cx, constraint.lhs, constraint, - exclude_variables, {}, reflexive, false) - then - return true +function region_context:add_disjointness_constraint(lhs, rhs) + if not self:has_region(lhs) then self:add_region(lhs) end + if not self:has_region(rhs) then self:add_region(rhs) end + + if self.disjoint[lhs][rhs] then + return + end + + -- 4. Disjointness of subregions + for lhs in self.subregions[lhs]:items() do + for rhs in self.subregions[rhs]:items() do + self.disjoint[lhs][rhs] = true + self.disjoint[rhs][lhs] = true -- 3. Disjointness symmetry + end + end +end + +function region_context:add_constraint(lhs, rhs, op) + if std.is_symbol(lhs) then lhs = lhs:gettype() end + if std.is_symbol(rhs) then rhs = rhs:gettype() end + if std.is_cross_product(lhs) then lhs = lhs:partition() end + if std.is_cross_product(rhs) then rhs = rhs:partition() end + + if op == std.disjointness then + self:add_disjointness_constraint(lhs, rhs) + else + assert(op == std.subregion) + self:add_subregion_constraint(lhs, rhs) + end +end + +function region_context:meet_privilege(x, privilege, region, path) + local reads_writes = + privilege == std.reads and self:check_privilege_above(std.writes, region, path) or + privilege == std.writes and self:check_privilege_above(std.reads, region, path) + + return x:check_privilege_above(privilege, region, path, reads_writes) +end + +function region_context:__add(x) + local out = region_context() + + for privilege in self.privileges:items() do + if x.privileges:has(privilege) then + for region in self.privileges[privilege]:items() do + if x.privileges[privilege]:has(region) then + for path in self.privileges[privilege][region]:items() do + local matches, privilege = self:meet_privilege(x, privilege, region, path) + if matches then + out.privileges[privilege][region][path] = true + end + end + + for path in x.privileges[privilege][region]:items() do + local matches, privilege = x:meet_privilege(self, privilege, region, path) + if matches then + out.privileges[privilege][region][path] = true + end + end end end + end + end - return false - end) + for lhs, i in self.superregions:items() do + if x.superregions:has(lhs) then + for rhs in i:items() do + if x.superregions[lhs][rhs] then + out.superregions[lhs][rhs] = true + out.subregions[rhs][lhs] = true + end + end + end + end + + for lhs, i in self.disjoint:items() do + if x.disjoint:has(lhs) then + for rhs in i:items() do + if x.disjoint[lhs][rhs] then + out.disjoint[lhs][rhs] = true + end + end + end + end + + -- Any regions only present in one branch should be considered uninitialized + -- in the other, so we can just copy the privileges and constraints from the + -- initialized branch. + out:merge_context(self) + out:merge_context(x) + + return out +end + +-- Add all regions from x that are not in self, along with their privileges and +-- constraints. +function region_context:merge_context(x) + local to_add = terralib.newlist() + for region in x.subregions:items() do + if not self:has_region(region) then + to_add:insert(region) + end + end + + for _, region in ipairs(to_add) do + for privilege, privileged_regions in x.privileges:items() do + if privileged_regions:has(region) then + for path in privileged_regions[region]:items() do + self:add_privilege(privilege, region, path) + end + end + end + + for parent in x.superregions[region]:items() do + self:add_subregion_constraint(region, parent) + end + + for child in x.subregions[region]:items() do + self:add_subregion_constraint(child, region) + end + + if x.disjoint:has(region) then + for rhs in x.disjoint[region]:items() do + self:add_disjointness_constraint(region, rhs) + end + end + end end -function std.check_constraint(cx, constraint, exclude_variables) +function region_context:check_constraint(constraint) local lhs = constraint.lhs if lhs == std.wild then return true @@ -435,35 +683,123 @@ function std.check_constraint(cx, constraint, exclude_variables) if std.is_cross_product(rhs) then rhs = rhs:partition() end assert(std.type_supports_constraints(rhs)) - local constraint = { - lhs = lhs, - rhs = rhs, - op = constraint.op, - } - return std.search_constraint( - cx, constraint.lhs, constraint, exclude_variables, {}, - constraint.op == std.subregion --[[ reflexive ]], - constraint.op == std.disjointness --[[ symmetric ]]) + if not self:has_region(lhs) or not self:has_region(rhs) then + return constraint.op == std.subregion and lhs == rhs + end + + if constraint.op == std.subregion then + return self.superregions[lhs][rhs] + else + return self.disjoint:has(lhs) and self.disjoint[lhs][rhs] + end end -function std.check_constraints(cx, constraints, mapping) - if not mapping then - mapping = {} +function region_context:check_privilege_above(privilege, region, field_path, reads_writes) + assert(privilege:is(ast.privilege_kind)) + assert(std.type_supports_constraints(region)) + assert(data.is_tuple(field_path)) + + if self.privileges:has(privilege) and self.privileges[privilege]:has(region) then + for i = #field_path, 0, -1 do + if self.privileges[privilege][region][field_path:slice(1, i)] then + return true, privilege + end + end + + if reads_writes then + for i = #field_path, 0, -1 do + for reduction in self.privileges:items() do + if std.is_reduce(reduction) and + self.privileges[reduction]:has(region) and + self.privileges[reduction][region][field_path:slice(1, i)] then + return true, reduction + end + end + end + end end - for _, constraint in ipairs(constraints) do - local constraint = { - lhs = mapping[constraint.lhs] or constraint.lhs, - rhs = mapping[constraint.rhs] or constraint.rhs, - op = constraint.op, - } - if not std.check_constraint(cx, constraint) then - return false, constraint + if std.is_reduce(privilege) then + if self:check_privilege_above(std.reads, region, field_path) and + self:check_privilege_above(std.writes, region, field_path) then + return true, privilege + end + end + + return false +end + +local function check_privilege_below(self, privilege, region, field_path) + if not (self.privileges:has(privilege) and self.privileges[privilege]:has(region)) then + return false + end + + local field_type = std.get_field_path(region.fspace_type, field_path) + assert(field_type) + + if not (field_type:isstruct() or std.is_fspace_instance(field_type)) or + std.is_bounded_type(field_type)then + return false + end + + for _, entry in ipairs(field_type:getentries()) do + local subfield_name = entry[1] or entry.field + local subfield_path = field_path .. data.newtuple(subfield_name) + + if not self.privileges[privilege][region][subfield_path] and + not check_privilege_below(self, privilege, region, subfield_path) then + return false end end + return true end +function region_context:check_privilege(privilege, region, field_path) + return self:check_privilege_above(privilege, region, field_path) or + check_privilege_below(self, privilege, region, field_path) +end + +function region_context:find_codegen_privileges(region_type) + assert(std.type_supports_privileges(region_type)) + + -- Cannot group privileges as they might be grouped in different ways from + -- different branches of the code. + local ungrouped_privileges = terralib.newlist() + local ungrouped_field_paths = terralib.newlist() + local ungrouped_field_types = terralib.newlist() + + local field_paths, field_types = std.flatten_struct_fields(region_type:fspace()) + + local privilege_index = data.newmap() + for i, field_path in ipairs(field_paths) do + local field_type = field_types[i] + + local field_privilege = "none" + for privilege in self.privileges:items() do + if self:check_privilege(privilege, region_type, field_path) then + field_privilege = base.meet_privilege(field_privilege, tostring(privilege)) + end + end + + -- FIXME: For now, render write privileges as + -- read-write. Otherwise, write would get rendered as + -- write-discard, which would not be correct without explicit + -- user annotation. + if field_privilege == "writes" then + field_privilege = "reads_writes" + end + + if field_privilege ~= "none" then + ungrouped_privileges:insert(field_privilege) + ungrouped_field_paths:insert(terralib.newlist({field_path})) + ungrouped_field_types:insert(terralib.newlist({field_type})) + end + end + + return ungrouped_privileges, ungrouped_field_paths, ungrouped_field_types +end + -- ##################################### -- ## Physical Privilege Helpers -- ################# @@ -796,7 +1132,13 @@ local function reconstruct_return_as_arg_type(return_type, mapping) local result = reconstruct_param_as_arg_type(return_type, mapping, true) if result then return result end - return std.type_sub(return_type, mapping) + local output_type = std.type_sub(return_type, mapping) + if std.is_region(output_type) then + output_type = std.region( + std.newsymbol(output_type:ispace(), output_type.ispace_symbol:hasname()), + output_type:fspace()) + end + return output_type end function std.validate_args(node, params, args, isvararg, return_type, mapping, strict) @@ -1094,15 +1436,18 @@ function std.check_read(cx, node) local region_types, error_message = t:bounds() if region_types == nil then report.error(node, error_message) end local field_path = t.field_path - for i, region_type in ipairs(region_types) do - if not std.check_privilege(cx, std.reads, region_type, field_path) then - local regions = t.bounds_symbols - local ref_as_ptr = t.pointer_type.index_type(t.refers_to_type, unpack(regions)) - report.error(node, "invalid privilege reads(" .. - (data.newtuple(regions[i]) .. field_path):mkstring(".") .. - ") for dereference of " .. tostring(ref_as_ptr)) + + cx.region_checks:insert(function(region_cx) + for i, region_type in ipairs(region_types) do + if not region_cx:check_privilege(std.reads, region_type, field_path) then + local regions = t.bounds_symbols + local ref_as_ptr = t.pointer_type.index_type(t.refers_to_type, unpack(regions)) + report.error(node, "invalid privilege reads(" .. + (data.newtuple(regions[i]) .. field_path):mkstring(".") .. + ") for dereference of " .. tostring(ref_as_ptr)) + end end - end + end) end return std.as_read(t) end @@ -1114,15 +1459,18 @@ function std.check_write(cx, node) local region_types, error_message = t:bounds() if region_types == nil then report.error(node, error_message) end local field_path = t.field_path - for i, region_type in ipairs(region_types) do - if not std.check_privilege(cx, std.writes, region_type, field_path) then - local regions = t.bounds_symbols - local ref_as_ptr = t.pointer_type.index_type(t.refers_to_type, unpack(regions)) - report.error(node, "invalid privilege writes(" .. - (data.newtuple(regions[i]) .. field_path):mkstring(".") .. - ") for dereference of " .. tostring(ref_as_ptr)) + + cx.region_checks:insert(function(region_cx) + for i, region_type in ipairs(region_types) do + if not region_cx:check_privilege(std.writes, region_type, field_path) then + local regions = t.bounds_symbols + local ref_as_ptr = t.pointer_type.index_type(t.refers_to_type, unpack(regions)) + report.error(node, "invalid privilege writes(" .. + (data.newtuple(regions[i]) .. field_path):mkstring(".") .. + ") for dereference of " .. tostring(ref_as_ptr)) + end end - end + end) return std.as_read(t) elseif std.is_rawref(t) then return std.as_read(t) @@ -1138,15 +1486,18 @@ function std.check_reduce(cx, op, node) local region_types, error_message = t:bounds() if region_types == nil then report.error(node, error_message) end local field_path = t.field_path - for i, region_type in ipairs(region_types) do - if not std.check_privilege(cx, std.reduces(op), region_type, field_path) then - local regions = t.bounds_symbols - local ref_as_ptr = t.pointer_type.index_type(t.refers_to_type, unpack(regions)) - report.error(node, "invalid privilege " .. tostring(std.reduces(op)) .. "(" .. - (data.newtuple(regions[i]) .. field_path):mkstring(".") .. - ") for dereference of " .. tostring(ref_as_ptr)) + + cx.region_checks:insert(function(region_cx) + for i, region_type in ipairs(region_types) do + if not region_cx:check_privilege(std.reduces(op), region_type, field_path) then + local regions = t.bounds_symbols + local ref_as_ptr = t.pointer_type.index_type(t.refers_to_type, unpack(regions)) + report.error(node, "invalid privilege " .. tostring(std.reduces(op)) .. "(" .. + (data.newtuple(regions[i]) .. field_path):mkstring(".") .. + ") for dereference of " .. tostring(ref_as_ptr)) + end end - end + end) return std.as_read(t) elseif std.is_rawref(t) then return std.as_read(t) diff --git a/language/src/regent/std_base.t b/language/src/regent/std_base.t index bea2e0bad6..a73d3c1742 100644 --- a/language/src/regent/std_base.t +++ b/language/src/regent/std_base.t @@ -143,6 +143,12 @@ end -- ################# function base.type_meet(a, b) + if base.types.is_region(a) and base.types.is_region(b) and + a:ispace().index_type == b:ispace().index_type and + a.fspace_type == b.fspace_type then + return b + end + local function test() local terra query(x : a, y : b) if true then return x end @@ -1230,15 +1236,15 @@ function base.task:get_param_constraints() return self.param_constraints end -function base.task:set_constraints(t) - assert(not self.constraints) - assert(t) - self.constraints = t +function base.task:set_region_context(cx) + assert(not self.region_cx) + assert(cx) + self.region_cx = cx end -function base.task:get_constraints() - assert(self.constraints) - return self.constraints +function base.task:get_region_context() + assert(self.region_cx) + return self.region_cx end function base.task:set_region_universe(t) @@ -1432,7 +1438,7 @@ do flags = false, conditions = false, param_constraints = false, - constraints = false, + region_cx = false, region_universe = false, -- Variants and alternative versions: diff --git a/language/src/regent/symbol_table.t b/language/src/regent/symbol_table.t index 07ce496c10..f879e3abb3 100644 --- a/language/src/regent/symbol_table.t +++ b/language/src/regent/symbol_table.t @@ -69,6 +69,11 @@ function symbol_table:insert(node, index, value) return value end +function symbol_table:force_insert(index, value) + rawset(self.local_env, index, value) + return value +end + function symbol_table:env() return self.combined_env end diff --git a/language/src/regent/type_check.t b/language/src/regent/type_check.t index 5cbbb085af..5d79c2b578 100644 --- a/language/src/regent/type_check.t +++ b/language/src/regent/type_check.t @@ -21,6 +21,7 @@ local pretty = require("regent/pretty") local report = require("common/report") local std = require("regent/std") local symbol_table = require("regent/symbol_table") +local dataflow = require("regent/dataflow") local type_check = {} @@ -38,21 +39,43 @@ function context:__newindex(field, value) error("context has no field '" .. field .. "' (in assignment)", 2) end -function context:new_local_scope(must_epoch, breakable_loop) +function context:new_local_scope(must_epoch, breakable_loop, break_footer) assert(not (self.must_epoch and must_epoch)) + assert(breakable_loop or not break_footer) + must_epoch = self.must_epoch or must_epoch or false breakable_loop = self.breakable_loop or breakable_loop or false local cx = { type_env = self.type_env:new_local_scope(), + region_env = self.region_env:new_local_scope(), + privileges = self.privileges, constraints = self.constraints, + dataflow_actions = self.dataflow_actions, + region_checks = self.region_checks, region_universe = self.region_universe, expected_return_type = self.expected_return_type, fixup_nodes = self.fixup_nodes, must_epoch = must_epoch, breakable_loop = breakable_loop, external = self.external, + + recursive = self.recursive, + + break_footer = self.break_footer, + return_footer = self.return_footer, } + + if breakable_loop then + cx.break_footer = break_footer or false + end + if break_footer then + function cx.return_footer(...) + break_footer(...) + self.return_footer(...) + end + end + setmetatable(cx, context) return cx end @@ -60,14 +83,23 @@ end function context:new_task_scope(expected_return_type) local cx = { type_env = self.type_env:new_local_scope(), + region_env = self.region_env:new_local_scope(), + privileges = data.newmap(), constraints = data.new_recursive_map(2), + dataflow_actions = data.new_default_map(function() return terralib.newlist() end), + region_checks = terralib.newlist(), region_universe = data.newmap(), expected_return_type = {expected_return_type}, fixup_nodes = terralib.newlist(), must_epoch = false, breakable_loop = false, external = false, + + recursive = { value = false }, + + break_footer = false, + return_footer = function() end, } setmetatable(cx, context) return cx @@ -76,6 +108,7 @@ end function context.new_global_scope(type_env) local cx = { type_env = symbol_table.new_global_scope(type_env), + region_env = symbol_table.new_global_scope({}), } setmetatable(cx, context) return cx @@ -105,6 +138,192 @@ function context:set_external(external) self.external = external end +local region_dataflow_meta = {} +local region_dataflow = setmetatable({}, region_dataflow_meta) +region_dataflow.__index = region_dataflow + +function region_dataflow_meta:__call(type_check_cx, region_cx) + local out = { + context = region_cx, + type_check_cx = type_check_cx, + unreachable = false, + } + + return setmetatable(out, region_dataflow) +end + +function region_dataflow:copy() + local out = { + context = self.context:copy(), + type_check_cx = self.type_check_cx, + unreachable = self.unreachable, + } + + return setmetatable(out, region_dataflow) +end + +function region_dataflow:__eq(x) + return self.context == x.context +end + +function region_dataflow:__add(x) + assert(self.type_check_cx == x.type_check_cx) + + if self.unreachable then + return x:copy() + elseif x.unreachable then + return self:copy() + end + + local out = { + context = self.context + x.context, + type_check_cx = self.type_check_cx, + unreachable = false, + } + + return setmetatable(out, region_dataflow) +end + +function region_dataflow:enter_block(block) +end + +function region_dataflow:exit_block(block) +end + +function region_dataflow:stat_var(node) + if node.value then + self:expr(node.value) + end +end + +function region_dataflow:stat_expr(node) + self:expr(node.expr) +end + +function region_dataflow:stat_assignment(node) + self:expr(node.rhs) + self:expr(node.lhs) +end + +function region_dataflow:stat_reduce(node) + self:expr(node.lhs) + self:expr(node.rhs) +end + +function region_dataflow:stat_raw_delete(node) + self:expr(node.value) +end + +function region_dataflow:stat_parallel_prefix(node) + self:expr(node.lhs) + self:expr(node.rhs) + self:expr(node.dir) +end + +function region_dataflow:stat_return(node) + if node.value then + self:expr(node.value) + + local expected_type = self.type_check_cx:get_return_type() + local value_type = std.as_read(node.value.expr_type) + if std.is_region(expected_type) and std.is_region(value_type) then + assert(expected_type ~= value_type) + self.context:dup_region(value_type, expected_type) + end + end +end + +function region_dataflow:stat_ignore(node) +end + +local function unreachable(node) + assert(false, "unreachable") +end + +local region_dataflow_stat_node = { + [ast.typed.stat.Var] = region_dataflow.stat_var, + [ast.typed.stat.VarUnpack] = region_dataflow.stat_var, + [ast.typed.stat.Expr] = region_dataflow.stat_expr, + [ast.typed.stat.Assignment] = region_dataflow.stat_assignment, + [ast.typed.stat.Reduce] = region_dataflow.stat_reduce, + [ast.typed.stat.RawDelete] = region_dataflow.stat_raw_delete, + [ast.typed.stat.ParallelPrefix] = region_dataflow.stat_parallel_prefix, + [ast.typed.stat.Return] = region_dataflow.stat_return, + [ast.typed.stat.Break] = region_dataflow.stat_ignore, + [ast.typed.stat.Fence] = region_dataflow.stat_ignore, + + [ast.typed.stat.BeginTrace] = unreachable, + [ast.typed.stat.EndTrace] = unreachable, + [ast.typed.stat.MapRegions] = unreachable, + [ast.typed.stat.UnmapRegions] = unreachable, + + -- The rest are already handled by the data-flow analysis framework. + [ast.typed.stat] = unreachable, +} + +local region_dataflow_stat = ast.make_single_dispatch( + region_dataflow_stat_node, + {ast.typed.stat}) + +function region_dataflow:statement(node) + region_dataflow_stat(self)(node) + self:run_dataflow_actions(node) +end + +-- If the expression has region type then return the region symbol. Otherwise +-- return nothing. + +function region_dataflow:expr_id(node) + if std.is_region(node.value:gettype()) then + return node.value + end +end + +function region_dataflow:expr_call(node) + local expr_type = node.expr_type + if not std.is_region(expr_type) then + return + end + + if not node.fn.value.region_cx and + not self.type_check_cx.recursive.value then + self.type_check_cx.recursive.value = true + + -- Use other execution paths to determine the privileges of the return + -- value. + self.unreachable = true + end +end + +function region_dataflow:expr(node) + for k, v in pairs(node) do + if terralib.islist(v) then + for _, x in ipairs(v) do + if ast.is_node(x) and x:is(ast.typed.expr) then + self:expr(x) + end + end + elseif ast.is_node(v) and v:is(ast.typed.expr) then + self:expr(v) + end + end + + if node:is(ast.typed.expr.Call) then + self:expr_call(node) + end + + self:run_dataflow_actions(node) +end + +function region_dataflow:run_dataflow_actions(node) + local actions = self.type_check_cx.dataflow_actions + if actions:has(node) then + for _, action in ipairs(actions[node]) do + action(self.context) + end + end +end + function type_check.region_field(cx, node, region, prefix_path, value_type) assert(std.is_symbol(region)) local field_path = prefix_path .. data.newtuple(node.field_name) @@ -420,10 +639,13 @@ function type_check.constraints(cx, node) end function type_check.expr_id(cx, node) - local expr_type = cx.type_env:lookup(node, node.value) + local value = node.value + value = cx.region_env:safe_lookup(value) or value + + local expr_type = cx.type_env:lookup(node, value) return ast.typed.expr.ID { - value = node.value, + value = value, expr_type = expr_type, annotations = node.annotations, span = node.span, @@ -545,10 +767,6 @@ function type_check.expr_field_access(cx, node) end end - if constraints then - std.add_constraints(cx, constraints) - end - local field_type if std.is_region(std.as_read(unpack_type)) and node.field_name == "ispace" then field_type = std.as_read(unpack_type):ispace() @@ -563,6 +781,8 @@ function type_check.expr_field_access(cx, node) -- Volume can be retrieved on any ispace. field_type = int64 elseif std.is_region(std.as_read(unpack_type)) and (node.field_name == "bounds" or node.field_name == "volume") then + assert(not constraints) + -- Index space fields can also be retrieved through a region. return type_check.expr( cx, @@ -591,22 +811,28 @@ function type_check.expr_field_access(cx, node) end end - return ast.typed.expr.FieldAccess { + local typed_expr = ast.typed.expr.FieldAccess { value = value, field_name = node.field_name, expr_type = field_type, annotations = node.annotations, span = node.span, } + + if constraints then + std.add_constraints(cx, typed_expr, constraints) + end + + return typed_expr end -local function add_analyzable_disjointness_constraints(cx, partition, subregion) +local function add_analyzable_disjointness_constraints(cx, node, partition, subregion) local index = subregion:get_index_expr() local other_subregions = partition:subregions_constant() for _, other_subregion in other_subregions:items() do local other_index = other_subregion:get_index_expr() if affine_helper.analyze_index_noninterference(index, other_index) then - std.add_constraint(cx, subregion, other_subregion, std.disjointness, true) + std.add_constraint(cx, node, subregion, other_subregion, std.disjointness) end end end @@ -635,24 +861,26 @@ function type_check.expr_index_access(cx, node) local subregion if analyzable then subregion = value_type:subregion_constant(index) - - if value_type:is_disjoint() then - add_analyzable_disjointness_constraints(cx, value_type, subregion) - end else subregion = value_type:subregion_dynamic() end - std.add_constraint(cx, partition, parent, std.subregion, false) - std.add_constraint(cx, subregion, partition, std.subregion, false) - - return ast.typed.expr.IndexAccess { + local typed_expr = ast.typed.expr.IndexAccess { value = value, index = index, expr_type = subregion, annotations = node.annotations, span = node.span, } + + if analyzable and value_type:is_disjoint() then + add_analyzable_disjointness_constraints(cx, typed_expr, value_type, subregion) + end + + std.add_constraint(cx, typed_expr, partition, parent, std.subregion) + std.add_constraint(cx, typed_expr, subregion, partition, std.subregion) + + return typed_expr elseif std.is_cross_product(value_type) then local color_type = value_type:partition():colors().index_type if not std.validate_implicit_cast(index_type, color_type) then @@ -666,26 +894,28 @@ function type_check.expr_index_access(cx, node) if analyzable then subpartition = value_type:subpartition_constant(index) subregion = subpartition:parent_region() - - if value_type:is_disjoint() then - add_analyzable_disjointness_constraints(cx, value_type, subregion) - end else subpartition = value_type:subpartition_dynamic() subregion = subpartition:parent_region() end - std.add_constraint(cx, partition, parent, std.subregion, false) - std.add_constraint(cx, subregion, partition, std.subregion, false) - std.add_constraint(cx, subpartition:partition(), subregion, std.subregion, false) - - return ast.typed.expr.IndexAccess { + local typed_expr = ast.typed.expr.IndexAccess { value = value, index = index, expr_type = subpartition, annotations = node.annotations, span = node.span, } + + if analyzable and value_type:is_disjoint() then + add_analyzable_disjointness_constraints(cx, typed_expr, value_type, subregion) + end + + std.add_constraint(cx, typed_expr, partition, parent, std.subregion) + std.add_constraint(cx, typed_expr, subregion, partition, std.subregion) + std.add_constraint(cx, typed_expr, subpartition:partition(), subregion, std.subregion) + + return typed_expr elseif std.is_region(value_type) then -- FIXME: Need to check if this is a bounded type (with the right -- bound) and, if not, insert a dynamic cast. @@ -749,15 +979,18 @@ function type_check.expr_index_access(cx, node) else expr_type = value_type:slice(1) end - std.add_constraint(cx, expr_type, value_type, std.subregion, false) - return ast.typed.expr.IndexAccess { + local typed_expr = ast.typed.expr.IndexAccess { value = value, index = index, expr_type = expr_type, annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, value_type, std.subregion) + + return typed_expr end elseif std.is_transform_type(value_type) then local expected = std.int2d @@ -773,6 +1006,8 @@ function type_check.expr_index_access(cx, node) span = node.span, } else + -- FIXME: This returns a std.ref, which messes up permission checking. + -- Ask the Terra compiler to kindly tell us what type this operator returns. local test if std.is_regent_array(value_type) then @@ -869,6 +1104,16 @@ function type_check.expr_method_call(cx, node) } end +local function disjoint_from_all(region_cx, region, disjoint_regions) + for i, disjoint_region in ipairs(disjoint_regions) do + if std.is_region(disjoint_region) and + not region_cx:check_constraint({lhs = region, rhs = disjoint_region, op = std.disjointness}) then + return false + end + end + return true +end + function type_check.expr_call(cx, node) local fn = type_check.expr(cx, node.fn) local args = node.args:map( @@ -951,48 +1196,20 @@ function type_check.expr_call(cx, node) local expr_type, need_cast = std.validate_args( node, param_symbols, arg_symbols, def_type.isvararg, def_type.returntype, {}, false) + local mapping = {} + local arg_nums = {} if std.is_task(fn.value) then if cx.must_epoch then -- Inside a must epoch tasks are not allowed to return. expr_type = terralib.types.unit end - local mapping = {} for i, arg_symbol in ipairs(arg_symbols) do local param_symbol = param_symbols[i] local param_type = fn_type.parameters[i] mapping[param_symbol] = arg_symbol mapping[param_type] = arg_symbol - end - - local privileges = fn.value:get_privileges() - for _, privilege_list in ipairs(privileges) do - for _, privilege in ipairs(privilege_list) do - local privilege_type = privilege.privilege - local region = privilege.region - local field_path = privilege.field_path - assert(std.type_supports_privileges(region:gettype())) - local arg_region = mapping[region:gettype()] - if not std.check_privilege(cx, privilege_type, arg_region:gettype(), field_path) then - for i, arg in ipairs(arg_symbols) do - if std.type_eq(arg:gettype(), arg_region:gettype()) then - report.error( - node, "invalid privileges in argument " .. tostring(i) .. - ": " .. tostring(privilege_type) .. "(" .. - (data.newtuple(arg_region) .. field_path):mkstring(".") .. - ")") - end - end - assert(false) - end - end - end - - local constraints = fn.value:get_param_constraints() - local satisfied, constraint = std.check_constraints(cx, constraints, mapping) - if not satisfied then - report.error(node, "invalid call missing constraint " .. tostring(constraint.lhs) .. - " " .. tostring(constraint.op) .. " " .. tostring(constraint.rhs)) + arg_nums[arg_symbol] = i end end @@ -1026,9 +1243,84 @@ function type_check.expr_call(cx, node) annotations = node.annotations, span = node.span, } + if expr_type == untyped then cx.fixup_nodes:insert(result) + elseif std.is_partition(expr_type) then + std.add_constraint(cx, result, expr_type, expr_type:parent_region(), std.subregion) + end + + if std.is_task(fn.value) then + local privileges = fn.value:get_privileges() + for _, privilege_list in ipairs(privileges) do + for _, privilege in ipairs(privilege_list) do + local privilege_type = privilege.privilege + local region = privilege.region + local field_path = privilege.field_path + assert(std.type_supports_privileges(region:gettype())) + local arg_region = mapping[region:gettype()] + std.require_privilege(cx, node, "argument " .. arg_nums[arg_region], privilege_type, + arg_region:gettype(), field_path, arg_region, + cx.dataflow_actions[result]) + end + end + + local constraints = fn.value:get_param_constraints() + std.require_constraints(cx, node, "call", constraints, mapping, cx.dataflow_actions[result]) end + + if std.is_region(expr_type) then + local saved_universe = cx.region_universe:copy() + cx:intern_region(expr_type) + + arg_types = arg_symbols:map(function(x) return x:gettype() end) + + -- Determine return value's privileges and constraints + cx.dataflow_actions[result]:insert(function(region_cx) + local region = def_type.returntype + local fn_cx = fn.value.region_cx + if not fn_cx then return end + + for privilege, privileged_regions in fn_cx.privileges:items() do + if privileged_regions:has(region) then + for path in privileged_regions[region]:items() do + region_cx:add_privilege(privilege, expr_type, path) + end + end + end + + for parent in fn_cx.superregions[region]:items() do + if mapping[parent] then + region_cx:add_subregion_constraint(expr_type, mapping[parent]:gettype()) + end + end + + for child in fn_cx.subregions[region]:items() do + if mapping[child] then + region_cx:add_subregion_constraint(mapping[child]:gettype(), expr_type) + end + end + + if fn_cx.disjoint:has(region) then + for rhs in fn_cx.disjoint[region]:items() do + if mapping[rhs] then + region_cx:add_disjointness_constraint(expr_type, mapping[rhs]:gettype()) + end + end + end + + -- Returned regions must be disjoint from all other regions that are + -- disjoint from all of the parameters. + for other_region, _ in saved_universe:items() do + assert(not std.type_eq(expr_type, other_region)) + if std.type_maybe_eq(expr_type:fspace(), other_region:fspace()) and + disjoint_from_all(region_cx, other_region, arg_types) then + region_cx:add_disjointness_constraint(expr_type, other_region) + end + end + end) + end + return result end @@ -1078,11 +1370,7 @@ function type_check.expr_cast(cx, node) end std.validate_args(node, to_fields, from_fields, false, terralib.types.unit, mapping, false) - local satisfied, constraint = std.check_constraints(cx, to_constraints, mapping) - if not satisfied then - report.error(node, "invalid cast missing constraint " .. tostring(constraint.lhs) .. - " " .. tostring(constraint.op) .. " " .. tostring(constraint.rhs)) - end + std.require_constraints(cx, node, "cast", to_constraints, mapping) else if not std.validate_explicit_cast(from_type, to_type) then report.error(node, "invalid cast from " .. tostring(from_type) .. " to " .. tostring(to_type)) @@ -1340,18 +1628,22 @@ function type_check.expr_static_cast(cx, node) end local parent_region_map = {} - for i, value_region_symbol in ipairs(value_type.bounds_symbols) do - for j, expr_region_symbol in ipairs(expr_type.bounds_symbols) do - local constraint = std.constraint( - value_region_symbol, - expr_region_symbol, - std.subregion) - if std.check_constraint(cx, constraint) then - parent_region_map[i] = j - break + cx.region_checks:insert(function(region_cx) + for i, value_region_symbol in ipairs(value_type.bounds_symbols) do + for j, expr_region_symbol in ipairs(expr_type.bounds_symbols) do + local constraint = std.constraint( + value_region_symbol, + expr_region_symbol, + std.subregion) + if region_cx:check_constraint(constraint) then + parent_region_map[i] = j + break + end end end - end + + -- TODO: Why doesn't it check to make sure that parent_region_map is total? + end) return ast.typed.expr.StaticCast { value = value, @@ -1445,8 +1737,16 @@ function type_check.expr_region(cx, node) end assert(std.type_eq(ispace_symbol:gettype(), ispace_type)) - std.add_privilege(cx, std.reads, region, data.newtuple()) - std.add_privilege(cx, std.writes, region, data.newtuple()) + local typed_expr = ast.typed.expr.Region { + ispace = ispace, + fspace_type = node.fspace_type, + expr_type = region, + annotations = node.annotations, + span = node.span, + } + + std.add_privilege(cx, typed_expr, std.reads, region, data.newtuple()) + std.add_privilege(cx, typed_expr, std.writes, region, data.newtuple()) -- Freshly created regions are, by definition, disjoint from all -- other regions. for other_region, _ in cx.region_universe:items() do @@ -1454,18 +1754,12 @@ function type_check.expr_region(cx, node) -- But still, don't bother litering the constraint space with -- trivial constraints. if std.type_maybe_eq(region:fspace(), other_region:fspace()) then - std.add_constraint(cx, region, other_region, std.disjointness, true) + std.add_constraint(cx, typed_expr, region, other_region, std.disjointness) end end cx:intern_region(region) - return ast.typed.expr.Region { - ispace = ispace, - fspace_type = node.fspace_type, - expr_type = region, - annotations = node.annotations, - span = node.span, - } + return typed_expr end function type_check.expr_partition(cx, node) @@ -1551,7 +1845,7 @@ function type_check.expr_partition(cx, node) end assert(expr_type.parent_region_symbol:gettype() == region_type) - return ast.typed.expr.Partition { + local typed_expr = ast.typed.expr.Partition { disjointness = disjointness, region = region, coloring = coloring, @@ -1560,6 +1854,10 @@ function type_check.expr_partition(cx, node) annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, region_type, std.subregion) + + return typed_expr end function type_check.expr_partition_equal(cx, node) @@ -1604,13 +1902,17 @@ function type_check.expr_partition_equal(cx, node) end assert(expr_type.parent_region_symbol:gettype() == region_type) - return ast.typed.expr.PartitionEqual { + local typed_expr = ast.typed.expr.PartitionEqual { region = region, colors = colors, expr_type = expr_type, annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, region_type, std.subregion) + + return typed_expr end function type_check.expr_partition_by_field(cx, node) @@ -1651,12 +1953,8 @@ function type_check.expr_partition_by_field(cx, node) end local expr_type = std.partition(std.disjoint, region_symbol, colors_symbol) - if not std.check_privilege(cx, std.reads, region_type, region.fields[1]) then - report.error( - node, "invalid privileges in argument 1: " .. tostring(std.reads) .. "(" .. - (data.newtuple(expr_type.parent_region_symbol) .. region.fields[1]):mkstring(".") .. - ")") - end + std.require_privilege(cx, node, "argument 1", std.reads, region_type, + region.fields[1], expr_type.parent_region_symbol) -- Hack: Stuff the region type back into the partition's region -- argument, if necessary. @@ -1665,13 +1963,17 @@ function type_check.expr_partition_by_field(cx, node) end assert(expr_type.parent_region_symbol:gettype() == region_type) - return ast.typed.expr.PartitionByField { + local typed_expr = ast.typed.expr.PartitionByField { region = region, colors = colors, expr_type = expr_type, annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, region_type, std.subregion) + + return typed_expr end function type_check.expr_partition_by_restriction(cx, node) @@ -1750,7 +2052,7 @@ function type_check.expr_partition_by_restriction(cx, node) end assert(expr_type.parent_region_symbol:gettype() == region_type) - return ast.typed.expr.PartitionByRestriction { + local typed_expr = ast.typed.expr.PartitionByRestriction { disjointness = node.disjointness, region = region, transform = transform, @@ -1760,6 +2062,10 @@ function type_check.expr_partition_by_restriction(cx, node) annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, region_type, std.subregion) + + return typed_expr end function type_check.expr_image(cx, node) @@ -1817,16 +2123,10 @@ function type_check.expr_image(cx, node) local region_symbol if region.region:is(ast.typed.expr.ID) then region_symbol = region.region.value - else - region_symbol = std.newsymbol(region_type) end - if not std.check_privilege(cx, std.reads, region_type, region.fields[1]) then - report.error( - node, "invalid privileges in argument 3: " .. tostring(std.reads) .. "(" .. - (data.newtuple(region_symbol) .. region.fields[1]):mkstring(".") .. - ")") - end + std.require_privilege(cx, node, "argument 3", std.reads, region_type, + region.fields[1], region_symbol) local parent_symbol if parent:is(ast.typed.expr.ID) then @@ -1849,29 +2149,22 @@ function type_check.expr_image(cx, node) partition_type.parent_region_symbol, region_symbol, std.subregion) - if not std.check_constraint(cx, constraint) then - report.error(node, "invalid image missing constraint " .. - tostring(constraint.lhs) .. " " .. tostring(constraint.op) .. - " " .. tostring(constraint.rhs)) - end + std.require_constraint(cx, node, "image", constraint) end if std.is_bounded_type(field_type) and field_type:is_ptr() then -- Check that parent is a subregion of the field bounds. + -- FIXME: This should check that it is a subregion of one bound, not all. for _, bound_symbol in ipairs(field_type.bounds_symbols) do local constraint = std.constraint( parent_symbol, bound_symbol, std.subregion) - if not std.check_constraint(cx, constraint) then - report.error(node, "invalid image missing constraint " .. - tostring(constraint.lhs) .. " " .. tostring(constraint.op) .. - " " .. tostring(constraint.rhs)) - end + std.require_constraint(cx, node, "image", constraint) end end - return ast.typed.expr.Image { + local typed_expr = ast.typed.expr.Image { disjointness = node.disjointness, parent = parent, partition = partition, @@ -1880,6 +2173,10 @@ function type_check.expr_image(cx, node) annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, parent_type, std.subregion) + + return typed_expr end function type_check.expr_image_by_task(cx, node) @@ -1936,7 +2233,7 @@ function type_check.expr_image_by_task(cx, node) end local expr_type = std.partition(std.aliased, parent_symbol, partition_type.colors_symbol) - return ast.typed.expr.ImageByTask { + local typed_expr = ast.typed.expr.ImageByTask { parent = parent, partition = partition, task = task, @@ -1944,6 +2241,10 @@ function type_check.expr_image_by_task(cx, node) annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, parent_type, std.subregion) + + return typed_expr end function type_check.expr_preimage(cx, node) @@ -2001,16 +2302,10 @@ function type_check.expr_preimage(cx, node) local region_symbol if region.region:is(ast.typed.expr.ID) then region_symbol = region.region.value - else - region_symbol = std.newsymbol(region_type) end - if not std.check_privilege(cx, std.reads, region_type, region.fields[1]) then - report.error( - node, "invalid privileges in argument 3: " .. tostring(std.reads) .. "(" .. - (data.newtuple(region_symbol) .. region.fields[1]):mkstring(".") .. - ")") - end + std.require_privilege(cx, node, "argument 3", std.reads, region_type, + region.fields[1], region_symbol) local parent_symbol if parent:is(ast.typed.expr.ID) then @@ -2038,29 +2333,22 @@ function type_check.expr_preimage(cx, node) parent_symbol, region_symbol, std.subregion) - if not std.check_constraint(cx, constraint) then - report.error(node, "invalid image missing constraint " .. - tostring(constraint.lhs) .. " " .. tostring(constraint.op) .. - " " .. tostring(constraint.rhs)) - end + std.require_constraint(cx, node, "preimage", constraint) end if std.is_bounded_type(field_type) and field_type:is_ptr() then -- Check that partitions's parent is a subregion of the field bounds. + -- FIXME: This should check that it is a subregion of one bound, not all. for _, bound_symbol in ipairs(field_type.bounds_symbols) do local constraint = std.constraint( partition_type.parent_region_symbol, bound_symbol, std.subregion) - if not std.check_constraint(cx, constraint) then - report.error(node, "invalid image missing constraint " .. - tostring(constraint.lhs) .. " " .. tostring(constraint.op) .. - " " .. tostring(constraint.rhs)) - end + std.require_constraint(cx, node, "preimage", constraint) end end - return ast.typed.expr.Preimage { + local typed_expr = ast.typed.expr.Preimage { disjointness = node.disjointness, partition = partition, region = region, @@ -2069,6 +2357,10 @@ function type_check.expr_preimage(cx, node) annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, parent_type, std.subregion) + + return typed_expr end function type_check.expr_cross_product(cx, node) @@ -2145,17 +2437,19 @@ function type_check.expr_list_slice_partition(cx, node) -- FIXME: The privileges for these region aren't necessarily exactly -- one level up. - std.copy_privileges(cx, partition_type:parent_region(), expr_type) - -- FIXME: Copy constraints. - cx:intern_region(expr_type) - - return ast.typed.expr.ListSlicePartition { + local typed_expr = ast.typed.expr.ListSlicePartition { partition = partition, indices = indices, expr_type = expr_type, annotations = node.annotations, span = node.span, } + + std.copy_privileges(cx, typed_expr, partition_type:parent_region(), expr_type) + -- FIXME: Copy constraints. + cx:intern_region(expr_type) + + return typed_expr end function type_check.expr_list_duplicate_partition(cx, node) @@ -2176,8 +2470,16 @@ function type_check.expr_list_duplicate_partition(cx, node) partition_type:parent_region():fspace()), partition_type) - std.add_privilege(cx, std.reads, expr_type, data.newtuple()) - std.add_privilege(cx, std.writes, expr_type, data.newtuple()) + local typed_expr = ast.typed.expr.ListDuplicatePartition { + partition = partition, + indices = indices, + expr_type = expr_type, + annotations = node.annotations, + span = node.span, + } + + std.add_privilege(cx, typed_expr, std.reads, expr_type, data.newtuple()) + std.add_privilege(cx, typed_expr, std.writes, expr_type, data.newtuple()) -- Freshly created regions are, by definition, disjoint from all -- other regions. for other_region, _ in cx.region_universe:items() do @@ -2185,18 +2487,12 @@ function type_check.expr_list_duplicate_partition(cx, node) -- But still, don't bother litering the constraint space with -- trivial constraints. if std.type_maybe_eq(expr_type:fspace(), other_region:fspace()) then - std.add_constraint(cx, expr_type, other_region, std.disjointness, true) + std.add_constraint(cx, typed_expr, expr_type, other_region, std.disjointness) end end cx:intern_region(expr_type) - return ast.typed.expr.ListDuplicatePartition { - partition = partition, - indices = indices, - expr_type = expr_type, - annotations = node.annotations, - span = node.span, - } + return typed_expr end function type_check.expr_list_cross_product(cx, node) @@ -2219,9 +2515,7 @@ function type_check.expr_list_cross_product(cx, node) expr_type = std.list(std.list(rhs_type:subregion_dynamic(), nil, 1), nil, 1) end - std.add_constraint(cx, expr_type, rhs_type, std.subregion, false) - - return ast.typed.expr.ListCrossProduct { + local typed_expr = ast.typed.expr.ListCrossProduct { lhs = lhs, rhs = rhs, shallow = node.shallow, @@ -2229,6 +2523,10 @@ function type_check.expr_list_cross_product(cx, node) annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, rhs_type, std.subregion) + + return typed_expr end function type_check.expr_list_cross_product_complete(cx, node) @@ -2247,15 +2545,17 @@ function type_check.expr_list_cross_product_complete(cx, node) std.list(product_type:subregion_dynamic(), nil, 1), nil, 1) - std.add_constraint(cx, expr_type, product_type, std.subregion, false) - - return ast.typed.expr.ListCrossProductComplete { + local typed_expr = ast.typed.expr.ListCrossProductComplete { lhs = lhs, product = product, expr_type = expr_type, annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, expr_type, product_type, std.subregion) + + return typed_expr end function type_check.expr_list_phase_barriers(cx, node) @@ -2558,60 +2858,19 @@ function type_check.expr_copy(cx, node) end for _, field_path in ipairs(src.fields) do - if not std.check_privilege(cx, std.reads, src_type, field_path) then - local src_symbol - if node.src.region:is(ast.specialized.expr.ID) then - src_symbol = node.src.region.value - else - src_symbol = std.newsymbol() - end - report.error( - node, "invalid privileges in copy: " .. tostring(std.reads) .. "(" .. - (data.newtuple(src_symbol) .. field_path):mkstring(".") .. ")") + local src_symbol + if node.src.region:is(ast.specialized.expr.ID) then + src_symbol = node.src.region.value end + std.require_privilege(cx, node, "copy", std.reads, src_type, field_path, src_symbol) end for _, field_path in ipairs(dst.fields) do - if node.op then - if not std.check_privilege(cx, std.reduces(node.op), dst_type, field_path) - then - local dst_symbol - if node.dst.region:is(ast.specialized.expr.ID) then - dst_symbol = node.dst.region.value - else - dst_symbol = std.newsymbol() - end - report.error( - node, - "invalid privileges in copy: " .. tostring(std.reduces(node.op)) .. - "(" .. (data.newtuple(dst_symbol) .. field_path):mkstring(".") .. - ")") - end - else - if not std.check_privilege(cx, std.reads, dst_type, field_path) then - local dst_symbol - if node.dst.region:is(ast.specialized.expr.ID) then - dst_symbol = node.dst.region.value - else - dst_symbol = std.newsymbol() - end - report.error( - node, "invalid privileges in copy: " .. tostring(std.reads) .. - "(" .. (data.newtuple(dst_symbol) .. field_path):mkstring(".") .. - ")") - end - if not std.check_privilege(cx, std.writes, dst_type, field_path) then - local dst_symbol - if node.dst.region:is(ast.specialized.expr.ID) then - dst_symbol = node.dst.region.value - else - dst_symbol = std.newsymbol() - end - report.error( - node, "invalid privileges in copy: " .. tostring(std.writes) .. - "(" .. (data.newtuple(dst_symbol) .. field_path):mkstring(".") .. - ")") - end + local privilege = node.op and std.reduces(node.op) or std.writes + local dst_symbol + if node.dst.region:is(ast.specialized.expr.ID) then + dst_symbol = node.dst.region.value end + std.require_privilege(cx, node, "copy", privilege, dst_type, field_path, dst_symbol) end return ast.typed.expr.Copy { @@ -2645,17 +2904,11 @@ function type_check.expr_fill(cx, node) end for _, field_path in ipairs(dst.fields) do - if not std.check_privilege(cx, std.writes, dst_type, field_path) then - local dst_symbol - if node.dst.region:is(ast.specialized.expr.ID) then - dst_symbol = node.dst.region.value - else - dst_symbol = std.newsymbol() - end - report.error( - node, "invalid privileges in fill: " .. tostring(std.writes) .. - "(" .. (data.newtuple(dst_symbol) .. field_path):mkstring(".") .. ")") + local dst_symbol + if node.dst.region:is(ast.specialized.expr.ID) then + dst_symbol = node.dst.region.value end + std.require_privilege(cx, node, "fill", std.writes, dst_type, field_path, dst_symbol) end for _, field_path in ipairs(dst.fields) do @@ -2686,27 +2939,12 @@ function type_check.expr_acquire(cx, node) local expr_type = terralib.types.unit for _, field_path in ipairs(region.fields) do - if not std.check_privilege(cx, std.reads, region_type, field_path) then + for _, privilege in ipairs({std.reads, std.writes}) do local region_symbol if node.region.region:is(ast.specialized.expr.ID) then region_symbol = node.region.region.value - else - region_symbol = std.newsymbol() end - report.error( - node, "invalid privileges in acquire: " .. tostring(std.reads) .. - "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") - end - if not std.check_privilege(cx, std.writes, region_type, field_path) then - local region_symbol - if node.region.region:is(ast.specialized.expr.ID) then - region_symbol = node.region.region.value - else - region_symbol = std.newsymbol() - end - report.error( - node, "invalid privileges in acquire: " .. tostring(std.writes) .. - "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") + std.require_privilege(cx, node, "acquire", privilege, region_type, field_path, region_symbol) end end @@ -2729,27 +2967,12 @@ function type_check.expr_release(cx, node) local expr_type = terralib.types.unit for _, field_path in ipairs(region.fields) do - if not std.check_privilege(cx, std.reads, region_type, field_path) then + for _, privilege in ipairs({std.reads, std.writes}) do local region_symbol if node.region.region:is(ast.specialized.expr.ID) then region_symbol = node.region.region.value - else - region_symbol = std.newsymbol() end - report.error( - node, "invalid privileges in release: " .. tostring(std.reads) .. - "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") - end - if not std.check_privilege(cx, std.writes, region_type, field_path) then - local region_symbol - if node.region.region:is(ast.specialized.expr.ID) then - region_symbol = node.region.region.value - else - region_symbol = std.newsymbol() - end - report.error( - node, "invalid privileges in release: " .. tostring(std.writes) .. - "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") + std.require_privilege(cx, node, "release", privilege, region_type, field_path, region_symbol) end end @@ -2794,27 +3017,12 @@ function type_check.expr_attach_hdf5(cx, node) field_map = field_map and insert_implicit_cast(field_map, field_map_type, &rawstring) for _, field_path in ipairs(region.fields) do - if not std.check_privilege(cx, std.reads, region_type, field_path) then + for _, privilege in ipairs({std.reads, std.writes}) do local region_symbol if node.region.region:is(ast.specialized.expr.ID) then region_symbol = node.region.region.value - else - region_symbol = std.newsymbol() end - report.error( - node, "invalid privileges in attach: " .. tostring(std.reads) .. - "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") - end - if not std.check_privilege(cx, std.writes, region_type, field_path) then - local region_symbol - if node.region.region:is(ast.specialized.expr.ID) then - region_symbol = node.region.region.value - else - region_symbol = std.newsymbol() - end - report.error( - node, "invalid privileges in detach: " .. tostring(std.writes) .. - "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") + std.require_privilege(cx, node, "attach", privilege, region_type, field_path, region_symbol) end end @@ -2840,27 +3048,12 @@ function type_check.expr_detach_hdf5(cx, node) end for _, field_path in ipairs(region.fields) do - if not std.check_privilege(cx, std.reads, region_type, field_path) then + for _, privilege in ipairs({std.reads, std.writes}) do local region_symbol if node.region.region:is(ast.specialized.expr.ID) then region_symbol = node.region.region.value - else - region_symbol = std.newsymbol() end - report.error( - node, "invalid privileges in detach: " .. tostring(std.reads) .. - "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") - end - if not std.check_privilege(cx, std.writes, region_type, field_path) then - local region_symbol - if node.region.region:is(ast.specialized.expr.ID) then - region_symbol = node.region.region.value - else - region_symbol = std.newsymbol() - end - report.error( - node, "invalid privileges in detach: " .. tostring(std.writes) .. - "(" .. (data.newtuple(region_symbol) .. field_path):mkstring(".") .. ")") + std.require_privilege(cx, node, "detach", privilege, region_type, field_path, region_symbol) end end @@ -2901,15 +3094,18 @@ function type_check.expr_with_scratch_fields(cx, node) end end - std.copy_privileges(cx, region_type, expr_type) - - return ast.typed.expr.WithScratchFields { + local typed_expr = ast.typed.expr.WithScratchFields { region = region, field_ids = field_ids, expr_type = expr_type, annotations = node.annotations, span = node.span, } + + std.copy_privileges(cx, typed_expr, region_type, expr_type) + cx:intern_region(expr_type) + + return typed_expr end local function unary_op_type(op) @@ -3042,6 +3238,7 @@ function type_check.expr_binary(cx, node) local rhs_type = std.check_read(cx, rhs) local expr_type + local parent_type if std.is_partition(lhs_type) then if not std.is_partition(rhs_type) then report.error(node.rhs, "type mismatch: expected a partition but got " .. tostring(rhs_type)) @@ -3069,6 +3266,7 @@ function type_check.expr_binary(cx, node) expr_type = std.partition( disjointness, lhs_type.parent_region_symbol, lhs_type.colors_symbol) + parent_type = lhs_type.parent_region_symbol:gettype() elseif std.is_region(lhs_type) then if not std.is_partition(rhs_type) then report.error(node.rhs, "type mismatch: expected a partition but got " .. tostring(rhs_type)) @@ -3091,6 +3289,7 @@ function type_check.expr_binary(cx, node) end expr_type = std.partition( rhs_type.disjointness, region_symbol, rhs_type.colors_symbol) + parent_type = region_symbol:gettype() elseif std.is_index_type(lhs_type) and (std.is_region(rhs_type) or std.is_ispace(rhs_type)) then if node.op ~= "<=" then report.error(node.rhs, "operator " .. tostring(node.op) .. @@ -3114,7 +3313,7 @@ function type_check.expr_binary(cx, node) expr_type = binary_ops[node.op](cx, node, lhs_type, rhs_type) end - return ast.typed.expr.Binary { + local typed_expr = ast.typed.expr.Binary { op = node.op, lhs = lhs, rhs = rhs, @@ -3122,6 +3321,12 @@ function type_check.expr_binary(cx, node) annotations = node.annotations, span = node.span, } + + if parent_type then + std.add_constraint(cx, typed_expr, expr_type, parent_type, std.subregion) + end + + return typed_expr end function type_check.expr_deref(cx, node) @@ -3244,26 +3449,28 @@ function type_check.expr_import_region(cx, node) end assert(std.type_eq(ispace_symbol:gettype(), ispace_type)) - std.add_privilege(cx, std.reads, region, data.newtuple()) - std.add_privilege(cx, std.writes, region, data.newtuple()) + local typed_expr = ast.typed.expr.ImportRegion { + ispace = ispace, + value = value, + field_ids = field_ids, + expr_type = region, + annotations = node.annotations, + span = node.span, + } + + std.add_privilege(cx, typed_expr, std.reads, region, data.newtuple()) + std.add_privilege(cx, typed_expr, std.writes, region, data.newtuple()) -- Freshly imported regions are considered as disjoint from all -- other regions. for other_region, _ in cx.region_universe:items() do assert(not std.type_eq(region, other_region)) if std.type_maybe_eq(region:fspace(), other_region:fspace()) then - std.add_constraint(cx, region, other_region, std.disjointness, true) + std.add_constraint(cx, typed_expr, region, other_region, std.disjointness) end end cx:intern_region(region) - return ast.typed.expr.ImportRegion { - ispace = ispace, - value = value, - field_ids = field_ids, - expr_type = region, - annotations = node.annotations, - span = node.span, - } + return typed_expr end function type_check.expr_import_partition(cx, node) @@ -3303,7 +3510,7 @@ function type_check.expr_import_partition(cx, node) end local partition = std.partition(node.disjointness, region_symbol, colors_symbol) - return ast.typed.expr.ImportPartition { + local typed_expr = ast.typed.expr.ImportPartition { region = region, colors = colors, value = value, @@ -3311,6 +3518,10 @@ function type_check.expr_import_partition(cx, node) annotations = node.annotations, span = node.span, } + + std.add_constraint(cx, typed_expr, partition, region_type, std.subregion) + + return typed_expr end function type_check.expr_parallelizer_constraint(cx, node) @@ -3432,15 +3643,94 @@ function type_check.expr(cx, node) return type_check_expr(cx)(node) end -function type_check.block(cx, node) +function type_check.block(cx, node, footer) + local typed_stats = terralib.newlist() + for _, stat in ipairs(node.stats) do + local typed_stat = type_check.stat(cx, stat) + + if terralib.islist(typed_stat) then + typed_stats:insertall(typed_stat) + else + typed_stats:insert(typed_stat) + end + end + + if footer then + footer(cx, typed_stats) + end + return ast.typed.Block { - stats = node.stats:map( - function(stat) return type_check.stat(cx, stat) end), + stats = typed_stats, span = node.span, } end +function type_check.declare_phi_vars(cx, node, changed_regions, stats, region_mapping) + for region in pairs(changed_regions) do + local old_type = region:gettype() + local type = std.region(std.newsymbol(old_type:ispace(), old_type.ispace_symbol:hasname()), old_type:fspace()) + --local type = std.region(old_type.ispace_symbol, old_type:fspace()) + local renamed_region = std.newsymbol(type, region:getname()) + cx.type_env:insert(node, renamed_region, std.rawref(&type)) + region_mapping[region] = renamed_region + + local decl = ast.typed.stat.Var { + symbol = renamed_region, + type = type, + value = false, + annotations = node.annotations, + span = node.span, + } + stats:insert(decl) + end +end + +function type_check.use_phi_vars(cx, node, region_mapping) + for region, renamed_region in pairs(region_mapping) do + cx.region_env:force_insert(region, renamed_region) + end +end + +function type_check.write_phi_vars(cx, node, region_mapping, stats) + for region, renamed_region in pairs(region_mapping) do + local assign = ast.typed.stat.Assignment { + lhs = ast.typed.expr.ID { + value = renamed_region, + expr_type = cx.type_env:lookup(node, renamed_region), + annotations = node.annotations, + span = node.span, + }, + + rhs = type_check.expr(cx, ast.specialized.expr.ID { + value = region, + annotations = node.annotations, + span = node.span, + }), + + metadata = false, + annotations = node.annotations, + span = node.span, + } + + cx.dataflow_actions[assign]:insert(function(region_cx) + local old = assign.rhs.value:gettype() + local new = renamed_region:gettype() + region_cx:remove_region(new) + region_cx:dup_region(old, new) + end) + + stats:insert(assign) + end +end + function type_check.stat_if(cx, node) + local changed_regions = {} + type_check.changed_regions(cx, node, changed_regions) + + local stats = terralib.newlist() + local region_mapping = {} + type_check.declare_phi_vars(cx, node, changed_regions, stats, region_mapping) + local cond = type_check.expr(cx, node.cond) local cond_type = std.check_read(cx, cond) if not std.validate_implicit_cast(cond_type, bool) then @@ -3448,20 +3738,28 @@ function type_check.stat_if(cx, node) end cond = insert_implicit_cast(cond, cond_type, bool) + local function footer(cx, stats) + type_check.write_phi_vars(cx, node, region_mapping, stats) + end + local then_cx = cx:new_local_scope() local else_cx = cx:new_local_scope() - return ast.typed.stat.If { + local if_stat = ast.typed.stat.If { cond = cond, - then_block = type_check.block(then_cx, node.then_block), + then_block = type_check.block(then_cx, node.then_block, footer), elseif_blocks = node.elseif_blocks:map( - function(block) return type_check.stat_elseif(cx, block) end), - else_block = type_check.block(else_cx, node.else_block), + function(block) return type_check.stat_elseif(cx, block, footer) end), + else_block = type_check.block(else_cx, node.else_block, footer), annotations = node.annotations, span = node.span, } + stats:insert(if_stat) + + type_check.use_phi_vars(cx, node, region_mapping) + return stats end -function type_check.stat_elseif(cx, node) +function type_check.stat_elseif(cx, node, footer) local cond = type_check.expr(cx, node.cond) local cond_type = std.check_read(cx, cond) if not std.validate_implicit_cast(cond_type, bool) then @@ -3472,13 +3770,41 @@ function type_check.stat_elseif(cx, node) local body_cx = cx:new_local_scope() return ast.typed.stat.Elseif { cond = cond, - block = type_check.block(body_cx, node.block), + block = type_check.block(body_cx, node.block, footer), annotations = node.annotations, span = node.span, } end +function type_check.make_loop_footers(cx, node, stats) + local out = {} + + out.changed_regions = {} + type_check.changed_regions(cx, node, out.changed_regions) + + out.region_mapping_inner = {} + out.region_mapping_after = {} + type_check.declare_phi_vars(cx, node, out.changed_regions, stats, out.region_mapping_inner) + type_check.declare_phi_vars(cx, node, out.changed_regions, stats, out.region_mapping_after) + + function out.break_footer(cx, stats) + type_check.write_phi_vars(cx, node, out.region_mapping_after, stats) + end + + function out.footer(cx, stats) + type_check.write_phi_vars(cx, node, out.region_mapping_inner, stats) + out.break_footer(cx, stats) + end + + out.footer(cx, stats) + type_check.use_phi_vars(cx, node, out.region_mapping_inner) + return out +end + function type_check.stat_while(cx, node) + local stats = terralib.newlist() + local footers = type_check.make_loop_footers(cx, node, stats) + local cond = type_check.expr(cx, node.cond) local cond_type = std.check_read(cx, cond) if not std.validate_implicit_cast(cond_type, bool) then @@ -3486,13 +3812,17 @@ function type_check.stat_while(cx, node) end cond = insert_implicit_cast(cond, cond_type, bool) - local body_cx = cx:new_local_scope(nil, true) - return ast.typed.stat.While { + local body_cx = cx:new_local_scope(nil, true, footers.break_footer) + local while_stat = ast.typed.stat.While { cond = cond, - block = type_check.block(body_cx, node.block), + block = type_check.block(body_cx, node.block, footers.footer), annotations = node.annotations, span = node.span, } + stats:insert(while_stat) + + type_check.use_phi_vars(cx, node, footers.region_mapping_after) + return stats end function type_check.stat_for_num(cx, node) @@ -3507,11 +3837,14 @@ function type_check.stat_for_num(cx, node) end end + local stats = terralib.newlist() + local footers = type_check.make_loop_footers(cx, node, stats) + -- Enter scope for header. - local cx = cx:new_local_scope() + local header_cx = cx:new_local_scope() local var_type = node.symbol:hastype() or value_types[1] if value_types[3] then - var_type = binary_op_type("+")(cx, node, var_type, value_types[3]) + var_type = binary_op_type("+")(header_cx, node, var_type, value_types[3]) end if not var_type:isintegral() then report.error(node, "numeric for loop expected integral type, got " .. tostring(var_type)) @@ -3520,18 +3853,22 @@ function type_check.stat_for_num(cx, node) node.symbol:settype(var_type) end assert(std.type_eq(var_type, node.symbol:gettype())) - cx.type_env:insert(node, node.symbol, var_type) + header_cx.type_env:insert(node, node.symbol, var_type) -- Enter scope for body. - local cx = cx:new_local_scope(nil, true) - return ast.typed.stat.ForNum { + local body_cx = header_cx:new_local_scope(nil, true, footers.break_footer) + local for_stat = ast.typed.stat.ForNum { symbol = node.symbol, values = values, - block = type_check.block(cx, node.block), + block = type_check.block(body_cx, node.block, footers.footer), metadata = false, annotations = node.annotations, span = node.span, } + stats:insert(for_stat) + + type_check.use_phi_vars(cx, node, footers.region_mapping_after) + return stats end function type_check.stat_for_list(cx, node) @@ -3545,8 +3882,11 @@ function type_check.stat_for_list(cx, node) tostring(value_type)) end + local stats = terralib.newlist() + local footers = type_check.make_loop_footers(cx, node, stats) + -- Enter scope for header. - local cx = cx:new_local_scope() + local header_cx = cx:new_local_scope() -- Hack: Try to recover the original symbol for this bound if possible local bound @@ -3585,23 +3925,33 @@ function type_check.stat_for_list(cx, node) node.symbol:settype(var_type) end assert(std.type_eq(var_type, node.symbol:gettype())) - cx.type_env:insert(node, node.symbol, var_type) + header_cx.type_env:insert(node, node.symbol, var_type) + if std.is_region(var_type) then + header_cx.region_env:insert(node, node.symbol, node.symbol) + end -- Enter scope for body. - local cx = cx:new_local_scope(nil, true) - return ast.typed.stat.ForList { + local body_cx = header_cx:new_local_scope(nil, true, footers.break_footer) + local for_stat = ast.typed.stat.ForList { symbol = node.symbol, value = value, - block = type_check.block(cx, node.block), + block = type_check.block(body_cx, node.block, footers.footer), metadata = false, annotations = node.annotations, span = node.span, } + stats:insert(for_stat) + + type_check.use_phi_vars(cx, node, footers.region_mapping_after) + return stats end function type_check.stat_repeat(cx, node) - local block_cx = cx:new_local_scope(nil, true) - local block = type_check.block(cx, node.block) + local stats = terralib.newlist() + local footers = type_check.make_loop_footers(cx, node, stats) + + local block_cx = cx:new_local_scope(nil, true, footers.break_footer) + local block = type_check.block(block_cx, node.block, footers.footer) local until_cond = type_check.expr(block_cx, node.until_cond) local until_cond_type = std.check_read(block_cx, until_cond) @@ -3610,12 +3960,16 @@ function type_check.stat_repeat(cx, node) end until_cond = insert_implicit_cast(until_cond, until_cond_type, bool) - return ast.typed.stat.Repeat { + local repeat_stat = ast.typed.stat.Repeat { block = block, until_cond = until_cond, annotations = node.annotations, span = node.span, } + stats:insert(repeat_stat) + + type_check.use_phi_vars(cx, node, footers.region_mapping_after) + return stats end function type_check.stat_must_epoch(cx, node) @@ -3648,6 +4002,19 @@ function type_check.stat_var(cx, node) local var_type = symbol:hastype() if var_type then + if value and std.is_region(var_type) then + local value_type_read = std.as_read(value_type) + if std.is_region(value_type_read) and + std.type_eq(var_type:ispace().index_type, value_type_read:ispace().index_type) and + std.type_eq(var_type.fspace_type, value_type_read.fspace_type) then + + -- Use value's region type instead, as otherwise they won't be + -- compatible. + var_type = value_type_read + symbol:settype(var_type) + end + end + if value and not std.validate_implicit_cast(value_type, var_type) then report.error(node, "type mismatch in var: expected " .. tostring(var_type) .. " but got " .. tostring(value_type)) end @@ -3663,6 +4030,10 @@ function type_check.stat_var(cx, node) end cx.type_env:insert(node, symbol, std.rawref(&var_type)) + if std.is_region(var_type) then + cx.region_env:insert(node, symbol, symbol) + end + value = value and insert_implicit_cast(value, value_type, symbol:gettype()) or false return ast.typed.stat.Var { @@ -3719,14 +4090,13 @@ function type_check.stat_var_unpack(cx, node) end assert(symbol:gettype() == field_type) cx.type_env:insert(node, symbol, std.rawref(&field_type)) + if std.is_region(field_type) then + cx.region_env:insert(node, symbol, symbol) + end field_types:insert(field_type) end - if constraints then - std.add_constraints(cx, constraints) - end - - return ast.typed.stat.VarUnpack { + local typed_expr = ast.typed.stat.VarUnpack { symbols = node.symbols, fields = node.fields, field_types = field_types, @@ -3734,6 +4104,12 @@ function type_check.stat_var_unpack(cx, node) annotations = node.annotations, span = node.span, } + + if constraints then + std.add_constraints(cx, typed_expr, constraints) + end + + return typed_expr end function type_check.stat_return(cx, node) @@ -3748,6 +4124,10 @@ function type_check.stat_return(cx, node) local expected_type = cx:get_return_type() assert(expected_type) if std.type_eq(expected_type, std.untyped) then + if std.is_region(value_type) then + local ispace = std.newsymbol(value_type:ispace(), value_type.ispace_symbol:hasname()) + value_type = std.region(ispace, value_type:fspace()) + end cx:set_return_type(value_type) else local result_type = std.type_meet(value_type, expected_type) @@ -3757,30 +4137,61 @@ function type_check.stat_return(cx, node) cx:set_return_type(result_type) end - return ast.typed.stat.Return { + local stats = terralib.newlist() + cx.return_footer(cx, stats) + + local return_stat = ast.typed.stat.Return { value = value, annotations = node.annotations, span = node.span, } + stats:insert(return_stat) + return stats end function type_check.stat_break(cx, node) if not cx.breakable_loop then report.error(node, "break must be inside a loop") end - return ast.typed.stat.Break { + + local stats = terralib.newlist() + if cx.break_footer then + cx.break_footer(cx, stats) + end + + local break_stat = ast.typed.stat.Break { annotations = node.annotations, span = node.span, } + stats:insert(break_stat) + return stats end function type_check.stat_assignment(cx, node) - local lhs = type_check.expr(cx, node.lhs) - local lhs_type = std.check_write(cx, lhs) - local rhs = type_check.expr(cx, node.rhs) local rhs_type = std.check_read(cx, rhs) + -- Rename region variables, like for SSA + if node.lhs:is(ast.specialized.expr.ID) then + local lhs = node.lhs.value + if cx.region_env:safe_lookup(lhs) then + local renamed_region = std.newsymbol(rhs_type, lhs:getname()) + cx.type_env:insert(node, renamed_region, std.rawref(&rhs_type)) + cx.region_env:force_insert(lhs, renamed_region) + + return ast.typed.stat.Var { + symbol = renamed_region, + type = rhs_type, + value = rhs, + annotations = node.annotations, + span = node.span, + } + end + end + + local lhs = type_check.expr(cx, node.lhs) + local lhs_type = std.check_write(cx, lhs) + if not std.validate_implicit_cast(rhs_type, lhs_type) then report.error(node, "type mismatch in assignment: expected " .. tostring(lhs_type) .. " but got " .. tostring(rhs_type)) end @@ -3935,20 +4346,14 @@ function type_check.stat_parallel_prefix(cx, node) report.error(node.dir, "type mismatch in argument 4: expected an integer type, but got " .. tostring(dir.expr_type)) end + local lhs_field_path = data.newtuple() if #lhs.fields > 0 then lhs_field_path = lhs.fields[1] end - if not std.check_privilege(cx, std.writes, lhs.expr_type, lhs_field_path) then - report.error(lhs, - "invalid privilege in argument 1: " .. tostring(std.writes) .. "(" .. - (data.newtuple(lhs.region.value) .. lhs_field_path):mkstring(".") .. ")") - end + std.require_privilege(cx, lhs, "argument 1", std.writes, lhs.expr_type, lhs_field_path, lhs.region.value) + local rhs_field_path = data.newtuple() if #rhs.fields > 0 then rhs_field_path = rhs.fields[1] end - if not std.check_privilege(cx, std.reads, rhs.expr_type, rhs_field_path) then - report.error(rhs, - "invalid privilege in argument 2: " .. tostring(std.reads) .. "(" .. - (data.newtuple(rhs.region.value) .. rhs_field_path):mkstring(".") .. ")") - end + std.require_privilege(cx, rhs, "argument 2", std.reads, rhs.expr_type, rhs_field_path, rhs.region.value) return ast.typed.stat.ParallelPrefix { lhs = lhs, @@ -3991,6 +4396,68 @@ function type_check.stat(cx, node) return type_check_stat(cx)(node) end +function type_check.changed_regions_block(cx, block, changed_regions) + for _, stat in ipairs(block.stats) do + type_check.changed_regions(cx, stat, changed_regions) + end +end + +function type_check.changed_regions_if(cx, node, changed_regions) + type_check.changed_regions_block(cx, node.then_block, changed_regions) + for _, else_if in ipairs(node.elseif_blocks) do + type_check.changed_regions_block(cx, else_if.block, changed_regions) + end + type_check.changed_regions_block(cx, node.else_block, changed_regions) +end + +function type_check.changed_regions_stat_block(cx, node, changed_regions) + type_check.changed_regions_block(cx, node.block, changed_regions) +end + +function type_check.changed_regions_assignment(cx, node, changed_regions) + if node.lhs:is(ast.specialized.expr.ID) then + local lhs = node.lhs.value + if cx.region_env:safe_lookup(lhs) then + changed_regions[lhs] = true + end + end +end + +function type_check.changed_regions_ignore(cx, node, changed_regions) +end + +local type_check_changed_regions_node = { + [ast.specialized.stat.If] = type_check.changed_regions_if, + [ast.specialized.stat.While] = type_check.changed_regions_stat_block, + [ast.specialized.stat.ForNum] = type_check.changed_regions_stat_block, + [ast.specialized.stat.ForList] = type_check.changed_regions_stat_block, + [ast.specialized.stat.Repeat] = type_check.changed_regions_stat_block, + [ast.specialized.stat.MustEpoch] = type_check.changed_regions_stat_block, + [ast.specialized.stat.Block] = type_check.changed_regions_stat_block, + [ast.specialized.stat.ParallelizeWith] = type_check.changed_regions_stat_block, + [ast.specialized.stat.Assignment] = type_check.changed_regions_assignment, + + [ast.specialized.stat.Var] = type_check.changed_regions_ignore, + [ast.specialized.stat.VarUnpack] = type_check.changed_regions_ignore, + [ast.specialized.stat.Reduce] = type_check.changed_regions_ignore, + [ast.specialized.stat.Expr] = type_check.changed_regions_ignore, + [ast.specialized.stat.Return] = type_check.changed_regions_ignore, + [ast.specialized.stat.Break] = type_check.changed_regions_ignore, + [ast.specialized.stat.RawDelete] = type_check.changed_regions_ignore, + [ast.specialized.stat.Fence] = type_check.changed_regions_ignore, + [ast.specialized.stat.ParallelPrefix] = type_check.changed_regions_ignore, + + [ast.specialized.stat.Elseif] = unreachable, +} + +local type_check_changed_regions = ast.make_single_dispatch( + type_check_changed_regions_node, + {ast.specialized.stat}) + +function type_check.changed_regions(cx, node, changed_regions) + return type_check_changed_regions(cx)(node, changed_regions) +end + local opaque_types = { [std.c.legion_domain_point_iterator_t] = true, [std.c.legion_coloring_t] = true, @@ -4053,6 +4520,9 @@ end function type_check.top_task_param(cx, node, task, mapping, is_defined) local param_type = node.symbol:gettype() cx.type_env:insert(node, node.symbol, std.rawref(¶m_type)) + if std.is_region(param_type) then + cx.region_env:insert(node, node.symbol, node.symbol) + end -- Check for parameters with duplicate types. if std.type_supports_constraints(param_type) then @@ -4138,7 +4608,6 @@ function type_check.top_task(cx, node) end end) end - std.add_privilege(cx, privilege_type, region_type, field_path) cx:intern_region(region_type) end end @@ -4154,7 +4623,6 @@ function type_check.top_task(cx, node) prototype:set_conditions(conditions) local constraints = type_check.constraints(cx, node.constraints) - std.add_constraints(cx, constraints) prototype:set_param_constraints(constraints) local body = node.body and type_check.block(cx, node.body) @@ -4167,17 +4635,55 @@ function type_check.top_task(cx, node) params:map(function(param) return param.param_type end), return_type, false) prototype:set_type(task_type, true) - for _, fixup_node in ipairs(cx.fixup_nodes) do - if fixup_node:is(ast.typed.expr.Call) then - local fn_type = fixup_node.fn.value:get_type() - assert(fn_type.returntype ~= untyped) - fixup_node.expr_type = fn_type.returntype - else - assert(false) + local region_cx = std.region_context() + for _, privilege_list in ipairs(privileges) do + for _, privilege in ipairs(privilege_list) do + region_cx:add_privilege(privilege.privilege, privilege.region:gettype(), privilege.field_path) end end + for _, constraint in ipairs(constraints) do + local lhs, rhs, op = constraint.lhs, constraint.rhs, constraint.op + region_cx:add_constraint(lhs, rhs, op) + end + + if body then + local entry = region_dataflow(cx, region_cx) + + local exit + local count = 0 + repeat + local df = dataflow.forward(exit) + exit = df:block(body, entry:copy()) + + if exit and not exit.unreachable and return_type ~= terralib.types.unit then + report.warn(node, "WARNING: function ends without return statement.") + end + exit = dataflow.meet(exit, df:exit_state()) + + local prev_region_cx = region_cx + region_cx = exit.context + prototype.region_cx = region_cx + count = count + 1 + assert(count < 30) + until not cx.recursive.value or prev_region_cx == region_cx + + for _, check in ipairs(cx.region_checks) do + check(region_cx) + end + + for _, fixup_node in ipairs(cx.fixup_nodes) do + if fixup_node:is(ast.typed.expr.Call) then + local fn_type = fixup_node.fn.value:get_type() + assert(fn_type.returntype ~= untyped) + fixup_node.expr_type = fn_type.returntype + else + assert(false) + end + end + else + prototype:set_region_context(region_cx) + end - prototype:set_constraints(cx.constraints) prototype:set_region_universe(cx.region_universe) return ast.typed.top.Task { diff --git a/language/tests/regent/compile_fail/region_assignment1.rg b/language/tests/regent/compile_fail/region_assignment1.rg new file mode 100644 index 0000000000..5bcfc9423f --- /dev/null +++ b/language/tests/regent/compile_fail/region_assignment1.rg @@ -0,0 +1,42 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- fails-with: +-- region_assignment1.rg:40: invalid call missing constraint $x <= $s + +import "regent" + +task assert_subregion(x : region(int), y : region(int)) +where + x <= y +do +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = region(ispace(ptr, 5), int) + + s[0] = 1 + t[0] = 2 + + var x : region(int) + if true then + x = s + else + x = t + end + + assert_subregion(x, s) +end +regentlib.start(main) diff --git a/language/tests/regent/compile_fail/region_assignment2.rg b/language/tests/regent/compile_fail/region_assignment2.rg new file mode 100644 index 0000000000..ee4b57e57a --- /dev/null +++ b/language/tests/regent/compile_fail/region_assignment2.rg @@ -0,0 +1,42 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- fails-with: +-- region_assignment2.rg:40: invalid call missing constraint $x * $s + +import "regent" + +task assert_disjoint(x : region(int), y : region(int)) +where + x * y +do +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = region(ispace(ptr, 5), int) + + s[0] = 1 + t[0] = 2 + + var x : region(int) + if true then + x = s + else + x = t + end + + assert_disjoint(x, s) +end +regentlib.start(main) diff --git a/language/tests/regent/compile_fail/region_return1.rg b/language/tests/regent/compile_fail/region_return1.rg new file mode 100644 index 0000000000..73dec95a35 --- /dev/null +++ b/language/tests/regent/compile_fail/region_return1.rg @@ -0,0 +1,42 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- fails-with: +-- region_return1.rg:40: invalid privilege writes($r) for dereference of ptr(int32, $r) + +import "regent" + +task switch(b : bool, s : region(int), t : region(int)) : region(int) +where + reads(s), + reads(t) +do + if b then + return s + else + return t + end +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = region(ispace(ptr, 5), int) + + s[0] = 1 + t[0] = 2 + + var r = switch(true, s, t) + r[0] = 1 +end +regentlib.start(main) diff --git a/language/tests/regent/compile_fail/region_return2.rg b/language/tests/regent/compile_fail/region_return2.rg new file mode 100644 index 0000000000..ff133ac708 --- /dev/null +++ b/language/tests/regent/compile_fail/region_return2.rg @@ -0,0 +1,42 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- fails-with: +-- region_return2.rg:40: invalid privilege writes($r) for dereference of ptr(int32, $r) + +import "regent" + +task switch(b : bool, s : region(int), t : region(int)) : region(int) +where + reads(s), + reads(t) +do + if b then + return s + else + return switch(true, t, s) + end +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = region(ispace(ptr, 5), int) + + s[0] = 1 + t[0] = 2 + + var r = switch(true, s, t) + r[0] = 1 +end +regentlib.start(main) diff --git a/language/tests/regent/compile_fail/region_return3.rg b/language/tests/regent/compile_fail/region_return3.rg new file mode 100644 index 0000000000..78a64389e5 --- /dev/null +++ b/language/tests/regent/compile_fail/region_return3.rg @@ -0,0 +1,40 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- fails-with: +-- region_return3.rg:38: invalid call missing constraint $s * $t + +import "regent" + +task assert_disjoint(x : region(int), y : region(int)) +where + x * y +do +end + +task constructor(s : region(int)) : region(int) + if false then + return s + else + return region(ispace(ptr, 5), int) + end +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = constructor(s) + + assert_disjoint(s, t) +end +regentlib.start(main) diff --git a/language/tests/regent/compile_fail/region_return4.rg b/language/tests/regent/compile_fail/region_return4.rg new file mode 100644 index 0000000000..5f8ca1af7f --- /dev/null +++ b/language/tests/regent/compile_fail/region_return4.rg @@ -0,0 +1,41 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- fails-with: +-- region_return4.rg:39: invalid call missing constraint $w * $t + +import "regent" + +task assert_disjoint(x : region(int), y : region(int)) +where + x * y +do +end + +task constructor(s : region(int)) : region(int) + if false then + return s + else + return region(ispace(ptr, 5), int) + end +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = constructor(s) + var w = constructor(s) + + assert_disjoint(w, t) +end +regentlib.start(main) diff --git a/language/tests/regent/compile_fail/type_mismatch_parallel_prefix13.rg b/language/tests/regent/compile_fail/type_mismatch_parallel_prefix13.rg index 110d425621..773c007396 100644 --- a/language/tests/regent/compile_fail/type_mismatch_parallel_prefix13.rg +++ b/language/tests/regent/compile_fail/type_mismatch_parallel_prefix13.rg @@ -13,7 +13,7 @@ -- limitations under the License. -- fails-with: --- type_mismatch_parallel_prefix13.rg:24: invalid privilege in argument 1: writes($r) +-- type_mismatch_parallel_prefix13.rg:24: invalid privileges in argument 1: writes($r) -- __parallel_prefix(r, s, +, 1) -- ^ diff --git a/language/tests/regent/compile_fail/type_mismatch_parallel_prefix14.rg b/language/tests/regent/compile_fail/type_mismatch_parallel_prefix14.rg index 6642990d25..39358bd32a 100644 --- a/language/tests/regent/compile_fail/type_mismatch_parallel_prefix14.rg +++ b/language/tests/regent/compile_fail/type_mismatch_parallel_prefix14.rg @@ -13,7 +13,7 @@ -- limitations under the License. -- fails-with: --- type_mismatch_parallel_prefix14.rg:26: invalid privilege in argument 2: reads($s) +-- type_mismatch_parallel_prefix14.rg:26: invalid privileges in argument 2: reads($s) -- __parallel_prefix(r, s, +, 1) -- ^ diff --git a/language/tests/regent/run_pass/privilege_reduce5.rg b/language/tests/regent/run_pass/privilege_reduce5.rg new file mode 100644 index 0000000000..aae79fc56a --- /dev/null +++ b/language/tests/regent/run_pass/privilege_reduce5.rg @@ -0,0 +1,40 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +import "regent" + +fspace s { + a : int, +} + +task f(r : region(s)) +where + reduces +(r.a) +do + r[0].a += 10 +end + +task g(r : region(s)) +where + reads(r), + writes(r.a) +do + f(r) +end + +task main() + var r = region(ispace(int1d, 1), a) + g(r) +end +regentlib.start(main) diff --git a/language/tests/regent/run_pass/region_assignment1.rg b/language/tests/regent/run_pass/region_assignment1.rg new file mode 100644 index 0000000000..e861e33fc5 --- /dev/null +++ b/language/tests/regent/run_pass/region_assignment1.rg @@ -0,0 +1,40 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +import "regent" + +task foo(x : region(int), y : region(int)) +where + x <= y, + y <= x, + reads writes(x) +do + y[0] = x[0] + x[0] = y[0] +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = region(ispace(ptr, 5), int) + s[0] = 1 + t[0] = 2 + + s = t + foo(s, t) + + regentlib.assert(s[0] == 2, "test failed") + s[0] = 0 + regentlib.assert(s[0] == 0, "test failed") +end +regentlib.start(main) diff --git a/language/tests/regent/compile_fail/type_mismatch_assignment2.rg b/language/tests/regent/run_pass/region_assignment2.rg similarity index 75% rename from language/tests/regent/compile_fail/type_mismatch_assignment2.rg rename to language/tests/regent/run_pass/region_assignment2.rg index 19b51490ef..93b0f9cb99 100644 --- a/language/tests/regent/compile_fail/type_mismatch_assignment2.rg +++ b/language/tests/regent/run_pass/region_assignment2.rg @@ -12,17 +12,24 @@ -- See the License for the specific language governing permissions and -- limitations under the License. --- fails-with: --- type_mismatch_assignment2.rg:26: type mismatch in assignment: expected region(int32) but got region(int32) --- s = t --- ^ - import "regent" -task f() +task main() var s = region(ispace(ptr, 5), int) var t = region(ispace(ptr, 5), int) - s = t + s[0] = 1 + t[0] = 2 + + var x : region(int) + if true then + x = s + else + x = t + end + + regentlib.assert(x[0] == 1, "test failed") + x[0] = 0 + regentlib.assert(s[0] == 0, "test failed") end -f:compile() +regentlib.start(main) diff --git a/language/tests/regent/run_pass/region_return1.rg b/language/tests/regent/run_pass/region_return1.rg new file mode 100644 index 0000000000..02cbc46cd5 --- /dev/null +++ b/language/tests/regent/run_pass/region_return1.rg @@ -0,0 +1,49 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +import "regent" + +task switch(b : bool, s : region(int), t : region(int)) : region(int) +where + reads(s), + reads(t) +do + if b then + return s + else + return t + end +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = region(ispace(ptr, 5), int) + + s[0] = 1 + t[0] = 2 + + var x : region(int) + if true then + x = s + else + x = t + end + + var r = switch(true, s, t) + + regentlib.assert(x[0] == r[0], "test failed") + x[0] = 0 + regentlib.assert(r[0] == 0, "test failed") +end +regentlib.start(main) diff --git a/language/tests/regent/run_pass/region_return2.rg b/language/tests/regent/run_pass/region_return2.rg new file mode 100644 index 0000000000..ca337597f1 --- /dev/null +++ b/language/tests/regent/run_pass/region_return2.rg @@ -0,0 +1,49 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +import "regent" + +task switch(b : bool, s : region(int), t : region(int)) : region(int) +where + reads(s), + reads(t) +do + if b then + return s + else + return switch(true, t, s) + end +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = region(ispace(ptr, 5), int) + + s[0] = 1 + t[0] = 2 + + var x : region(int) + if true then + x = s + else + x = t + end + + var r = switch(true, s, t) + + regentlib.assert(x[0] == r[0], "test failed") + x[0] = 0 + regentlib.assert(r[0] == 0, "test failed") +end +regentlib.start(main) diff --git a/language/tests/regent/run_pass/region_return3.rg b/language/tests/regent/run_pass/region_return3.rg new file mode 100644 index 0000000000..ef775c63d9 --- /dev/null +++ b/language/tests/regent/run_pass/region_return3.rg @@ -0,0 +1,36 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +import "regent" + +task assert_disjoint(x : region(int), y : region(int)) +where + x * y +do +end + +task constructor() : region(int) + return region(ispace(ptr, 5), int) +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = constructor() + var w = constructor() + + assert_disjoint(s, t) + assert_disjoint(s, w) + assert_disjoint(w, t) +end +regentlib.start(main) diff --git a/language/tests/regent/run_pass/region_return4.rg b/language/tests/regent/run_pass/region_return4.rg new file mode 100644 index 0000000000..03cee5a1c4 --- /dev/null +++ b/language/tests/regent/run_pass/region_return4.rg @@ -0,0 +1,42 @@ +-- Copyright 2019 Stanford University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +import "regent" + +task assert_disjoint(x : region(int), y : region(int)) +where + x * y +do +end + +task constructor(s : region(int)) : region(int) + if false then + return s + else + return region(ispace(ptr, 5), int) + end +end + +task main() + var s = region(ispace(ptr, 5), int) + var t = constructor(s) + var u = region(ispace(ptr, 5), int) + var w = constructor(u) + + assert_disjoint(s, u) + assert_disjoint(s, w) + assert_disjoint(t, u) + assert_disjoint(t, w) +end +regentlib.start(main)