diff --git a/rustworkx-core/src/connectivity/conn_components.rs b/rustworkx-core/src/connectivity/conn_components.rs index 5bdf19900..6e6e53e0b 100644 --- a/rustworkx-core/src/connectivity/conn_components.rs +++ b/rustworkx-core/src/connectivity/conn_components.rs @@ -9,14 +9,15 @@ // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations // under the License. - use hashbrown::HashSet; use std::collections::VecDeque; use std::hash::Hash; - -use petgraph::visit::{GraphProp, IntoNeighborsDirected, IntoNodeIdentifiers, VisitMap, Visitable}; +// use super::digraph; +use petgraph::visit::{ + GraphProp, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, NodeIndexable, VisitMap, + Visitable, +}; use petgraph::{Incoming, Outgoing}; - /// Given an graph, a node in the graph, and a visit_map, /// return the set of nodes connected to the given node /// using breadth first search and treating all edges @@ -165,6 +166,33 @@ where num_components } +pub fn is_connected(graph: G) -> bool +where + G: GraphProp + + IntoNeighborsDirected + + Visitable + + IntoNodeIdentifiers + + NodeIndexable + + NodeCount, + G::NodeId: Eq + Hash, +{ + match graph.node_identifiers().next() { + Some(node) => { + let component = node_connected_component(graph, node); + component.len() == graph.node_count() + } + None => false, + } +} + +pub fn node_connected_component(graph: G, node: G::NodeId) -> HashSet +where + G: GraphProp + IntoNeighborsDirected + Visitable + IntoNodeIdentifiers, + G::NodeId: Eq + Hash, +{ + bfs_undirected(&graph, node, &mut graph.visit_map()) +} + #[cfg(test)] mod test_conn_components { use hashbrown::HashSet; @@ -174,7 +202,9 @@ mod test_conn_components { use petgraph::{Directed, Undirected}; use std::iter::FromIterator; - use crate::connectivity::{bfs_undirected, connected_components, number_connected_components}; + use crate::connectivity::{bfs_undirected, node_connected_component,connected_components, number_connected_components}; + + use super::is_connected; #[test] fn test_number_connected() { @@ -182,6 +212,29 @@ mod test_conn_components { assert_eq!(number_connected_components(&graph), 2); } + #[test] + fn test_is_connected() { + let graph = Graph::<(), (), Directed>::from_edges([(0,1), (1,2), (2,3)]); + assert_eq!(is_connected(&graph), true); + } + + #[test] + fn test_is_not_connected(){ + let disconnected_graph = Graph::<(), (), Directed>::from_edges([(0,1), (3,4)]); + assert_eq!(is_connected(&disconnected_graph), false); + } + + #[test] + fn test_node_connected_components(){ + let graph = Graph::<(), (), Directed>::from_edges(&[(0, 1), (1, 2), (2, 3), (4,5)]); + let node_idx: NodeIndex = NodeIndex::new(3); + let expected: HashSet = HashSet::from_iter([ndx(3),ndx(0),ndx(1),ndx(2)]); + let component = node_connected_component(&graph, node_idx); + assert_eq!(component, expected); + + } + + #[test] fn test_number_node_holes() { let mut graph = Graph::<(), (), Directed>::from_edges([(0, 1), (1, 2)]); diff --git a/rustworkx-core/src/connectivity/mod.rs b/rustworkx-core/src/connectivity/mod.rs index a86e97841..e014529e0 100644 --- a/rustworkx-core/src/connectivity/mod.rs +++ b/rustworkx-core/src/connectivity/mod.rs @@ -22,4 +22,6 @@ pub use biconnected::articulation_points; pub use chain::chain_decomposition; pub use conn_components::bfs_undirected; pub use conn_components::connected_components; +pub use conn_components::is_connected; +pub use conn_components::node_connected_component; pub use conn_components::number_connected_components; diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index 4def48ae3..83d03dcd8 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -261,20 +261,7 @@ pub fn connected_components(graph: &graph::PyGraph) -> Vec> { #[pyfunction] #[pyo3(text_signature = "(graph, node, /)")] pub fn node_connected_component(graph: &graph::PyGraph, node: usize) -> PyResult> { - let node = NodeIndex::new(node); - - if !graph.graph.contains_node(node) { - return Err(InvalidNode::new_err( - "The input index for 'node' is not a valid node index", - )); - } - - Ok( - connectivity::bfs_undirected(&graph.graph, node, &mut graph.graph.visit_map()) - .into_iter() - .map(|x| x.index()) - .collect(), - ) + connectivity::node_connected_component(&graph.graph, node) } /// Check if the graph is connected. @@ -288,13 +275,7 @@ pub fn node_connected_component(graph: &graph::PyGraph, node: usize) -> PyResult #[pyfunction] #[pyo3(text_signature = "(graph, /)")] pub fn is_connected(graph: &graph::PyGraph) -> PyResult { - match graph.graph.node_indices().next() { - Some(node) => { - let component = node_connected_component(graph, node.index())?; - Ok(component.len() == graph.graph.node_count()) - } - None => Err(NullGraph::new_err("Invalid operation on a NullGraph")), - } + connectivity::is_connected(&graph.graph) } /// Find the number of weakly connected components in a directed graph