Skip to content

Commit

Permalink
spawnat
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed Jan 29, 2024
1 parent 3701655 commit 00f6625
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Manifest.toml
.vscode
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ versus
julia> Core.Compiler.return_type(() -> fetch(Threads.@spawn 1 + 1), Tuple{})
Any
```

The package also provides `StableTasks.@spawnat` (not exported), which is similar to `StableTasks.@spawn` but creates a *sticky* task (it won't migrate) on a specific thread.

```julia
julia> t = StableTasks.@spawnat 4 Threads.threadid();

julia> fetch(t)
4
```
1 change: 1 addition & 0 deletions src/StableTasks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module StableTasks

macro spawn end
macro spawnat end

using Base: RefValue
struct StableTask{T}
Expand Down
47 changes: 39 additions & 8 deletions src/internals.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Internals
module Internals

import StableTasks: @spawn, StableTask
import StableTasks: @spawn, @spawnat, StableTask

function Base.fetch(t::StableTask{T}) where {T}
fetch(t.t)
Expand All @@ -26,14 +26,12 @@ Base.schedule(t, val; error=false) = (schedule(t.t, val; error); t)


macro spawn(ex)
tp = QuoteNode(:default)

letargs = _lift_one_interp!(ex)

thunk = replace_linenums!(:(()->($(esc(ex)))), __source__)
thunk = replace_linenums!(:(() -> ($(esc(ex)))), __source__)
var = esc(Base.sync_varname) # This is for the @sync macro which sets a local variable whose name is
# the symbol bound to Base.sync_varname
# I asked on slack and this is apparently safe to consider a public API
# the symbol bound to Base.sync_varname
# I asked on slack and this is apparently safe to consider a public API
quote
let $(letargs...)
f = $thunk
Expand All @@ -51,6 +49,39 @@ macro spawn(ex)
end
end

macro spawnat(thrdid, ex)
letargs = _lift_one_interp!(ex)

thunk = replace_linenums!(:(() -> ($(esc(ex)))), __source__)
var = esc(Base.sync_varname)

tid = esc(thrdid)
@static if VERSION < v"1.9"
nt = :(Threads.nthreads())
else
nt = :(Threads.maxthreadid())
end
quote
if $tid < 1 || $tid > $nt
throw(ArgumentError("Invalid thread id ($($tid)). Must be between in " *
"1:(total number of threads), i.e. $(1:$nt)."))
end
let $(letargs...)
thunk = $thunk
RT = Core.Compiler.return_type(thunk, Tuple{})
ret = Ref{RT}()
thunk_wrap = () -> (ret[] = thunk(); nothing)
local task = Task(thunk_wrap)
task.sticky = true
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid - 1)
if $(Expr(:islocal, var))
put!($var, task)
end
schedule(task)
StableTask(task, ret)
end
end
end

# Copied from base rather than calling it directly because who knows if it'll change in the future
function _lift_one_interp!(e)
Expand All @@ -74,7 +105,7 @@ function _lift_one_interp_helper(expr::Expr, in_quote_context, letargs)
elseif expr.head === :macrocall
return expr # Don't recur into macro calls, since some other macros use $
end
for (i,e) in enumerate(expr.args)
for (i, e) in enumerate(expr.args)
expr.args[i] = _lift_one_interp_helper(e, in_quote_context, letargs)
end
expr
Expand Down
24 changes: 22 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
using Test, StableTasks
using StableTasks: @spawn
using StableTasks: @spawn, @spawnat

@testset "Type stability" begin
@test 2 ==@inferred fetch(@spawn 1 + 1)
@test 2 == @inferred fetch(@spawn 1 + 1)
t = @eval @spawn inv([1 2 ; 3 4])
@test inv([1 2 ; 3 4]) == @inferred fetch(t)

@test 2 == @inferred fetch(@spawnat 1 1 + 1)
t = @eval @spawnat 1 inv([1 2 ; 3 4])
@test inv([1 2 ; 3 4]) == @inferred fetch(t)
end

@testset "API funcs" begin
Expand All @@ -22,4 +26,20 @@ end
@test r[] == 0
end
@test r[] == 1

T = @spawnat 1 rand(Bool)
@test isnothing(wait(T))
@test istaskdone(T)
@test istaskfailed(T) == false
@test istaskstarted(T)
@test fetch(@spawnat 1 Threads.threadid()) == 1
r = Ref(0)
@sync begin
@spawnat 1 begin
sleep(5)
r[] = 1
end
@test r[] == 0
end
@test r[] == 1
end

2 comments on commit 00f6625

@MasonProtter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.1.0 already exists

Please sign in to comment.