From b9f4f69f1214338c54399fe4e00fa45e10b65e75 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 30 Jun 2023 23:06:09 +0200 Subject: [PATCH] The OVERHAUL (#78) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Complete refactor of the Regularized structures: - Only one struct named `Regularized`, every regularized layer is a particular case of it - Specific constructors for `SparseArgmax`, `SoftArgmax`, and `RegularizedFrankWolfe` - Now we can also use `Regularized` with a custom optimizer (we may need to test this feature) * The OVERHAUL * Fix tutorial in tests, more docs * Bump version * Fix Léo's remarks * Fix imports * typo * Don't export compress_dist --------- Co-authored-by: BatyLeo --- .gitignore | 4 +- CITATION.bib | 4 +- Project.toml | 4 +- README.md | 10 +- docs/Manifest.toml | 713 +++++++++--------- docs/Project.toml | 4 +- docs/make.jl | 7 +- ..._tutorials.md => advanced_applications.md} | 0 docs/src/algorithms.md | 107 --- docs/src/api.md | 25 + docs/src/background.md | 5 +- docs/src/losses.md | 10 + docs/src/optim.md | 10 + examples/tutorial.jl | 178 +++++ ext/InferOptFrankWolfeExt.jl | 35 +- src/InferOpt.jl | 68 +- src/fenchel_young/fenchel_young.jl | 57 -- src/fenchel_young/perturbed.jl | 47 -- src/imitation/fenchel_young_loss.jl | 114 +++ src/imitation/imitation_loss.jl | 84 +++ src/{spo => imitation}/spoplus_loss.jl | 17 +- src/imitation/ssvm_loss.jl | 64 ++ src/imitation/zero_one_loss.jl | 58 ++ src/imitation_loss/imitation_loss.jl | 79 -- src/interface.jl | 50 ++ src/perturbed/abstract_perturbed.jl | 22 +- src/perturbed/additive.jl | 10 +- src/perturbed/multiplicative.jl | 10 +- src/plus_identity/plus_identity.jl | 29 - src/regularized/abstract_regularized.jl | 27 + src/regularized/isregularized.jl | 23 - src/regularized/regularized_frank_wolfe.jl | 65 ++ src/regularized/regularized_generic.jl | 51 -- src/regularized/soft_argmax.jl | 17 +- src/regularized/sparse_argmax.jl | 15 +- src/simple/identity.jl | 29 + .../interpolation.jl | 6 +- src/ssvm/isbaseloss.jl | 20 - src/ssvm/ssvm_loss.jl | 54 -- src/ssvm/zeroone_baseloss.jl | 26 - src/utils/probability_distribution.jl | 40 - src/utils/pushforward.jl | 30 +- .../some_functions.jl} | 0 test/argmax.jl | 54 +- test/code.jl | 2 + test/imitation_loss.jl | 46 +- test/paths.jl | 40 +- test/ranking.jl | 40 +- test/tutorial.jl | 178 +---- 49 files changed, 1302 insertions(+), 1286 deletions(-) rename docs/src/{advanced_tutorials.md => advanced_applications.md} (100%) delete mode 100644 docs/src/algorithms.md create mode 100644 docs/src/api.md create mode 100644 docs/src/losses.md create mode 100644 docs/src/optim.md create mode 100644 examples/tutorial.jl delete mode 100644 src/fenchel_young/fenchel_young.jl delete mode 100644 src/fenchel_young/perturbed.jl create mode 100644 src/imitation/fenchel_young_loss.jl create mode 100644 src/imitation/imitation_loss.jl rename src/{spo => imitation}/spoplus_loss.jl (82%) create mode 100644 src/imitation/ssvm_loss.jl create mode 100644 src/imitation/zero_one_loss.jl delete mode 100644 src/imitation_loss/imitation_loss.jl create mode 100644 src/interface.jl delete mode 100644 src/plus_identity/plus_identity.jl create mode 100644 src/regularized/abstract_regularized.jl delete mode 100644 src/regularized/isregularized.jl create mode 100644 src/regularized/regularized_frank_wolfe.jl delete mode 100644 src/regularized/regularized_generic.jl create mode 100644 src/simple/identity.jl rename src/{interpolation => simple}/interpolation.jl (87%) delete mode 100644 src/ssvm/isbaseloss.jl delete mode 100644 src/ssvm/ssvm_loss.jl delete mode 100644 src/ssvm/zeroone_baseloss.jl rename src/{regularized/regularized_utils.jl => utils/some_functions.jl} (100%) diff --git a/.gitignore b/.gitignore index 0f067eb..11feb24 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,10 @@ *.jl.mem /Manifest.toml /docs/build/ +/docs/src/index.md +/docs/src/tutorial.md /test/profiling.jl - +.vscode *.tar.gz *.zip *.csv diff --git a/CITATION.bib b/CITATION.bib index 223ad9e..8fc31e6 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -2,7 +2,7 @@ @misc{InferOpt.jl author = {Guillaume Dalle, Léo Baty, Louis Bouvier and Axel Parmentier}, title = {InferOpt.jl}, url = {https://github.com/axelparmentier/InferOpt.jl}, - version = {v0.4.0}, - year = {2022}, + version = {v0.5.0}, + year = {2023}, month = {7} } diff --git a/Project.toml b/Project.toml index e0e0705..a1f65f9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,13 @@ name = "InferOpt" uuid = "4846b161-c94e-4150-8dac-c7ae193c601f" authors = ["Guillaume Dalle", "Léo Baty", "Louis Bouvier", "Axel Parmentier"] -version = "0.5.0-DEV" +version = "0.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" @@ -22,7 +21,6 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe" [compat] ChainRulesCore = "1" DifferentiableFrankWolfe = "0.1.2" -SimpleTraits = "0.9" StatsBase = "0.33, 0.34" TestItemRunner = "0.2.2" ThreadsX = "0.1.11" diff --git a/README.md b/README.md index 8453231..cc2eb6f 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ To install the stable version, open a Julia REPL and run the following command: julia> using Pkg; Pkg.add("InferOpt") ``` -To install the development version, run this command instead: +To install the development version (*recommended for now*), run this command instead: ```julia julia> using Pkg; Pkg.add(url="https://github.com/axelparmentier/InferOpt.jl") @@ -33,3 +33,11 @@ julia> using Pkg; Pkg.add(url="https://github.com/axelparmentier/InferOpt.jl") If you use our package in your research, please cite the following paper: > [Learning with Combinatorial Optimization Layers: a Probabilistic Approach](https://arxiv.org/abs/2207.13513) - Guillaume Dalle, Léo Baty, Louis Bouvier and Axel Parmentier (2022) + +## Related packages + +The following libraries implement similar functionalities: + +- [ImplicitDifferentiation.jl](https://github.com/gdalle/ImplicitDifferentiation.jl): automatic differentiation of implicit functions +- [DiffOpt.jl](https://github.com/jump-dev/DiffOpt.jl): differentiating convex optimization programs w.r.t. program parameters +- [JAXopt](https://github.com/google/jaxopt): hardware accelerated, batchable and differentiable optimizers in JAX \ No newline at end of file diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 7ab9c0e..7f68970 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,14 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.8.2" +julia_version = "1.9.1" manifest_format = "2.0" -project_hash = "f148cd77e0cb10305714e924b2b62579b309736a" - -[[deps.AMD]] -deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"] -git-tree-sha1 = "00163dc02b882ca5ec032400b919e5f5011dbd31" -uuid = "14f7f29c-3bd6-536c-9a0b-7339e30b5a3e" -version = "0.5.0" +project_hash = "8b69cefb3ccfc9c2c7fa95c168c6cd682958d99a" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -16,22 +10,29 @@ uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" version = "0.0.1" [[deps.AbstractFFTs]] -deps = ["ChainRulesCore", "LinearAlgebra"] -git-tree-sha1 = "69f7020bd72f069c219b5e8c236c1fa90d2cb409" +deps = ["LinearAlgebra"] +git-tree-sha1 = "8bc0aaec0ca548eb6cf5f0d7d16351650c1ee956" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.2.1" +version = "1.3.2" +weakdeps = ["ChainRulesCore"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" -[[deps.Accessors]] -deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "eb7a1342ff77f4f9b6552605f27fd432745a53a3" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.22" +[[deps.AbstractTrees]] +git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.4" [[deps.Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f" +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.4.0" +version = "3.6.2" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" @@ -48,44 +49,40 @@ git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae" uuid = "ec485272-7323-5ecc-a04f-4719b315124d" version = "0.2.0" -[[deps.Arpack]] -deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"] -git-tree-sha1 = "9b9b347613394885fd1c8c7729bfc60528faa436" -uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" -version = "0.5.4" - -[[deps.Arpack_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "OpenBLAS_jll", "Pkg"] -git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e" -uuid = "68821587-b530-5797-8361-c406ea357684" -version = "3.5.1+1" - -[[deps.ArrayInterface]] -deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "Static"] -git-tree-sha1 = "d6173480145eb632d6571c148d94b9d3d773820e" -uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "6.0.23" - -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "c46fb7dd1d8ca1d213ba25848a5ec4e47a1a1b08" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.26" - [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + [[deps.BFloat16s]] deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.2.0" +version = "0.4.2" [[deps.BangBang]] -deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] -git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca" +deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] +git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.37" +version = "0.3.39" + + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -95,70 +92,64 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" uuid = "9718e550-a3fa-408a-8086-8db961cd8217" version = "0.1.1" -[[deps.BenchmarkTools]] -deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] -git-tree-sha1 = "d9a9701b899b30332bbcb3e1679c41cce81fb0e8" -uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -version = "1.3.2" - -[[deps.Bzip2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" -uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+0" - [[deps.CEnum]] git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.2" [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] -git-tree-sha1 = "49549e2c28ffb9cc77b3689dc10e46e6271e9452" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "442d989978ed3ff4e174c928ee879dc09d1ef693" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.12.0" +version = "4.3.2" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.5.0+1" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.2.2" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "5248d9c45712e51e27ba9b30eebec65658c6ce29" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.6.0+0" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "2918fbffb50e3b7a0b9127617587afa76d4276e8" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "8.8.1+0" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "0c8c8887763f42583e1206ee35413a43c91e2623" +git-tree-sha1 = "1cdf290d4feec68824bfb84f4bfc9f3aba185647" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.45.0" +version = "1.51.1" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb" +git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.6" - -[[deps.ChangesOfVariables]] -deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8" -uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.4" +version = "1.16.0" [[deps.CodeTracking]] deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "cc4bd91eba9cdbbb4df4746124c22c0832a460d6" +git-tree-sha1 = "d730914ef30a06732bdd9f763f6cc32e92ffbff1" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.1.1" - -[[deps.CodecBzip2]] -deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"] -git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7" -uuid = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" -version = "0.7.2" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.0" +version = "1.3.1" [[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Random"] -git-tree-sha1 = "1fd869cc3875b57347f7027521f561cf46d1fcd8" +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "be6ab11021cd29f0344d5c4357b163af05a48cba" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.19.0" +version = "3.21.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -168,15 +159,15 @@ version = "0.11.4" [[deps.ColorVectorSpace]] deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] -git-tree-sha1 = "d08c20eef1f2cbc6e60fd3612ac4340b89fea322" +git-tree-sha1 = "600cc5508d66b78aae350f7accdb58763ac18589" uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.9.9" +version = "0.9.10" [[deps.Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" +git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.8" +version = "0.12.10" [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -185,26 +176,44 @@ uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" [[deps.Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "3ca828fe1b75fa84b021a7860bd039eaea84d2f2" +deps = ["UUIDs"] +git-tree-sha1 = "4e88377ae7ebeaf29a047aa1ee40826e0b708a5d" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.3.0" +version = "4.7.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.5.2+0" +version = "1.0.2+0" [[deps.CompositionsBase]] -git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.1" +version = "0.1.2" + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + + [deps.CompositionsBase.weakdeps] + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] -git-tree-sha1 = "fb21ddd70a051d882a1686a5a550990bbe371a95" +git-tree-sha1 = "738fec4d684a9a6ee9598a8bfee305b26831f28c" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.4.1" +version = "1.5.2" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.ContextVariablesX]] deps = ["Compat", "Logging", "UUIDs"] @@ -223,15 +232,15 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.1.1" [[deps.DataAPI]] -git-tree-sha1 = "e08915633fcb3ea83bf9d6126292e5bc5c739922" +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.13.0" +version = "1.15.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" +git-tree-sha1 = "cf25ccb972fec4e4817764d01c82386ae94f77b4" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.13" +version = "0.18.14" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -249,7 +258,9 @@ version = "0.1.2" [[deps.DelimitedFiles]] deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" [[deps.DiffResults]] deps = ["StaticArraysCore"] @@ -259,9 +270,9 @@ version = "1.1.0" [[deps.DiffRules]] deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "c5b6685d53f933c11404a3ae9822afe30d522494" +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.12.2" +version = "1.15.1" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] @@ -269,15 +280,15 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.DocStringExtensions]] deps = ["LibGit2"] -git-tree-sha1 = "c36550cb29cbe373e95b3f40486b9a4148f89ffd" +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.2" +version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "6030186b00a38e9d0434518627426570aac2ef95" +git-tree-sha1 = "58fea7c536acd71f3eef6be3b21c0df5f3df88fd" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.23" +version = "0.27.24" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -285,9 +296,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" [[deps.ExprTools]] -git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" +git-tree-sha1 = "c1d06d129da9f55715c6c212866f5b1bddc5fa00" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.8" +version = "0.1.9" [[deps.FLoops]] deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] @@ -301,25 +312,14 @@ git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" version = "0.1.1" -[[deps.FastClosures]] -git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" -uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -version = "0.3.2" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "7be5f99f7d15578798f338f5433b6c432ea8037b" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.0" - [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "802bfc139833d2ba893dd9e62ba1767c88d708ae" +git-tree-sha1 = "7072f1e3e5a8be51d525d64f63d3ec1287ff2790" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.5" +version = "0.13.11" [[deps.FixedPointNumbers]] deps = ["Statistics"] @@ -328,51 +328,34 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.4" [[deps.Flux]] -deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"] -git-tree-sha1 = "66b62bf72c4b5d4904441ed0677eab53266033c7" +deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote", "cuDNN"] +git-tree-sha1 = "3e2c3704c2173ab4b1935362384ca878b53d4c34" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.7" +version = "0.13.17" -[[deps.FoldsThreads]] -deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] -git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" -uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" -version = "0.1.1" + [deps.Flux.extensions] + AMDGPUExt = "AMDGPU" + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" [[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "10fa12fe96e4d76acfa738f4df2126589a67374f" +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.33" - -[[deps.FrankWolfe]] -deps = ["Arpack", "GenericSchur", "Hungarian", "LinearAlgebra", "MathOptInterface", "Printf", "ProgressMeter", "Random", "Setfield", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "3a9895914d003f1ffba6ef933eb7a1a46e677b84" -uuid = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30" -version = "0.2.15" - -[[deps.FreeType]] -deps = ["CEnum", "FreeType2_jll"] -git-tree-sha1 = "cabd77ab6a6fdff49bfd24af2ebe76e6e018a2b4" -uuid = "b38be410-82b0-50bf-ab77-7b57e271db43" -version = "4.0.0" - -[[deps.FreeType2_jll]] -deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "87eb71354d8ec1a96d4a7636bd57a7347dde3ef9" -uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" -version = "2.10.4+0" - -[[deps.FunctionWrappers]] -git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" -uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -version = "1.1.3" +version = "0.10.35" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "a2657dd0f3e8a61dbe70fc7c122038bd33790af5" +git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.3.0" +version = "0.4.4" [[deps.Future]] deps = ["Random"] @@ -380,68 +363,57 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "45d7deaf05cbb44116ba785d147c518ab46352d7" +git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.5.0" +version = "8.8.1" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "6872f5ec8fd1a38880f027a26739d42dcda6691f" +git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.2" +version = "0.1.5" [[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "76f70a337a153c1632104af19d29023dbb6f30dd" +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "cb090aea21c6ca78d59672a7e7d13bd56d09de64" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.16.6" - -[[deps.GenericSchur]] -deps = ["LinearAlgebra", "Printf"] -git-tree-sha1 = "fb69b2a645fa69ba5f474af09221b9308b160ce6" -uuid = "c145ed77-6b09-5dd9-b285-bf645a82121e" -version = "0.5.3" +version = "0.20.3" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ba2d094a88b6b287bd25cfa86f301e7693ffae2f" +git-tree-sha1 = "1cf1d7dcb4bc32d7b4a5add4232db3750c27ecb4" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.7.4" +version = "1.8.0" [[deps.GridGraphs]] deps = ["DataStructures", "FillArrays", "Graphs", "SparseArrays"] -git-tree-sha1 = "2155496389b0069dd578bb06ed64032b34f8d0b0" +git-tree-sha1 = "858b2a7a7798e649dc5612792969541e3a88379a" uuid = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb" -version = "0.9.0" - -[[deps.Hungarian]] -deps = ["LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "371a7df7a6cce5909d6c576f234a2da2e3fa0c98" -uuid = "e91730f6-4275-51fb-a7a0-7064cfbd3b39" -version = "0.6.0" +version = "0.9.1" [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" +git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.2" +version = "0.2.3" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "2e99184fca5eb6f075944b04c22edec29beb4778" +git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.7" - -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" +version = "0.4.10" [[deps.InferOpt]] -deps = ["ChainRulesCore", "FrankWolfe", "Krylov", "LinearAlgebra", "LinearOperators", "Random", "SimpleTraits", "SparseArrays", "Statistics", "StatsBase", "Test", "ThreadsX"] +deps = ["ChainRulesCore", "LinearAlgebra", "Random", "Statistics", "StatsBase", "ThreadsX"] path = ".." uuid = "4846b161-c94e-4150-8dac-c7ae193c601f" -version = "0.3.1" +version = "0.5.0-DEV" + + [deps.InferOpt.extensions] + InferOptFrankWolfeExt = "DifferentiableFrankWolfe" + + [deps.InferOpt.weakdeps] + DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" [[deps.Inflate]] git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428" @@ -457,16 +429,10 @@ version = "0.3.1" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.8" - [[deps.IrrationalConstants]] -git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.1.1" +version = "0.2.2" [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" @@ -481,15 +447,15 @@ version = "1.4.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.3" +version = "0.21.4" [[deps.JuliaInterpreter]] deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] -git-tree-sha1 = "0f960b1404abb0b244c1ece579a0ec78d056a5d1" +git-tree-sha1 = "6a125e6a4cb391e0b9adbd1afa9e771c2179f8ef" uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -version = "0.9.15" +version = "0.9.23" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -497,29 +463,23 @@ git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" version = "0.2.4" -[[deps.Krylov]] -deps = ["LinearAlgebra", "Printf", "SparseArrays"] -git-tree-sha1 = "92256444f81fb094ff5aa742ed10835a621aef75" -uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" -version = "0.8.4" - -[[deps.LDLFactorizations]] -deps = ["AMD", "LinearAlgebra", "SparseArrays", "Test"] -git-tree-sha1 = "cbf4b646f82bfc58bb48bcca9dcce2eb88da4cd1" -uuid = "40e66cde-538c-5869-a4ad-c39174c6795b" -version = "0.10.0" +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "b48617c5d764908b5fac493cd907cf33cc11eec1" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.6" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "e7e9184b0bf0158ac4e4aa9daf00041b5909bf1a" +git-tree-sha1 = "5007c1421563108110bbd57f63d8ad4565808818" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "4.14.0" +version = "5.2.0" [[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] -git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "1222116d7313cdefecf3d45a2bc1a89c4e7c9217" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.16+0" +version = "0.0.22+0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -548,15 +508,9 @@ version = "1.10.2+0" uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -[[deps.LinearOperators]] -deps = ["FastClosures", "LDLFactorizations", "LinearAlgebra", "Printf", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "088eac0646933c3ee2ae67b966a1e24ff348c49f" -uuid = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" -version = "2.4.1" - [[deps.Literate]] deps = ["Base64", "IOCapture", "JSON", "REPL"] git-tree-sha1 = "1c4418beaa6664041e0f9b48f0710f57bff2fcbe" @@ -564,30 +518,40 @@ uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306" version = "2.14.0" [[deps.LogExpFunctions]] -deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "94d9c52ca447e23eac0c0f074effbcd38830deb5" +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.18" +version = "0.3.24" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoweredCodeUtils]] deps = ["JuliaInterpreter"] -git-tree-sha1 = "dedbebe234e06e1ddad435f5c6f4b85cd8ce55f7" +git-tree-sha1 = "60168780555f3e663c536500aa790b6368adc02a" uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "2.2.2" +version = "2.3.0" [[deps.MLStyle]] -git-tree-sha1 = "43f9be9c281179fe44205e2dc19f22e71e022d41" +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.15" +version = "0.4.17" [[deps.MLUtils]] -deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase", "Transducers"] -git-tree-sha1 = "824e9dfc7509cab1ec73ba77b55a916bb2905e26" +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "3504cdb8c2bc05bde4d4b09a81b01df88fcbbba0" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.2.11" +version = "0.4.3" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -596,68 +560,62 @@ uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.10" [[deps.MarchingCubes]] -deps = ["SnoopPrecompile", "StaticArrays"] -git-tree-sha1 = "ffc66942498a5f0d02b9e7b1b1af0f5873142cdc" +deps = ["PrecompileTools", "StaticArrays"] +git-tree-sha1 = "c8e29e2bacb98c9b6f10445227a8b0402f2f173a" uuid = "299715c1-40a9-479a-aaf9-4a633d36f717" -version = "0.1.4" +version = "0.1.8" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[deps.MathOptInterface]] -deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "DataStructures", "ForwardDiff", "JSON", "LinearAlgebra", "MutableArithmetics", "NaNMath", "OrderedCollections", "Printf", "SparseArrays", "SpecialFunctions", "Test", "Unicode"] -git-tree-sha1 = "ceed48edffe0325a6e9ea00ecf3607af5089c413" -uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" -version = "1.9.0" - [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.0+0" +version = "2.28.2+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] -git-tree-sha1 = "4d5917a26ca33c66c8e5ca3247bd163624d35493" +git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.3" +version = "0.1.4" [[deps.Missings]] deps = ["DataAPI"] -git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.0.2" +version = "1.1.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.2.1" - -[[deps.MutableArithmetics]] -deps = ["LinearAlgebra", "SparseArrays", "Test"] -git-tree-sha1 = "1d57a7dc42d563ad6b5e95d7a8aebd550e5162c0" -uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" -version = "1.0.5" +version = "2022.10.11" [[deps.NNlib]] -deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "00bcfcea7b2063807fdcab2e0ce86ef00b8b8000" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "72240e3f5ca031937bd536182cb2c031da5f46dd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.10" +version = "0.8.21" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" [[deps.NNlibCUDA]] -deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "4429261364c5ea5b7308aecaa10e803ace101631" +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] +git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" -version = "0.2.4" +version = "0.2.7" [[deps.NaNMath]] deps = ["OpenLibm_jll"] -git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211" +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.1" +version = "1.0.2" [[deps.NameResolution]] deps = ["PrettyPrint"] @@ -670,15 +628,15 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" [[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "GPUArrays", "LinearAlgebra", "MLUtils", "NNlib"] -git-tree-sha1 = "2f6efe2f76d57a0ee67cb6eff49b4d02fccbd175" +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "5e4029759e8699ec12ebdf8721e51a659443403c" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.1.0" +version = "0.2.4" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.20+0" +version = "0.3.21+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -693,31 +651,37 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "8a9102cb805df46fc3d6effdc2917f09b0215c0b" +git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.10" +version = "0.2.18" [[deps.OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" +version = "1.6.0" [[deps.Parsers]] -deps = ["Dates", "SnoopPrecompile"] -git-tree-sha1 = "cceb0257b662528ecdf0b4b4302eb00e767b38e7" +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "4b2e829ee66d4218e0cef22c0a64ee37cf258c29" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.5.0" +version = "2.7.1" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" +version = "1.9.0" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.1.2" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.3.0" +version = "1.4.0" [[deps.PrettyPrint]] git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" @@ -728,10 +692,6 @@ version = "0.2.0" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -[[deps.Profile]] -deps = ["Printf"] -uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" - [[deps.ProgressLogging]] deps = ["Logging", "SHA", "UUIDs"] git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" @@ -754,9 +714,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[deps.Random123]] deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "7a1a306b72cfa60634f03a911405f4e64d1b718b" +git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3" uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.6.0" +version = "1.6.1" [[deps.RandomNumbers]] deps = ["Random", "Requires"] @@ -789,14 +749,20 @@ version = "1.3.0" [[deps.Revise]] deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"] -git-tree-sha1 = "dad726963ecea2d8a81e26286f625aee09a91b7c" +git-tree-sha1 = "1e597b93700fa4045d7189afa7c004e0584ea548" uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" -version = "3.4.0" +version = "3.5.3" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.0" + [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -821,29 +787,28 @@ git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" version = "0.9.4" -[[deps.SnoopPrecompile]] -git-tree-sha1 = "f604441450a3c0569830946e5b33b78c928e1a85" -uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" -version = "1.0.1" - [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" +git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.0" +version = "1.1.1" [[deps.SparseArrays]] -deps = ["LinearAlgebra", "Random"] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.SpecialFunctions]] -deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d" +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.1.7" +version = "2.3.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" [[deps.SplittablesBase]] deps = ["Setfield", "Test"] @@ -851,17 +816,11 @@ git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" -[[deps.Static]] -deps = ["IfElse"] -git-tree-sha1 = "03170c1e8a94732c1d835ce4c5b904b4b52cba1c" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.7.8" - [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "f86b3a049e5d05227b10e15dbb315c5b90f14988" +git-tree-sha1 = "832afbae2a45b4ae7e831f86965469a24d1d8a83" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.9" +version = "1.5.26" [[deps.StaticArraysCore]] git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" @@ -871,33 +830,35 @@ version = "1.4.0" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] -git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6" +git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.5.0" +version = "1.6.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.21" +version = "0.34.0" [[deps.StructArrays]] -deps = ["Adapt", "DataAPI", "StaticArraysCore", "Tables"] -git-tree-sha1 = "13237798b407150a6d2e2bce5d793d7d9576e99e" +deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] +git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.13" +version = "0.6.15" -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+6" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.0" +version = "1.0.3" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -907,14 +868,14 @@ version = "1.0.1" [[deps.Tables]] deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d" +git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.10.0" +version = "1.10.1" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.1" +version = "1.10.0" [[deps.TensorCore]] deps = ["LinearAlgebra"] @@ -934,21 +895,29 @@ version = "0.1.11" [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.22" - -[[deps.TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "8a75929dcd3c38611db2f8d08546decb514fcadf" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.9" +version = "0.5.23" [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "77fea79baa5b22aeda896a8d9c6445a74500a2c2" +git-tree-sha1 = "a66fb81baec325cf6ccafa243af573b031e87b00" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.74" +version = "0.4.77" + + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" [[deps.UUIDs]] deps = ["Random", "SHA"] @@ -958,38 +927,74 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[deps.UnicodePlots]] -deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "FileIO", "FreeType", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase", "Unitful"] -git-tree-sha1 = "80403d1795114d9150f93870e707c6709a2d3cfe" +deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "PrecompileTools", "Printf", "Requires", "SparseArrays", "StaticArrays", "StatsBase"] +git-tree-sha1 = "b96de03092fe4b18ac7e4786bee55578d4b75ae8" uuid = "b8865327-cd53-5732-bb35-84acbb429228" -version = "3.2.4" +version = "3.6.0" + + [deps.UnicodePlots.extensions] + FreeTypeExt = ["FileIO", "FreeType"] + ImageInTerminalExt = "ImageInTerminal" + IntervalSetsExt = "IntervalSets" + TermExt = "Term" + UnitfulExt = "Unitful" + + [deps.UnicodePlots.weakdeps] + FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" + FreeType = "b38be410-82b0-50bf-ab77-7b57e271db43" + ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Term = "22787eb5-b846-44ae-b979-8e399b8463ab" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" -[[deps.Unitful]] -deps = ["ConstructionBase", "Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "d57a4ed70b6f9ff1da6719f5f2713706d57e0d66" -uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.12.0" +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.2" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.12+3" +version = "1.2.13+0" [[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "66cc604b9a27a660e25a54e408b4371123a186a6" +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "5be3ddb88fc992a7d8ea96c3f10a49a7e98ebc7b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.49" +version = "0.6.62" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [[deps.ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.2" +version = "0.2.3" + +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDNN_jll"] +git-tree-sha1 = "f65490d187861d6222cb38bcbbff3fd949a7ec3e" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.0.4" [[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.1.1+0" +version = "5.8.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/docs/Project.toml b/docs/Project.toml index 38927f6..44924bf 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,11 @@ [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" GridGraphs = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb" InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" @@ -11,4 +13,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" \ No newline at end of file +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/make.jl b/docs/make.jl index 62eb459..c2bf664 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -23,7 +23,7 @@ end # Parse test/tutorial.jl into docs/src/tutorial.md (overwriting) -tuto_jl_file = joinpath(dirname(@__DIR__), "test", "tutorial.jl") +tuto_jl_file = joinpath(dirname(@__DIR__), "examples", "tutorial.jl") tuto_md_dir = joinpath(@__DIR__, "src") Literate.markdown(tuto_jl_file, tuto_md_dir; documenter=true, execute=false) @@ -39,9 +39,10 @@ makedocs(; ), pages=[ "Home" => "index.md", - "Tutorials" => ["tutorial.md", "advanced_tutorials.md"], "Background" => "background.md", - "Algorithms & API" => "algorithms.md", + "Examples" => ["tutorial.md", "advanced_applications.md"], + "Algorithms" => ["optim.md", "losses.md"], + "API reference" => "api.md", ], ) diff --git a/docs/src/advanced_tutorials.md b/docs/src/advanced_applications.md similarity index 100% rename from docs/src/advanced_tutorials.md rename to docs/src/advanced_applications.md diff --git a/docs/src/algorithms.md b/docs/src/algorithms.md deleted file mode 100644 index c1e34a1..0000000 --- a/docs/src/algorithms.md +++ /dev/null @@ -1,107 +0,0 @@ -# API Reference - -## Probability distributions - -```@autodocs -Modules = [InferOpt] -Pages = ["utils/probability_distribution.jl", "utils/pushforward.jl"] -``` - -## Plus identity - -!!! note "Reference" - [Backpropagation through Combinatorial Algorithms: Identity with Projection Works](https://arxiv.org/abs/2205.15213) - -```@autodocs -Modules = [InferOpt] -Pages = ["plus_identity/plus_identity.jl"] -``` - -## Interpolation - -!!! note "Reference" - [Differentiation of Blackbox Combinatorial Solvers](https://arxiv.org/abs/1912.02175) - -```@autodocs -Modules = [InferOpt] -Pages = ["interpolation/interpolation.jl"] -``` - -## Smart "Predict, then Optimize" - -!!! note "Reference" - [Smart "Predict, then Optimize"](https://arxiv.org/abs/1710.08005) - -```@autodocs -Modules = [InferOpt] -Pages = ["spo/spoplus_loss.jl"] -``` - -## Structured Support Vector Machines - -!!! note "Reference" - [Structured learning and prediction in computer vision](https://pub.ist.ac.at/~chl/papers/nowozin-fnt2011.pdf), Chapter 6 - -```@autodocs -Modules = [InferOpt] -Pages = ["ssvm/isbaseloss.jl", "ssvm/ssvm_loss.jl", "ssvm/zeroone_baseloss.jl"] -``` - -## Frank-Wolfe - -!!! note "Reference" - [Efficient and Modular Implicit Differentiation](http://arxiv.org/abs/2105.15183) - -!!! note "Reference" - [FrankWolfe.jl: a high-performance and flexible toolbox for Frank-Wolfe algorithms and Conditional Gradients](https://arxiv.org/abs/2104.06675) - -```@autodocs -Modules = [InferOpt] -Pages = ["frank_wolfe/frank_wolfe_utils.jl", "frank_wolfe/differentiable_frank_wolfe.jl"] -``` - -## Regularized optimizers - -!!! note "Reference" - [Learning with Fenchel-Young Losses](https://arxiv.org/abs/1901.02324) - -```@autodocs -Modules = [InferOpt] -Pages = ["regularized/isregularized.jl", "regularized/regularized_generic.jl", "regularized/regularized_utils.jl", "regularized/soft_argmax.jl", "regularized/sparse_argmax.jl"] -``` - -## Perturbed optimizers - -!!! note "Reference" - [Learning with Differentiable Perturbed Optimizers](https://arxiv.org/abs/2002.08676) - -```@autodocs -Modules = [InferOpt] -Pages = ["perturbed/abstract_perturbed.jl", "perturbed/additive.jl", "perturbed/multiplicative.jl"] -``` - -## Fenchel-Young losses - -!!! note "Reference" - [Learning with Fenchel-Young Losses](https://arxiv.org/abs/1901.02324) - -```@autodocs -Modules = [InferOpt] -Pages = ["fenchel_young/fenchel_young.jl", "fenchel_young/perturbed.jl"] -``` - -## Generalized imitation losses - -!!! note "Reference" - [Learning with Combinatorial Optimization Layers: a Probabilistic Approach](https://arxiv.org/abs/2207.13513) - -```@autodocs -Modules = [InferOpt] -Pages = ["imitation_loss/imitation_loss.jl"] -``` - -## Index - -```@index -Modules = [InferOpt] -``` \ No newline at end of file diff --git a/docs/src/api.md b/docs/src/api.md new file mode 100644 index 0000000..c95c58e --- /dev/null +++ b/docs/src/api.md @@ -0,0 +1,25 @@ +# API Reference + +```@docs +InferOpt +``` + +## Types + +```@autodocs +Modules = [InferOpt] +Order = [:type] +``` + +## Functions + +```@autodocs +Modules = [InferOpt] +Order = [:function] +``` + +## Index + +```@index +Modules = [InferOpt] +``` \ No newline at end of file diff --git a/docs/src/background.md b/docs/src/background.md index 8515d01..a7e3e5c 100644 --- a/docs/src/background.md +++ b/docs/src/background.md @@ -35,6 +35,5 @@ Since we want our package to be as generic as possible, we don't make any assump That way, the best solver can be selected for each use case. We only ask the user to provide a black box function called `maximizer`, taking $\theta$ as argument and returning $f(\theta)$. -This function is then wrapped into a callable Julia `struct`, which can be used (for instance) within neural networks from the [`Flux.jl`](https://github.com/FluxML/Flux.jl) library. -To achieve this compatibility, we leverage Julia's automatic differentiation (AD) ecosystem, which revolves around the [`ChainRules.jl`](https://github.com/JuliaDiff/ChainRules.jl) package. -See their [documentation](https://juliadiff.org/ChainRulesCore.jl/dev/index.html) for more details. \ No newline at end of file +This function is then wrapped into a callable Julia `struct`, which can be used (for instance) within neural networks from the [Flux.jl](https://github.com/FluxML/Flux.jl) or [Lux.jl](https://github.com/LuxDL/Lux.jl) library. +To achieve this compatibility, we leverage Julia's automatic differentiation (AD) ecosystem, which revolves around the [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) package. \ No newline at end of file diff --git a/docs/src/losses.md b/docs/src/losses.md new file mode 100644 index 0000000..4d4ef49 --- /dev/null +++ b/docs/src/losses.md @@ -0,0 +1,10 @@ +# Losses + +!!! info "Work in progress" + Come back later! + +```@example +using AbstractTrees, InferOpt, InteractiveUtils +AbstractTrees.children(x::Type) = subtypes(x) +print_tree(InferOpt.AbstractLossLayer) +``` \ No newline at end of file diff --git a/docs/src/optim.md b/docs/src/optim.md new file mode 100644 index 0000000..36f0b90 --- /dev/null +++ b/docs/src/optim.md @@ -0,0 +1,10 @@ +# Optimization + +!!! info "Work in progress" + Come back later! + +```@example +using AbstractTrees, InferOpt, InteractiveUtils +AbstractTrees.children(x::Type) = subtypes(x) +print_tree(InferOpt.AbstractOptimizationLayer) +``` \ No newline at end of file diff --git a/examples/tutorial.jl b/examples/tutorial.jl new file mode 100644 index 0000000..b1c27ee --- /dev/null +++ b/examples/tutorial.jl @@ -0,0 +1,178 @@ +# # Basic tutorial + +# ## Context + +#= +Let us imagine that we observe the itineraries chosen by a public transport user in several different networks, and that we want to understand their decision-making process (a.k.a. recover their utility function). + +More precisely, each point in our dataset consists in: +- a graph $G$ +- a shortest path $P$ from the top left to the bottom right corner + +We don't know the true costs that were used to compute the shortest path, but we can exploit a set of features to approximate these costs. +The question is: how should we combine these features? + +We will use InferOpt.jl to learn the appropriate weights, so that we may propose relevant paths to the user in the future. +=# + +using Flux +using Graphs +using GridGraphs +using InferOpt +using LinearAlgebra +using ProgressMeter +using Random +using Statistics +using Test +using UnicodePlots + +Random.seed!(63); + +# ## Grid graphs + +#= +For the purposes of this tutorial, we consider grid graphs, as implemented in [GridGraphs.jl](https://github.com/gdalle/GridGraphs.jl). +In such graphs, each vertex corresponds to a couple of coordinates $(i, j)$, where $1 \leq i \leq h$ and $1 \leq j \leq w$. + +To ensure acyclicity, we only allow the user to move right, down or both. +Since the cost of a move is defined as the cost of the arrival vertex, any grid graph is entirely characterized by its cost matrix $\theta \in \mathbb{R}^{h \times w}$. +=# + +h, w = 50, 100 +queen_directions = GridGraphs.QUEEN_ACYCLIC_DIRECTIONS +g = GridGraph(rand(h, w); directions=queen_directions); + +#= +For convenience, GridGraphs.jl also provides custom functions to compute shortest paths efficiently. +Let us see what those paths look like. +=# + +p = path_to_matrix(g, grid_topological_sort(g, 1, nv(g))); +spy(p) + +# ## Dataset + +#= +As announced, we do not know the cost of each vertex, only a set of relevant features. +Let us assume that the user combines them using a shallow neural network. +=# + +nb_features = 5 +true_encoder = Chain(Dense(nb_features, 1), z -> dropdims(z; dims=1)); + +#= +The true vertex costs computed from this encoding are then used within shortest path computations. +To be consistent with the literature, we frame this problem as a linear maximization problem, which justifies the change of sign in front of $\theta$. +Note that `linear_maximizer` can take keyword arguments, eg. to give additional information about the instance that `θ` doesn't contain. +=# + +function linear_maximizer(θ; directions) + g = GridGraph(-θ; directions=directions) + path = grid_topological_sort(g, 1, nv(g)) + return path_to_matrix(g, path) +end; + +#= +We now have everything we need to build our dataset. +=# + +nb_instances = 30 + +X_train = [randn(Float32, nb_features, h, w) for n in 1:nb_instances]; +θ_train = [true_encoder(x) for x in X_train]; +Y_train = [linear_maximizer(θ; directions=queen_directions) for θ in θ_train]; + +# ## Learning + +#= +We create a trainable model with the same structure as the true encoder but another set of randomly-initialized weights. +=# + +initial_encoder = Chain(Dense(nb_features, 1), z -> dropdims(z; dims=1)); + +#= +Here is the crucial part where InferOpt.jl intervenes: the choice of a clever loss function that enables us to +- differentiate through the shortest path maximizer, even though it is a combinatorial operation +- evaluate the quality of our model based on the paths that it recommends +=# + +layer = PerturbedMultiplicative(linear_maximizer; ε=1.0, nb_samples=5); +loss = FenchelYoungLoss(layer); + +#= +This probabilistic layer is just a thin wrapper around our `linear_maximizer`, but with a very different behavior: +=# + +p_layer = layer(θ_train[1]; directions=queen_directions); +spy(p_layer) + +#= +Instead of choosing just one path, it spreads over several possible paths, allowing its output to change smoothly as $\theta$ varies. +Thanks to this smoothing, we can now train our model with a standard gradient optimizer. +=# + +encoder = deepcopy(initial_encoder) +opt = Flux.Adam(); +losses = Float64[] +for epoch in 1:100 + l = 0.0 + for (x, y) in zip(X_train, Y_train) + grads = gradient(Flux.params(encoder)) do + l += loss(encoder(x), y; directions=queen_directions) + end + Flux.update!(opt, Flux.params(encoder), grads) + end + push!(losses, l) +end; + +# ## Results + +#= +Since the Fenchel-Young loss is convex, it is no wonder that optimization worked like a charm. +=# + +lineplot(losses; xlabel="Epoch", ylabel="Loss") + +#= +To assess performance, we can compare the learned weights with their true (hidden) values +=# + +learned_weight = encoder[1].weight / norm(encoder[1].weight) +true_weight = true_encoder[1].weight / norm(true_encoder[1].weight) +hcat(learned_weight, true_weight) + +#= +We are quite close to recovering the exact user weights. +But in reality, it doesn't matter as much as our ability to provide accurate path predictions. +Let us therefore compare our predictions with the actual paths on the training set. +=# + +normalized_hamming(x, y) = mean(x[i] != y[i] for i in eachindex(x)); + +#- + +Y_train_pred = [linear_maximizer(encoder(x); directions=queen_directions) for x in X_train]; + +train_error = mean( + normalized_hamming(y, y_pred) for (y, y_pred) in zip(Y_train, Y_train_pred) +) + +# Not too bad, at least compared with our random initial encoder. + +Y_train_pred_initial = [ + linear_maximizer(initial_encoder(x); directions=queen_directions) for x in X_train +]; + +train_error_initial = mean( + normalized_hamming(y, y_pred) for (y, y_pred) in zip(Y_train, Y_train_pred_initial) +) + +#= +This is definitely a success. +Of course in real prediction settings we should measure performance on a test set as well. +This is left as an exercise to the reader. +=# + +# CI tests, not included in the documentation #src + +@test train_error < train_error_initial / 3 #src diff --git a/ext/InferOptFrankWolfeExt.jl b/ext/InferOptFrankWolfeExt.jl index cdb8cff..d9eb7df 100644 --- a/ext/InferOptFrankWolfeExt.jl +++ b/ext/InferOptFrankWolfeExt.jl @@ -1,49 +1,28 @@ module InferOptFrankWolfeExt using DifferentiableFrankWolfe: DiffFW, LinearMaximizationOracleWithKwargs -using InferOpt: InferOpt, RegularizedGeneric, FixedAtomsProbabilityDistribution +using InferOpt: InferOpt, RegularizedFrankWolfe, FixedAtomsProbabilityDistribution using InferOpt: compute_expectation, compute_probability_distribution using LinearAlgebra: dot -## Forward pass - -function InferOpt.compute_probability_distribution( - dfw::DiffFW, θ::AbstractArray; frank_wolfe_kwargs=NamedTuple() -) - weights, atoms = dfw.implicit(θ; frank_wolfe_kwargs=frank_wolfe_kwargs) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - """ - compute_probability_distribution(regularized::RegularizedGeneric, θ; kwargs...) + compute_probability_distribution(regularized::RegularizedFrankWolfe, θ; kwargs...) Construct a `DifferentiableFrankWolfe.DiffFW` struct and call `compute_probability_distribution` on it. Keyword arguments are passed to the underlying linear maximizer. """ function InferOpt.compute_probability_distribution( - regularized::RegularizedGeneric, θ::AbstractArray; kwargs... + regularized::RegularizedFrankWolfe, θ::AbstractArray; kwargs... ) - (; maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized + (; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized f(y, θ) = Ω(y) - dot(θ, y) f_grad1(y, θ) = Ω_grad(y) - θ - lmo = LinearMaximizationOracleWithKwargs(maximizer, kwargs) + lmo = LinearMaximizationOracleWithKwargs(linear_maximizer, kwargs) dfw = DiffFW(f, f_grad1, lmo) - probadist = compute_probability_distribution(dfw, θ; frank_wolfe_kwargs) + weights, atoms = dfw.implicit(θ; frank_wolfe_kwargs=frank_wolfe_kwargs) + probadist = FixedAtomsProbabilityDistribution(atoms, weights) return probadist end -""" - (regularized::RegularizedGeneric)(θ; kwargs...) - -Apply `compute_probability_distribution(regularized, θ)` and return the expectation. - -Keyword arguments are passed to the underlying linear maximizer. -""" -function (regularized::RegularizedGeneric)(θ::AbstractArray; kwargs...) - probadist = compute_probability_distribution(regularized, θ; kwargs...) - return compute_expectation(probadist) -end - end diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 752236b..108d07f 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -1,74 +1,72 @@ +""" + InferOpt + +A toolbox for using combinatorial optimization algorithms within machine learning pipelines. + +See our preprint +""" module InferOpt using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, Tangent, ZeroTangent using ChainRulesCore: rrule, rrule_via_ad, unthunk using LinearAlgebra: dot using Random: AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed! -using SimpleTraits: SimpleTraits -using SimpleTraits: @traitdef, @traitfn, @traitimpl using Statistics: mean using StatsBase: StatsBase, sample using ThreadsX: ThreadsX +include("interface.jl") + +include("utils/some_functions.jl") include("utils/probability_distribution.jl") include("utils/pushforward.jl") -include("plus_identity/plus_identity.jl") +include("simple/interpolation.jl") +include("simple/identity.jl") -include("interpolation/interpolation.jl") - -include("regularized/isregularized.jl") -include("regularized/regularized_utils.jl") +include("regularized/abstract_regularized.jl") include("regularized/soft_argmax.jl") include("regularized/sparse_argmax.jl") -include("regularized/regularized_generic.jl") +include("regularized/regularized_frank_wolfe.jl") include("perturbed/abstract_perturbed.jl") include("perturbed/additive.jl") include("perturbed/multiplicative.jl") -include("fenchel_young/perturbed.jl") -include("fenchel_young/fenchel_young.jl") - -include("spo/spoplus_loss.jl") - -include("ssvm/isbaseloss.jl") -include("ssvm/zeroone_baseloss.jl") -include("ssvm/ssvm_loss.jl") - -include("imitation_loss/imitation_loss.jl") +include("imitation/spoplus_loss.jl") +include("imitation/ssvm_loss.jl") +include("imitation/fenchel_young_loss.jl") +include("imitation/imitation_loss.jl") +include("imitation/zero_one_loss.jl") if !isdefined(Base, :get_extension) include("../ext/InferOptFrankWolfeExt.jl") end +export half_square_norm +export shannon_entropy, negative_shannon_entropy +export one_hot_argmax, ranking + export FixedAtomsProbabilityDistribution -export compute_expectation, compress_distribution! -export Pushforward +export compute_expectation export compute_probability_distribution +export Pushforward -export PlusIdentity - +export IdentityRelaxation export Interpolation -export half_square_norm -export shannon_entropy, negative_shannon_entropy -export one_hot_argmax, ranking -export IsRegularized -export soft_argmax, sparse_argmax -export RegularizedGeneric +export AbstractRegularized +export SoftArgmax, soft_argmax +export SparseArgmax, sparse_argmax +export RegularizedFrankWolfe +export AbstractPerturbed export PerturbedAdditive export PerturbedMultiplicative export FenchelYoungLoss - -export SPOPlusLoss - -export IsBaseLoss -export ZeroOneBaseLoss export StructuredSVMLoss - -export ImitationLoss, get_y_true +export ImitationLoss +export SPOPlusLoss end diff --git a/src/fenchel_young/fenchel_young.jl b/src/fenchel_young/fenchel_young.jl deleted file mode 100644 index 6b4a8a4..0000000 --- a/src/fenchel_young/fenchel_young.jl +++ /dev/null @@ -1,57 +0,0 @@ -""" - FenchelYoungLoss{P} - -Fenchel-Young loss associated with a given regularized prediction function. - -# Fields -- `predictor::P`: prediction function of the form `ŷ(θ) = argmax {θᵀy - Ω(y)}` - -Reference: -""" -struct FenchelYoungLoss{P} - predictor::P -end - -function Base.show(io::IO, fyl::FenchelYoungLoss) - (; predictor) = fyl - return print(io, "FenchelYoungLoss($predictor)") -end - -## Forward pass - -function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...) - l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) - return l -end - -@traitfn function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {P; IsRegularized{P}} - (; predictor) = fyl - ŷ = predictor(θ; kwargs...) - Ωy_true = compute_regularization(predictor, y_true) - Ωŷ = compute_regularization(predictor, ŷ) - l = (Ωy_true - dot(θ, y_true)) - (Ωŷ - dot(θ, ŷ)) - g = ŷ - y_true - return l, g -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {P<:AbstractPerturbed} - (; predictor) = fyl - F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(predictor, θ; kwargs...) - l = F - dot(θ, y_true) - g = almost_ŷ - y_true - return l, g -end - -## Backward pass - -function ChainRulesCore.rrule( - fyl::FenchelYoungLoss, θ::AbstractArray, y_true::AbstractArray; kwargs... -) - l, g = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) - fyl_pullback(dl) = NoTangent(), dl * g, NoTangent() - return l, fyl_pullback -end diff --git a/src/fenchel_young/perturbed.jl b/src/fenchel_young/perturbed.jl deleted file mode 100644 index 1f22d33..0000000 --- a/src/fenchel_young/perturbed.jl +++ /dev/null @@ -1,47 +0,0 @@ -function compute_F_and_y_samples( - perturbed::AbstractPerturbed{false}, θ::AbstractArray, Z_samples; kwargs... -) - F_and_y_samples = [ - fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for - Z in Z_samples - ] - return F_and_y_samples -end - -function compute_F_and_y_samples( - perturbed::AbstractPerturbed{true}, θ::AbstractArray, Z_samples; kwargs... -) - return ThreadsX.map( - Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples - ) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::AbstractPerturbed, θ::AbstractArray; kwargs... -) - Z_samples = sample_perturbations(perturbed, θ) - F_and_y_samples = compute_F_and_y_samples(perturbed, θ, Z_samples; kwargs...) - return mean(first, F_and_y_samples), mean(last, F_and_y_samples) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs... -) - (; maximizer, ε) = perturbed - θ_perturbed = θ .+ ε .* Z - y = maximizer(θ_perturbed; kwargs...) - F = dot(θ_perturbed, y) - return F, y -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs... -) - (; maximizer, ε) = perturbed - eZ = exp.(ε .* Z .- ε^2) - θ_perturbed = θ .* eZ - y = maximizer(θ_perturbed; kwargs...) - F = dot(θ_perturbed, y) - y_scaled = y .* eZ - return F, y_scaled -end diff --git a/src/imitation/fenchel_young_loss.jl b/src/imitation/fenchel_young_loss.jl new file mode 100644 index 0000000..26533df --- /dev/null +++ b/src/imitation/fenchel_young_loss.jl @@ -0,0 +1,114 @@ +""" + FenchelYoungLoss <: AbstractLossLayer + +Fenchel-Young loss associated with a given optimization layer. +``` +L(θ, y_true) = (Ω(y_true) - θᵀy_true) - (Ω(ŷ) - θᵀŷ) +``` + +Reference: + +# Fields + +- `optimization_layer::AbstractOptimizationLayer`: optimization layer that can be formulated as `ŷ(θ) = argmax {θᵀy - Ω(y)}` (either regularized or perturbed) +""" +struct FenchelYoungLoss{O<:AbstractOptimizationLayer} <: AbstractLossLayer + optimization_layer::O +end + +function Base.show(io::IO, fyl::FenchelYoungLoss) + (; optimization_layer) = fyl + return print(io, "FenchelYoungLoss($optimization_layer)") +end + +## Forward pass + +""" + (fyl::FenchelYoungLoss)(θ, y_true; kwargs...) +""" +function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...) + l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) + return l +end + +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... +) where {O<:AbstractRegularized} + (; optimization_layer) = fyl + ŷ = optimization_layer(θ; kwargs...) + Ωy_true = compute_regularization(optimization_layer, y_true) + Ωŷ = compute_regularization(optimization_layer, ŷ) + l = (Ωy_true - dot(θ, y_true)) - (Ωŷ - dot(θ, ŷ)) + g = ŷ - y_true + return l, g +end + +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... +) where {O<:AbstractPerturbed} + (; optimization_layer) = fyl + F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(optimization_layer, θ; kwargs...) + l = F - dot(θ, y_true) + g = almost_ŷ - y_true + return l, g +end + +## Backward pass + +function ChainRulesCore.rrule( + fyl::FenchelYoungLoss, θ::AbstractArray, y_true::AbstractArray; kwargs... +) + l, g = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) + fyl_pullback(dl) = NoTangent(), dl * g, NoTangent() + return l, fyl_pullback +end + +## Specific overrides for perturbed layers + +function compute_F_and_y_samples( + perturbed::AbstractPerturbed{false}, θ::AbstractArray, Z_samples; kwargs... +) + F_and_y_samples = [ + fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for + Z in Z_samples + ] + return F_and_y_samples +end + +function compute_F_and_y_samples( + perturbed::AbstractPerturbed{true}, θ::AbstractArray, Z_samples; kwargs... +) + return ThreadsX.map( + Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples + ) +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::AbstractPerturbed, θ::AbstractArray; kwargs... +) + Z_samples = sample_perturbations(perturbed, θ) + F_and_y_samples = compute_F_and_y_samples(perturbed, θ, Z_samples; kwargs...) + return mean(first, F_and_y_samples), mean(last, F_and_y_samples) +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs... +) + (; maximizer, ε) = perturbed + θ_perturbed = θ .+ ε .* Z + y = maximizer(θ_perturbed; kwargs...) + F = dot(θ_perturbed, y) + return F, y +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs... +) + (; maximizer, ε) = perturbed + eZ = exp.(ε .* Z .- ε^2) + θ_perturbed = θ .* eZ + y = maximizer(θ_perturbed; kwargs...) + F = dot(θ_perturbed, y) + y_scaled = y .* eZ + return F, y_scaled +end diff --git a/src/imitation/imitation_loss.jl b/src/imitation/imitation_loss.jl new file mode 100644 index 0000000..9f4ee71 --- /dev/null +++ b/src/imitation/imitation_loss.jl @@ -0,0 +1,84 @@ +""" + ImitationLoss <: AbstractLossLayer + +Generic imitation loss of the form +``` +L(θ, t_true) = max_y {δ(y, t_true) + α θᵀ(y - y_true) - (Ω(y) - Ω(y_true))} +``` + +- When `δ` is zero, this is equivalent to a [`FenchelYoungLoss`](@ref). +- When `Ω` is zero, this is equivalent to a [`StructuredSVMLoss`](@ref). + +Note: by default, `t_true` is a named tuple with field `y_true`, but it can be any data structure for which the [`get_y_true`](@ref) method is implemented. + +# Fields + +- `aux_loss_maximizer`: function of `(θ, t_true, α)` that computes the argmax in the problem above +- `δ`: base loss function +- `Ω`: regularization function +- `α::Float64`: hyperparameter with a default value of 1.0 +""" +struct ImitationLoss{M,L,R} <: AbstractLossLayer + aux_loss_maximizer::M + δ::L + Ω::R + α::Float64 +end + +function Base.show(io::IO, l::ImitationLoss) + (; aux_loss_maximizer, δ, Ω, α) = l + return print(io, "ImitationLoss($aux_loss_maximizer, $δ, $Ω, $α)") +end + +""" + ImitationLoss(; aux_loss_maximizer, δ, Ω, α=1.0) + +Explicit constructor with keyword arguments. +""" +function ImitationLoss(; aux_loss_maximizer, δ, Ω, α=1.0) + return ImitationLoss(aux_loss_maximizer, δ, Ω, float(α)) +end + +""" + get_y_true(t_true::Any) + +Retrieve `y_true` from `t_true`. + +This method should be implemented when using a custom data structure for `t_true` other than a `NamedTuple`. +""" +function get_y_true end + +""" + get_y_true(t_true::NamedTuple) + +Retrieve `y_true` from `t_true`. `t_true` must contain an `y_true` field. +""" +get_y_true(t_true::NamedTuple) = t_true.y_true + +function prediction_and_loss(l::ImitationLoss, θ::AbstractArray, t_true; kwargs...) + (; aux_loss_maximizer, δ, Ω, α) = l + y_true = get_y_true(t_true) + ŷ = aux_loss_maximizer(θ, t_true, α; kwargs...) + l = δ(ŷ, t_true) + α * (dot(θ, ŷ) - dot(θ, y_true)) + Ω(y_true) - Ω(ŷ) + return ŷ, l +end + +## Forward pass + +""" + (il::ImitationLoss)(θ, t_true; kwargs...) +""" +function (il::ImitationLoss)(θ::AbstractArray, t_true; kwargs...) + _, l = prediction_and_loss(il, θ, t_true; kwargs...) + return l +end + +## Backward pass + +function ChainRulesCore.rrule(il::ImitationLoss, θ::AbstractArray, t_true; kwargs...) + (; α) = il + y_true = get_y_true(t_true) + ŷ, l = prediction_and_loss(il, θ, t_true; kwargs...) + il_pullback(dl) = NoTangent(), dl .* α .* (ŷ .- y_true), NoTangent() + return l, il_pullback +end diff --git a/src/spo/spoplus_loss.jl b/src/imitation/spoplus_loss.jl similarity index 82% rename from src/spo/spoplus_loss.jl rename to src/imitation/spoplus_loss.jl index 40a4626..bc0d7f7 100644 --- a/src/spo/spoplus_loss.jl +++ b/src/imitation/spoplus_loss.jl @@ -1,15 +1,15 @@ """ - SPOPlusLoss{F} + SPOPlusLoss <: AbstractLossLayer Convex surrogate of the Smart "Predict-then-Optimize" loss. # Fields -- `maximizer::F`: linear maximizer function of the form `θ ⟼ ŷ(θ) = argmax θᵀy` -- `α::Float64`: convexification parameter +- `maximizer`: linear maximizer function of the form `θ -> ŷ(θ) = argmax θᵀy` +- `α::Float64`: convexification parameter, default = 2.0 Reference: """ -struct SPOPlusLoss{F} +struct SPOPlusLoss{F} <: AbstractLossLayer maximizer::F α::Float64 end @@ -19,10 +19,16 @@ function Base.show(io::IO, spol::SPOPlusLoss) return print(io, "SPOPlusLoss($maximizer, $α)") end +""" + SPOPlusLoss(maximizer; α=2.0) +""" SPOPlusLoss(maximizer; α=2.0) = SPOPlusLoss(maximizer, float(α)) ## Forward pass +""" + (spol::SPOPlusLoss)(θ, θ_true, y_true; kwargs...) +""" function (spol::SPOPlusLoss)( θ::AbstractArray, θ_true::AbstractArray, y_true::AbstractArray; kwargs... ) @@ -33,6 +39,9 @@ function (spol::SPOPlusLoss)( return l end +""" + (spol::SPOPlusLoss)(θ, θ_true; kwargs...) +""" function (spol::SPOPlusLoss)(θ::AbstractArray, θ_true::AbstractArray; kwargs...) y_true = spol.maximizer(θ_true; kwargs...) return spol(θ, θ_true, y_true) diff --git a/src/imitation/ssvm_loss.jl b/src/imitation/ssvm_loss.jl new file mode 100644 index 0000000..5149b42 --- /dev/null +++ b/src/imitation/ssvm_loss.jl @@ -0,0 +1,64 @@ +""" + StructuredSVMLoss <: AbstractLossLayer + +Loss associated with the Structured Support Vector Machine, defined by +``` +L(θ, y_true) = max_y {δ(y, y_true) + α θᵀ(y - y_true)} +``` + +Reference: (Chapter 6) + +# Fields + +- `aux_loss_maximizer::M`: function of `(θ, y_true, α)` that computes the argmax in the problem above +- `δ::L`: base loss function +- `α::Float64`: hyperparameter with a default value of 1.0 +""" +struct StructuredSVMLoss{M,L} <: AbstractLossLayer + aux_loss_maximizer::M + δ::L + α::Float64 +end + +""" + StructuredSVMLoss(; aux_loss_maximizer, δ, α=1.0) +""" +function StructuredSVMLoss(; aux_loss_maximizer, δ, α=1.0) + return StructuredSVMLoss(aux_loss_maximizer, δ, float(α)) +end + +function Base.show(io::IO, ssvml::StructuredSVMLoss) + (; aux_loss_maximizer, δ, α) = ssvml + return print(io, "StructuredSVMLoss($aux_loss_maximizer, $δ, $α)") +end + +function prediction_and_loss( + ssvml::StructuredSVMLoss, θ::AbstractArray, y_true::AbstractArray; kwargs... +) + (; aux_loss_maximizer, δ, α) = ssvml + ŷ = aux_loss_maximizer(θ, y_true, α; kwargs...) + l = δ(ŷ, y_true) + α * (dot(θ, ŷ) - dot(θ, y_true)) + return ŷ, l +end + +## Forward pass + +""" + (ssvml::StructuredSVMLoss)(θ, y_true; kwargs...) +""" +function (ssvml::StructuredSVMLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...) + _, l = prediction_and_loss(ssvml, θ, y_true; kwargs...) + return l +end + +## Backward pass + +function ChainRulesCore.rrule( + ssvml::StructuredSVMLoss, θ::AbstractArray, y_true::AbstractArray; kwargs... +) + (; α) = ssvml + ŷ, l = prediction_and_loss(ssvml, θ, y_true; kwargs...) + g = α .* (ŷ .- y_true) + ssvml_pullback(dl) = NoTangent(), dl * g, NoTangent() + return l, ssvml_pullback +end diff --git a/src/imitation/zero_one_loss.jl b/src/imitation/zero_one_loss.jl new file mode 100644 index 0000000..fae9364 --- /dev/null +++ b/src/imitation/zero_one_loss.jl @@ -0,0 +1,58 @@ +""" + zero_one_loss(y, y_true) + +0-1 loss for multiclass classification: `δ(y, y_true) = 0` if `y = y_true`, and `1` otherwise. +""" +function zero_one_loss(y::AbstractArray, y_true::AbstractArray) + return y == y_true ? zero(eltype(y)) : one(eltype(y)) +end + +""" + zero_one_loss_maximizer(y, y_true; α) + +For `δ = zero_one_loss`, compute +``` +argmax_y {δ(y, y_true) + α θᵀ(y - y_true)} +``` +""" +function zero_one_loss_maximizer( + θ::AbstractVector, + y_true::AbstractVector{R}, # TODO: does it work with arrays? + α, +) where {R<:Real} + i_true = findfirst(==(one(R)), y_true) + i_θ = argmax(θ) + y = zeros(R, size(y_true)) + if (i_true == i_θ) || (α * θ[i_true] > 1 + α * θ[i_θ]) + y[i_true] = one(R) + else + y[i_θ] = one(R) + end + return y +end + +""" + ZeroOneStructuredSVMLoss + +Implementation of the [`StructuredSVMLoss`](@ref) based on a 0-1 loss for multiclass classification. +""" +function ZeroOneStructuredSVMLoss(α=1) + return StructuredSVMLoss(; + aux_loss_maximizer=zero_one_loss_maximizer, δ=zero_one_loss, α=α + ) +end + +""" + ZeroOneStructuredSVMLoss(α) + +Implementation of the [`ImitationLoss`](@ref) based on a 0-1 loss for multiclass classification with no regularization. +""" +function ZeroOneImitationLoss(α=1) + return ImitationLoss(; + δ=(y, t_true) -> zero_one_loss(y, get_y_true(t_true)), + Ω=y -> 0, + α=α, + aux_loss_maximizer=(θ, t_true, α) -> + zero_one_loss_maximizer(θ, get_y_true(t_true), α), + ) +end diff --git a/src/imitation_loss/imitation_loss.jl b/src/imitation_loss/imitation_loss.jl deleted file mode 100644 index 6f3c55f..0000000 --- a/src/imitation_loss/imitation_loss.jl +++ /dev/null @@ -1,79 +0,0 @@ -""" - ImitationLoss{L,R,P} - -Generic imitation loss: max_y base_loss(y, t_true) + α θᵀ(y - y_true) - (Ω(y) - Ω(y_true))). - -When `base_loss = 0`, this loss is equivalent to a [`FenchelYoungLoss`](@ref). -When `Ω = 0`, this loss is equivalent to the [`StructuredSVMLoss`](@ref). - -# Fields -- `maximizer::P`: function that computes - argmax_y base_loss(y, t_true) + α θᵀ(y - y_true) - (Ω(y) - Ω(y_true)), takes (θ, y_true, kwargs...) - or (θ, t_true, kwargs...) as input -- `base_loss::L`: base loss, takes (y, t_true) as input -- `Ω::R`: regularization, takes y as input -- `α::Float64`: default value of 1.0 - -Note: by default, `t_true` is a named tuple with field `y_true`, - but can be any data structure for which the [`get_y_true`](@ref) method is implemented. -""" -struct ImitationLoss{P,L,R} - maximizer::P - base_loss::L - Ω::R - α::Float64 -end - -function Base.show(io::IO, l::ImitationLoss) - (; base_loss, Ω, maximizer, α) = l - return print(io, "ImitationLoss($maximizer, $base_loss, $Ω, $α)") -end - -""" - ImitationLoss(maximizer[; base_loss=(y,t_true)->0.0, Ω=y->0.0, α=1.0]) - -Shorter constructor with defaults. -""" -function ImitationLoss(maximizer; base_loss=(y, t_true) -> 0.0, Ω=y -> 0.0, α=1.0) - return ImitationLoss(maximizer, base_loss, Ω, α) -end - -""" - get_y_true(t_true::Any) - -Retrieve `y_true` from `t_true`. -This method should be implemented when using a custom data structure for `t_true` other than a `NamedTuple`. -""" -function get_y_true end - -""" - get_y_true(t_true::NamedTuple) - -Retrieve `y_true` from `t_true`. `t_true` must contain an `y_true` field. -""" -get_y_true(t_true::NamedTuple) = t_true.y_true - -function prediction_and_loss(l::ImitationLoss, θ::AbstractArray, t_true; kwargs...) - (; base_loss, Ω, maximizer, α) = l - y_true = get_y_true(t_true) - ŷ = maximizer(θ, t_true; kwargs...) - l = base_loss(ŷ, t_true) + α * (dot(θ, ŷ) - dot(θ, y_true)) + Ω(y_true) - Ω(ŷ) - return ŷ, l -end - -## Forward pass - -function (l::ImitationLoss)(θ::AbstractArray, t_true; kwargs...) - _, l = prediction_and_loss(l, θ, t_true; kwargs...) - return l -end - -## Backward pass - -function ChainRulesCore.rrule(l::ImitationLoss, θ::AbstractArray, t_true; kwargs...) - (; α) = l - y_true = get_y_true(t_true) - ŷ, l = prediction_and_loss(l, θ, t_true; kwargs...) - l_pullback(dl) = NoTangent(), dl .* α .* (ŷ .- y_true), NoTangent() - return l, l_pullback -end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..cb4ac75 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,50 @@ +""" + AbstractLayer + +Supertype for all the layers defined in InferOpt. + +All of these layers are callable, and differentiable with any ChainRules-compatible autodiff backend. + +# Interface +- `(layer::AbstractLayer)(args...; kwargs...)` +""" +abstract type AbstractLayer end + +## Optimization + +""" + AbstractOptimizationLayer <: AbstractLayer + +Supertype for all the optimization layers defined in InferOpt. + +# Interface +- `(layer::AbstractOptimizationLayer)(θ; kwargs...)` +- `compute_probability_distribution(layer, θ; kwargs...)` (only if the layer is probabilistic) +""" +abstract type AbstractOptimizationLayer <: AbstractLayer end + +## Losses + +""" + AbstractLossLayer <: AbstractLayer + +Supertype for all the loss layers defined in InferOpt. + +Depending on the precise loss, the arguments to the layer might vary + +# Interface +- `(layer::AbstractLossLayer)(θ; kwargs...)` or +- `(layer::AbstractLossLayer)(θ, θ_true; kwargs...)` or +- `(layer::AbstractLossLayer)(θ, y_true; kwargs...)` or +- `(layer::AbstractLossLayer)(θ, (; θ_true, y_true); kwargs...)` +""" +abstract type AbstractLossLayer <: AbstractLayer end + +## Checking specific properties + +""" + compute_probability_distribution(layer, θ; kwargs...) + +Apply a probabilistic optimization layer to an objective direction `θ` in order to generate a [`FixedAtomsProbabilityDistribution`](@ref) on the vertices of a polytope. +""" +function compute_probability_distribution end diff --git a/src/perturbed/abstract_perturbed.jl b/src/perturbed/abstract_perturbed.jl index 6d58604..3f9180f 100644 --- a/src/perturbed/abstract_perturbed.jl +++ b/src/perturbed/abstract_perturbed.jl @@ -1,20 +1,16 @@ """ - AbstractPerturbed{B} + AbstractPerturbed{parallel} <: AbstractOptimizationLayer Differentiable perturbation of a black box optimizer. -The parameter `parallel` is a boolean value, equal to true if the perturbations are run in parallel. - -# Applicable functions -- [`compute_probability_distribution(perturbed::AbstractPerturbed, θ)`](@ref) -- `(perturbed::AbstractPerturbed)(θ)` +The parameter `parallel` is a boolean value, equal to true if the perturbations are run in parallel. -# Available subtypes +# Available implementations - [`PerturbedAdditive`](@ref) - [`PerturbedMultiplicative`](@ref) -These subtypes share the following fields: +These two subtypes share the following fields: - `maximizer`: black box optimizer - `ε`: magnitude of the perturbation @@ -22,7 +18,7 @@ These subtypes share the following fields: - `rng::AbstractRNG`: random number generator - `seed::Union{Nothing,Int}`: random seed """ -abstract type AbstractPerturbed{parallel} end +abstract type AbstractPerturbed{parallel} <: AbstractOptimizationLayer end """ sample_perturbations(perturbed::AbstractPerturbed, θ) @@ -67,9 +63,11 @@ function compute_probability_distribution( end """ - compute_probability_distribution(perturbed::AbstractPerturbed, θ) + compute_probability_distribution(perturbed::AbstractPerturbed, θ; kwargs...) Turn random perturbations of `θ` into a distribution on polytope vertices. + +Keyword arguments are passed to the underlying linear maximizer. """ function compute_probability_distribution( perturbed::AbstractPerturbed, θ::AbstractArray; kwargs... @@ -79,9 +77,9 @@ function compute_probability_distribution( end """ - (perturbed::AbstractPerturbed)(θ) + (perturbed::AbstractPerturbed)(θ; kwargs...) -Apply `compute_probability_distribution(perturbed, θ)` and return the expectation. +Apply `compute_probability_distribution(perturbed, θ; kwargs...)` and return the expectation. """ function (perturbed::AbstractPerturbed)(θ::AbstractArray; kwargs...) probadist = compute_probability_distribution(perturbed, θ; kwargs...) diff --git a/src/perturbed/additive.jl b/src/perturbed/additive.jl index 8a486c4..90187dc 100644 --- a/src/perturbed/additive.jl +++ b/src/perturbed/additive.jl @@ -1,11 +1,11 @@ """ - PerturbedAdditive{F} + PerturbedAdditive <: AbstractPerturbed -Differentiable normal perturbation of a black-box optimizer of type `F`: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`. - -See also: [`AbstractPerturbed`](@ref). +Differentiable normal perturbation of a black-box maximizer: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`. Reference: + +See [`AbstractPerturbed`](@ref) for more details. """ struct PerturbedAdditive{F,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: AbstractPerturbed{parallel} @@ -32,8 +32,6 @@ end """ PerturbedAdditive(maximizer[; ε=1.0, nb_samples=1]) - -Shorter constructor with defaults. """ function PerturbedAdditive( maximizer::F; diff --git a/src/perturbed/multiplicative.jl b/src/perturbed/multiplicative.jl index eb2c9ae..15c9f14 100644 --- a/src/perturbed/multiplicative.jl +++ b/src/perturbed/multiplicative.jl @@ -1,11 +1,11 @@ """ - PerturbedMultiplicative{F} + PerturbedMultiplicative <: AbstractPerturbed -Differentiable log-normal perturbation of a black-box optimizer of type `F`: the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ N(0, I)`. +Differentiable log-normal perturbation of a black-box maximizer: the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ N(0, I)`. -See also: [`AbstractPerturbed`](@ref). +Reference: -Reference: preprint coming soon. +See [`AbstractPerturbed`](@ref) for more details. """ struct PerturbedMultiplicative{F,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: AbstractPerturbed{parallel} @@ -32,8 +32,6 @@ end """ PerturbedMultiplicative(maximizer[; ε=1.0, nb_samples=1]) - -Shorter constructor with defaults. """ function PerturbedMultiplicative( maximizer::F; diff --git a/src/plus_identity/plus_identity.jl b/src/plus_identity/plus_identity.jl deleted file mode 100644 index 6cbf850..0000000 --- a/src/plus_identity/plus_identity.jl +++ /dev/null @@ -1,29 +0,0 @@ -""" - PlusIdentity{F} - -Naive relaxation of a black-box optimizer where constraints are simply forgotten. - -Consider (centering and) normalizing `θ` before applying it. - -# Fields -- `maximizer::F`: underlying argmax function - -Reference: -""" -struct PlusIdentity{F} - maximizer::F -end - -function Base.show(io::IO, plusid::PlusIdentity) - return print(io, "PlusIdentity($(plusid.maximizer)") -end - -function (plusid::PlusIdentity)(θ::AbstractArray; kwargs...) - return plusid.maximizer(θ; kwargs...) -end - -function ChainRulesCore.rrule(plusid::PlusIdentity, θ::AbstractArray; kwargs...) - y = plusid.maximizer(θ; kwargs...) - plusid_pullback(dy) = NoTangent(), dy - return y, plusid_pullback -end diff --git a/src/regularized/abstract_regularized.jl b/src/regularized/abstract_regularized.jl new file mode 100644 index 0000000..9ca8405 --- /dev/null +++ b/src/regularized/abstract_regularized.jl @@ -0,0 +1,27 @@ +""" + AbstractRegularized{parallel} <: AbstractOptimizationLayer + +Convex regularization perturbation of a black box optimizer +``` +ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)} +``` + +# Interface + +- `(regularized::AbstractRegularized)(θ; kwargs...)`: return `ŷ(θ)` +- `compute_regularization(regularized, y)`: return `Ω(y)` + +# Available implementations + +- [`SoftArgmax`](@ref) +- [`SparseArgmax`](@ref) +- [`RegularizedFrankWolfe`](@ref) +""" +abstract type AbstractRegularized <: AbstractOptimizationLayer end + +""" + compute_regularization(regularized, y) + +Return the convex penalty `Ω(y)` associated with an `AbstractRegularized` layer. +""" +function compute_regularization end diff --git a/src/regularized/isregularized.jl b/src/regularized/isregularized.jl deleted file mode 100644 index 6f16184..0000000 --- a/src/regularized/isregularized.jl +++ /dev/null @@ -1,23 +0,0 @@ -""" - IsRegularized{P} - -Trait-based interface for regularized prediction functions `ŷ(θ) = argmax {θᵀy - Ω(y)}`. - -For `predictor::P` to comply with this interface, the following methods must exist: -- `(predictor)(θ)` -- `compute_regularization(predictor, y)` - -# Available implementations -- [`one_hot_argmax`](@ref) -- [`soft_argmax`](@ref) -- [`sparse_argmax`](@ref) -- [`RegularizedGeneric`](@ref) -""" -@traitdef IsRegularized{P} - -""" - compute_regularization(predictor::P, y) - -Compute the convex regularization function `Ω(y)`. -""" -function compute_regularization end diff --git a/src/regularized/regularized_frank_wolfe.jl b/src/regularized/regularized_frank_wolfe.jl new file mode 100644 index 0000000..fa7923c --- /dev/null +++ b/src/regularized/regularized_frank_wolfe.jl @@ -0,0 +1,65 @@ +""" + RegularizedFrankWolfe <: AbstractRegularized + +Regularized optimization layer which relies on the Frank-Wolfe algorithm to define a probability distribution while solving +``` +ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)} +``` + +!!! warning "Warning" + Since this is a conditional dependency, you need to have loaded the package DifferentiableFrankWolfe.jl before using `RegularizedFrankWolfe`. + +# Fields + +- `linear_maximizer`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C` +- `Ω`: regularization function `Ω(y)` +- `Ω_grad`: gradient function of the regularization function `∇Ω(y)` +- `frank_wolfe_kwargs`: named tuple of keyword arguments passed to the Frank-Wolfe algorithm + +# Frank-Wolfe parameters + +Some values you can tune: +- `epsilon::Float64`: precision target +- `max_iteration::Integer`: max number of iterations +- `timeout::Float64`: max runtime in seconds +- `lazy::Bool`: caching strategy +- `away_steps::Bool`: avoid zig-zagging +- `line_search::FrankWolfe.LineSearchMethod`: step size selection +- `verbose::Bool`: console output + +See the documentation of FrankWolfe.jl for details. +""" +struct RegularizedFrankWolfe{M,RF,RG,FWK} <: AbstractRegularized + linear_maximizer::M + Ω::RF + Ω_grad::RG + frank_wolfe_kwargs::FWK +end + +""" + RegularizedFrankWolfe(linear_maximizer; Ω, Ω_grad, frank_wolfe_kwargs=(;)) +""" +function RegularizedFrankWolfe(linear_maximizer; Ω, Ω_grad, frank_wolfe_kwargs=NamedTuple()) + return RegularizedFrankWolfe(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) +end + +function Base.show(io::IO, regularized::RegularizedFrankWolfe) + (; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized + return print( + io, "RegularizedFrankWolfe($linear_maximizer, $Ω, $Ω_grad, $frank_wolfe_kwargs)" + ) +end + +function compute_regularization(regularized::RegularizedFrankWolfe, y::AbstractArray) + return regularized.Ω(y) +end + +""" + (regularized::RegularizedFrankWolfe)(θ; kwargs...) + +Apply `compute_probability_distribution(regularized, θ; kwargs...)` and return the expectation. +""" +function (regularized::RegularizedFrankWolfe)(θ::AbstractArray; kwargs...) + probadist = compute_probability_distribution(regularized, θ; kwargs...) + return compute_expectation(probadist) +end diff --git a/src/regularized/regularized_generic.jl b/src/regularized/regularized_generic.jl deleted file mode 100644 index 650e9aa..0000000 --- a/src/regularized/regularized_generic.jl +++ /dev/null @@ -1,51 +0,0 @@ -""" - RegularizedGeneric{M,RF,RG} - -Differentiable regularized prediction function `ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)}`. - -Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytope. - -!!! warning "Warning" - Since this is a conditional dependency, you need to run `import DifferentiableFrankWolfe` before using `RegularizedGeneric`. - -# Fields -- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C` -- `Ω::RF`: regularization function `Ω(y)` -- `Ω_grad::RG`: gradient of the regularization function `∇Ω(y)` -- `frank_wolfe_kwargs::FWK`: keyword arguments passed to the Frank-Wolfe algorithm - -# Applicable methods - -- [`compute_probability_distribution(regularized::RegularizedGeneric, θ; kwargs...)`](@ref) -- `(regularized::RegularizedGeneric)(θ; kwargs...)` - -# Frank-Wolfe parameters - -Some values you can tune: -- `epsilon::Float64`: precision target -- `max_iteration::Integer`: max number of iterations -- `timeout::Float64`: max runtime in seconds -- `lazy::Bool`: caching strategy -- `away_steps::Bool`: avoid zig-zagging -- `line_search::FrankWolfe.LineSearchMethod`: step size selection -- `verbose::Bool`: console output - -See the documentation of FrankWolfe.jl for details. -""" -struct RegularizedGeneric{M,RF,RG,FWK} - maximizer::M - Ω::RF - Ω_grad::RG - frank_wolfe_kwargs::FWK -end - -function Base.show(io::IO, regularized::RegularizedGeneric) - (; maximizer, Ω, Ω_grad) = regularized - return print(io, "RegularizedGeneric($maximizer, $Ω, $Ω_grad)") -end - -@traitimpl IsRegularized{RegularizedGeneric} - -function compute_regularization(regularized::RegularizedGeneric, y::AbstractArray) - return regularized.Ω(y) -end diff --git a/src/regularized/soft_argmax.jl b/src/regularized/soft_argmax.jl index 21864e0..1a3f2a0 100644 --- a/src/regularized/soft_argmax.jl +++ b/src/regularized/soft_argmax.jl @@ -1,17 +1,20 @@ """ - soft_argmax(z) + SoftArgmax <: Regularized Soft argmax activation function `s(z) = (e^zᵢ / ∑ e^zⱼ)ᵢ`. Corresponds to regularized prediction on the probability simplex with entropic penalty. """ -function soft_argmax(z::AbstractVector; kwargs...) - s = exp.(z) / sum(exp, z) - return s -end +struct SoftArgmax <: AbstractRegularized end + +(::SoftArgmax)(z) = soft_argmax(z) +compute_regularization(::SoftArgmax, y) = soft_argmax_regularization(y) -@traitimpl IsRegularized{typeof(soft_argmax)} +function soft_argmax(z::AbstractVector) + s = exp.(z) + return s ./ sum(s) +end -function compute_regularization(::typeof(soft_argmax), y::AbstractVector{R}) where {R<:Real} +function soft_argmax_regularization(y::AbstractVector) return isprobadist(y) ? negative_shannon_entropy(y) : typemax(R) end diff --git a/src/regularized/sparse_argmax.jl b/src/regularized/sparse_argmax.jl index 56f5dad..56919f2 100644 --- a/src/regularized/sparse_argmax.jl +++ b/src/regularized/sparse_argmax.jl @@ -1,27 +1,28 @@ """ - sparse_argmax(z) + SparseArgmax <: AbstractRegularized Compute the Euclidean projection of the vector `z` onto the probability simplex. Corresponds to regularized prediction on the probability simplex with square norm penalty. """ +struct SparseArgmax <: AbstractRegularized end + +(::SparseArgmax)(z) = sparse_argmax(z) +compute_regularization(::SparseArgmax, y) = sparse_argmax_regularization(y) + function sparse_argmax(z::AbstractVector; kwargs...) p, _ = simplex_projection_and_support(z) return p end -@traitimpl IsRegularized{typeof(sparse_argmax)} - -function compute_regularization( - ::typeof(sparse_argmax), y::AbstractVector{R} -) where {R<:Real} +function sparse_argmax_regularization(y::AbstractVector) return isprobadist(y) ? half_square_norm(y) : typemax(R) end """ simplex_projection_and_support(z) -Compute the Euclidean projection `p` of `z` on the probability simplex (also called [`sparse_argmax`](@ref)), and the indicators `s` of its support. +Compute the Euclidean projection `p` of `z` on the probability simplex (also called `sparse_argmax`), and the indicators `s` of its support. Reference: . """ diff --git a/src/simple/identity.jl b/src/simple/identity.jl new file mode 100644 index 0000000..3f68218 --- /dev/null +++ b/src/simple/identity.jl @@ -0,0 +1,29 @@ +""" + IdentityRelaxation <: AbstractOptimizationLayer + +Naive relaxation of a black-box optimizer where constraints are simply forgotten. + +Consider (centering and) normalizing `θ` before applying it. + +# Fields +- `maximizer`: underlying argmax function + +Reference: +""" +struct IdentityRelaxation{F} <: AbstractOptimizationLayer + maximizer::F +end + +function Base.show(io::IO, id::IdentityRelaxation) + return print(io, "IdentityRelaxation($(id.maximizer)") +end + +function (id::IdentityRelaxation)(θ::AbstractArray; kwargs...) + return id.maximizer(θ; kwargs...) +end + +function ChainRulesCore.rrule(id::IdentityRelaxation, θ::AbstractArray; kwargs...) + y = id.maximizer(θ; kwargs...) + id_pullback(dy) = NoTangent(), dy + return y, id_pullback +end diff --git a/src/interpolation/interpolation.jl b/src/simple/interpolation.jl similarity index 87% rename from src/interpolation/interpolation.jl rename to src/simple/interpolation.jl index 3c5d921..4897372 100644 --- a/src/interpolation/interpolation.jl +++ b/src/simple/interpolation.jl @@ -1,15 +1,15 @@ """ - Interpolation{F} + Interpolation <: AbstractOptimizationLayer Piecewise-linear interpolation of a black-box optimizer. # Fields -- `maximizer::F`: underlying argmax function +- `maximizer`: underlying argmax function - `λ::Float64`: smoothing parameter (smaller = more faithful approximation, larger = more informative gradients) Reference: """ -struct Interpolation{F} +struct Interpolation{F} <: AbstractOptimizationLayer maximizer::F λ::Float64 end diff --git a/src/ssvm/isbaseloss.jl b/src/ssvm/isbaseloss.jl deleted file mode 100644 index 3186dc7..0000000 --- a/src/ssvm/isbaseloss.jl +++ /dev/null @@ -1,20 +0,0 @@ -""" - IsBaseLoss{L} - -Trait-based interface for loss functions `δ(y, y_true)`, which are the base of the more complex `StructuredSVMLoss`. - -For `δ::L` to comply with this interface, the following methods must exist: -- `(δ)(y, y_true)` -- [`compute_maximizer(δ, θ, α, y_true)`](@ref) - -# Available implementations -- [`ZeroOneBaseLoss`](@ref) -""" -@traitdef IsBaseLoss{L} - -""" - compute_maximizer(δ, θ, α, y_true) - -Compute `argmax_y {δ(y, y_true) + α θᵀ(y - y_true)}` to deduce the gradient of a `StructuredSVMLoss`. -""" -function compute_maximizer end diff --git a/src/ssvm/ssvm_loss.jl b/src/ssvm/ssvm_loss.jl deleted file mode 100644 index 87a9d3f..0000000 --- a/src/ssvm/ssvm_loss.jl +++ /dev/null @@ -1,54 +0,0 @@ -""" - StructuredSVMLoss{L} - -Loss associated with the Structured Support Vector Machine. - -`SSVM(θ, y_true) = max_y {base_loss(y, y_true) + α θᵀ(y - y_true)}` - -# Fields -- `base_loss::L`: of the `IsBaseLoss` trait -- `α::Float64` - -Reference: (Chapter 6) -""" -struct StructuredSVMLoss{L} - base_loss::L - α::Float64 -end - -function Base.show(io::IO, ssvml::StructuredSVMLoss) - (; base_loss, α) = ssvml - return print(io, "StructuredSVMLoss($base_loss, $α)") -end - -StructuredSVMLoss(base_loss; α=1.0) = StructuredSVMLoss(base_loss, float(α)) - -@traitfn function prediction_and_loss( - ssvml::StructuredSVMLoss{L}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {L; IsBaseLoss{L}} - (; base_loss, α) = ssvml - ŷ = compute_maximizer(base_loss, θ, α, y_true; kwargs...) - l = base_loss(ŷ, y_true) + α * (dot(θ, ŷ) - dot(θ, y_true)) - return ŷ, l -end - -## Forward pass - -@traitfn function (ssvml::StructuredSVMLoss{L})( - θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {L; IsBaseLoss{L}} - _, l = prediction_and_loss(ssvml, θ, y_true; kwargs...) - return l -end - -## Backward pass - -@traitfn function ChainRulesCore.rrule( - ssvml::StructuredSVMLoss{L}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {L; IsBaseLoss{L}} - (; α) = ssvml - ŷ, l = prediction_and_loss(ssvml, θ, y_true; kwargs...) - g = α .* (ŷ .- y_true) - ssvml_pullback(dl) = NoTangent(), dl * g, NoTangent() - return l, ssvml_pullback -end diff --git a/src/ssvm/zeroone_baseloss.jl b/src/ssvm/zeroone_baseloss.jl deleted file mode 100644 index d0f5345..0000000 --- a/src/ssvm/zeroone_baseloss.jl +++ /dev/null @@ -1,26 +0,0 @@ -""" - ZeroOneBaseLoss - -0-1 loss for multiclass classification: `δ(y, y_true) = 0` if `y = y_true`, and `1` otherwise. -""" -struct ZeroOneBaseLoss end - -@traitimpl IsBaseLoss{ZeroOneBaseLoss} - -function (::ZeroOneBaseLoss)(y::AbstractArray, y_true::AbstractArray) - return y == y_true ? zero(eltype(y)) : one(eltype(y)) -end - -function compute_maximizer( - ::ZeroOneBaseLoss, θ::AbstractVector, α::Real, y_true::AbstractVector{R} -) where {R<:Real} - i_true = findfirst(==(one(R)), y_true) - i_θ = argmax(θ) - y = zeros(R, size(y_true)) - if (i_true == i_θ) || (α * θ[i_true] > 1 + α * θ[i_θ]) - y[i_true] = one(R) - else - y[i_θ] = one(R) - end - return y -end diff --git a/src/utils/probability_distribution.jl b/src/utils/probability_distribution.jl index 904283a..0fdc21a 100644 --- a/src/utils/probability_distribution.jl +++ b/src/utils/probability_distribution.jl @@ -83,43 +83,3 @@ function ChainRulesCore.rrule( end return e, expectation_pullback end - -""" - compute_probability_distribution(layer, θ) - -Apply a probabilistic layer (regularized or perturbed) to an objective direction `θ` in order to generate a [`FixedAtomsProbabilityDistribution`](@ref) on the vertices of a polytope. - -The following layer types are supported: -- [`AbstractPerturbed`](@ref) -- [`RegularizedGeneric`](@ref) -""" -function compute_probability_distribution end - -""" - compress_distribution!(probadist[; atol]) - -Remove duplicated atoms in `probadist` (up to a tolerance on equality). - -This function can break probabilistic layers if used during training. It is only meant for analyzing outputs. -""" -function compress_distribution!( - probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0 -) where {A,W} - (; atoms, weights) = probadist - to_delete = Int[] - for i in length(probadist):-1:1 - ai = atoms[i] - for j in 1:(i - 1) - aj = atoms[j] - if isapprox(ai, aj; atol=atol) - weights[j] += weights[i] - push!(to_delete, i) - break - end - end - end - sort!(to_delete) - deleteat!(atoms, to_delete) - deleteat!(weights, to_delete) - return probadist -end diff --git a/src/utils/pushforward.jl b/src/utils/pushforward.jl index a4eb765..ddd915b 100644 --- a/src/utils/pushforward.jl +++ b/src/utils/pushforward.jl @@ -1,53 +1,53 @@ """ - Pushforward{L,G} + Pushforward <: AbstractLayer -Differentiable pushforward of a probabilistic `layer` with an arbitrary function `post_processing`. +Differentiable pushforward of a probabilistic optimization layer with an arbitrary function post-processing function. `Pushforward` can be used for direct regret minimization (aka learning by experience) when the post-processing returns a cost. # Fields -- `layer::L`: anything that implements `compute_probability_distribution(layer, θ; kwargs...)` -- `post_processing::P`: callable +- `optimization_layer::AbstractOptimizationLayer`: probabilistic optimization layer +- `post_processing`: callable See also: [`FixedAtomsProbabilityDistribution`](@ref). """ -struct Pushforward{L,P} - layer::L +struct Pushforward{O<:AbstractOptimizationLayer,P} <: AbstractLayer + optimization_layer::O post_processing::P end function Base.show(io::IO, pushforward::Pushforward) - (; layer, post_processing) = pushforward - return print(io, "Pushforward($layer, $post_processing)") + (; optimization_layer, post_processing) = pushforward + return print(io, "Pushforward($optimization_layer, $post_processing)") end """ compute_probability_distribution(pushforward, θ) -Output the distribution of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.layer` applied to `θ`. +Output the distribution of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`. This function is not differentiable if `pushforward.post_processing` isn't. See also: [`apply_on_atoms`](@ref). """ function compute_probability_distribution(pushforward::Pushforward, θ; kwargs...) - (; layer, post_processing) = pushforward - probadist = compute_probability_distribution(layer, θ; kwargs...) + (; optimization_layer, post_processing) = pushforward + probadist = compute_probability_distribution(optimization_layer, θ; kwargs...) post_processed_probadist = apply_on_atoms(post_processing, probadist; kwargs...) return post_processed_probadist end """ - (pushforward::Pushforward)(θ) + (pushforward::Pushforward)(θ; kwargs...) -Output the expectation of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.layer` applied to `θ`. +Output the expectation of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`. Unlike [`compute_probability_distribution(pushforward, θ)`](@ref), this function is differentiable, even if `pushforward.post_processing` isn't. See also: [`compute_expectation`](@ref). """ function (pushforward::Pushforward)(θ::AbstractArray; kwargs...) - (; layer, post_processing) = pushforward - probadist = compute_probability_distribution(layer, θ; kwargs...) + (; optimization_layer, post_processing) = pushforward + probadist = compute_probability_distribution(optimization_layer, θ; kwargs...) return compute_expectation(probadist, post_processing; kwargs...) end diff --git a/src/regularized/regularized_utils.jl b/src/utils/some_functions.jl similarity index 100% rename from src/regularized/regularized_utils.jl rename to src/utils/some_functions.jl diff --git a/test/argmax.jl b/test/argmax.jl index 977fbec..61c8ec8 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -38,12 +38,12 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=StructuredSVMLoss(ZeroOneBaseLoss()), + loss=InferOpt.ZeroOneStructuredSVMLoss(), error_function=hamming_distance, ) end -@testitem "Argmax - imit - MSE sparse argmax" default_imports = false begin +@testitem "Argmax - imit - MSE SparseArgmax" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -52,13 +52,13 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=one_hot_argmax, - maximizer=sparse_argmax, + maximizer=SparseArgmax(), loss=mse, error_function=hamming_distance, ) end -@testitem "Argmax - imit - MSE soft argmax" default_imports = false begin +@testitem "Argmax - imit - MSE SoftArgmax" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -67,7 +67,7 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=one_hot_argmax, - maximizer=soft_argmax, + maximizer=SoftArgmax(), loss=mse, error_function=hamming_distance, ) @@ -103,7 +103,7 @@ end ) end -@testitem "Argmax - imit - MSE RegularizedGeneric" default_imports = false begin +@testitem "Argmax - imit - MSE RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -112,18 +112,18 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=one_hot_argmax, - maximizer=RegularizedGeneric( - one_hot_argmax, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + maximizer=RegularizedFrankWolfe( + one_hot_argmax; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), loss=mse, error_function=hamming_distance, ) end -@testitem "Argmax - imit - FYL sparse argmax" default_imports = false begin +@testitem "Argmax - imit - FYL SparseArgmax" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -133,12 +133,12 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=FenchelYoungLoss(sparse_argmax), + loss=FenchelYoungLoss(SparseArgmax()), error_function=hamming_distance, ) end -@testitem "Argmax - imit - FYL soft argmax" default_imports = false begin +@testitem "Argmax - imit - FYL SoftArgmax" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -148,7 +148,7 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=FenchelYoungLoss(soft_argmax), + loss=FenchelYoungLoss(SoftArgmax()), error_function=hamming_distance, ) end @@ -183,7 +183,7 @@ end ) end -@testitem "Argmax - imit - FYL RegularizedGeneric" default_imports = false begin +@testitem "Argmax - imit - FYL RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -194,11 +194,11 @@ end true_maximizer=one_hot_argmax, maximizer=identity, loss=FenchelYoungLoss( - RegularizedGeneric( - one_hot_argmax, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + RegularizedFrankWolfe( + one_hot_argmax; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), ), error_function=hamming_distance, @@ -245,7 +245,7 @@ end ) end -@testitem "Argmax - exp - Pushforward RegularizedGeneric" default_imports = false begin +@testitem "Argmax - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random @@ -259,11 +259,11 @@ end true_maximizer=one_hot_argmax, maximizer=identity, loss=Pushforward( - RegularizedGeneric( - one_hot_argmax, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + RegularizedFrankWolfe( + one_hot_argmax; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), cost, ), diff --git a/test/code.jl b/test/code.jl index e14b76b..33899a3 100644 --- a/test/code.jl +++ b/test/code.jl @@ -1,6 +1,8 @@ @testitem "Quality (Aqua.jl)" begin using Aqua + using StatsBase Aqua.test_all(InferOpt; ambiguities=false) + Aqua.test_ambiguities(InferOpt; exclude=[StatsBase.TestStat]) end @testitem "Correctness (JET.jl)" begin diff --git a/test/imitation_loss.jl b/test/imitation_loss.jl index ed60aa4..df70e82 100644 --- a/test/imitation_loss.jl +++ b/test/imitation_loss.jl @@ -5,19 +5,13 @@ true_encoder = encoder_factory() - ssvm_base_loss = ZeroOneBaseLoss() - Random.seed!(67) perf = test_pipeline!( PipelineLossImitationLoss; instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=ImitationLoss( - (θ, t_true) -> - InferOpt.compute_maximizer(ssvm_base_loss, θ, 1.0, get_y_true(t_true)); - base_loss=(y, t_true) -> ssvm_base_loss(y, t_true.y_true), - ), + loss=InferOpt.ZeroOneImitationLoss(), error_function=hamming_distance, true_encoder, ) @@ -28,7 +22,7 @@ instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=StructuredSVMLoss(ZeroOneBaseLoss()), + loss=InferOpt.ZeroOneStructuredSVMLoss(), error_function=hamming_distance, true_encoder, verbose=false, @@ -39,7 +33,7 @@ @test all(isapprox.(perf.test_losses, benchmark_perf.test_losses, rtol=0.001)) end -@testitem "ImitationLoss vs FYL sparsemax" default_imports = false begin +@testitem "ImitationLoss vs FYL SparseMax" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random, Test Random.seed!(63) @@ -52,7 +46,11 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=ImitationLoss((θ, t_true) -> sparse_argmax(θ); Ω=half_square_norm), + loss=ImitationLoss(; + δ=(y, t_true) -> 0, + Ω=y -> half_square_norm(y), + aux_loss_maximizer=(θ, t_true, α) -> sparse_argmax(θ), + ), error_function=hamming_distance, true_encoder, ) @@ -63,7 +61,7 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=FenchelYoungLoss(sparse_argmax), + loss=FenchelYoungLoss(SparseArgmax()), error_function=hamming_distance, true_encoder, verbose=false, @@ -74,7 +72,7 @@ end @test all(isapprox.(perf.test_losses, benchmark_perf.test_losses, rtol=0.001)) end -@testitem "ImitationLoss vs FYL softmax" default_imports = false begin +@testitem "ImitationLoss vs FYL SoftMax" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random, Test Random.seed!(63) @@ -87,7 +85,11 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=ImitationLoss((θ, t_true) -> soft_argmax(θ); Ω=negative_shannon_entropy), + loss=ImitationLoss(; + δ=(y, t_true) -> 0, + Ω=y -> negative_shannon_entropy(y), + aux_loss_maximizer=(θ, t_true, α) -> soft_argmax(θ), + ), error_function=hamming_distance, true_encoder, ) @@ -98,7 +100,7 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=FenchelYoungLoss(soft_argmax), + loss=FenchelYoungLoss(SoftArgmax()), error_function=hamming_distance, true_encoder, verbose=false, @@ -116,9 +118,9 @@ end true_encoder = encoder_factory() - function spo_predictor(θ, t_true; kwargs...) + function spo_predictor(θ, t_true, α; kwargs...) (; θ_true) = t_true - θ_α = θ - θ_true + θ_α = α * θ - θ_true y_α = one_hot_argmax(θ_α; kwargs...) return y_α end @@ -134,7 +136,9 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=ImitationLoss(spo_predictor; base_loss=spo_base_loss), + loss=ImitationLoss(; + δ=spo_base_loss, Ω=y -> 0, α=1, aux_loss_maximizer=spo_predictor + ), error_function=hamming_distance, true_encoder, ) @@ -163,9 +167,9 @@ end true_encoder = encoder_factory() - function spo_predictor(θ, t_true; kwargs...) + function spo_predictor(θ, t_true, α; kwargs...) (; θ_true) = t_true - θ_α = 2 .* θ - θ_true + θ_α = 2 * θ - θ_true y_α = one_hot_argmax(θ_α; kwargs...) return y_α end @@ -181,7 +185,9 @@ end instance_dim=5, true_maximizer=one_hot_argmax, maximizer=identity, - loss=ImitationLoss(spo_predictor; base_loss=spo_base_loss, α=2.0), + loss=ImitationLoss(; + δ=spo_base_loss, Ω=y -> 0, α=2, aux_loss_maximizer=spo_predictor + ), error_function=hamming_distance, true_encoder, ) diff --git a/test/paths.jl b/test/paths.jl index 936d09e..865355d 100644 --- a/test/paths.jl +++ b/test/paths.jl @@ -28,7 +28,7 @@ end ) end -@testitem "Paths - imit - MSE PlusIdentity" default_imports = false begin +@testitem "Paths - imit - MSE IdentityRelaxation" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random Random.seed!(63) @@ -37,7 +37,7 @@ end PipelineLossImitation; instance_dim=(5, 5), true_maximizer=shortest_path_maximizer, - maximizer=normalize ∘ PlusIdentity(shortest_path_maximizer), + maximizer=normalize ∘ IdentityRelaxation(shortest_path_maximizer), loss=mse, error_function=mse, ) @@ -88,7 +88,7 @@ end ) end -@testitem "Paths - imit - MSE RegularizedGeneric" default_imports = false begin +@testitem "Paths - imit - MSE RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -97,11 +97,11 @@ end PipelineLossImitation; instance_dim=(5, 5), true_maximizer=shortest_path_maximizer, - maximizer=RegularizedGeneric( - shortest_path_maximizer, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + maximizer=RegularizedFrankWolfe( + shortest_path_maximizer; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), loss=mse, error_function=mse, @@ -143,7 +143,7 @@ end ) end -@testitem "Paths - imit - FYL RegularizedGeneric" default_imports = false begin +@testitem "Paths - imit - FYL RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -154,11 +154,11 @@ end true_maximizer=shortest_path_maximizer, maximizer=identity, loss=FenchelYoungLoss( - RegularizedGeneric( - shortest_path_maximizer, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + RegularizedFrankWolfe( + shortest_path_maximizer; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), ), error_function=mse, @@ -210,7 +210,7 @@ end ) end -@testitem "Paths - exp - Pushforward RegularizedGeneric" default_imports = false begin +@testitem "Paths - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random @@ -224,11 +224,11 @@ end true_maximizer=shortest_path_maximizer, maximizer=identity, loss=Pushforward( - RegularizedGeneric( - shortest_path_maximizer, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + RegularizedFrankWolfe( + shortest_path_maximizer; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), cost, ), diff --git a/test/ranking.jl b/test/ranking.jl index b3f32c0..02766fb 100644 --- a/test/ranking.jl +++ b/test/ranking.jl @@ -28,7 +28,7 @@ end ) end -@testitem "Ranking - imit - MSE PlusIdentity" default_imports = false begin +@testitem "Ranking - imit - MSE IdentityRelaxation" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random Random.seed!(63) @@ -37,7 +37,7 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=ranking, - maximizer=normalize ∘ PlusIdentity(ranking), + maximizer=normalize ∘ IdentityRelaxation(ranking), loss=mse, error_function=hamming_distance, ) @@ -88,7 +88,7 @@ end ) end -@testitem "Ranking - imit - MSE RegularizedGeneric" default_imports = false begin +@testitem "Ranking - imit - MSE RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -97,11 +97,11 @@ end PipelineLossImitation; instance_dim=5, true_maximizer=ranking, - maximizer=RegularizedGeneric( - ranking, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + maximizer=RegularizedFrankWolfe( + ranking; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), loss=mse, error_function=hamming_distance, @@ -139,7 +139,7 @@ end ) end -@testitem "Ranking - imit - FYL RegularizedGeneric" default_imports = false begin +@testitem "Ranking - imit - FYL RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -150,11 +150,11 @@ end true_maximizer=ranking, maximizer=identity, loss=FenchelYoungLoss( - RegularizedGeneric( - ranking, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + RegularizedFrankWolfe( + ranking; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), ), error_function=hamming_distance, @@ -202,7 +202,7 @@ end ) end -@testitem "Ranking - exp - Pushforward RegularizedGeneric" default_imports = false begin +@testitem "Ranking - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random @@ -216,11 +216,11 @@ end true_maximizer=ranking, maximizer=identity, loss=Pushforward( - RegularizedGeneric( - ranking, - half_square_norm, - identity, - (; max_iteration=10, line_search=FrankWolfe.Agnostic()), + RegularizedFrankWolfe( + ranking; + Ω=half_square_norm, + Ω_grad=identity, + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), cost, ), diff --git a/test/tutorial.jl b/test/tutorial.jl index 8014548..1636a17 100644 --- a/test/tutorial.jl +++ b/test/tutorial.jl @@ -1,175 +1,3 @@ -# # Basic tutorial - -# ## Context - -#= -Let us imagine that we observe the itineraries chosen by a public transport user in several different networks, and that we want to understand their decision-making process (a.k.a. recover their utility function). - -More precisely, each point in our dataset consists in: -- a graph ``G`` -- a shortest path ``P`` from the top left to the bottom right corner - -We don't know the true costs that were used to compute the shortest path, but we can exploit a set of features to approximate these costs. -The question is: how should we combine these features? - -We will use `InferOpt` to learn the appropriate weights, so that we may propose relevant paths to the user in the future. -=# - -using Flux -using Graphs -using GridGraphs -using InferOpt -using LinearAlgebra -using ProgressMeter -using Random -using Statistics -using Test -using UnicodePlots - -Random.seed!(63); - -# ## Grid graphs - -#= -For the purposes of this tutorial, we consider grid graphs, as implemented in [GridGraphs.jl](https://github.com/gdalle/GridGraphs.jl). -In such graphs, each vertex corresponds to a couple of coordinates ``(i, j)``, where ``1 \leq i \leq h`` and ``1 \leq j \leq w``. - -To ensure acyclicity, we only allow the user to move right, down or both. -Since the cost of a move is defined as the cost of the arrival vertex, any grid graph is entirely characterized by its cost matrix ``\theta \in \mathbb{R}^{h \times w}``. -=# - -h, w = 50, 100 -g = GridGraph(rand(h, w); directions=GridGraphs.QUEEN_ACYCLIC_DIRECTIONS); - -#= -For convenience, `GridGraphs.jl` also provides custom functions to compute shortest paths efficiently. -Let us see what those paths look like. -=# - -p = path_to_matrix(g, grid_topological_sort(g, 1, nv(g))); -spy(p) - -# ## Dataset - -#= -As announced, we do not know the cost of each vertex, only a set of relevant features. -Let us assume that the user combines them using a shallow neural network. -=# - -nb_features = 5 - -true_encoder = Chain(Dense(nb_features, 1), z -> dropdims(z; dims=1)); - -#= -The true vertex costs computed from this encoding are then used within shortest path computations. -To be consistent with the literature, we frame this problem as a linear maximization problem, which justifies the change of sign in front of ``\theta``. -=# - -function linear_maximizer(θ) - g = GridGraph(-θ; directions=GridGraphs.QUEEN_ACYCLIC_DIRECTIONS) - path = grid_topological_sort(g, 1, nv(g)) - return path_to_matrix(g, path) -end; - -#= -We now have everything we need to build our dataset. -=# - -nb_instances = 30 - -X_train = [randn(Float32, nb_features, h, w) for n in 1:nb_instances]; -θ_train = [true_encoder(x) for x in X_train]; -Y_train = [linear_maximizer(θ) for θ in θ_train]; - -# ## Learning - -#= -We create a trainable model with the same structure as the true encoder but another set of randomly-initialized weights. -=# - -initial_encoder = Chain(Dense(nb_features, 1), z -> dropdims(z; dims=1)); - -#= -Here is the crucial part where `InferOpt` intervenes: the choice of a clever loss function that enables us to -- differentiate through the shortest path maximizer, even though it is a combinatorial operation -- evaluate the quality of our model based on the paths that it recommends -=# - -regularized_predictor = PerturbedAdditive(linear_maximizer; ε=1.0, nb_samples=5); -loss = FenchelYoungLoss(regularized_predictor); - -#= -The regularized predictor is just a thin wrapper around our `linear_maximizer`, but with a very different behavior: -=# - -p_regularized = regularized_predictor(θ_train[1]); -spy(p_regularized) - -#= -Instead of choosing just one path, it spreads over several possible paths, allowing its output to change smoothly as ``\theta`` varies. -Thanks to this smoothing, we can now train our model with a standard gradient optimizer. -=# - -encoder = deepcopy(initial_encoder) -opt = Flux.Adam(); -losses = Float64[] -for epoch in 1:200 - l = 0.0 - for (x, y) in zip(X_train, Y_train) - grads = gradient(Flux.params(encoder)) do - l += loss(encoder(x), y) - end - Flux.update!(opt, Flux.params(encoder), grads) - end - push!(losses, l) -end; - -# ## Results - -#= -Since the Fenchel-Young loss is convex, it is no wonder that optimization worked like a charm. -=# - -lineplot(losses; xlabel="Epoch", ylabel="Loss") - -#= -To assess performance, we can compare the learned weights with their true (hidden) values -=# - -learned_weight = encoder[1].weight / norm(encoder[1].weight) -true_weight = true_encoder[1].weight / norm(true_encoder[1].weight) -hcat(learned_weight, true_weight) - -#= -We are quite close to recovering the exact user weights. -But in reality, it doesn't matter as much as our ability to provide accurate path predictions. -Let us therefore compare our predictions with the actual paths on the training set. -=# - -normalized_hamming(x, y) = mean(x[i] != y[i] for i in eachindex(x)) - -#- - -Y_train_pred = [linear_maximizer(encoder(x)) for x in X_train]; - -train_error = mean( - normalized_hamming(y, y_pred) for (y, y_pred) in zip(Y_train, Y_train_pred) -) - -# Not too bad, at least compared with our random initial encoder. - -Y_train_pred_initial = [linear_maximizer(initial_encoder(x)) for x in X_train]; - -train_error_initial = mean( - normalized_hamming(y, y_pred) for (y, y_pred) in zip(Y_train, Y_train_pred_initial) -) - -#= -This is definitely a success. -Of course in real prediction settings we should measure performance on a test set as well. -This is left as an exercise to the reader. -=# - -# CI tests, not included in the documentation #src - -@test train_error < train_error_initial / 3 #src +@testitem "Tutorial" begin + include(joinpath(dirname(@__DIR__), "examples", "tutorial.jl")) +end