diff --git a/Project.toml b/Project.toml index a444893..35b0b7c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,18 +5,12 @@ version = "0.1.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [compat] ChainRulesCore = "1" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 159f12f..763da4a 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -16,9 +16,9 @@ version = "1.1.0" [[deps.Accessors]] deps = ["Compat", "CompositionsBase", "ConstructionBase", "Future", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "2bba2aa45df94e95b1a9c2405d7cfc3d60281db8" +git-tree-sha1 = "0264a938934447408c7f0be8985afec2a2237af4" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.9" +version = "0.1.11" [[deps.Adapt]] deps = ["LinearAlgebra"] @@ -42,9 +42,9 @@ version = "0.2.0" [[deps.ArrayInterface]] deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] -git-tree-sha1 = "c933ce606f6535a7c7b98e1d86d5d1014f730596" +git-tree-sha1 = "81f0cb60dc994ca17f68d9fb7c942a5ae70d9ee4" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "5.0.7" +version = "5.0.8" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -76,27 +76,21 @@ uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" version = "1.0.8+0" [[deps.CEnum]] -git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.1" +version = "0.4.2" [[deps.CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] -git-tree-sha1 = "ba75320aaa092b3e17c020a2d8b9e0a572dbfa6a" +git-tree-sha1 = "19fb33957a5f85efb3cc10e70cf4dd4e30174ac9" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.9.0" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" +version = "3.10.0" [[deps.ChainRules]] deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] -git-tree-sha1 = "8b887daa6af5daf705081061e36386190204ac87" +git-tree-sha1 = "ab656fb36197083c5817667e76cccd10d11f5c30" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.1" +version = "1.32.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] @@ -106,15 +100,15 @@ version = "1.14.0" [[deps.ChangesOfVariables]] deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1" +git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.2" +version = "0.1.3" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" +git-tree-sha1 = "a985dc37e357a3b22b260a5def99f3530fb415d3" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.0" +version = "0.11.2" [[deps.ColorVectorSpace]] deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] @@ -173,15 +167,15 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.1.1" [[deps.DataAPI]] -git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8" +git-tree-sha1 = "fb5f5316dd3fd4c5e7c30a24d50643b73e37cd40" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.9.0" +version = "1.10.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3daef5523dd2e769dad2365274f760ff5f282c7d" +git-tree-sha1 = "cc1a8e22627f33c789ab60b36a9132ac050bbf75" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.11" +version = "0.18.12" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -201,12 +195,6 @@ version = "0.1.2" deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -[[deps.DensityInterface]] -deps = ["InverseFunctions", "Test"] -git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" -uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" -version = "0.4.0" - [[deps.DiffResults]] deps = ["StaticArrays"] git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" @@ -215,20 +203,14 @@ version = "1.0.3" [[deps.DiffRules]] deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0" +git-tree-sha1 = "28d605d9a0ac17118fe2c5e9ce0fbb76c3ceb120" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.10.0" +version = "1.11.0" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -[[deps.Distributions]] -deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "f206814c860c2a909d2a467af0484d08edd05ee7" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.57" - [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" @@ -245,12 +227,6 @@ version = "0.27.15" deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - [[deps.EarCut_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "3f3a2501fa7236e9b911e0f7a588c657e822bb6d" @@ -262,11 +238,6 @@ git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.8" -[[deps.ExternalDocstrings]] -git-tree-sha1 = "1224740fc4d07c989949e1c1b508ebd49a65a5f6" -uuid = "e189563c-0753-4f5e-ad5c-be4293c83fb4" -version = "0.1.1" - [[deps.FLoops]] deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] git-tree-sha1 = "4391d3ed58db9dc5a9883b23a0578316b4798b1f" @@ -281,9 +252,9 @@ version = "0.1.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "80ced645013a5dbdc52cf70329399c35ce007fae" +git-tree-sha1 = "9267e5f50b0e12fdfd5a2455534345c4cf2c7f7a" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.13.0" +version = "1.14.0" [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] @@ -303,12 +274,6 @@ git-tree-sha1 = "e932b26ac243f312af2d9009de08b89be0e01a84" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" version = "0.13.0" -[[deps.Folds]] -deps = ["Accessors", "BangBang", "Baselet", "DefineSingletons", "Distributed", "ExternalDocstrings", "InitialValues", "MicroCollections", "Referenceables", "Requires", "Test", "ThreadedScans", "Transducers"] -git-tree-sha1 = "638109532de382a1f99b1aae1ca8b5d08515d85a" -uuid = "41a02a25-b8f0-4f67-bc48-60067656b558" -version = "0.2.8" - [[deps.FoldsThreads]] deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" @@ -317,9 +282,9 @@ version = "0.1.1" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "1bd6fc0c344fc0cbee1f42f8d2e7ec8253dda2d2" +git-tree-sha1 = "89cc49bf5819f0a10a7a3c38885e7c7ee048de57" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.25" +version = "0.10.29" [[deps.FreeType]] deps = ["CEnum", "FreeType2_jll"] @@ -361,9 +326,9 @@ version = "8.3.2" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "556190e1e0ea3e37d83059fc9aa576f1e2104375" +git-tree-sha1 = "05374e47bb136db517b33f62fbe852adf8deb0be" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.14.1" +version = "0.15.1" [[deps.GeometryBasics]] deps = ["EarCut_jll", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] @@ -377,11 +342,11 @@ git-tree-sha1 = "57c021de207e234108a6f1454003120a1bf350c4" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" version = "1.6.0" -[[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "SpecialFunctions", "Test"] -git-tree-sha1 = "65e4589030ef3c44d3b90bdc5aac462b4bb05567" -uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.8" +[[deps.GridGraphs]] +deps = ["DataStructures", "Graphs", "SparseArrays"] +git-tree-sha1 = "b6d33f54428fee0174d0bfae256fdde1d5333594" +uuid = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb" +version = "0.1.2" [[deps.IOCapture]] deps = ["Logging", "Random"] @@ -391,9 +356,9 @@ version = "0.2.2" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "7f43342f8d5fd30ead0ba1b49ab1a3af3b787d24" +git-tree-sha1 = "af14a478780ca78d5eb9908b263023096c2b9d64" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.5" +version = "0.4.6" [[deps.IfElse]] git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" @@ -401,7 +366,7 @@ uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" version = "0.1.1" [[deps.InferOpt]] -deps = ["ChainRulesCore", "DataStructures", "Distributions", "Folds", "Graphs", "LinearAlgebra", "ProgressMeter", "Random", "SimpleTraits", "SparseArrays", "Statistics", "Test", "UnicodePlots"] +deps = ["ChainRulesCore", "LinearAlgebra", "Random", "SimpleTraits", "SparseArrays", "Statistics", "Test"] path = ".." uuid = "4846b161-c94e-4150-8dac-c7ae193c601f" version = "0.1.0" @@ -422,9 +387,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.InverseFunctions]] deps = ["Test"] -git-tree-sha1 = "91b5dcf362c5add98049e6c29ee756910b03051d" +git-tree-sha1 = "336cc738f03e069ef2cac55a104eb823455dca75" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.3" +version = "0.1.4" [[deps.IrrationalConstants]] git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" @@ -461,15 +426,15 @@ version = "0.2.4" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "c9b86064be5ae0f63e50816a5a90b08c474507ae" +git-tree-sha1 = "c8d47589611803a0f3b4813d9e267cd4e3dbcefb" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "4.9.1" +version = "4.11.1" [[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "5558ad3c8972d602451efe9d81c78ec14ef4f5ef" +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.14+2" +version = "0.0.16+0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -506,23 +471,23 @@ version = "2.13.1" [[deps.LogExpFunctions]] deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571" +git-tree-sha1 = "09e4b894ce6a976c354a69041a04748180d43637" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.12" +version = "0.3.15" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.MLStyle]] -git-tree-sha1 = "594e189325f66e23a8818e5beb11c43bb0141bcd" +git-tree-sha1 = "e49789e5eb7b2d5577aaea395bfcac769df64bb8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.10" +version = "0.4.11" [[deps.MLUtils]] deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"] -git-tree-sha1 = "32eeb46fa393ae36a4127c9442ade478c8d01117" +git-tree-sha1 = "95ab49a8c9afb6a8a0fc81df25617a6798c0fb73" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.2.3" +version = "0.2.5" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -564,9 +529,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" [[deps.NNlib]] deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "a59a614b8b4ea6dc1dcec8c6514e251f13ccbe10" +git-tree-sha1 = "f89de462a7bc3243f95834e75751d70b3a33e59d" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.4" +version = "0.8.5" [[deps.NNlibCUDA]] deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] @@ -575,9 +540,9 @@ uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" version = "0.2.2" [[deps.NaNMath]] -git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe" +git-tree-sha1 = "737a5957f387b17e74d4ad2f440eb330b39a62c5" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.7" +version = "1.0.0" [[deps.NameResolution]] deps = ["PrettyPrint"] @@ -604,26 +569,20 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "e440ecef249dea69e79248857e800e71820d386c" +git-tree-sha1 = "2442c3ddbda547c80e8b6451a103719d6a3593dd" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.1" +version = "0.2.4" [[deps.OrderedCollections]] git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.4.1" -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "3114946c67ef9925204cc024a73c9e679cebe0d7" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.8" - [[deps.Parsers]] deps = ["Dates"] -git-tree-sha1 = "621f4f3b4977325b9128d5fae7a8b4829a0c2222" +git-tree-sha1 = "1285416549ccfcdf0c50d4997a94331e88d68413" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.2.4" +version = "2.3.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -631,9 +590,9 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "d3538e7f8a790dc8903519090857ef8e1283eecd" +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.2.5" +version = "1.3.0" [[deps.PrettyPrint]] git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" @@ -656,12 +615,6 @@ git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.7.2" -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "78aadffb3efd2155af139781b8a8df1ef279ea39" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.4.2" - [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -693,30 +646,12 @@ git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" version = "1.2.2" -[[deps.Referenceables]] -deps = ["Adapt"] -git-tree-sha1 = "e681d3bfa49cd46c3c161505caddf20f0e62aaa9" -uuid = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" -version = "0.1.2" - [[deps.Requires]] deps = ["UUIDs"] git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.3.0" -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.0" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.3.0+0" - [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -771,15 +706,15 @@ version = "0.1.14" [[deps.Static]] deps = ["IfElse"] -git-tree-sha1 = "87e9954dfa33fd145694e42337bdd3d5b07021a6" +git-tree-sha1 = "5309da1cdef03e95b73cd3251ac3a39f887da53e" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.6.0" +version = "0.6.4" [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "4f6ec5d99a28e1a749559ef7dd518663c5eca3d5" +git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.4.3" +version = "1.4.4" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -787,9 +722,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[deps.StatsAPI]] deps = ["LinearAlgebra"] -git-tree-sha1 = "8d7530a38dbd2c397be7ddd01a424e4f411dcc41" +git-tree-sha1 = "c82aaa13b44ea00134f8c9c89819477bd3986ecd" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.2.2" +version = "1.3.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] @@ -797,21 +732,11 @@ git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" version = "0.33.16" -[[deps.StatsFuns]] -deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "ca9f8a0c9f2e41431dc5b7697058a3f8f8b89498" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.0.0" - [[deps.StructArrays]] deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"] -git-tree-sha1 = "57617b34fa34f91d536eb265df67c2d4519b8b98" +git-tree-sha1 = "e75d82493681dfd884a357952bbd7ab0608e1dc3" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.5" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +version = "0.6.7" [[deps.TOML]] deps = ["Dates"] @@ -843,17 +768,11 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[deps.ThreadedScans]] -deps = ["ArgCheck"] -git-tree-sha1 = "ca1ba3000289eacba571aaa4efcefb642e7a1de6" -uuid = "24d252fe-5d94-4a69-83ea-56a14333d47a" -version = "0.1.0" - [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "d60b0c96a16aaa42138d5d38ad386df672cb8bd8" +git-tree-sha1 = "7638550aaea1c9a1e86817a231ef0faa9aca79bd" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.16" +version = "0.5.19" [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] diff --git a/docs/Project.toml b/docs/Project.toml index 1519a9e..2732861 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,7 +1,8 @@ [deps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +GridGraphs = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb" InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" diff --git a/docs/src/tutorial.md b/docs/src/tutorial.md index b32e0e3..26cc5e9 100644 --- a/docs/src/tutorial.md +++ b/docs/src/tutorial.md @@ -19,8 +19,9 @@ We will use `InferOpt` to learn the appropriate weights, so that we may propose ````@example tutorial using Flux +using Graphs +using GridGraphs using InferOpt -using InferOpt.GridGraphs using InferOpt.Testing using LinearAlgebra using ProgressMeter @@ -35,7 +36,7 @@ nothing #hide ## Grid graphs -For the purposes of this tutorial, we consider grid graphs, as implemented in `InferOpt.GridGraphs`. +For the purposes of this tutorial, we consider grid graphs, as implemented in [GridGraphs.jl](https://github.com/gdalle/GridGraphs.jl). In such graphs, each vertex corresponds to a couple of coordinates ``(i, j)``, where ``1 \leq i \leq h`` and ``1 \leq j \leq w``. To ensure acyclicity, we only allow the user to move right, down or both. @@ -51,7 +52,7 @@ For convenience, `InferOpt.GridGraphs` also provides custom functions to compute Let us see what those look like. ````@example tutorial -p = grid_shortest_path(g, 1, nv(g)); +p = path_to_matrix(g, grid_topological_sort(g, 1, nv(g))); spy(p) ```` @@ -73,7 +74,8 @@ To be consistent with the literature, we frame this problem as a linear maximiza ````@example tutorial function linear_maximizer(θ) g = AcyclicGridGraph(-θ) - return grid_shortest_path(g, 1, nv(g)) + path = grid_topological_sort(g, 1, nv(g)) + return path_to_matrix(g, path) end; nothing #hide ```` diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 3c68716..68a8852 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -1,18 +1,12 @@ module InferOpt using ChainRulesCore -using DataStructures -using Distributions -using Folds -using Graphs using LinearAlgebra -using ProgressMeter using Random using SimpleTraits using SparseArrays using Statistics using Test -using UnicodePlots include("interpolation/interpolation.jl") @@ -22,8 +16,8 @@ include("regularized/simplex.jl") include("regularized/prediction.jl") include("regularized/ranking.jl") -include("perturbed/perturbed.jl") -include("perturbed/perturbed_generic.jl") +include("perturbed/perturbed_abstract.jl") +include("perturbed/perturbed_normal.jl") include("fenchel_young/fenchel_young.jl") @@ -32,7 +26,8 @@ include("smart_predict_optimize/smart_predict_optimize.jl") include("structured_svm/structured_loss.jl") include("structured_svm/structured_svm.jl") -include("utils/grid_graphs/GridGraphs.jl") +include("experimental/perturbed_lognormal.jl") + include("utils/testing/Testing.jl") export shannon_entropy, half_square_norm @@ -43,7 +38,6 @@ export IsRegularizedPrediction export Interpolation export Perturbed, PerturbedCost -export PerturbedGeneric export FenchelYoungLoss diff --git a/src/experimental/perturbed_lognormal.jl b/src/experimental/perturbed_lognormal.jl new file mode 100644 index 0000000..915aeaf --- /dev/null +++ b/src/experimental/perturbed_lognormal.jl @@ -0,0 +1,39 @@ +lognormal_gradlogpdf(x::Real, μ::Real, σ::Real) = (log(x) - μ - one(x)) / (x * σ^2) + +""" + PerturbedLogNormal{F} + +Differentiable log-normal perturbation of a black-box optimizer. + +# Fields +- `maximizer::F`: underlying argmax function +- `ε::Float64`: noise scaling parameter +- `M::Int`: number of noise samples for Monte-Carlo computations +""" +struct PerturbedLogNormal{F} <: AbstractPerturbed + maximizer::F + ε::Float64 + M::Int +end + +PerturbedLogNormal(maximizer; ε=1.0, M=2) = PerturbedLogNormal(maximizer, float(ε), M) + +function (perturbed::PerturbedLogNormal)(θ::AbstractArray; kwargs...) + (; maximizer, ε, M) = perturbed + d = size(θ) + θ_samples = [exp.(log.(θ) + ε * randn(d) - ε^2 / 2) for _ in 1:M] + y_samples = [maximizer(θ_sample; kwargs...) for θ_sample in θ_samples] + y_mean = mean(y_samples) + return y_mean +end + +function compute_y_and_Fθ(perturbed::PerturbedLogNormal, θ::AbstractArray; kwargs...) + (; maximizer, ε, M) = perturbed + d = size(θ) + θ_samples = [exp.(log.(θ) + ε * randn(d) - ε^2 / 2) for _ in 1:M] + y_samples = [maximizer(θ_sample; kwargs...) for θ_sample in θ_samples] + F_θ_samples = [dot(θ_sample, y) for (θ_sample, y) in zip(θ_samples, y_samples)] + y_mean = mean(y_samples) + Fθ_mean = mean(F_θ_samples) # useful for computing Fenchel-Young loss + return y_mean, Fθ_mean +end diff --git a/src/perturbed/perturbed_abstract.jl b/src/perturbed/perturbed_abstract.jl new file mode 100644 index 0000000..f6d4d4b --- /dev/null +++ b/src/perturbed/perturbed_abstract.jl @@ -0,0 +1 @@ +abstract type AbstractPerturbed end diff --git a/src/perturbed/perturbed_generic.jl b/src/perturbed/perturbed_generic.jl deleted file mode 100644 index 9199661..0000000 --- a/src/perturbed/perturbed_generic.jl +++ /dev/null @@ -1,54 +0,0 @@ -""" - PerturbedGeneric{F,D} - -Differentiable perturbation of a black-box optimizer. - -# Fields -- `maximizer::F`: underlying argmax function -- `noise_dist::D`: function taking `θ` and returning a local noise distribution -- `M::Int`: number of noise samples for Monte-Carlo computations -""" -struct PerturbedGeneric{F,D} <: AbstractPerturbed - maximizer::F - noise_dist::D - M::Int -end - -function PerturbedGeneric(maximizer::F; noise_dist::D, M) where {F,D} - return PerturbedGeneric{F,D}(maximizer, noise_dist, M) -end - -function (perturbed::PerturbedGeneric)(θ::AbstractArray; kwargs...) - (; maximizer, noise_dist, M) = perturbed - local_noise_dist = noise_dist(θ) - y_samples = Folds.map(m -> maximizer(rand(local_noise_dist); kwargs...), 1:M) - y_mean = mean(y_samples) - return y_mean -end - -function compute_y_and_Fθ(perturbed::PerturbedGeneric, θ::AbstractArray; kwargs...) - (; maximizer, noise_dist, M) = perturbed - local_noise_dist = noise_dist(θ) - perturbed_θs = [rand(local_noise_dist) for _ in 1:M] - y_samples = Folds.map(θ_perturbed -> maximizer(θ_perturbed; kwargs...), perturbed_θs) - F_θ_sample = [dot(θ_perturbed, y) for (θ_perturbed, y) in zip(perturbed_θs, y_samples)] - y_mean = mean(y_samples) - Fθ_mean = mean(F_θ_sample) # useful for computing Fenchel-Young loss - return y_mean, Fθ_mean -end - -function ChainRulesCore.rrule(perturbed::PerturbedGeneric, θ::AbstractArray; kwargs...) - (; maximizer, noise_dist, M) = perturbed - local_noise_dist = noise_dist(θ) - ∇ν(z) = -gradlogpdf(local_noise_dist, z + θ) - εZ_samples = [rand(local_noise_dist) - θ for m in 1:M] - y_samples = [maximizer(θ + εZ; kwargs...) for εZ in εZ_samples] - y_mean = mean(y_samples) - - function perturbed_generic_pullback(dy) - vjp = mean(dot(dy, y_samples[m]) * ∇ν(εZ_samples[m]) for m in 1:M) - return NoTangent(), vjp - end - - return y_mean, perturbed_generic_pullback -end diff --git a/src/perturbed/perturbed.jl b/src/perturbed/perturbed_normal.jl similarity index 83% rename from src/perturbed/perturbed.jl rename to src/perturbed/perturbed_normal.jl index 0011c2b..bb5194f 100644 --- a/src/perturbed/perturbed.jl +++ b/src/perturbed/perturbed_normal.jl @@ -1,9 +1,7 @@ -abstract type AbstractPerturbed end - """ Perturbed{F} -Differentiable perturbation of a black-box optimizer. +Differentiable normal perturbation of a black-box optimizer. # Fields - `maximizer::F`: underlying argmax function @@ -21,7 +19,7 @@ Perturbed(maximizer; ε=1.0, M=2) = Perturbed(maximizer, float(ε), M) function (perturbed::Perturbed)(θ::AbstractArray; kwargs...) (; maximizer, ε, M) = perturbed d = size(θ) - y_samples = Folds.map(m -> maximizer(θ + ε * randn(d); kwargs...), 1:M) + y_samples = [maximizer(θ + ε * randn(d); kwargs...) for _ in 1:M] y_mean = mean(y_samples) return y_mean end @@ -29,11 +27,11 @@ end function compute_y_and_Fθ(perturbed::Perturbed, θ::AbstractArray; kwargs...) (; maximizer, ε, M) = perturbed d = size(θ) - perturbed_θs = [θ + ε * randn(d) for _ in 1:M] - y_samples = Folds.map(θ_perturbed -> maximizer(θ_perturbed; kwargs...), perturbed_θs) - F_θ_sample = [dot(θ_perturbed, y) for (θ_perturbed, y) in zip(perturbed_θs, y_samples)] + θ_samples = [θ + ε * randn(d) for _ in 1:M] + y_samples = [maximizer(θ_sample; kwargs...) for θ_sample in θ_samples] + F_θ_samples = [dot(θ_sample, y) for (θ_sample, y) in zip(θ_samples, y_samples)] y_mean = mean(y_samples) - Fθ_mean = mean(F_θ_sample) # useful for computing Fenchel-Young loss + Fθ_mean = mean(F_θ_samples) # useful for computing Fenchel-Young loss return y_mean, Fθ_mean end @@ -88,7 +86,7 @@ end function ChainRulesCore.rrule(perturbed_cost::PerturbedCost, θ::AbstractArray; kwargs...) (; maximizer, cost, ε, M) = perturbed_cost d = size(θ) - Z_samples = [randn(d) for m in 1:M] + Z_samples = [randn(d) for _ in 1:M] y_samples = [maximizer(θ + ε * Z; kwargs...) for Z in Z_samples] costs = [cost(y; kwargs...) for y in y_samples] diff --git a/src/utils/grid_graphs/GridGraphs.jl b/src/utils/grid_graphs/GridGraphs.jl deleted file mode 100644 index 2df39de..0000000 --- a/src/utils/grid_graphs/GridGraphs.jl +++ /dev/null @@ -1,15 +0,0 @@ -module GridGraphs - -using ..DataStructures -using ..Graphs - -include("abstract.jl") -include("acyclic.jl") -include("symmetric.jl") -include("shortest_paths.jl") - -export AcyclicGridGraph, SymmetricGridGraph -export grid_shortest_path, grid_shortest_path_cost -export nv, ne - -end diff --git a/src/utils/grid_graphs/abstract.jl b/src/utils/grid_graphs/abstract.jl deleted file mode 100644 index e9d20b4..0000000 --- a/src/utils/grid_graphs/abstract.jl +++ /dev/null @@ -1,63 +0,0 @@ -abstract type AbstractGridGraph{R<:AbstractFloat} <: AbstractGraph{Int} end - -## Basic accessors - -Base.size(g::AbstractGridGraph, args...) = size(g.cell_costs, args...) - -Base.eltype(::AbstractGridGraph) = Int -Graphs.edgetype(::AbstractGridGraph) = Edge{Int} - -Graphs.is_directed(::AbstractGridGraph) = true -Graphs.is_directed(::Type{<:AbstractGridGraph}) = true - -height(g::AbstractGridGraph) = size(g, 1) -width(g::AbstractGridGraph) = size(g, 2) - -Graphs.nv(g::AbstractGridGraph) = prod(size(g)) -Graphs.vertices(g::AbstractGridGraph) = 1:nv(g) -Graphs.has_vertex(g::AbstractGridGraph, v::Integer) = 1 <= v <= nv(g) - -## Indexing translators - -function node_index(g::AbstractGridGraph, i::Integer, j::Integer) - h, w = size(g) - if (1 <= i <= h) && (1 <= j <= w) - v = (j - 1) * h + (i - 1) + 1 # enumerate column by column - return v - else - return 0 - end -end - -function node_coord(g::AbstractGridGraph, v::Integer) - if has_vertex(g, v) - h, w = size(g) - j = (v - 1) ÷ h + 1 - i = (v - 1) % h + 1 - return i, j - else - return (0, 0) - end -end - -## Edges - -function Graphs.has_edge(g::AbstractGridGraph, s::Integer, d::Integer) - if has_vertex(g, s) && has_vertex(g, d) - is, js = node_coord(g, s) - id, jd = node_coord(g, d) - return (s != d) && (abs(is - id) <= 1) && (abs(js - jd) <= 1) # 8 neighbors max - else - return false - end -end - -function Graphs.edges(g::AbstractGridGraph) - return (Edge(s, d) for s in vertices(g) for d in outneighbors(g, s)) -end - -## Costs - -get_cost(g::AbstractGridGraph, v::Integer) = g.cell_costs[v] -get_cost(g::AbstractGridGraph, i::Integer, j::Integer) = g.cell_costs[i, j] -has_negative_costs(g::AbstractGridGraph) = any(<(0.0), g.cell_costs) diff --git a/src/utils/grid_graphs/acyclic.jl b/src/utils/grid_graphs/acyclic.jl deleted file mode 100644 index b863b6c..0000000 --- a/src/utils/grid_graphs/acyclic.jl +++ /dev/null @@ -1,58 +0,0 @@ -## Graph subtyping - -struct AcyclicGridGraph{R<:AbstractFloat} <: AbstractGridGraph{R} - cell_costs::Matrix{R} -end - -function Graphs.ne(g::AcyclicGridGraph) - h, w = size(g) - return ( - (h - 1) * (w - 1) * 3 + # topleft rectangle - (w - 1) * 1 + # bottom row - (h - 1) * 1 # bottom row - ) -end - -function Graphs.has_edge(g::AcyclicGridGraph, s::Integer, d::Integer) - if has_vertex(g, s) && has_vertex(g, d) - is, js = node_coord(g, s) - id, jd = node_coord(g, d) - return (s != d) && (0 <= id - is <= 1) && (0 <= jd - js <= 1) # 3 neighbors max - else - return false - end -end - -function Graphs.outneighbors(g::AcyclicGridGraph, s::Integer) - h, w = size(g) - i, j = node_coord(g, s) - possible_neighbors = ( # listed in ascending index order! - (i + 1, j + 0), # bottom - (i + 0, j + 1), # right - (i + 1, j + 1) # bottom right - ) - neighbors = ( - node_index(g, id, jd) for - (id, jd) in possible_neighbors if (1 <= id <= h) && (1 <= jd <= w) - ) - return neighbors -end - -function Graphs.inneighbors(g::AcyclicGridGraph, s::Integer) - h, w = size(g) - i, j = node_coord(g, s) - possible_neighbors = ( # listed in ascending index order! - (i - 1, j - 1), # top left - (i + 0, j - 1), # left - (i - 1, j + 0) # top - ) - neighbors = ( - node_index(g, id, jd) for - (id, jd) in possible_neighbors if (1 <= id <= h) && (1 <= jd <= w) - ) - return neighbors -end - -function grid_shortest_paths(g::AcyclicGridGraph, s::Integer) - return grid_topological_sorting(g, s) -end diff --git a/src/utils/grid_graphs/shortest_paths.jl b/src/utils/grid_graphs/shortest_paths.jl deleted file mode 100644 index d50163f..0000000 --- a/src/utils/grid_graphs/shortest_paths.jl +++ /dev/null @@ -1,78 +0,0 @@ -## Shortest path storage - -struct ShortestPathTree{R<:AbstractFloat} - parents::Vector{Int} - dists::Vector{R} -end - -## Dijkstra - -function grid_dijkstra(g::AbstractGridGraph{R}, s::Integer) where {R<:AbstractFloat} - @assert !has_negative_costs(g) - dists = fill(typemax(R), nv(g)) - parents = zeros(Int, nv(g)) - Q = PriorityQueue{Int,R}() - dists[s] = zero(R) - enqueue!(Q, s, zero(R)) - while !isempty(Q) - u = dequeue!(Q) - d_u = dists[u] - for v in outneighbors(g, u) - dist_through_u = d_u + get_cost(g, v) - if dist_through_u < dists[v] - dists[v] = dist_through_u - parents[v] = u - Q[v] = dist_through_u - end - end - end - return ShortestPathTree(parents, dists) -end - -## Topological sorting - -function grid_topological_sorting( - g::AbstractGridGraph{R}, s::Integer -) where {R<:AbstractFloat} - dists = fill(typemax(R), nv(g)) - parents = zeros(Int, nv(g)) - dists[s] = zero(R) - for u in s:nv(g) - for v in outneighbors(g, u) - c_uv = get_cost(g, v) - if dists[u] + c_uv < dists[v] - dists[v] = dists[u] + c_uv - parents[v] = u - end - end - end - return ShortestPathTree(parents, dists) -end - -## Rebuild path - -function grid_shortest_paths(g::AbstractGridGraph, s::Integer) - return error("Not implemented") -end - -function grid_shortest_path(g::AbstractGridGraph, s::Integer, d::Integer) - spt = grid_shortest_paths(g, s) - parents = spt.parents - v = d - path = [v] - while v != s - v = parents[v] - pushfirst!(path, v) - end - y = zeros(Bool, height(g), width(g)) - for v in path - i, j = node_coord(g, v) - y[i, j] = 1 - end - return y -end - -function grid_shortest_path_cost(g::AbstractGridGraph, s::Integer, d::Integer) - spt = grid_shortest_paths(g, s) - return spt.dists[d] -end diff --git a/src/utils/grid_graphs/symmetric.jl b/src/utils/grid_graphs/symmetric.jl deleted file mode 100644 index d5177a1..0000000 --- a/src/utils/grid_graphs/symmetric.jl +++ /dev/null @@ -1,39 +0,0 @@ -struct SymmetricGridGraph{R<:AbstractFloat} <: AbstractGridGraph{R} - cell_costs::Matrix{R} -end - -function Graphs.ne(g::SymmetricGridGraph) - h, w = size(g) - return ( - (h - 2) * (w - 2) * 8 + # central nodes - 2 * (h - 2) * 5 + # vertical borders - 2 * (w - 2) * 5 + # horizontal borders - 2 * 2 * 3 # corners - ) -end - -function Graphs.outneighbors(g::SymmetricGridGraph, s::Integer) - h, w = size(g) - i, j = node_coord(g, s) - possible_neighbors = ( # listed in ascending index order! - (i - 1, j - 1), # top left - (i + 0, j - 1), # left - (i + 1, j - 1), # bottom left - (i - 1, j + 0), # top - (i + 1, j + 0), # bottom - (i - 1, j + 1), # top right - (i + 0, j + 1), # right - (i + 1, j + 1), # bottom right - ) - neighbors = ( - node_index(g, id, jd) for - (id, jd) in possible_neighbors if (1 <= id <= h) && (1 <= jd <= w) - ) - return neighbors -end - -Graphs.inneighbors(g::SymmetricGridGraph, d::Integer) = outneighbors(g, d) - -function grid_shortest_paths(g::SymmetricGridGraph, s::Integer) - return grid_dijkstra(g, s) -end diff --git a/src/utils/testing/Testing.jl b/src/utils/testing/Testing.jl index f79fecf..3bb5468 100644 --- a/src/utils/testing/Testing.jl +++ b/src/utils/testing/Testing.jl @@ -1,11 +1,9 @@ module Testing -using ..InferOpt -using ..LinearAlgebra -using ..ProgressMeter -using ..Statistics -using ..Test -using ..UnicodePlots +using InferOpt +using LinearAlgebra +using Statistics +using Test include("dataset.jl") include("error.jl") @@ -17,7 +15,7 @@ dropfirstdim(z::AbstractArray) = dropdims(z; dims=1) export generate_dataset export mape, normalized_mape export hamming_distance, normalized_hamming_distance -export define_flux_loss +export define_pipeline_loss export init_perf, update_perf!, plot_perf, test_perf export dropfirstdim diff --git a/src/utils/testing/loss.jl b/src/utils/testing/loss.jl index adb610e..7e51b23 100644 --- a/src/utils/testing/loss.jl +++ b/src/utils/testing/loss.jl @@ -1,15 +1,15 @@ -function define_flux_loss(encoder, maximizer, loss, target) - flux_loss_none(x, θ, y) = loss(maximizer(encoder(x)); instance=x) - flux_loss_θ(x, θ, y) = loss(maximizer(encoder(x)), θ) - flux_loss_y(x, θ, y) = loss(maximizer(encoder(x)), y) - flux_loss_θy(x, θ, y) = loss(maximizer(encoder(x)), θ, y) +function define_pipeline_loss(encoder, maximizer, loss, target) + pipeline_loss_none(x, θ, y) = loss(maximizer(encoder(x)); instance=x) + pipeline_loss_θ(x, θ, y) = loss(maximizer(encoder(x)), θ) + pipeline_loss_y(x, θ, y) = loss(maximizer(encoder(x)), y) + pipeline_loss_θy(x, θ, y) = loss(maximizer(encoder(x)), θ, y) - flux_losses = Dict( - "none" => flux_loss_none, - "θ" => flux_loss_θ, - "y" => flux_loss_y, - "(θ,y)" => flux_loss_θy, + pipeline_losses = Dict( + "none" => pipeline_loss_none, + "θ" => pipeline_loss_θ, + "y" => pipeline_loss_y, + "(θ,y)" => pipeline_loss_θy, ) - return flux_losses[target] + return pipeline_losses[target] end diff --git a/src/utils/testing/perf.jl b/src/utils/testing/perf.jl index e02c591..fa99db0 100644 --- a/src/utils/testing/perf.jl +++ b/src/utils/testing/perf.jl @@ -20,7 +20,7 @@ function update_perf!( true_encoder, encoder, true_maximizer, - flux_loss, + pipeline_loss, error_function, cost, ) @@ -37,8 +37,8 @@ function update_perf!( (X_train, thetas_train, Y_train) = data_train (X_test, thetas_test, Y_test) = data_test - train_loss = sum(flux_loss(t...) for t in zip(data_train...)) - test_loss = sum(flux_loss(t...) for t in zip(data_test...)) + train_loss = sum(pipeline_loss(t...) for t in zip(data_train...)) + test_loss = sum(pipeline_loss(t...) for t in zip(data_test...)) Y_train_pred = generate_predictions(encoder, true_maximizer, X_train) Y_test_pred = generate_predictions(encoder, true_maximizer, X_test) @@ -114,7 +114,7 @@ function test_perf(perf_storage::NamedTuple; test_name::String) end end -function plot_perf(perf_storage::NamedTuple) +function plot_perf(perf_storage::NamedTuple; lineplot_function::Function) (; train_losses, test_losses, @@ -127,17 +127,17 @@ function plot_perf(perf_storage::NamedTuple) plts = [] if length(train_losses) > 0 - plt = lineplot(train_losses; xlabel="Epoch", title="Train loss") + plt = lineplot_function(train_losses; xlabel="Epoch", title="Train loss") push!(plts, plt) end if length(test_losses) > 0 - plt = lineplot(test_losses; xlabel="Epoch", title="Test loss") + plt = lineplot_function(test_losses; xlabel="Epoch", title="Test loss") push!(plts, plt) end if length(train_errors) > 0 - plt = lineplot( + plt = lineplot_function( train_errors; xlabel="Epoch", title="Train error", @@ -147,7 +147,7 @@ function plot_perf(perf_storage::NamedTuple) end if length(test_errors) > 0 - plt = lineplot( + plt = lineplot_function( test_errors; xlabel="Epoch", title="Test error", # ylim=(0, maximum(test_errors)) ) @@ -155,7 +155,7 @@ function plot_perf(perf_storage::NamedTuple) end if length(train_cost_gaps) > 0 - plt = lineplot( + plt = lineplot_function( train_cost_gaps; xlabel="Epoch", title="Train cost gap", @@ -165,7 +165,7 @@ function plot_perf(perf_storage::NamedTuple) end if length(train_cost_gaps) > 0 - plt = lineplot( + plt = lineplot_function( test_cost_gaps; xlabel="Epoch", title="Test cost gap", @@ -175,7 +175,7 @@ function plot_perf(perf_storage::NamedTuple) end if length(parameter_errors) > 0 - plt = lineplot( + plt = lineplot_function( parameter_errors; xlabel="Epoch", title="Parameter error", @@ -184,8 +184,5 @@ function plot_perf(perf_storage::NamedTuple) push!(plts, plt) end - for plt in plts - println(plt) - end - return nothing + return plts end diff --git a/test/Manifest.toml b/test/Manifest.toml index 79fe6a8..bc885d6 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -16,9 +16,9 @@ version = "1.1.0" [[deps.Accessors]] deps = ["Compat", "CompositionsBase", "ConstructionBase", "Future", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "2bba2aa45df94e95b1a9c2405d7cfc3d60281db8" +git-tree-sha1 = "0264a938934447408c7f0be8985afec2a2237af4" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.9" +version = "0.1.11" [[deps.Adapt]] deps = ["LinearAlgebra"] @@ -40,11 +40,17 @@ version = "2.3.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.2.0" + [[deps.ArrayInterface]] deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] -git-tree-sha1 = "c933ce606f6535a7c7b98e1d86d5d1014f730596" +git-tree-sha1 = "81f0cb60dc994ca17f68d9fb7c942a5ae70d9ee4" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "5.0.7" +version = "5.0.8" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -76,27 +82,21 @@ uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" version = "1.0.8+0" [[deps.CEnum]] -git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.1" +version = "0.4.2" [[deps.CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] -git-tree-sha1 = "ba75320aaa092b3e17c020a2d8b9e0a572dbfa6a" +git-tree-sha1 = "19fb33957a5f85efb3cc10e70cf4dd4e30174ac9" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.9.0" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" +version = "3.10.0" [[deps.ChainRules]] deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] -git-tree-sha1 = "8b887daa6af5daf705081061e36386190204ac87" +git-tree-sha1 = "ab656fb36197083c5817667e76cccd10d11f5c30" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.1" +version = "1.32.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] @@ -106,15 +106,15 @@ version = "1.14.0" [[deps.ChangesOfVariables]] deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1" +git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.2" +version = "0.1.3" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" +git-tree-sha1 = "a985dc37e357a3b22b260a5def99f3530fb415d3" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.0" +version = "0.11.2" [[deps.ColorVectorSpace]] deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] @@ -173,15 +173,15 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.1.1" [[deps.DataAPI]] -git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8" +git-tree-sha1 = "fb5f5316dd3fd4c5e7c30a24d50643b73e37cd40" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.9.0" +version = "1.10.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3daef5523dd2e769dad2365274f760ff5f282c7d" +git-tree-sha1 = "cc1a8e22627f33c789ab60b36a9132ac050bbf75" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.11" +version = "0.18.12" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -201,12 +201,6 @@ version = "0.1.2" deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -[[deps.DensityInterface]] -deps = ["InverseFunctions", "Test"] -git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" -uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" -version = "0.4.0" - [[deps.DiffResults]] deps = ["StaticArrays"] git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" @@ -215,20 +209,14 @@ version = "1.0.3" [[deps.DiffRules]] deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0" +git-tree-sha1 = "28d605d9a0ac17118fe2c5e9ce0fbb76c3ceb120" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.10.0" +version = "1.11.0" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -[[deps.Distributions]] -deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "f206814c860c2a909d2a467af0484d08edd05ee7" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.57" - [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" @@ -245,12 +233,6 @@ version = "0.27.15" deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - [[deps.EarCut_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "3f3a2501fa7236e9b911e0f7a588c657e822bb6d" @@ -276,9 +258,9 @@ version = "0.1.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "80ced645013a5dbdc52cf70329399c35ce007fae" +git-tree-sha1 = "9267e5f50b0e12fdfd5a2455534345c4cf2c7f7a" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.13.0" +version = "1.14.0" [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] @@ -306,9 +288,9 @@ version = "0.1.1" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "1bd6fc0c344fc0cbee1f42f8d2e7ec8253dda2d2" +git-tree-sha1 = "89cc49bf5819f0a10a7a3c38885e7c7ee048de57" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.25" +version = "0.10.29" [[deps.FreeType]] deps = ["CEnum", "FreeType2_jll"] @@ -350,9 +332,9 @@ version = "8.3.2" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "556190e1e0ea3e37d83059fc9aa576f1e2104375" +git-tree-sha1 = "05374e47bb136db517b33f62fbe852adf8deb0be" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.14.1" +version = "0.15.1" [[deps.GeometryBasics]] deps = ["EarCut_jll", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] @@ -360,11 +342,17 @@ git-tree-sha1 = "83ea630384a13fc4f002b77690bc0afeb4255ac9" uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" version = "0.4.2" -[[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "SpecialFunctions", "Test"] -git-tree-sha1 = "65e4589030ef3c44d3b90bdc5aac462b4bb05567" -uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.8" +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "57c021de207e234108a6f1454003120a1bf350c4" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.6.0" + +[[deps.GridGraphs]] +deps = ["DataStructures", "Graphs", "SparseArrays"] +git-tree-sha1 = "b6d33f54428fee0174d0bfae256fdde1d5333594" +uuid = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb" +version = "0.1.2" [[deps.IOCapture]] deps = ["Logging", "Random"] @@ -374,15 +362,20 @@ version = "0.2.2" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "7f43342f8d5fd30ead0ba1b49ab1a3af3b787d24" +git-tree-sha1 = "af14a478780ca78d5eb9908b263023096c2b9d64" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.5" +version = "0.4.6" [[deps.IfElse]] git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" version = "0.1.1" +[[deps.Inflate]] +git-tree-sha1 = "f5fc07d4e706b84f72d54eedcc1c13d92fb0871c" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.2" + [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" @@ -394,9 +387,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.InverseFunctions]] deps = ["Test"] -git-tree-sha1 = "91b5dcf362c5add98049e6c29ee756910b03051d" +git-tree-sha1 = "336cc738f03e069ef2cac55a104eb823455dca75" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.3" +version = "0.1.4" [[deps.IrrationalConstants]] git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" @@ -433,15 +426,15 @@ version = "0.2.4" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "c9b86064be5ae0f63e50816a5a90b08c474507ae" +git-tree-sha1 = "c8d47589611803a0f3b4813d9e267cd4e3dbcefb" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "4.9.1" +version = "4.11.1" [[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "5558ad3c8972d602451efe9d81c78ec14ef4f5ef" +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.14+2" +version = "0.0.16+0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -472,23 +465,23 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571" +git-tree-sha1 = "09e4b894ce6a976c354a69041a04748180d43637" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.12" +version = "0.3.15" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.MLStyle]] -git-tree-sha1 = "594e189325f66e23a8818e5beb11c43bb0141bcd" +git-tree-sha1 = "e49789e5eb7b2d5577aaea395bfcac769df64bb8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.10" +version = "0.4.11" [[deps.MLUtils]] deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"] -git-tree-sha1 = "32eeb46fa393ae36a4127c9442ade478c8d01117" +git-tree-sha1 = "95ab49a8c9afb6a8a0fc81df25617a6798c0fb73" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.2.3" +version = "0.2.5" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -530,9 +523,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" [[deps.NNlib]] deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "a59a614b8b4ea6dc1dcec8c6514e251f13ccbe10" +git-tree-sha1 = "f89de462a7bc3243f95834e75751d70b3a33e59d" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.4" +version = "0.8.5" [[deps.NNlibCUDA]] deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] @@ -541,9 +534,9 @@ uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" version = "0.2.2" [[deps.NaNMath]] -git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe" +git-tree-sha1 = "737a5957f387b17e74d4ad2f440eb330b39a62c5" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.7" +version = "1.0.0" [[deps.NameResolution]] deps = ["PrettyPrint"] @@ -570,26 +563,20 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "e440ecef249dea69e79248857e800e71820d386c" +git-tree-sha1 = "2442c3ddbda547c80e8b6451a103719d6a3593dd" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.1" +version = "0.2.4" [[deps.OrderedCollections]] git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.4.1" -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "3114946c67ef9925204cc024a73c9e679cebe0d7" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.8" - [[deps.Parsers]] deps = ["Dates"] -git-tree-sha1 = "621f4f3b4977325b9128d5fae7a8b4829a0c2222" +git-tree-sha1 = "1285416549ccfcdf0c50d4997a94331e88d68413" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.2.4" +version = "2.3.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -597,9 +584,9 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "d3538e7f8a790dc8903519090857ef8e1283eecd" +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.2.5" +version = "1.3.0" [[deps.PrettyPrint]] git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" @@ -622,12 +609,6 @@ git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.7.2" -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "78aadffb3efd2155af139781b8a8df1ef279ea39" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.4.2" - [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -665,18 +646,6 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.3.0" -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.0" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.3.0+0" - [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -698,6 +667,12 @@ git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" version = "0.1.0" +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -725,15 +700,15 @@ version = "0.1.14" [[deps.Static]] deps = ["IfElse"] -git-tree-sha1 = "87e9954dfa33fd145694e42337bdd3d5b07021a6" +git-tree-sha1 = "5309da1cdef03e95b73cd3251ac3a39f887da53e" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.6.0" +version = "0.6.4" [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "4f6ec5d99a28e1a749559ef7dd518663c5eca3d5" +git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.4.3" +version = "1.4.4" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -741,9 +716,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[deps.StatsAPI]] deps = ["LinearAlgebra"] -git-tree-sha1 = "8d7530a38dbd2c397be7ddd01a424e4f411dcc41" +git-tree-sha1 = "c82aaa13b44ea00134f8c9c89819477bd3986ecd" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.2.2" +version = "1.3.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] @@ -751,21 +726,11 @@ git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" version = "0.33.16" -[[deps.StatsFuns]] -deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "ca9f8a0c9f2e41431dc5b7697058a3f8f8b89498" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.0.0" - [[deps.StructArrays]] deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"] -git-tree-sha1 = "57617b34fa34f91d536eb265df67c2d4519b8b98" +git-tree-sha1 = "e75d82493681dfd884a357952bbd7ab0608e1dc3" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.5" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +version = "0.6.7" [[deps.TOML]] deps = ["Dates"] @@ -799,9 +764,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "d60b0c96a16aaa42138d5d38ad386df672cb8bd8" +git-tree-sha1 = "7638550aaea1c9a1e86817a231ef0faa9aca79bd" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.16" +version = "0.5.19" [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] diff --git a/test/Project.toml b/test/Project.toml index 7f0fe67..d13589c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,9 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +GridGraphs = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/argmax.jl b/test/argmax.jl index f16b8a7..a5676b3 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -5,10 +5,6 @@ using LinearAlgebra using Random using Test -Random.seed!(63) - -include("utils.jl") - ## Main functions nb_features = 5 diff --git a/test/jacobian_approx.jl b/test/jacobian_approx.jl index 046ec7f..5e5198f 100644 --- a/test/jacobian_approx.jl +++ b/test/jacobian_approx.jl @@ -1,18 +1,10 @@ -using Distributions using InferOpt using LinearAlgebra using Random using Test using Zygote -Random.seed!(63) - -one_hot_argmax_approximations = [ - Perturbed(one_hot_argmax; ε=0.3, M=1000) - PerturbedGeneric( - one_hot_argmax; noise_dist=θ -> MultivariateNormal(θ, 0.3^2 * I), M=1000 - ) -] +one_hot_argmax_approximations = [Perturbed(one_hot_argmax; ε=0.3, M=5000)] for approx in one_hot_argmax_approximations @testset verbose = true "$approx" begin diff --git a/test/paths.jl b/test/paths.jl index 4caf16a..3b4a61e 100644 --- a/test/paths.jl +++ b/test/paths.jl @@ -1,15 +1,12 @@ using Flux +using Graphs +using GridGraphs using InferOpt using InferOpt.Testing -using InferOpt.GridGraphs using LinearAlgebra using Random using Test -Random.seed!(63) - -include("utils.jl") - ## Main functions nb_features = 5 @@ -18,9 +15,12 @@ true_encoder = Chain(Dense(nb_features, 1), dropfirstdim) cost(y; instance) = dot(y, -true_encoder(instance)) error_function(y1, y2) = Flux.Losses.mse(y1, y2) -function true_maximizer(θ::AbstractMatrix; kwargs...) - g = AcyclicGridGraph(-θ) - return grid_shortest_path(g, 1, nv(g)) +function true_maximizer(θ::AbstractMatrix{R}; kwargs...) where {R<:Real} + g = AcyclicGridGraph{Int,R}(-θ) + shortest_path_tree = GridGraphs.grid_topological_sort(g, 1) + path = GridGraphs.get_path(shortest_path_tree, 1, nv(g)) + y = GridGraphs.path_to_matrix(g, path) + return y end ## Pipelines diff --git a/test/quality.jl b/test/quality.jl deleted file mode 100644 index 7d05cfb..0000000 --- a/test/quality.jl +++ /dev/null @@ -1,9 +0,0 @@ -using Aqua -using InferOpt - -Aqua.test_all( - InferOpt; - deps_compat=false, - project_extras=false, - ambiguities=false -) diff --git a/test/ranking.jl b/test/ranking.jl index 1268d15..38616b5 100644 --- a/test/ranking.jl +++ b/test/ranking.jl @@ -5,10 +5,6 @@ using LinearAlgebra using Random using Test -Random.seed!(63) - -include("utils.jl") - ## Main functions nb_features = 5 diff --git a/test/runtests.jl b/test/runtests.jl index 396559e..c354d19 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,18 @@ +using Aqua +using InferOpt using Test +include("utils/pipelines.jl") +include("utils/loop.jl") + @testset verbose = true "InferOpt.jl" begin @testset verbose = true "Code quality (Aqua.jl)" begin - include("quality.jl") + Aqua.test_all( + InferOpt; + deps_compat=true, + project_extras=true, + ambiguities=false + ) end @testset verbose = true "Jacobian approx" begin include("jacobian_approx.jl") diff --git a/test/tutorial.jl b/test/tutorial.jl index b58b4c7..d2c06ba 100644 --- a/test/tutorial.jl +++ b/test/tutorial.jl @@ -16,8 +16,9 @@ We will use `InferOpt` to learn the appropriate weights, so that we may propose =# using Flux +using Graphs +using GridGraphs using InferOpt -using InferOpt.GridGraphs using InferOpt.Testing using LinearAlgebra using ProgressMeter @@ -31,7 +32,7 @@ Random.seed!(63); # ## Grid graphs #= -For the purposes of this tutorial, we consider grid graphs, as implemented in `InferOpt.GridGraphs`. +For the purposes of this tutorial, we consider grid graphs, as implemented in [GridGraphs.jl](https://github.com/gdalle/GridGraphs.jl). In such graphs, each vertex corresponds to a couple of coordinates ``(i, j)``, where ``1 \leq i \leq h`` and ``1 \leq j \leq w``. To ensure acyclicity, we only allow the user to move right, down or both. @@ -42,11 +43,11 @@ h, w = 50, 100 g = AcyclicGridGraph(rand(h, w)); #= -For convenience, `InferOpt.GridGraphs` also provides custom functions to compute shortest paths. -Let us see what those look like. +For convenience, `GridGraphs.jl` also provides custom functions to compute shortest paths efficiently. +Let us see what those paths look like. =# -p = grid_shortest_path(g, 1, nv(g)); +p = path_to_matrix(g, grid_topological_sort(g, 1, nv(g))); spy(p) # ## Dataset @@ -67,7 +68,8 @@ To be consistent with the literature, we frame this problem as a linear maximiza function linear_maximizer(θ) g = AcyclicGridGraph(-θ) - return grid_shortest_path(g, 1, nv(g)) + path = grid_topological_sort(g, 1, nv(g)) + return path_to_matrix(g, path) end; #= diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index 5bf3b86..0000000 --- a/test/utils.jl +++ /dev/null @@ -1,92 +0,0 @@ -using Flux -using InferOpt -using InferOpt.Testing -using ProgressMeter - -function list_standard_pipelines(true_maximizer; nb_features, cost=nothing) - pipelines = Dict{String,Vector}() - - pipelines["θ"] = [( - encoder=Chain(Dense(nb_features, 1), dropfirstdim), - maximizer=identity, - loss=SPOPlusLoss(true_maximizer), - )] - - pipelines["(θ,y)"] = [( - encoder=Chain(Dense(nb_features, 1), dropfirstdim), - maximizer=identity, - loss=SPOPlusLoss(true_maximizer), - )] - - pipelines["y"] = [ - # Perturbations - ( - encoder=Chain(Dense(nb_features, 1), dropfirstdim), - maximizer=identity, - loss=FenchelYoungLoss(Perturbed(true_maximizer; ε=1.0, M=5)), - ), - ( - encoder=Chain(Dense(nb_features, 1), dropfirstdim), - maximizer=Perturbed(true_maximizer; ε=1.0, M=5), - loss=Flux.Losses.mse, - ), - ] - - if !isnothing(cost) - pipelines["none"] = [( - encoder=Chain(Dense(nb_features, 1), dropfirstdim), - maximizer=identity, - loss=PerturbedCost(true_maximizer, cost; ε=1.0, M=5), - )] - end - - return pipelines -end - -function test_loop( - pipelines; - true_encoder, - true_maximizer, - data_train, - data_test, - error_function, - cost, - epochs, - show_plots, - setting_name="???", -) - pipelines = deepcopy(pipelines) - - for target in keys(pipelines), pipeline in pipelines[target] - (; encoder, maximizer, loss) = pipeline - flux_loss = define_flux_loss(encoder, maximizer, loss, target) - @info "Testing $setting_name" target encoder maximizer loss - - ## Optimization - - opt = ADAM() - perf_storage = init_perf() - - @showprogress for _ in 1:epochs - update_perf!( - perf_storage; - data_train=data_train, - data_test=data_test, - true_encoder=true_encoder, - encoder=encoder, - true_maximizer=true_maximizer, - flux_loss=flux_loss, - error_function=error_function, - cost=cost, - ) - Flux.train!(flux_loss, Flux.params(encoder), zip(data_train...), opt) - end - - ## Evaluation - - if show_plots - plot_perf(perf_storage) - end - test_perf(perf_storage; test_name="$target - $maximizer - $loss") - end -end diff --git a/test/utils/loop.jl b/test/utils/loop.jl new file mode 100644 index 0000000..48b27eb --- /dev/null +++ b/test/utils/loop.jl @@ -0,0 +1,56 @@ +using Flux +using InferOpt +using InferOpt.Testing +using ProgressMeter +using UnicodePlots + +function test_loop( + pipelines; + true_encoder, + true_maximizer, + data_train, + data_test, + error_function, + cost, + epochs, + show_plots, + setting_name="???", +) + pipelines = deepcopy(pipelines) + + for target in keys(pipelines), pipeline in pipelines[target] + (; encoder, maximizer, loss) = pipeline + pipeline_loss = define_pipeline_loss(encoder, maximizer, loss, target) + @info "Testing $setting_name" target encoder maximizer loss + + ## Optimization + + opt = ADAM() + perf_storage = init_perf() + + @showprogress for _ in 1:epochs + update_perf!( + perf_storage; + data_train=data_train, + data_test=data_test, + true_encoder=true_encoder, + encoder=encoder, + true_maximizer=true_maximizer, + pipeline_loss=pipeline_loss, + error_function=error_function, + cost=cost, + ) + Flux.train!(pipeline_loss, Flux.params(encoder), zip(data_train...), opt) + end + + ## Evaluation + + if show_plots + plts = plot_perf(perf_storage; lineplot_function = lineplot) + for plt in plts + println(plt) + end + end + test_perf(perf_storage; test_name="$target - $maximizer - $loss") + end +end diff --git a/test/utils/pipelines.jl b/test/utils/pipelines.jl new file mode 100644 index 0000000..d52d5b5 --- /dev/null +++ b/test/utils/pipelines.jl @@ -0,0 +1,43 @@ +using Flux +using InferOpt +using InferOpt.Testing + +function list_standard_pipelines(true_maximizer; nb_features, cost=nothing) + pipelines = Dict{String,Vector}() + + pipelines["θ"] = [( + encoder=Chain(Dense(nb_features, 1), dropfirstdim), + maximizer=identity, + loss=SPOPlusLoss(true_maximizer), + )] + + pipelines["(θ,y)"] = [( + encoder=Chain(Dense(nb_features, 1), dropfirstdim), + maximizer=identity, + loss=SPOPlusLoss(true_maximizer), + )] + + pipelines["y"] = [ + # Perturbations + ( + encoder=Chain(Dense(nb_features, 1), dropfirstdim), + maximizer=identity, + loss=FenchelYoungLoss(Perturbed(true_maximizer; ε=1.0, M=5)), + ), + ( + encoder=Chain(Dense(nb_features, 1), dropfirstdim), + maximizer=Perturbed(true_maximizer; ε=1.0, M=5), + loss=Flux.Losses.mse, + ), + ] + + if !isnothing(cost) + pipelines["none"] = [( + encoder=Chain(Dense(nb_features, 1), dropfirstdim), + maximizer=identity, + loss=PerturbedCost(true_maximizer, cost; ε=1.0, M=5), + )] + end + + return pipelines +end