Skip to content

Commit

Permalink
[TTFS] clean interior-point kernels to improve inference (#228)
Browse files Browse the repository at this point in the history
- use @inbounds @simd in a systematic way before for loops
- fix small broadcast calls in kernels.jl
- rewrite broadcast operators in print_init to ease compiler job
  • Loading branch information
frapac authored Oct 7, 2022
1 parent 44cddcc commit 8260ac9
Show file tree
Hide file tree
Showing 10 changed files with 427 additions and 269 deletions.
8 changes: 5 additions & 3 deletions src/IPM/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ function eval_grad_f_wrapper!(solver::MadNLPSolver, f::Vector{T},x::Vector{T}) w
nlp = solver.nlp
cnt = solver.cnt
@trace(solver.logger,"Evaluating objective gradient.")
obj_scaling = solver.obj_scale[] * (get_minimize(nlp) ? one(T) : -one(T))
x_nlpmodel = _madnlp_unsafe_wrap(x, get_nvar(nlp))
f_nlpmodel = _madnlp_unsafe_wrap(f, get_nvar(nlp))
cnt.eval_function_time += @elapsed grad!(
nlp,
x_nlpmodel,
f_nlpmodel
)
f.*=solver.obj_scale[] * (get_minimize(nlp) ? 1. : -1.)
_scal!(obj_scaling, f)
cnt.obj_grad_cnt+=1
cnt.obj_grad_cnt==1 && (is_valid(f) || throw(InvalidNumberException(:grad)))
return f
Expand All @@ -38,8 +39,8 @@ function eval_cons_wrapper!(solver::MadNLPSolver, c::Vector{T},x::Vector{T}) whe
c_nlpmodel
)
view(c,solver.ind_ineq).-=view(x,get_nvar(nlp)+1:solver.n)
c.-=solver.rhs
c.*=solver.con_scale
c .-= solver.rhs
c .*= solver.con_scale
cnt.con_cnt+=1
cnt.con_cnt==2 && (is_valid(c) || throw(InvalidNumberException(:cons)))
return c
Expand Down Expand Up @@ -124,3 +125,4 @@ function eval_lag_hess_wrapper!(solver::MadNLPSolver, kkt::AbstractDenseKKTSyste
cnt.lag_hess_cnt==1 && (is_valid(hess) || throw(InvalidNumberException(:hess)))
return hess
end

2 changes: 1 addition & 1 deletion src/IPM/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function solve_refine_wrapper!(
xy = view(full(x), kkt.ind_eq_shifted)
xz = view(full(x), kkt.ind_ineq_shifted)

v_c .= 0.0
fill!(v_c, zero(T))
v_c[kkt.ind_ineq] .= (Σs .* bz .+ α .* bs) ./ α.^2
jtprod!(jv_t, kkt, v_c)
# init right-hand-side
Expand Down
Loading

0 comments on commit 8260ac9

Please sign in to comment.