Skip to content

Commit

Permalink
broadcast: align ndims implementation with intent behind code (#56999)
Browse files Browse the repository at this point in the history
The `N<:Integer` constraint was nonsensical, given that `(N === Any) ||
(N isa Int)`. N5N3 noticed this back in 2022:
#44061 (comment)

Follow up on #44061. Also xref #45477.
  • Loading branch information
nsajko authored Jan 10, 2025
1 parent 6ac351a commit d3964b6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 7 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,14 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
end

Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, :args))
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims_broadcasted(BC)
# the `AbstractArrayStyle` type parameter is required to be either equal to `Any` or be an `Int` value
Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{Any},Nothing}}) = _maxndims_broadcasted(BC)
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N} = N::Int

function _maxndims_broadcasted(BC::Type{<:Broadcasted})
_maxndims(fieldtype(BC, :args))
end
_maxndims(::Type{T}) where {T<:Tuple} = reduce(max, ntuple(n -> (F = fieldtype(T, n); F <: Tuple ? 1 : ndims(F)), Base._counttuple(T)))
_maxndims(::Type{<:Tuple{T}}) where {T} = T <: Tuple ? 1 : ndims(T)
function _maxndims(::Type{<:Tuple{T, S}}) where {T, S}
Expand Down
2 changes: 2 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@ let

@test @inferred(Base.IteratorSize(Broadcast.broadcasted(+, (1,2,3), a1, zeros(3,3,3)))) === Base.HasShape{3}()

@test @inferred(Base.IteratorSize(Base.broadcasted(randn))) === Base.HasShape{0}()

# inference on nested
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))
bc_nest = Base.broadcasted(+, bc , bc)
Expand Down

0 comments on commit d3964b6

Please sign in to comment.