diff --git a/ext/InferOptFrankWolfeExt.jl b/ext/InferOptFrankWolfeExt.jl index cdb8cff..cc9b06b 100644 --- a/ext/InferOptFrankWolfeExt.jl +++ b/ext/InferOptFrankWolfeExt.jl @@ -1,12 +1,11 @@ module InferOptFrankWolfeExt using DifferentiableFrankWolfe: DiffFW, LinearMaximizationOracleWithKwargs -using InferOpt: InferOpt, RegularizedGeneric, FixedAtomsProbabilityDistribution +using InferOpt: + InferOpt, Regularized, FixedAtomsProbabilityDistribution, FrankWolfeOptimizer using InferOpt: compute_expectation, compute_probability_distribution using LinearAlgebra: dot -## Forward pass - function InferOpt.compute_probability_distribution( dfw::DiffFW, θ::AbstractArray; frank_wolfe_kwargs=NamedTuple() ) @@ -23,17 +22,23 @@ Construct a `DifferentiableFrankWolfe.DiffFW` struct and call `compute_probabili Keyword arguments are passed to the underlying linear maximizer. """ function InferOpt.compute_probability_distribution( - regularized::RegularizedGeneric, θ::AbstractArray; kwargs... + optimizer::FrankWolfeOptimizer, θ::AbstractArray; kwargs... ) - (; maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized + (; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = optimizer f(y, θ) = Ω(y) - dot(θ, y) f_grad1(y, θ) = Ω_grad(y) - θ - lmo = LinearMaximizationOracleWithKwargs(maximizer, kwargs) + lmo = LinearMaximizationOracleWithKwargs(linear_maximizer, kwargs) dfw = DiffFW(f, f_grad1, lmo) probadist = compute_probability_distribution(dfw, θ; frank_wolfe_kwargs) return probadist end +function InferOpt.compute_probability_distribution( + regularized::Regularized{<:FrankWolfeOptimizer}, θ::AbstractArray; kwargs... +) + return compute_probability_distribution(regularized.optimizer, θ; kwargs...) +end + """ (regularized::RegularizedGeneric)(θ; kwargs...) @@ -41,8 +46,8 @@ Apply `compute_probability_distribution(regularized, θ)` and return the expecta Keyword arguments are passed to the underlying linear maximizer. """ -function (regularized::RegularizedGeneric)(θ::AbstractArray; kwargs...) - probadist = compute_probability_distribution(regularized, θ; kwargs...) +function (optimizer::FrankWolfeOptimizer)(θ::AbstractArray; kwargs...) + probadist = compute_probability_distribution(optimizer, θ; kwargs...) return compute_expectation(probadist) end diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 752236b..9072974 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -17,11 +17,11 @@ include("plus_identity/plus_identity.jl") include("interpolation/interpolation.jl") -include("regularized/isregularized.jl") include("regularized/regularized_utils.jl") include("regularized/soft_argmax.jl") include("regularized/sparse_argmax.jl") -include("regularized/regularized_generic.jl") +include("regularized/regularized.jl") +include("regularized/frank_wolfe_optimizer.jl") include("perturbed/abstract_perturbed.jl") include("perturbed/additive.jl") @@ -54,9 +54,8 @@ export Interpolation export half_square_norm export shannon_entropy, negative_shannon_entropy export one_hot_argmax, ranking -export IsRegularized export soft_argmax, sparse_argmax -export RegularizedGeneric +export Regularized export PerturbedAdditive export PerturbedMultiplicative @@ -71,4 +70,8 @@ export StructuredSVMLoss export ImitationLoss, get_y_true +export SparseArgmax, SoftArgmax + +export RegularizedFrankWolfe + end diff --git a/src/fenchel_young/fenchel_young.jl b/src/fenchel_young/fenchel_young.jl index 6b4a8a4..2d29a1c 100644 --- a/src/fenchel_young/fenchel_young.jl +++ b/src/fenchel_young/fenchel_young.jl @@ -24,9 +24,9 @@ function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwarg return l end -@traitfn function fenchel_young_loss_and_grad( +function fenchel_young_loss_and_grad( fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {P; IsRegularized{P}} +) where {P<:Regularized} (; predictor) = fyl ŷ = predictor(θ; kwargs...) Ωy_true = compute_regularization(predictor, y_true) diff --git a/src/regularized/regularized_generic.jl b/src/regularized/frank_wolfe_optimizer.jl similarity index 57% rename from src/regularized/regularized_generic.jl rename to src/regularized/frank_wolfe_optimizer.jl index 9d41a7b..9229912 100644 --- a/src/regularized/regularized_generic.jl +++ b/src/regularized/frank_wolfe_optimizer.jl @@ -1,5 +1,5 @@ """ - RegularizedGeneric{M,RF,RG} + FrankWolfeOptimizer{M,RF,RG,FWK} Differentiable regularized prediction function `ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)}`. @@ -9,7 +9,7 @@ Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytop Since this is a conditional dependency, you need to run `import DifferentiableFrankWolfe` before using `RegularizedGeneric`. # Fields -- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C` +- `linear_maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C` - `Ω::RF`: regularization function `Ω(y)` - `Ω_grad::RG`: gradient of the regularization function `∇Ω(y)` - `frank_wolfe_kwargs::FWK`: keyword arguments passed to the Frank-Wolfe algorithm @@ -32,30 +32,14 @@ Some values you can tune: See the documentation of FrankWolfe.jl for details. """ -struct RegularizedGeneric{M,RF,RG,FWK} - maximizer::M +struct FrankWolfeOptimizer{M,RF,RG,FWK} + linear_maximizer::M Ω::RF Ω_grad::RG frank_wolfe_kwargs::FWK end -function Base.show(io::IO, regularized::RegularizedGeneric) - (; maximizer, Ω, Ω_grad) = regularized - return print(io, "RegularizedGeneric($maximizer, $Ω, $Ω_grad)") -end - -@traitimpl IsRegularized{RegularizedGeneric} - -function compute_regularization(regularized::RegularizedGeneric, y::AbstractArray) - return regularized.Ω(y) -end - -""" - (regularized::RegularizedGeneric)(θ; kwargs...) - -Apply `compute_probability_distribution(regularized, θ, kwargs...)` and return the expectation. -""" -function (regularized::RegularizedGeneric)(θ::AbstractArray; kwargs...) - probadist = compute_probability_distribution(regularized, θ; kwargs...) - return compute_expectation(probadist) +function Base.show(io::IO, optimizer::FrankWolfeOptimizer) + (; linear_maximizer, Ω, Ω_grad) = optimizer + return print(io, "RegularizedGeneric($linear_maximizer, $Ω, $Ω_grad)") end diff --git a/src/regularized/isregularized.jl b/src/regularized/isregularized.jl deleted file mode 100644 index 6f16184..0000000 --- a/src/regularized/isregularized.jl +++ /dev/null @@ -1,23 +0,0 @@ -""" - IsRegularized{P} - -Trait-based interface for regularized prediction functions `ŷ(θ) = argmax {θᵀy - Ω(y)}`. - -For `predictor::P` to comply with this interface, the following methods must exist: -- `(predictor)(θ)` -- `compute_regularization(predictor, y)` - -# Available implementations -- [`one_hot_argmax`](@ref) -- [`soft_argmax`](@ref) -- [`sparse_argmax`](@ref) -- [`RegularizedGeneric`](@ref) -""" -@traitdef IsRegularized{P} - -""" - compute_regularization(predictor::P, y) - -Compute the convex regularization function `Ω(y)`. -""" -function compute_regularization end diff --git a/src/regularized/regularized.jl b/src/regularized/regularized.jl new file mode 100644 index 0000000..5e2ea92 --- /dev/null +++ b/src/regularized/regularized.jl @@ -0,0 +1,46 @@ +""" +optimizer: θ ⟼ argmax θᵀy - Ω(y) +""" +struct Regularized{O,R} + Ω::R + optimizer::O +end + +function Base.show(io::IO, regularized::Regularized) + (; optimizer, Ω) = regularized + return print(io, "Regularized($optimizer, $Ω)") +end + +function (regularized::Regularized)(θ::AbstractArray; kwargs...) + return regularized.optimizer(θ; kwargs...) +end + +function compute_regularization(regularized::Regularized, y::AbstractArray) + return regularized.Ω(y) +end + +# Specific constructors + +""" +TODO +""" +function SparseArgmax() + return Regularized(sparse_argmax_regularization, sparse_argmax) +end + +""" +TODO +""" +function SoftArgmax() + return Regularized(soft_argmax_regularization, soft_argmax) +end + +""" +TODO +""" +function RegularizedFrankWolfe(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs=NamedTuple()) + # TODO : add a warning if DifferentiableFrankWolfe is not imported ? + return Regularized( + Ω, FrankWolfeOptimizer(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) + ) +end diff --git a/src/regularized/soft_argmax.jl b/src/regularized/soft_argmax.jl index 21864e0..d235b8f 100644 --- a/src/regularized/soft_argmax.jl +++ b/src/regularized/soft_argmax.jl @@ -10,8 +10,12 @@ function soft_argmax(z::AbstractVector; kwargs...) return s end -@traitimpl IsRegularized{typeof(soft_argmax)} +# @traitimpl IsRegularized{typeof(soft_argmax)} -function compute_regularization(::typeof(soft_argmax), y::AbstractVector{R}) where {R<:Real} +# function compute_regularization(::typeof(soft_argmax), y::AbstractVector{R}) where {R<:Real} +# return isprobadist(y) ? negative_shannon_entropy(y) : typemax(R) +# end + +function soft_argmax_regularization(y::AbstractVector) return isprobadist(y) ? negative_shannon_entropy(y) : typemax(R) end diff --git a/src/regularized/sparse_argmax.jl b/src/regularized/sparse_argmax.jl index 56f5dad..260ecae 100644 --- a/src/regularized/sparse_argmax.jl +++ b/src/regularized/sparse_argmax.jl @@ -10,11 +10,15 @@ function sparse_argmax(z::AbstractVector; kwargs...) return p end -@traitimpl IsRegularized{typeof(sparse_argmax)} +# @traitimpl IsRegularized{typeof(sparse_argmax)} -function compute_regularization( - ::typeof(sparse_argmax), y::AbstractVector{R} -) where {R<:Real} +# function compute_regularization( +# ::typeof(sparse_argmax), y::AbstractVector{R} +# ) where {R<:Real} +# return isprobadist(y) ? half_square_norm(y) : typemax(R) +# end + +function sparse_argmax_regularization(y::AbstractVector) return isprobadist(y) ? half_square_norm(y) : typemax(R) end diff --git a/test/argmax.jl b/test/argmax.jl index 977fbec..71a610d 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -52,7 +52,7 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=one_hot_argmax, - maximizer=sparse_argmax, + maximizer=SparseArgmax(), loss=mse, error_function=hamming_distance, ) @@ -67,7 +67,7 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=one_hot_argmax, - maximizer=soft_argmax, + maximizer=SoftArgmax(), loss=mse, error_function=hamming_distance, ) @@ -112,7 +112,7 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=one_hot_argmax, - maximizer=RegularizedGeneric( + maximizer=RegularizedFrankWolfe( one_hot_argmax, half_square_norm, identity, @@ -133,7 +133,7 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=FenchelYoungLoss(sparse_argmax), + loss=FenchelYoungLoss(SparseArgmax()), error_function=hamming_distance, ) end @@ -148,7 +148,7 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=FenchelYoungLoss(soft_argmax), + loss=FenchelYoungLoss(SoftArgmax()), error_function=hamming_distance, ) end @@ -194,7 +194,7 @@ end true_maximizer=one_hot_argmax, maximizer=identity, loss=FenchelYoungLoss( - RegularizedGeneric( + RegularizedFrankWolfe( one_hot_argmax, half_square_norm, identity, @@ -259,7 +259,7 @@ end true_maximizer=one_hot_argmax, maximizer=identity, loss=Pushforward( - RegularizedGeneric( + RegularizedFrankWolfe( one_hot_argmax, half_square_norm, identity, diff --git a/test/imitation_loss.jl b/test/imitation_loss.jl index ed60aa4..ec0792b 100644 --- a/test/imitation_loss.jl +++ b/test/imitation_loss.jl @@ -63,7 +63,7 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=FenchelYoungLoss(sparse_argmax), + loss=FenchelYoungLoss(SparseArgmax()), error_function=hamming_distance, true_encoder, verbose=false, @@ -98,7 +98,7 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=FenchelYoungLoss(soft_argmax), + loss=FenchelYoungLoss(SoftArgmax()), error_function=hamming_distance, true_encoder, verbose=false, diff --git a/test/paths.jl b/test/paths.jl index 936d09e..52a06f1 100644 --- a/test/paths.jl +++ b/test/paths.jl @@ -88,7 +88,7 @@ end ) end -@testitem "Paths - imit - MSE RegularizedGeneric" default_imports = false begin +@testitem "Paths - imit - MSE RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -97,7 +97,7 @@ end PipelineLossImitation; instance_dim=(5, 5), true_maximizer=shortest_path_maximizer, - maximizer=RegularizedGeneric( + maximizer=RegularizedFrankWolfe( shortest_path_maximizer, half_square_norm, identity, @@ -143,7 +143,7 @@ end ) end -@testitem "Paths - imit - FYL RegularizedGeneric" default_imports = false begin +@testitem "Paths - imit - FYL RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -154,7 +154,7 @@ end true_maximizer=shortest_path_maximizer, maximizer=identity, loss=FenchelYoungLoss( - RegularizedGeneric( + RegularizedFrankWolfe( shortest_path_maximizer, half_square_norm, identity, @@ -210,7 +210,7 @@ end ) end -@testitem "Paths - exp - Pushforward RegularizedGeneric" default_imports = false begin +@testitem "Paths - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random @@ -224,7 +224,7 @@ end true_maximizer=shortest_path_maximizer, maximizer=identity, loss=Pushforward( - RegularizedGeneric( + RegularizedFrankWolfe( shortest_path_maximizer, half_square_norm, identity, diff --git a/test/ranking.jl b/test/ranking.jl index b3f32c0..9640bb2 100644 --- a/test/ranking.jl +++ b/test/ranking.jl @@ -88,7 +88,7 @@ end ) end -@testitem "Ranking - imit - MSE RegularizedGeneric" default_imports = false begin +@testitem "Ranking - imit - MSE RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -97,7 +97,7 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=ranking, - maximizer=RegularizedGeneric( + maximizer=RegularizedFrankWolfe( ranking, half_square_norm, identity, @@ -139,7 +139,7 @@ end ) end -@testitem "Ranking - imit - FYL RegularizedGeneric" default_imports = false begin +@testitem "Ranking - imit - FYL RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -150,7 +150,7 @@ end true_maximizer=ranking, maximizer=identity, loss=FenchelYoungLoss( - RegularizedGeneric( + RegularizedFrankWolfe( ranking, half_square_norm, identity, @@ -202,7 +202,7 @@ end ) end -@testitem "Ranking - exp - Pushforward RegularizedGeneric" default_imports = false begin +@testitem "Ranking - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random @@ -216,7 +216,7 @@ end true_maximizer=ranking, maximizer=identity, loss=Pushforward( - RegularizedGeneric( + RegularizedFrankWolfe( ranking, half_square_norm, identity,