Skip to content

Commit

Permalink
Merge pull request #3 from mwien/list_dags
Browse files Browse the repository at this point in the history
List DAGs in MEC
  • Loading branch information
mwien authored Dec 31, 2024
2 parents 69b29bd + 43b4840 commit fbebdaf
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cliquepicking_python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion cliquepicking_python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cliquepicking"
version = "0.2.3"
version = "0.2.4"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
6 changes: 6 additions & 0 deletions cliquepicking_python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ The module provides the functions
- ```mec_size(G)```, which outputs the number of DAGs in the MEC represented by CPDAG G
- ```mec_sample_dags(G, k)```, which returns k uniformly sampled DAGs from the MEC represented by CPDAG G
- ```mec_sample_orders(G, k)``` which returns topological orders of k uniformly sampled DAGs from the MEC represented by CPDAG G
- ```mec_list_dags(G)```, which returns a list of all DAGs in the MEC represented by CPDAG G
- ```mec_list_orders(G)```, which returns topological orders of all DAGs in the MEC represented by CPDAG G

The DAGs are returned as edge lists and they can be read e.g. in networkx using ```nx.DiGraph(dag)``` (see the example at the bottom).

Be aware that ```mec_sample_dags(G, k)``` holds (and returns) k DAGs in memory. (For large graphs) to avoid high memory demand, generate DAGs in smaller batches or use ```mec_sample_orders(G, k)```, which only returns the easier-to-store topological order.

The same holds for ```mec_list_dags(G)```, consider checking the size of the MEC using ```mec_size(G)``` before calling this method.

In all cases, G should be given as an edge list (vertices should be represented by zero-indexed integers), which includes ```(a, b)``` and ```(b, a)``` for undirected edges $a - b$ and only ```(a, b)``` for directed edges $a \rightarrow b$. E.g.

```python
Expand Down
26 changes: 25 additions & 1 deletion cliquepicking_python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use pyo3::prelude::*;

use cliquepicking_rs::count::count_cpdag;
use cliquepicking_rs::enumerate::list_cpdag;
use cliquepicking_rs::enumerate::list_cpdag_orders;
use cliquepicking_rs::partially_directed_graph::PartiallyDirectedGraph;
use cliquepicking_rs::sample::sample_cpdag;
use cliquepicking_rs::sample::sample_cpdag_orders;
Expand All @@ -13,6 +15,8 @@ fn cliquepicking(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(mec_size, m)?)?;
m.add_function(wrap_pyfunction!(mec_sample_dags, m)?)?;
m.add_function(wrap_pyfunction!(mec_sample_orders, m)?)?;
m.add_function(wrap_pyfunction!(mec_list_dags, m)?)?;
m.add_function(wrap_pyfunction!(mec_list_orders, m)?)?;
Ok(())
}

Expand All @@ -37,14 +41,34 @@ fn mec_sample_dags(cpdag: Vec<(usize, usize)>, k: usize) -> PyResult<Vec<Vec<(us
Ok(samples)
}

/// Sample k DAGs uniformly from the Markov equivalence class represented by CPDAG cpdag.
/// Sample k DAGs (represented by a topological order) uniformly from the Markov equivalence class represented by CPDAG cpdag.
#[pyfunction]
fn mec_sample_orders(cpdag: Vec<(usize, usize)>, k: usize) -> PyResult<Vec<Vec<usize>>> {
let mx = max_element(&cpdag);
let g = PartiallyDirectedGraph::from_edge_list(cpdag, mx + 1);
Ok(sample_cpdag_orders(&g, k))
}

/// List all DAGs from the Markov equivalence class represented by CPDAG cpdag.
#[pyfunction]
fn mec_list_dags(cpdag: Vec<(usize, usize)>) -> PyResult<Vec<Vec<(usize, usize)>>> {
let mx = max_element(&cpdag);
let g = PartiallyDirectedGraph::from_edge_list(cpdag, mx + 1);
let samples = list_cpdag(&g)
.into_iter()
.map(|sample| sample.to_edge_list())
.collect();
Ok(samples)
}

/// List all DAGs (represented by a topological orderfrom the Markov equivalence class represented by CPDAG cpdag.
#[pyfunction]
fn mec_list_orders(cpdag: Vec<(usize, usize)>) -> PyResult<Vec<Vec<usize>>> {
let mx = max_element(&cpdag);
let g = PartiallyDirectedGraph::from_edge_list(cpdag, mx + 1);
Ok(list_cpdag_orders(&g))
}

// small helper
fn max_element(tuple_list: &[(usize, usize)]) -> usize {
let mut mx = 0;
Expand Down
257 changes: 257 additions & 0 deletions cliquepicking_rs/src/enumerate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
use crate::{
directed_graph::DirectedGraph, graph::Graph, partially_directed_graph::PartiallyDirectedGraph,
};

#[derive(Debug)]
struct McsState {
ordering: Vec<usize>,
sets: Vec<Vec<usize>>,
cardinality: Vec<usize>,
max_cardinality: usize,
position: usize,
}

impl McsState {
pub fn new(n: usize) -> McsState {
let mut sets = vec![Vec::new(); n];
sets[0] = (0..n).collect();
McsState {
ordering: Vec::new(),
sets,
cardinality: vec![0; n],
max_cardinality: 0,
position: 0,
}
}
}

fn visit(g: &Graph, state: &mut McsState, u: usize) {
state.position += 1;
state.ordering.push(u);
state.cardinality[u] = usize::MAX; // TODO: use Option to encode this
for &v in g.neighbors(u) {
if state.cardinality[v] < g.n {
state.cardinality[v] += 1;
state.sets[state.cardinality[v]].push(v);
}
}
state.max_cardinality += 1;
while state.max_cardinality > 0 && state.sets[state.max_cardinality].is_empty() {
state.max_cardinality -= 1;
}
}

fn unvisit(g: &Graph, state: &mut McsState, u: usize, last_cardinality: usize) {
state.position -= 1;
state.ordering.pop();
state.cardinality[u] = last_cardinality;
state.sets[state.cardinality[u]].push(u);

for &v in g.neighbors(u).rev() {
// TODO: sets will get bigger and bigger -> cleanup?
if state.cardinality[v] < g.n {
state.cardinality[v] -= 1;
state.sets[state.cardinality[v]].push(v);
}
}

state.max_cardinality = state.cardinality[u];
}

fn reach(g: &Graph, st: &[usize], s: usize) -> Vec<usize> {
let mut visited = vec![false; g.n];
visited[s] = true;
let mut blocked = vec![true; g.n];
st.iter().for_each(|&v| blocked[v] = false);
let mut queue = vec![s];

while let Some(u) = queue.pop() {
for &v in g.neighbors(u) {
if !visited[v] && !blocked[v] {
queue.push(v);
visited[v] = true;
}
}
}

visited
.iter()
.enumerate()
.filter(|(_, &val)| val)
.map(|(i, _)| i)
.collect()
}

fn rec_list_chordal_orders(g: &Graph, orders: &mut Vec<Vec<usize>>, state: &mut McsState) {
if state.position == g.n {
orders.push(state.ordering.clone());
return;
}

// do this better
let u = loop {
while state.max_cardinality > 0 && state.sets[state.max_cardinality].is_empty() {
state.max_cardinality -= 1;
}
let next_vertex = state.sets[state.max_cardinality].pop().unwrap();
// use Result instead of this hack
if state.cardinality[next_vertex] == state.max_cardinality {
break next_vertex;
}
};

let last_cardinality = state.cardinality[u];
visit(g, state, u);
rec_list_chordal_orders(g, orders, state);
unvisit(g, state, u, last_cardinality);

let st: Vec<_> = state.sets[state.max_cardinality]
.iter()
.copied()
.filter(|&v| state.max_cardinality == state.cardinality[v])
.collect();
let reachable = reach(g, &st, u);

for x in reachable {
if x == u || state.cardinality[x] != state.max_cardinality {
continue;
}
let last_cardinality = state.cardinality[x];
visit(g, state, x);
rec_list_chordal_orders(g, orders, state);
unvisit(g, state, x, last_cardinality);
}
}

fn list_chordal_orders(g: &Graph) -> Vec<Vec<usize>> {
let mut orders = Vec::new();
rec_list_chordal_orders(g, &mut orders, &mut McsState::new(g.n));
orders
}

fn sort_order(d: &DirectedGraph, cmp: &[usize], order: &[usize]) -> Vec<usize> {
let mut component_no = vec![usize::MAX; *cmp.iter().max().unwrap() + 1];
let mut sorted_order = Vec::new();

let to = d.topological_order();
let mut found_comps = 0;
for &u in to.iter() {
if component_no[cmp[u]] == usize::MAX {
component_no[cmp[u]] = found_comps;
found_comps += 1;
sorted_order.push(Vec::new());
}
}

for &u in order.iter() {
let cmp_u = component_no[cmp[u]];
sorted_order[cmp_u].push(u);
}

sorted_order.into_iter().flatten().collect()
}

// TODO: rename
pub fn list_cpdag_orders(g: &PartiallyDirectedGraph) -> Vec<Vec<usize>> {
let undirected_subgraph = g.undirected_subgraph();
let directed_subgraph = g.directed_subgraph();
let unsorted_orders = list_chordal_orders(&undirected_subgraph);

// could use a method which only returns list of vertex lists
let (_, vertices) = undirected_subgraph.connected_components();
let mut cmp = vec![0; g.n];
vertices
.iter()
.enumerate()
.for_each(|(i, l)| l.iter().for_each(|&v| cmp[v] = i));

unsorted_orders
.iter()
.map(|order| sort_order(&directed_subgraph, &cmp, order))
.collect()
}

pub fn list_cpdag(g: &PartiallyDirectedGraph) -> Vec<DirectedGraph> {
let undirected_subgraph = g.undirected_subgraph();
let directed_subgraph = g.directed_subgraph();

let mut dags = Vec::new();
for order in list_cpdag_orders(g).iter() {
let mut position = vec![0; order.len()];
order.iter().enumerate().for_each(|(i, &v)| position[v] = i);
let mut dag_edge_list = directed_subgraph.to_edge_list();
for &(u, v) in undirected_subgraph.to_edge_list().iter() {
if u > v {
continue;
}
if position[u] < position[v] {
dag_edge_list.push((u, v));
} else {
dag_edge_list.push((v, u));
}
}
dags.push(DirectedGraph::from_edge_list(dag_edge_list, order.len()));
}
dags
}

#[cfg(test)]
mod tests {

use crate::partially_directed_graph::PartiallyDirectedGraph;

fn get_paper_graph() -> PartiallyDirectedGraph {
PartiallyDirectedGraph::from_edge_list(
vec![
(0, 1),
(1, 0),
(0, 2),
(2, 0),
(1, 2),
(2, 1),
(1, 3),
(3, 1),
(1, 4),
(4, 1),
(1, 5),
(5, 1),
(2, 3),
(3, 2),
(2, 4),
(4, 2),
(2, 5),
(5, 2),
(3, 4),
(4, 3),
(4, 5),
(5, 4),
],
6,
)
}

fn get_basic_graph() -> PartiallyDirectedGraph {
PartiallyDirectedGraph::from_edge_list(
vec![(0, 1), (1, 0), (1, 2), (2, 1), (0, 3), (2, 3)],
4,
)
}

#[test]
fn list_cpdag_basic_check() {
let dags = super::list_cpdag(&get_paper_graph());
assert_eq!(dags.len(), 54);
let dags = super::list_cpdag(&get_basic_graph());
assert_eq!(dags.len(), 3);
// TODO: better tests
}

#[test]
fn list_cpdag_orders_basic_check() {
let orders = super::list_cpdag_orders(&get_paper_graph());
assert_eq!(orders.len(), 54);
let orders = super::list_cpdag_orders(&get_basic_graph());
assert_eq!(orders.len(), 3);
// TODO: better tests
}
}
10 changes: 10 additions & 0 deletions cliquepicking_rs/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ impl Graph {
}
}

pub fn to_edge_list(&self) -> Vec<(usize, usize)> {
let mut edge_list = Vec::new();
for u in 0..self.n {
for &v in self.neighbors(u) {
edge_list.push((u, v));
}
}
edge_list
}

pub fn neighbors(&self, u: usize) -> std::slice::Iter<'_, usize> {
self.neighbors[u].iter()
}
Expand Down
1 change: 1 addition & 0 deletions cliquepicking_rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod count;
pub mod directed_graph;
pub mod enumerate;
pub mod graph;
pub mod partially_directed_graph;
pub mod sample;
Expand Down
1 change: 1 addition & 0 deletions cliquepicking_rs/src/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ pub fn sample_chordal(g: &Graph, k: usize) -> Vec<DirectedGraph> {
}

// there are unnecessary allocations/conversions here, maybe optimize this at some point
// maybe call "sample_from_cpdag"
pub fn sample_cpdag(g: &PartiallyDirectedGraph, k: usize) -> Vec<DirectedGraph> {
let undirected_subgraph = g.undirected_subgraph();
let directed_subgraph = g.directed_subgraph();
Expand Down

0 comments on commit fbebdaf

Please sign in to comment.