Skip to content

Commit

Permalink
broadcast: align ndims implementation with intent behind code
Browse files Browse the repository at this point in the history
The `N<:Integer` constraint was nonsensical, given that
`(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022:
#44061 (comment)

Follow up on #44061. Also xref #45477.
  • Loading branch information
nsajko committed Jan 8, 2025
1 parent 8bf2802 commit d09703e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
16 changes: 14 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,21 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
end

Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, :args))
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims_broadcasted(BC)
function Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N}
if Any <: N
_maxndims_broadcasted(BC)
else
let n = N::Int
NTuple{n} # throw if negative
n
end
end
end

function _maxndims_broadcasted(BC::Type{<:Broadcasted})
_maxndims(fieldtype(BC, :args))
end
_maxndims(::Type{T}) where {T<:Tuple} = reduce(max, ntuple(n -> (F = fieldtype(T, n); F <: Tuple ? 1 : ndims(F)), Base._counttuple(T)))
_maxndims(::Type{<:Tuple{T}}) where {T} = T <: Tuple ? 1 : ndims(T)
function _maxndims(::Type{<:Tuple{T, S}}) where {T, S}
Expand Down
2 changes: 2 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@ let

@test @inferred(Base.IteratorSize(Broadcast.broadcasted(+, (1,2,3), a1, zeros(3,3,3)))) === Base.HasShape{3}()

@test @inferred(Base.IteratorSize(Base.broadcasted(randn))) === Base.HasShape{0}()

# inference on nested
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))
bc_nest = Base.broadcasted(+, bc , bc)
Expand Down

0 comments on commit d09703e

Please sign in to comment.