From e5189a8ca47ef9c410a76934e24a364d2ba44adc Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Sun, 20 Nov 2022 08:39:45 -0500 Subject: [PATCH] handle many args better --- Project.toml | 2 +- src/broadcast.jl | 19 ++++------- src/condense_loopset.jl | 68 ++++++++++++++++++++++---------------- src/reconstruct_loopset.jl | 65 +++++++++++++++++++++++++++++++++--- 4 files changed, 109 insertions(+), 45 deletions(-) diff --git a/Project.toml b/Project.toml index f4c6ae341..d25081764 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LoopVectorization" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" authors = ["Chris Elrod "] -version = "0.12.140" +version = "0.12.141" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/broadcast.jl b/src/broadcast.jl index a550c5e1b..2602b5351 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -44,10 +44,9 @@ end @inline ArrayInterface.device(::LowDimArray) = ArrayInterface.CPUPointer() @generated function ArrayInterface.size(A::LowDimArray{D,T,N}) where {D,T,N} t = Expr(:tuple) - gf = GlobalRef(Core, :getfield) for n ∈ 1:N if n > length(D) || D[n] - push!(t.args, Expr(:call, gf, :s, n, false)) + push!(t.args, Expr(:call, getfield, :s, n)) else push!(t.args, Expr(:call, Expr(:curly, lv(:StaticInt), 1))) end @@ -64,10 +63,9 @@ ArrayInterface.offsets(A::LowDimArray) = ArrayInterface.offsets(parent(A)) @generated function _lowdimfilter(::Val{D}, tup::Tuple{Vararg{Any,N}}) where {D,N} t = Expr(:tuple) - gf = GlobalRef(Core, :getfield) for n ∈ 1:N if n > length(D) || D[n] - push!(t.args, Expr(:call, gf, :tup, n, false)) + push!(t.args, Expr(:call, getfield, :tup, n)) end end Expr(:block, Expr(:meta, :inline), t) @@ -178,7 +176,6 @@ function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Ve N = length(R) q = Expr(:block, Expr(:meta, :inline)) strd_tup = Expr(:tuple) - gf = GlobalRef(Core, :getfield) ifel = GlobalRef(Core, :ifelse) Nrange = 1:1:N # type stability w/ respect to reverse use_stride_acc = true @@ -207,7 +204,7 @@ function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Ve elseif stride_acc ≠ 0 push!(strd_tup.args, staticexpr(stride_acc)) else - push!(strd_tup.args, :($gf(x, $n, false))) + push!(strd_tup.args, :($getfield(x, $n))) end else if xₙ_static @@ -217,7 +214,7 @@ function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Ve else push!( strd_tup.args, - :($ifel(isone($gf(s, $n, false)), zero($xₙ_type), $gf(x, $n, false))), + :($ifel(isone($getfield(s, $n)), zero($xₙ_type), $getfield(x, $n))), ) end end @@ -326,10 +323,9 @@ function add_broadcast!( Klen = gensym!(ls, "K") mA = gensym!(ls, "Aₘₖ") mB = gensym!(ls, "Bₖₙ") - gf = GlobalRef(Core, :getfield) pushprepreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a)))) pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b)))) - pushprepreamble!(ls, Expr(:(=), Klen, Expr(:call, gf, Expr(:call, :size, mB), 1, false))) + pushprepreamble!(ls, Expr(:(=), Klen, Expr(:call, getfield, Expr(:call, :size, mB), 1))) pushpreamble!(ls, Expr(:(=), Krange, Expr(:call, :(:), staticexpr(1), Klen))) k = gensym!(ls, "k") add_loop!(ls, Loop(k, 1, Klen, 1, Krange, Klen), k) @@ -481,10 +477,9 @@ function add_broadcast!( parents = Operation[] deps = Symbol[] # reduceddeps = Symbol[] - gf = GlobalRef(Core, :getfield) for (i, arg) ∈ enumerate(args) argname = gensym!(ls, "arg") - pushprepreamble!(ls, Expr(:(=), argname, Expr(:call, gf, bcargs, i, false))) + pushprepreamble!(ls, Expr(:(=), argname, Expr(:call, getfield, bcargs, i))) # dynamic dispatch parent = add_broadcast!( ls, @@ -539,7 +534,7 @@ end ::Val{UNROLL}, ::Val{dontbc}, ) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc} - # 2 + 1 + 2 + 1 # we have an N dimensional loop. # need to construct the LoopSet ls = LoopSet(Mod) diff --git a/src/condense_loopset.jl b/src/condense_loopset.jl index 36f6aa23a..6ebabe919 100644 --- a/src/condense_loopset.jl +++ b/src/condense_loopset.jl @@ -4,11 +4,10 @@ Base.:|(u::Unsigned, it::IndexType) = u | UInt8(it) Base.:(==)(u::Unsigned, it::IndexType) = (u % UInt8) == UInt8(it) function _append_fields!(t::Expr, body::Expr, sym::Symbol, ::Type{T}) where {T} - gf = GlobalRef(Core, :getfield) for f ∈ 1:fieldcount(T) TF = fieldtype(T, f) Base.issingletontype(TF) && continue - gfcall = Expr(:call, gf, sym, f) + gfcall = Expr(:call, getfield, sym, f) if fieldcount(TF) ≡ 0 push!(t.args, gfcall) elseif TF <: DataType @@ -37,16 +36,15 @@ end body end function rebuild_fields(offset::Int, ::Type{T}) where {T} - gf = GlobalRef(Core, :getfield) call = (T <: Tuple) ? Expr(:tuple) : Expr(:new, T) for f ∈ 1:fieldcount(T) TF = fieldtype(T, f) if Base.issingletontype(TF) push!(call.args, TF.instance) elseif fieldcount(TF) ≡ 0 - push!(call.args, Expr(:call, gf, :t, (offset += 1), false)) + push!(call.args, Expr(:call, getfield, :t, (offset += 1))) elseif TF <: DataType - push!(call.args, Expr(:call, lv(:gettype), Expr(:call, gf, :t, (offset += 1), false))) + push!(call.args, Expr(:call, lv(:gettype), Expr(:call, getfield, :t, (offset += 1)))) else arg, offset = rebuild_fields(offset, TF) push!(call.args, arg) @@ -58,9 +56,9 @@ end if Base.issingletontype(T) return T.instance elseif fieldcount(T) ≡ 0 - call = Expr(:call, GlobalRef(Core, :getfield), :t, 1, false) + call = Expr(:call, getfield, :t, 1) elseif T <: DataType - call = Expr(:call, lv(:gettype), Expr(:call, GlobalRef(Core, :getfield), :t, 1, false)) + call = Expr(:call, lv(:gettype), Expr(:call, getfield, :t, 1)) else call, _ = rebuild_fields(0, T) end @@ -377,10 +375,10 @@ val(x) = Expr(:call, Expr(:curly, :Val, x)) quote $(Expr(:meta, :inline)) p, li = - VectorizationBase.tdot(x, (vsub_nsw(getfield(i, 1, false), one($I)),), strides(x)) + VectorizationBase.tdot(x, (vsub_nsw(getfield(i, 1), one($I)),), strides(x)) ptr = gep(p, li) si = ArrayInterface.StrideIndex{1,$(R[ri],),$(C === 1 ? 1 : 0)}( - (getfield(strides(x), $ri, false),), + (getfield(strides(x), $ri),), (Zero(),), ) stridedpointer(ptr, si, StaticInt{$(B === 1 ? 1 : 0)}()) @@ -394,8 +392,8 @@ end quote $(Expr(:meta, :inline)) si = ArrayInterface.StrideIndex{1,$(R[ri],),$(C === 1 ? 1 : 0)}( - (getfield(strides(x), $ri, false),), - (getfield(offsets(x), $ri, false),), + (getfield(strides(x), $ri),), + (getfield(offsets(x), $ri),), ) stridedpointer(pointer(x), si, StaticInt{$(B == 1 ? 1 : 0)}()) end @@ -550,7 +548,7 @@ function add_grouped_strided_pointer!(extra_args::Expr, ls::LoopSet) push!(gsp.args, val(matcheddims)) gsps = gensym!(ls, "#grouped#strided#pointer#") push!(extra_args.args, gsps) - pushpreamble!(ls, Expr(:(=), gsps, Expr(:call, GlobalRef(Core, :getfield), gsp, 1))) + pushpreamble!(ls, Expr(:(=), gsps, Expr(:call, getfield, gsp, 1))) preserve, shouldindbyind, roots end @@ -802,21 +800,10 @@ function generate_call_types( argmeta = argmeta_and_consts_description(ls, arraysymbolinds) loop_bounds = loop_boundaries(ls, shouldindbyind) loop_syms = tuple_expr(QuoteNode, ls.loopsymbols) - func = debug ? lv(:_turbo_loopset_debug) : lv(:_turbo_!) lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread, warncheckarg, safe) unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL) - q = Expr( - :call, - func, - unroll_param_tup, - val(operation_descriptions), - val(arrayref_descriptions), - val(argmeta), - val(loop_syms), - ) - add_reassigned_syms!(extra_args, ls) # counterpart to `add_ops!` constants for (opid, sym) ∈ ls.preamble_symsym # counterpart to process_metadata! symsym extraction if instruction(ops[opid]) ≠ DROPPEDCONSTANT @@ -826,17 +813,42 @@ function generate_call_types( append!(extra_args.args, arraysymbolinds) # add_array_symbols! add_external_functions!(extra_args, ls) # extract_external_functions! add_outerreduct_types!(extra_args, ls) # extract_outerreduct_types! - if debug - vecwidthdefq = Expr(:block) + argcestimate = length(extra_args.args) - 1 + for ref = ls.refs_aliasing_syms + argcestimate += length(ref.loopedindex) + end + manyarg = !debug && (argcestimate > 16) + func = debug ? lv(:_turbo_loopset_debug) : (manyarg ? lv(:_turbo_manyarg!) : lv(:_turbo_!)) + q = Expr( + :call, + func, + unroll_param_tup, + val(operation_descriptions), + val(arrayref_descriptions), + val(argmeta), + val(loop_syms), + ) + vecwidthdefq = if debug push!(q.args, Expr(:tuple, lbarg, extra_args)) + Expr(:block) else vargsym = gensym(:vargsym) - vecwidthdefq = Expr(:block, Expr(:(=), vargsym, Expr(:tuple, lbarg, extra_args))) push!( q.args, - Expr(:call, GlobalRef(Base, :Val), Expr(:call, GlobalRef(Base, :typeof), vargsym)), - Expr(:(...), Expr(:call, lv(:flatten_to_tuple), vargsym)), + Expr(:call, GlobalRef(Base, :Val), Expr(:call, GlobalRef(Base, :typeof), vargsym)) ) + if manyarg + push!( + q.args, + Expr(:call, lv(:flatten_to_tuple), vargsym), + ) + else + push!( + q.args, + Expr(:(...), Expr(:call, lv(:flatten_to_tuple), vargsym)), + ) + end + Expr(:block, Expr(:(=), vargsym, Expr(:tuple, lbarg, extra_args))) end define_eltype_vec_width!(vecwidthdefq, ls, nothing, true) push!(vecwidthdefq.args, q) diff --git a/src/reconstruct_loopset.jl b/src/reconstruct_loopset.jl index 7fa0c9354..b3eef51d6 100644 --- a/src/reconstruct_loopset.jl +++ b/src/reconstruct_loopset.jl @@ -121,7 +121,7 @@ function Loop( end -extract_loop(l) = Expr(:call, GlobalRef(Core, :getfield), Symbol("#loop#bounds#"), l, false) +extract_loop(l) = Expr(:call, getfield, Symbol("#loop#bounds#"), l) function add_loops!(ls::LoopSet, LPSYM, LB) n = max(length(LPSYM), length(LB)) @@ -145,7 +145,7 @@ function add_loops!( ssym = String(sym) for k = N:-1:1 axisexpr = - :(getfield(getfield(getfield(var"#loop#bounds#", $i, false), :indices), $k, false)) + :($getfield($getfield($getfield(var"#loop#bounds#", $i), :indices), $k)) add_loop!( ls, Loop(ls, axisexpr, Symbol(ssym * '#' * string(k) * '#'), T.parameters[k])::Loop, @@ -258,7 +258,7 @@ function ArrayReferenceMeta( end -extract_varg(i) = :(getfield(var"#vargs#", $i, false)) +extract_varg(i) = :($getfield(var"#vargs#", $i)) # _extract(::Type{StaticInt{N}}) where {N} = N extract_gsp!(sptrs::Expr, name::Symbol) = (push!(sptrs.args, name); nothing) tupleranks(R::NTuple{8,Int}) = ntuple(n -> sum(R[n] .≥ R), Val{8}()) @@ -319,7 +319,7 @@ function _add_mref!( extract_gsp!(sptrs, tmpsp) strd_tup = Expr(:tuple) offsets_tup = Expr(:tuple) - gf = GlobalRef(Core, :getfield) + gf = getfield offsets = gensym(:offsets) strides = gensym(:strides) pushpreamble!(ls, Expr(:(=), offsets, Expr(:call, lv(:offsets), tmpsp))) @@ -1019,3 +1019,60 @@ Execute an `@turbo` block. The block's code is represented via the arguments: post === ls.preamble ? q : Expr(:block, q, post) # @show var"#UNROLL#", var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#" end +@generated function _turbo_manyarg!( + ::Val{var"#UNROLL#"}, + ::Val{var"#OPS#"}, + ::Val{var"#ARF#"}, + ::Val{var"#AM#"}, + ::Val{var"#LPSYM#"}, + ::Val{Tuple{var"#LB#",var"#V#"}}, + var"#flattened#var#arguments#"::Tuple{Vararg{Any,var"#num#vargs#"}}, +) where { + var"#UNROLL#", + var"#OPS#", + var"#ARF#", + var"#AM#", + var"#LPSYM#", + var"#LB#", + var"#V#", + var"#num#vargs#", +} + 1 + 1 # Irrelevant line you can comment out/in to force recompilation... + ls = _turbo_loopset( + var"#OPS#", + var"#ARF#", + var"#AM#", + var"#LPSYM#", + var"#LB#".parameters, + var"#V#".parameters, + var"#UNROLL#", + ) + pushfirst!( + ls.preamble.args, + :( + var"#lv#tuple#args#" = + reassemble_tuple(Tuple{var"#LB#",var"#V#"}, var"#flattened#var#arguments#") + ), + ) + post = hoist_constant_memory_accesses!(ls) + # q = @show(avx_body(ls, var"#UNROLL#")); post === ls.preamble ? q : Expr(:block, q, post) + q = if (var"#UNROLL#"[10] > 1) && length(var"#LPSYM#") == length(ls.loops) + inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, nt, wca, safe = var"#UNROLL#" + # wrap in `var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#"` in `Expr` to homogenize types + avx_threads_expr( + ls, + (inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, one(UInt), wca, safe), + nt, + :(Val{$(var"#OPS#")}()), + :(Val{$(var"#ARF#")}()), + :(Val{$(var"#AM#")}()), + :(Val{$(var"#LPSYM#")}()), + ) + else + # Main.BODY[] = avx_body(ls, var"#UNROLL#") + # return @show avx_body(ls, var"#UNROLL#") + avx_body(ls, var"#UNROLL#") + end + post === ls.preamble ? q : Expr(:block, q, post) + # @show var"#UNROLL#", var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#" +end