From 8e9cc2daed7d23dac5509cd258abc0c653beae92 Mon Sep 17 00:00:00 2001 From: confoundry <107474190+confoundry@users.noreply.github.com> Date: Fri, 28 Apr 2023 18:58:02 +0100 Subject: [PATCH] Release 0.3.2 (#44) --- poetry.lock | 138 +++++++++--------- pyproject.toml | 3 +- .../adjacency/directed_acyclic.py | 6 +- src/causica/distributions/adjacency/enco.py | 5 +- .../distributions/adjacency/three_way.py | 3 +- src/causica/distributions/noise/bernoulli.py | 9 ++ 6 files changed, 89 insertions(+), 75 deletions(-) diff --git a/poetry.lock b/poetry.lock index cdeff18..c69c337 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. [[package]] name = "absl-py" @@ -651,63 +651,63 @@ test-no-images = ["pytest"] [[package]] name = "coverage" -version = "7.2.3" +version = "7.2.4" description = "Code coverage measurement for Python" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "coverage-7.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e58c0d41d336569d63d1b113bd573db8363bc4146f39444125b7f8060e4e04f5"}, - {file = "coverage-7.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:344e714bd0fe921fc72d97404ebbdbf9127bac0ca1ff66d7b79efc143cf7c0c4"}, - {file = "coverage-7.2.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:974bc90d6f6c1e59ceb1516ab00cf1cdfbb2e555795d49fa9571d611f449bcb2"}, - {file = "coverage-7.2.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0743b0035d4b0e32bc1df5de70fba3059662ace5b9a2a86a9f894cfe66569013"}, - {file = "coverage-7.2.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d0391fb4cfc171ce40437f67eb050a340fdbd0f9f49d6353a387f1b7f9dd4fa"}, - {file = "coverage-7.2.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4a42e1eff0ca9a7cb7dc9ecda41dfc7cbc17cb1d02117214be0561bd1134772b"}, - {file = "coverage-7.2.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:be19931a8dcbe6ab464f3339966856996b12a00f9fe53f346ab3be872d03e257"}, - {file = "coverage-7.2.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:72fcae5bcac3333a4cf3b8f34eec99cea1187acd55af723bcbd559adfdcb5535"}, - {file = "coverage-7.2.3-cp310-cp310-win32.whl", hash = "sha256:aeae2aa38395b18106e552833f2a50c27ea0000122bde421c31d11ed7e6f9c91"}, - {file = "coverage-7.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:83957d349838a636e768251c7e9979e899a569794b44c3728eaebd11d848e58e"}, - {file = "coverage-7.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dfd393094cd82ceb9b40df4c77976015a314b267d498268a076e940fe7be6b79"}, - {file = "coverage-7.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:182eb9ac3f2b4874a1f41b78b87db20b66da6b9cdc32737fbbf4fea0c35b23fc"}, - {file = "coverage-7.2.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bb1e77a9a311346294621be905ea8a2c30d3ad371fc15bb72e98bfcfae532df"}, - {file = "coverage-7.2.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca0f34363e2634deffd390a0fef1aa99168ae9ed2af01af4a1f5865e362f8623"}, - {file = "coverage-7.2.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55416d7385774285b6e2a5feca0af9652f7f444a4fa3d29d8ab052fafef9d00d"}, - {file = "coverage-7.2.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:06ddd9c0249a0546997fdda5a30fbcb40f23926df0a874a60a8a185bc3a87d93"}, - {file = "coverage-7.2.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:fff5aaa6becf2c6a1699ae6a39e2e6fb0672c2d42eca8eb0cafa91cf2e9bd312"}, - {file = "coverage-7.2.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ea53151d87c52e98133eb8ac78f1206498c015849662ca8dc246255265d9c3c4"}, - {file = "coverage-7.2.3-cp311-cp311-win32.whl", hash = "sha256:8f6c930fd70d91ddee53194e93029e3ef2aabe26725aa3c2753df057e296b925"}, - {file = "coverage-7.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:fa546d66639d69aa967bf08156eb8c9d0cd6f6de84be9e8c9819f52ad499c910"}, - {file = "coverage-7.2.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b2317d5ed777bf5a033e83d4f1389fd4ef045763141d8f10eb09a7035cee774c"}, - {file = "coverage-7.2.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be9824c1c874b73b96288c6d3de793bf7f3a597770205068c6163ea1f326e8b9"}, - {file = "coverage-7.2.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c3b2803e730dc2797a017335827e9da6da0e84c745ce0f552e66400abdfb9a1"}, - {file = "coverage-7.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f69770f5ca1994cb32c38965e95f57504d3aea96b6c024624fdd5bb1aa494a1"}, - {file = "coverage-7.2.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1127b16220f7bfb3f1049ed4a62d26d81970a723544e8252db0efde853268e21"}, - {file = "coverage-7.2.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:aa784405f0c640940595fa0f14064d8e84aff0b0f762fa18393e2760a2cf5841"}, - {file = "coverage-7.2.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3146b8e16fa60427e03884301bf8209221f5761ac754ee6b267642a2fd354c48"}, - {file = "coverage-7.2.3-cp37-cp37m-win32.whl", hash = "sha256:1fd78b911aea9cec3b7e1e2622c8018d51c0d2bbcf8faaf53c2497eb114911c1"}, - {file = "coverage-7.2.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0f3736a5d34e091b0a611964c6262fd68ca4363df56185902528f0b75dbb9c1f"}, - {file = "coverage-7.2.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:981b4df72c93e3bc04478153df516d385317628bd9c10be699c93c26ddcca8ab"}, - {file = "coverage-7.2.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c0045f8f23a5fb30b2eb3b8a83664d8dc4fb58faddf8155d7109166adb9f2040"}, - {file = "coverage-7.2.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f760073fcf8f3d6933178d67754f4f2d4e924e321f4bb0dcef0424ca0215eba1"}, - {file = "coverage-7.2.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c86bd45d1659b1ae3d0ba1909326b03598affbc9ed71520e0ff8c31a993ad911"}, - {file = "coverage-7.2.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:172db976ae6327ed4728e2507daf8a4de73c7cc89796483e0a9198fd2e47b462"}, - {file = "coverage-7.2.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d2a3a6146fe9319926e1d477842ca2a63fe99af5ae690b1f5c11e6af074a6b5c"}, - {file = "coverage-7.2.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f649dd53833b495c3ebd04d6eec58479454a1784987af8afb77540d6c1767abd"}, - {file = "coverage-7.2.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7c4ed4e9f3b123aa403ab424430b426a1992e6f4c8fd3cb56ea520446e04d152"}, - {file = "coverage-7.2.3-cp38-cp38-win32.whl", hash = "sha256:eb0edc3ce9760d2f21637766c3aa04822030e7451981ce569a1b3456b7053f22"}, - {file = "coverage-7.2.3-cp38-cp38-win_amd64.whl", hash = "sha256:63cdeaac4ae85a179a8d6bc09b77b564c096250d759eed343a89d91bce8b6367"}, - {file = "coverage-7.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:20d1a2a76bb4eb00e4d36b9699f9b7aba93271c9c29220ad4c6a9581a0320235"}, - {file = "coverage-7.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ea748802cc0de4de92ef8244dd84ffd793bd2e7be784cd8394d557a3c751e21"}, - {file = "coverage-7.2.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21b154aba06df42e4b96fc915512ab39595105f6c483991287021ed95776d934"}, - {file = "coverage-7.2.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd214917cabdd6f673a29d708574e9fbdb892cb77eb426d0eae3490d95ca7859"}, - {file = "coverage-7.2.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c2e58e45fe53fab81f85474e5d4d226eeab0f27b45aa062856c89389da2f0d9"}, - {file = "coverage-7.2.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:87ecc7c9a1a9f912e306997ffee020297ccb5ea388421fe62a2a02747e4d5539"}, - {file = "coverage-7.2.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:387065e420aed3c71b61af7e82c7b6bc1c592f7e3c7a66e9f78dd178699da4fe"}, - {file = "coverage-7.2.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ea3f5bc91d7d457da7d48c7a732beaf79d0c8131df3ab278e6bba6297e23c6c4"}, - {file = "coverage-7.2.3-cp39-cp39-win32.whl", hash = "sha256:ae7863a1d8db6a014b6f2ff9c1582ab1aad55a6d25bac19710a8df68921b6e30"}, - {file = "coverage-7.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:3f04becd4fcda03c0160d0da9c8f0c246bc78f2f7af0feea1ec0930e7c93fa4a"}, - {file = "coverage-7.2.3-pp37.pp38.pp39-none-any.whl", hash = "sha256:965ee3e782c7892befc25575fa171b521d33798132692df428a09efacaffe8d0"}, - {file = "coverage-7.2.3.tar.gz", hash = "sha256:d298c2815fa4891edd9abe5ad6e6cb4207104c7dd9fd13aea3fdebf6f9b91259"}, + {file = "coverage-7.2.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9e5eedde6e6e241ec3816f05767cc77e7456bf5ec6b373fb29917f0990e2078f"}, + {file = "coverage-7.2.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5c6c6e3b8fb6411a2035da78d86516bfcfd450571d167304911814407697fb7a"}, + {file = "coverage-7.2.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7668a621afc52db29f6867e0e9c72a1eec9f02c94a7c36599119d557cf6e471"}, + {file = "coverage-7.2.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cdfb53bef4b2739ff747ebbd76d6ac5384371fd3c7a8af08899074eba034d483"}, + {file = "coverage-7.2.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5c4f2e44a2ae15fa6883898e756552db5105ca4bd918634cbd5b7c00e19e8a1"}, + {file = "coverage-7.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:700bc9fb1074e0c67c09fe96a803de66663830420781df8dc9fb90d7421d4ccb"}, + {file = "coverage-7.2.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ac4861241e693e21b280f07844ae0e0707665e1dfcbf9466b793584984ae45c4"}, + {file = "coverage-7.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3d6f3c5b6738a494f17c73b4aa3aa899865cc33a74aa85e3b5695943b79ad3ce"}, + {file = "coverage-7.2.4-cp310-cp310-win32.whl", hash = "sha256:437da7d2fcc35bf45e04b7e9cfecb7c459ec6f6dc17a8558ed52e8d666c2d9ab"}, + {file = "coverage-7.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:1d3893f285fd76f56651f04d1efd3bdce251c32992a64c51e5d6ec3ba9e3f9c9"}, + {file = "coverage-7.2.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6a17bf32e9e3333d78606ac1073dd20655dc0752d5b923fa76afd3bc91674ab4"}, + {file = "coverage-7.2.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f7ffdb3af2a01ce91577f84fc0faa056029fe457f3183007cffe7b11ea78b23c"}, + {file = "coverage-7.2.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89e63b38c7b888e00fd42ce458f838dccb66de06baea2da71801b0fc9070bfa0"}, + {file = "coverage-7.2.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4522dd9aeb9cc2c4c54ce23933beb37a4e106ec2ba94f69138c159024c8a906a"}, + {file = "coverage-7.2.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29c7d88468f01a75231797173b52dc66d20a8d91b8bb75c88fc5861268578f52"}, + {file = "coverage-7.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bc47015fc0455753e8aba1f38b81b731aaf7f004a0c390b404e0fcf1d6c1d72f"}, + {file = "coverage-7.2.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5c122d120c11a236558c339a59b4b60947b38ac9e3ad30a0e0e02540b37bf536"}, + {file = "coverage-7.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:50fda3d33b705b9c01e3b772cfa7d14de8aec2ec2870e4320992c26d057fde12"}, + {file = "coverage-7.2.4-cp311-cp311-win32.whl", hash = "sha256:ab08af91cf4d847a6e15d7d5eeae5fead1487caf16ff3a2056dbe64d058fd246"}, + {file = "coverage-7.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:876e4ef3eff00b50787867c5bae84857a9af4c369a9d5b266cd9b19f61e48ef7"}, + {file = "coverage-7.2.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3fc9cde48de956bfbacea026936fbd4974ff1dc2f83397c6f1968f0142c9d50b"}, + {file = "coverage-7.2.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12bc9127c8aca2f7c25c9acca53da3db6799b2999b40f28c2546237b7ea28459"}, + {file = "coverage-7.2.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2857894c22833d3da6e113623a9b7440159b2295280b4e0d954cadbfa724b85a"}, + {file = "coverage-7.2.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4db4e6c115d869cd5397d3d21fd99e4c7053205c33a4ae725c90d19dcd178af"}, + {file = "coverage-7.2.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f37ae1804596f13d811e0247ffc8219f5261b3565bdf45fcbb4fc091b8e9ff35"}, + {file = "coverage-7.2.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:cdee9a77fd0ce000781680b6a1f4b721c567f66f2f73a49be1843ff439d634f3"}, + {file = "coverage-7.2.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0b65a6a5484b7f2970393d6250553c05b2ede069e0e18abe907fdc7f3528252e"}, + {file = "coverage-7.2.4-cp37-cp37m-win32.whl", hash = "sha256:1a3e8697cb40f28e5bcfb6f4bda7852d96dbb6f6fd7cc306aba4ae690c9905ab"}, + {file = "coverage-7.2.4-cp37-cp37m-win_amd64.whl", hash = "sha256:4078939c4b7053e14e87c65aa68dbed7867e326e450f94038bfe1a1b22078ff9"}, + {file = "coverage-7.2.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:603a2b172126e3b08c11ca34200143089a088cd0297d4cfc4922d2c1c3a892f9"}, + {file = "coverage-7.2.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:72751d117ceaad3b1ea3bcb9e85f5409bbe9fb8a40086e17333b994dbccc0718"}, + {file = "coverage-7.2.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f19ba9301e6fb0b94ba71fda9a1b02d11f0aab7f8e2455122a4e2921b6703c2f"}, + {file = "coverage-7.2.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d784177a7fb9d0f58d24d3e60638c8b729c3693963bf67fa919120f750db237"}, + {file = "coverage-7.2.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d2a9180beff1922b09bd7389e23454928e108449e646c26da5c62e29b0bf4e3"}, + {file = "coverage-7.2.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:39747afc854a7ee14e5e132da7db179d6281faf97dc51e6d7806651811c47538"}, + {file = "coverage-7.2.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60feb703abc8d78e9427d873bcf924c9e30cf540a21971ef5a17154da763b60f"}, + {file = "coverage-7.2.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c2becddfcbf3d994a8f4f9dd2b6015cae3a3eff50dedc6e4a17c3cccbe8f93d4"}, + {file = "coverage-7.2.4-cp38-cp38-win32.whl", hash = "sha256:56a674ad18d6b04008283ca03c012be913bf89d91c0803c54c24600b300d9e51"}, + {file = "coverage-7.2.4-cp38-cp38-win_amd64.whl", hash = "sha256:ab08e03add2cf5793e66ac1bbbb24acfa90c125476f5724f5d44c56eeec1d635"}, + {file = "coverage-7.2.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:92b565c51732ea2e7e541709ccce76391b39f4254260e5922e08e00971e88e33"}, + {file = "coverage-7.2.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8769a67e8816c7e94d5bf446fc0501641fde78fdff362feb28c2c64d45d0e9b1"}, + {file = "coverage-7.2.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56d74d6fbd5a98a5629e8467b719b0abea9ca01a6b13555d125c84f8bf4ea23d"}, + {file = "coverage-7.2.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d9f770c6052d9b5c9b0e824fd8c003fe33276473b65b4f10ece9565ceb62438e"}, + {file = "coverage-7.2.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3023ce23e41a6f006c09f7e6d62b6c069c36bdc9f7de16a5ef823acc02e6c63"}, + {file = "coverage-7.2.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:fabd1f4d12dfd6b4f309208c2f31b116dc5900e0b42dbafe4ee1bc7c998ffbb0"}, + {file = "coverage-7.2.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:e41a7f44e73b37c6f0132ecfdc1c8b67722f42a3d9b979e6ebc150c8e80cf13a"}, + {file = "coverage-7.2.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:864e36947289be05abd83267c4bade35e772526d3e9653444a9dc891faf0d698"}, + {file = "coverage-7.2.4-cp39-cp39-win32.whl", hash = "sha256:ea534200efbf600e60130c48552f99f351cae2906898a9cd924c1c7f2fb02853"}, + {file = "coverage-7.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:00f8fd8a5fe1ffc3aef78ea2dbf553e5c0f4664324e878995e38d41f037eb2b3"}, + {file = "coverage-7.2.4-pp37.pp38.pp39-none-any.whl", hash = "sha256:856bcb837e96adede31018a0854ce7711a5d6174db1a84e629134970676c54fa"}, + {file = "coverage-7.2.4.tar.gz", hash = "sha256:7283f78d07a201ac7d9dc2ac2e4faaea99c4d302f243ee5b4e359f3e170dc008"}, ] [package.dependencies] @@ -1423,14 +1423,14 @@ files = [ [[package]] name = "jsonargparse" -version = "4.21.0" +version = "4.20.1" description = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables." category = "main" optional = false python-versions = ">=3.6" files = [ - {file = "jsonargparse-4.21.0-py3-none-any.whl", hash = "sha256:4d2ed32e3e477940c8e142282865d81bb920175ab7f862351b3635a56c38c2b7"}, - {file = "jsonargparse-4.21.0.tar.gz", hash = "sha256:37dd3870ec2b5faedb6842407ddb636d0d236cd377f7102bc8017329f2d65e86"}, + {file = "jsonargparse-4.20.1-py3-none-any.whl", hash = "sha256:416f173b37644745cc5b18bd14c5a6459b37bb085a4a76f6bfed91dd18fe8745"}, + {file = "jsonargparse-4.20.1.tar.gz", hash = "sha256:3000f06db29f094ff360b020a766fb634f9786b68715e1563d376db51c1f4801"}, ] [package.dependencies] @@ -1439,21 +1439,21 @@ PyYAML = ">=3.13" typeshed-client = {version = ">=2.1.0", optional = true, markers = "extra == \"signatures\""} [package.extras] -all = ["jsonargparse[argcomplete]", "jsonargparse[fsspec]", "jsonargparse[jsonnet]", "jsonargparse[jsonschema]", "jsonargparse[omegaconf]", "jsonargparse[reconplogger]", "jsonargparse[ruyaml]", "jsonargparse[signatures]", "jsonargparse[typing-extensions]", "jsonargparse[urls]"] +all = ["argcomplete (>=2.0.0)", "docstring-parser (>=0.15)", "fsspec (>=0.8.4)", "jsonnet (>=0.13.0)", "jsonnet-binary (>=0.17.0)", "jsonschema (>=3.2.0)", "omegaconf (>=2.1.1)", "reconplogger (>=4.4.0)", "requests (>=2.18.4)", "ruyaml (>=0.20.0)", "typeshed-client (>=2.1.0)", "typing-extensions (>=3.10.0.0)"] argcomplete = ["argcomplete (>=2.0.0)"] -dev = ["jsonargparse[doc]", "jsonargparse[mypy]", "jsonargparse[pylint]", "jsonargparse[test]", "pre-commit (>=2.19.0)", "pycodestyle (>=2.5.0)", "tox (>=3.25.0)"] +dev = ["Sphinx (>=1.7.9)", "autodocsumm (>=0.1.10)", "coverage (>=4.5.1)", "mypy (>=0.701)", "pre-commit (>=2.19.0)", "pycodestyle (>=2.5.0)", "pylint (>=2.15.6)", "responses (>=0.12.0)", "sphinx-autodoc-typehints (>=1.19.5)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.25.0)", "types-PyYAML (>=6.0.11)", "types-requests (>=2.28.9)"] doc = ["Sphinx (>=1.7.9)", "autodocsumm (>=0.1.10)", "sphinx-autodoc-typehints (>=1.19.5)", "sphinx-rtd-theme (>=0.4.3)"] fsspec = ["fsspec (>=0.8.4)"] jsonnet = ["jsonnet (>=0.13.0)", "jsonnet-binary (>=0.17.0)"] jsonschema = ["jsonschema (>=3.2.0)"] maintainer = ["bump2version (>=0.5.11)"] -mypy = ["jsonargparse[types-pyyaml]", "mypy (>=0.701)"] +mypy = ["mypy (>=0.701)", "types-PyYAML (>=6.0.11)"] omegaconf = ["omegaconf (>=2.1.1)"] pylint = ["pylint (>=2.15.6)"] reconplogger = ["reconplogger (>=4.4.0)"] ruyaml = ["ruyaml (>=0.20.0)"] -signatures = ["docstring-parser (>=0.15)", "jsonargparse[typing-extensions]", "typeshed-client (>=2.1.0)"] -test = ["attrs (>=22.2.0)", "jsonargparse[test-no-urls]", "jsonargparse[types-pyyaml]", "pydantic (>=1.10.7)", "responses (>=0.12.0)", "types-requests (>=2.28.9)"] +signatures = ["docstring-parser (>=0.15)", "typeshed-client (>=2.1.0)"] +test = ["coverage (>=4.5.1)", "responses (>=0.12.0)", "types-PyYAML (>=6.0.11)", "types-requests (>=2.28.9)"] test-no-urls = ["coverage (>=4.5.1)"] types-pyyaml = ["types-PyYAML (>=6.0.11)"] typing-extensions = ["typing-extensions (>=3.10.0.0)"] @@ -1857,14 +1857,14 @@ files = [ [[package]] name = "mlflow" -version = "2.3.0" +version = "2.3.1" description = "MLflow: A Platform for ML Development and Productionization" category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "mlflow-2.3.0-py3-none-any.whl", hash = "sha256:b87c8f692e3a0661357eb94d149cf8bd4d4d5999d1f0481dd637d0bb02df4790"}, - {file = "mlflow-2.3.0.tar.gz", hash = "sha256:2a7ba60a2c790c7ea742f486838706d586fab701fec308ea4c732f4bbadd8409"}, + {file = "mlflow-2.3.1-py3-none-any.whl", hash = "sha256:699c512d659c7463a498e087c5f74d3d139b5708cf6aaaccfa398d7b0c095204"}, + {file = "mlflow-2.3.1.tar.gz", hash = "sha256:63439397b2718ce5747288ef5475f46b3716b370a517be3e3c67b799a247a186"}, ] [package.dependencies] @@ -1902,18 +1902,18 @@ waitress = {version = "<3", markers = "platform_system == \"Windows\""} [package.extras] aliyun-oss = ["aliyunstoreplugin"] databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"] -extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,<1.3)", "mlserver-mlflow (>=1.2.0,<1.3)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"] +extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"] sqlserver = ["mlflow-dbstore"] [[package]] name = "mlflow-skinny" -version = "2.3.0" +version = "2.3.1" description = "MLflow: A Platform for ML Development and Productionization" category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "mlflow_skinny-2.3.0-py3-none-any.whl", hash = "sha256:735700c83d2f55bb42658f4d52c4ed86c065422f52647e7054dd2c1207d83fb2"}, + {file = "mlflow_skinny-2.3.1-py3-none-any.whl", hash = "sha256:33ba9668ff027af8ef865ecaf9f4984e113d46e2a9bcd93b38b326c237e5b13c"}, ] [package.dependencies] @@ -1933,7 +1933,7 @@ sqlparse = ">=0.4.0,<1" [package.extras] aliyun-oss = ["aliyunstoreplugin"] databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"] -extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,<1.3)", "mlserver-mlflow (>=1.2.0,<1.3)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"] +extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"] sqlserver = ["mlflow-dbstore"] [[package]] @@ -3757,4 +3757,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "~3.10" -content-hash = "c8f701531aae824c64700c7a1874ae36b24634ebc4752b50b2247aa1cee221fc" +content-hash = "4e75da1d01bae2bd1675a51b22aa54241002acfc7634aef13e7b2f039d6154ed" diff --git a/pyproject.toml b/pyproject.toml index 117f96a..536d619 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "causica" -version = "0.3.1" +version = "0.3.2" description = "" readme = "README.md" authors = [] @@ -18,6 +18,7 @@ numpy = "^1.22.4" pandas = "^1.4.2" tensorboard = "^2.9.0" pytorch-lightning = {version = "^1.9.0", extras= ["extra"]} +jsonargparse = "<4.21.0" # 4.21.0 breaks lightning cli dataclasses-json = "^0.5.7" types-PyYAML = "^6.0.12.2" tensordict = "^0.1.0" diff --git a/src/causica/distributions/adjacency/directed_acyclic.py b/src/causica/distributions/adjacency/directed_acyclic.py index 853afb3..92474ea 100644 --- a/src/causica/distributions/adjacency/directed_acyclic.py +++ b/src/causica/distributions/adjacency/directed_acyclic.py @@ -72,13 +72,13 @@ def mode(self) -> torch.Tensor: We return the mode corresponding to the "default" ordering. There are 2 possibilities: - p >= 0.5: A lower triangular matrix of ones - p < 0.5: A matrix of zeros + p > 0.5: A lower triangular matrix of ones + p <= 0.5: A matrix of zeros Returns: A tensor of shape batch_shape + (num_nodes, num_nodes) """ - return fill_triangular(self.bern_dist.mode) + return fill_triangular(torch.nan_to_num(self.bern_dist.mode, nan=0.0)) def log_prob(self, value: torch.Tensor) -> torch.Tensor: raise NotImplementedError diff --git a/src/causica/distributions/adjacency/enco.py b/src/causica/distributions/adjacency/enco.py index e900028..d8c2b7d 100644 --- a/src/causica/distributions/adjacency/enco.py +++ b/src/causica/distributions/adjacency/enco.py @@ -144,7 +144,10 @@ def mode(self) -> torch.Tensor: A tensor of shape batch_shape + (num_nodes, num_nodes) """ logits = self._get_independent_bernoulli_logits() - return self.base_dist(logits).mode * (1.0 - torch.eye(self.num_nodes, device=logits.device)) + # bernoulli mode can be nan for very small logits, favour sparseness and set to 0 + return torch.nan_to_num(self.base_dist(logits).mode, nan=0.0) * ( + 1.0 - torch.eye(self.num_nodes, device=logits.device) + ) def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ diff --git a/src/causica/distributions/adjacency/three_way.py b/src/causica/distributions/adjacency/three_way.py index ad1a08e..09d9cae 100644 --- a/src/causica/distributions/adjacency/three_way.py +++ b/src/causica/distributions/adjacency/three_way.py @@ -92,7 +92,8 @@ def mode(self) -> torch.Tensor: Returns: A tensor of shape batch_shape + (num_nodes, num_nodes) """ - return _triangular_vec_to_matrix(self.base_dist(self.logits).mode) + # bernoulli mode can be nan for very small logits, favour sparseness and set to 0 + return _triangular_vec_to_matrix(torch.nan_to_num(self.base_dist(self.logits).mode, 0.0)) def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ diff --git a/src/causica/distributions/noise/bernoulli.py b/src/causica/distributions/noise/bernoulli.py index 4d4bd96..8693e1b 100644 --- a/src/causica/distributions/noise/bernoulli.py +++ b/src/causica/distributions/noise/bernoulli.py @@ -59,6 +59,15 @@ def noise_to_sample(self, noise: torch.Tensor) -> torch.Tensor: """ return ((self.delta_logits + noise) > 0).float() + @property + def mode(self): + """ + Override the default `mode` method to prevent it returning nan's. + + We favour sparseness, so if logit == 0, set the mode to be zero. + """ + return (self.logits > 0).to(self.logits) + class BernoulliNoiseModule(NoiseModule[IndependentNoise[BernoulliNoise]]): """Represents a BernoulliNoise distribution with learnable logits."""