Skip to content

Commit

Permalink
Add basic code for binding partition revalidation (#56649)
Browse files Browse the repository at this point in the history
This adds the binding partition revalidation code from #54654. This is
the last piece of that PR that hasn't been merged yet - however the TODO
in that PR still stands for future work.

This PR itself adds a callback that gets triggered by deleting a
binding. It will then walk all code in the system and invalidate code
instances of Methods whose lowered source referenced the given global.
This walk is quite slow. Future work will add backedges and
optimizations to make this faster, but the basic functionality should be
in place with this PR.
  • Loading branch information
Keno authored Jan 9, 2025
1 parent 1ebacac commit cfd3922
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 2 deletions.
1 change: 1 addition & 0 deletions base/Base_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ include("ordering.jl")
using .Order

include("coreir.jl")
include("invalidation.jl")

# For OS specific stuff
# We need to strcat things here, before strings are really defined
Expand Down
111 changes: 111 additions & 0 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

struct GlobalRefIterator
mod::Module
end
IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown()
globalrefs(mod::Module) = GlobalRefIterator(mod)

function iterate(gri::GlobalRefIterator, i = 1)
m = gri.mod
table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m)
i == length(table) && return nothing
b = table[i]
b === nothing && return iterate(gri, i+1)
return ((b::Core.Binding).globalref, i+1)
end

const TYPE_TYPE_MT = Type.body.name.mt
const NONFUNCTION_MT = Core.MethodTable.name.mt
function foreach_module_mtable(visit, m::Module, world::UInt)
for gb in globalrefs(m)
binding = gb.binding
bpart = lookup_binding_partition(world, binding)
if is_defined_const_binding(binding_kind(bpart))
v = partition_restriction(bpart)
uw = unwrap_unionall(v)
name = gb.name
if isa(uw, DataType)
tn = uw.name
if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt)
# this is the original/primary binding for the type (name/wrapper)
mt = tn.mt
if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT
@assert mt.module === m
visit(mt) || return false
end
end
elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name
# this is the original/primary binding for the submodule
foreach_module_mtable(visit, v, world) || return false
elseif isa(v, Core.MethodTable) && v.module === m && v.name === name
# this is probably an external method table here, so let's
# assume so as there is no way to precisely distinguish them
visit(v) || return false
end
end
end
return true
end

function foreach_reachable_mtable(visit, world::UInt)
visit(TYPE_TYPE_MT) || return
visit(NONFUNCTION_MT) || return
for mod in loaded_modules_array()
foreach_module_mtable(visit, mod, world)
end
end

function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
found_any = false
labelchangemap = nothing
stmts = src.code
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
isgr(g) = false
for i = 1:length(stmts)
stmt = stmts[i]
if isgr(stmt)
found_any = true
continue
end
for ur in Compiler.userefs(stmt)
arg = ur[]
# If any of the GlobalRefs in this stmt match the one that
# we are about, we need to move out all GlobalRefs to preserve
# effect order, in case we later invalidate a different GR
if isa(arg, GlobalRef)
if isgr(arg)
@assert !isa(stmt, PhiNode)
found_any = true
break
end
end
end
end
return found_any
end

function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt)
valid_in_valuepos = false
foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable
for method in MethodList(mt)
if isdefined(method, :source)
src = _uncompressed_ir(method)
old_stmts = src.code
if should_invalidate_code_for_globalref(gr, src)
for mi in specializations(method)
ci = mi.cache
while true
if ci.max_world > new_max_world
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
end
end
end
end
return true
end
end
5 changes: 5 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,11 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo
JL_UNLOCK(&replaced_mi->def.method->writelock);
}

JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world)
{
invalidate_code_instance(replaced, max_world, 1);
}

static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) {
jl_array_t *backedges = replaced_mi->backedges;
if (backedges) {
Expand Down
26 changes: 24 additions & 2 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,21 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var
jl_gc_wb(bpart, val);
}

void jl_invalidate_binding_refs(jl_globalref_t *ref, size_t new_world)
{
static jl_value_t *invalidate_code_for_globalref = NULL;
if (invalidate_code_for_globalref == NULL && jl_base_module != NULL)
invalidate_code_for_globalref = jl_get_global(jl_base_module, jl_symbol("invalidate_code_for_globalref!"));
if (!invalidate_code_for_globalref)
jl_error("Binding invalidation is not permitted during bootstrap.");
if (jl_generating_output())
jl_error("Binding invalidation is not permitted during image generation.");
jl_value_t *boxed_world = jl_box_ulong(new_world);
JL_GC_PUSH1(&boxed_world);
jl_call2((jl_function_t*)invalidate_code_for_globalref, (jl_value_t*)ref, boxed_world);
JL_GC_POP();
}

extern jl_mutex_t world_counter_lock;
JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
{
Expand All @@ -1046,9 +1061,11 @@ JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)

JL_LOCK(&world_counter_lock);
jl_task_t *ct = jl_current_task;
size_t last_world = ct->world_age;
size_t new_max_world = jl_atomic_load_acquire(&jl_world_counter);
// TODO: Trigger invalidation here
(void)ct;
ct->world_age = jl_typeinf_world;
jl_invalidate_binding_refs(gr, new_max_world);
ct->world_age = last_world;
jl_atomic_store_release(&bpart->max_world, new_max_world);
jl_atomic_store_release(&jl_world_counter, new_max_world + 1);
JL_UNLOCK(&world_counter_lock);
Expand Down Expand Up @@ -1334,6 +1351,11 @@ JL_DLLEXPORT void jl_add_to_module_init_list(jl_value_t *mod)
jl_array_ptr_1d_push(jl_module_init_order, mod);
}

JL_DLLEXPORT jl_svec_t *jl_module_get_bindings(jl_module_t *m)
{
return jl_atomic_load_relaxed(&m->bindings);
}

JL_DLLEXPORT void jl_init_restored_module(jl_value_t *mod)
{
if (!jl_generating_output() || jl_options.incremental) {
Expand Down
7 changes: 7 additions & 0 deletions test/rebinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,11 @@ module Rebinding
@test Base.@world(Foo, defined_world_age) == typeof(x)
@test Base.@world(Rebinding.Foo, defined_world_age) == typeof(x)
@test Base.@world((@__MODULE__).Foo, defined_world_age) == typeof(x)

# Test invalidation (const -> undefined)
const delete_me = 1
f_return_delete_me() = delete_me
@test f_return_delete_me() == 1
Base.delete_binding(@__MODULE__, :delete_me)
@test_throws UndefVarError f_return_delete_me()
end

0 comments on commit cfd3922

Please sign in to comment.