Skip to content

Commit

Permalink
Fix condition to skip final reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
KeitaNakamura committed Oct 25, 2024
1 parent 5125f84 commit 205308c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/TensorCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,13 @@ _squash_right(B::AbstractArray, ::Val{N}) where {N} = reshape(B, :,prod(size(B)[
_squash_right(B::AbstractVecOrMat, ::Val{1}) = B

function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}, ::Val{K}) where {T,N,S,M,K}
N-K 1 && M-K 1 && N+M-2K 2 && return AB # These can skip final reshape
N == M == K+1 && return AB # These can skip final reshape
ax = ntuple(i -> iN-K ? axes(A, i) : axes(B, i-N+2K), Val(N+M-2K))
reshape(AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
end

# These can skip final reshape:
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, ::Val) = AB
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, ::Val{1}) = AB

# These produce scalar output:
function boxdot(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,N}, ::Val{N}) where {N}
Expand Down

0 comments on commit 205308c

Please sign in to comment.