Skip to content

Commit

Permalink
WIP find_shortest_chains
Browse files Browse the repository at this point in the history
  • Loading branch information
seddonym committed Dec 9, 2024
1 parent ccd280a commit 6e6b7fd
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 10 deletions.
20 changes: 11 additions & 9 deletions rust/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ find_downstream_modules - DONE
find_shortest_chain - DONE
find_shortest_chains - TODO
chain_exists - DONE
find_illegal_dependencies_for_layers - TODO
add_module - DONE - need to add is_squashed
find_illegal_dependencies_for_layers - DONE
add_module - DONE
remove_module - DONE
add_import - DONE
remove_import - DONE
Expand Down Expand Up @@ -577,19 +577,21 @@ impl Graph {
&self,
importer: &Module,
imported: &Module,
as_packages: bool,
) -> HashSet<Vec<&Module>> {
let mut chains = HashSet::new();

let mut importer_modules: HashSet<&Module> = HashSet::from([importer]);
// TODO don't do this if module is squashed?
for descendant in self.find_descendants(&importer).unwrap() {
importer_modules.insert(descendant);
}

let mut imported_modules: HashSet<&Module> = HashSet::from([imported]);

// TODO don't do this if module is squashed?
for descendant in self.find_descendants(&imported).unwrap() {
imported_modules.insert(descendant);
if as_packages {
for descendant in self.find_descendants(&importer).unwrap() {
importer_modules.insert(descendant);
}
for descendant in self.find_descendants(&imported).unwrap() {
imported_modules.insert(descendant);
}
}

// TODO - Error if modules have shared descendants.
Expand Down
26 changes: 26 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,32 @@ impl GraphWrapper {
Some(chain.iter().map(|module| module.name.clone()).collect())
}

#[pyo3(signature = (importer, imported, as_packages=true))]
pub fn find_shortest_chains<'py>(
&self,
py: Python<'py>,
importer: &str,
imported: &str,
as_packages: bool,
) -> Bound<'py, PySet> {
let rust_chains: HashSet<Vec<&Module>> = self._graph.find_shortest_chains(
&Module::new(importer.to_string()),
&Module::new(imported.to_string()),
as_packages,
);

let mut tuple_chains: Vec<Bound<'py, PyTuple>> = vec![];
for rust_chain in rust_chains.iter() {
let module_names: Vec<Bound<'py, PyString>> = rust_chain
.iter()
.map(|module| PyString::new_bound(py, &module.name))
.collect();
let tuple = PyTuple::new_bound(py, &module_names);
tuple_chains.push(tuple);
}
PySet::new_bound(py, &tuple_chains).unwrap()
}

#[pyo3(signature = (importer, imported, as_packages=false))]
pub fn chain_exists(
&self,
Expand Down
2 changes: 1 addition & 1 deletion src/grimp/adaptors/rustgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def find_shortest_chain(self, importer: str, imported: str) -> tuple[str, ...] |
def find_shortest_chains(
self, importer: str, imported: str, as_packages: bool = True
) -> Set[Tuple[str, ...]]:
return self._pygraph.find_shortest_chains(importer, imported, as_packages)
return self._rustgraph.find_shortest_chains(importer, imported, as_packages)

def chain_exists(self, importer: str, imported: str, as_packages: bool = False) -> bool:
return self._rustgraph.chain_exists(importer, imported, as_packages)
Expand Down

0 comments on commit 6e6b7fd

Please sign in to comment.