Skip to content

Commit

Permalink
Add Random.AbstractRNG type annotations (fixing dot_tilde_assume ambi…
Browse files Browse the repository at this point in the history
…guity)
  • Loading branch information
penelopeysm committed Jan 10, 2025
1 parent bc93fe4 commit 1bf5497
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
22 changes: 14 additions & 8 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ function tilde_observe!!(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end

function assume(rng, spl::Sampler, dist)
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)

Check warning on line 198 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L198

Added line #L198 was not covered by tests
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

Expand Down Expand Up @@ -291,14 +291,18 @@ end
function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi)
return dot_assume(right, left, vns, vi)
end
function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi)
function dot_tilde_assume(

Check warning on line 294 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L294

Added line #L294 was not covered by tests
::IsLeaf, rng::Random.AbstractRNG, ::AbstractContext, sampler, right, left, vns, vi
)
return dot_assume(rng, sampler, right, vns, left, vi)
end

function dot_tilde_assume(::IsParent, context::AbstractContext, args...)
return dot_tilde_assume(childcontext(context), args...)
end
function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...)
function dot_tilde_assume(

Check warning on line 303 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L303

Added line #L303 was not covered by tests
::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args...
)
return dot_tilde_assume(rng, childcontext(context), args...)
end

Expand Down Expand Up @@ -371,7 +375,7 @@ function dot_assume(
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dist::MultivariateDistribution,
vns::AbstractVector{<:VarName},
Expand Down Expand Up @@ -404,7 +408,7 @@ function dot_assume(
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dists::Union{Distribution,AbstractArray{<:Distribution}},
vns::AbstractArray{<:VarName},
Expand All @@ -416,7 +420,9 @@ function dot_assume(
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
return r, lp, vi
end
function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any)
function dot_assume(

Check warning on line 423 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L423

Added line #L423 was not covered by tests
rng::Random.AbstractRNG, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any
)
return error(
"[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement"
)
Expand All @@ -436,7 +442,7 @@ function _maybe_invlink_broadcast(vi, vn, dist)
end

function get_and_set_val!(
rng,
rng::Random.AbstractRNG,
vi::VarInfoOrThreadSafeVarInfo,
vns::AbstractVector{<:VarName},
dist::MultivariateDistribution,
Expand Down Expand Up @@ -478,7 +484,7 @@ function get_and_set_val!(
end

function get_and_set_val!(
rng,
rng::Random.AbstractRNG,
vi::VarInfoOrThreadSafeVarInfo,
vns::AbstractArray{<:VarName},
dists::Union{Distribution,AbstractArray{<:Distribution}},
Expand Down
4 changes: 2 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ function assume(
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dists::Union{Distribution,AbstractArray{<:Distribution}},
vns::AbstractArray{<:VarName},
Expand All @@ -529,7 +529,7 @@ function dot_assume(
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dist::MultivariateDistribution,
vns::AbstractVector{<:VarName},
Expand Down

0 comments on commit 1bf5497

Please sign in to comment.