From 6e6b7fdb193d3a01241f1aaeddd056fd1ed923b4 Mon Sep 17 00:00:00 2001 From: David Seddon Date: Mon, 9 Dec 2024 18:05:44 +0000 Subject: [PATCH] WIP find_shortest_chains --- rust/src/graph.rs | 20 +++++++++++--------- rust/src/lib.rs | 26 ++++++++++++++++++++++++++ src/grimp/adaptors/rustgraph.py | 2 +- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/rust/src/graph.rs b/rust/src/graph.rs index 434015c4..16f89039 100644 --- a/rust/src/graph.rs +++ b/rust/src/graph.rs @@ -12,8 +12,8 @@ find_downstream_modules - DONE find_shortest_chain - DONE find_shortest_chains - TODO chain_exists - DONE - find_illegal_dependencies_for_layers - TODO -add_module - DONE - need to add is_squashed +find_illegal_dependencies_for_layers - DONE +add_module - DONE remove_module - DONE add_import - DONE remove_import - DONE @@ -577,19 +577,21 @@ impl Graph { &self, importer: &Module, imported: &Module, + as_packages: bool, ) -> HashSet> { let mut chains = HashSet::new(); let mut importer_modules: HashSet<&Module> = HashSet::from([importer]); - // TODO don't do this if module is squashed? - for descendant in self.find_descendants(&importer).unwrap() { - importer_modules.insert(descendant); - } - let mut imported_modules: HashSet<&Module> = HashSet::from([imported]); + // TODO don't do this if module is squashed? - for descendant in self.find_descendants(&imported).unwrap() { - imported_modules.insert(descendant); + if as_packages { + for descendant in self.find_descendants(&importer).unwrap() { + importer_modules.insert(descendant); + } + for descendant in self.find_descendants(&imported).unwrap() { + imported_modules.insert(descendant); + } } // TODO - Error if modules have shared descendants. diff --git a/rust/src/lib.rs b/rust/src/lib.rs index c646e992..13d7d018 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -246,6 +246,32 @@ impl GraphWrapper { Some(chain.iter().map(|module| module.name.clone()).collect()) } + #[pyo3(signature = (importer, imported, as_packages=true))] + pub fn find_shortest_chains<'py>( + &self, + py: Python<'py>, + importer: &str, + imported: &str, + as_packages: bool, + ) -> Bound<'py, PySet> { + let rust_chains: HashSet> = self._graph.find_shortest_chains( + &Module::new(importer.to_string()), + &Module::new(imported.to_string()), + as_packages, + ); + + let mut tuple_chains: Vec> = vec![]; + for rust_chain in rust_chains.iter() { + let module_names: Vec> = rust_chain + .iter() + .map(|module| PyString::new_bound(py, &module.name)) + .collect(); + let tuple = PyTuple::new_bound(py, &module_names); + tuple_chains.push(tuple); + } + PySet::new_bound(py, &tuple_chains).unwrap() + } + #[pyo3(signature = (importer, imported, as_packages=false))] pub fn chain_exists( &self, diff --git a/src/grimp/adaptors/rustgraph.py b/src/grimp/adaptors/rustgraph.py index f26abd7c..534b5d02 100644 --- a/src/grimp/adaptors/rustgraph.py +++ b/src/grimp/adaptors/rustgraph.py @@ -124,7 +124,7 @@ def find_shortest_chain(self, importer: str, imported: str) -> tuple[str, ...] | def find_shortest_chains( self, importer: str, imported: str, as_packages: bool = True ) -> Set[Tuple[str, ...]]: - return self._pygraph.find_shortest_chains(importer, imported, as_packages) + return self._rustgraph.find_shortest_chains(importer, imported, as_packages) def chain_exists(self, importer: str, imported: str, as_packages: bool = False) -> bool: return self._rustgraph.chain_exists(importer, imported, as_packages)