Skip to content

Commit

Permalink
Merge pull request #6 from mwien/fix-flower-construction
Browse files Browse the repository at this point in the history
Fix flower construction bug
  • Loading branch information
mwien authored Jan 1, 2025
2 parents fbebdaf + 62120e0 commit 2a766d2
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 20 deletions.
2 changes: 1 addition & 1 deletion cliquepicking_python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cliquepicking"
version = "0.2.4"
version = "0.2.5"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 1 addition & 1 deletion cliquepicking_rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cliquepicking_rs"
version = "0.1.0"
version = "0.2.5"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
21 changes: 8 additions & 13 deletions cliquepicking_rs/src/clique_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,33 +121,28 @@ impl CliqueTree {
if flowers[edge_id].is_empty() {
let mut flower = Vec::new();
flower.push(t);
let mut add_ids = Vec::new();
add_ids.push(edge_id);
visited[s] = true;
visited[t] = true;
let mut q = VecDeque::new();
q.push_back(t);
while !q.is_empty() {
let u = q.pop_front().unwrap();
for &v in self.tree.neighbors(u) {
if !visited[v] && st_sep.is_subset(&self.cliques[v]) {
if separators[self.get_edge_id(u, v)] == *st_sep {
add_ids.push(self.get_edge_id(u, v));
} else {
flower.push(v);
visited[v] = true;
q.push_back(v);
}
if !visited[v]
&& st_sep.is_subset(&self.cliques[v])
&& separators[self.get_edge_id(u, v)] != *st_sep
{
flower.push(v);
visited[v] = true;
q.push_back(v);
}
}
}
visited[s] = false;
for &f in &flower {
visited[f] = false;
}
for &id in &add_ids {
flowers[id] = IndexSet::from(flower.clone());
}
flowers[edge_id] = IndexSet::from(flower.clone());
}
}
}
Expand Down
100 changes: 95 additions & 5 deletions cliquepicking_rs/src/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,10 @@ pub fn sample_cpdag_orders(g: &PartiallyDirectedGraph, k: usize) -> Vec<Vec<usiz
mod tests {
use std::collections::{HashMap, HashSet};

use crate::graph::Graph;
use crate::{graph::Graph, partially_directed_graph::PartiallyDirectedGraph};

#[test]
fn sample_amos_basic_check() {
let g = Graph::from_edge_list(
fn get_paper_graph() -> Graph {
Graph::from_edge_list(
vec![
(0, 1),
(0, 2),
Expand All @@ -549,7 +548,74 @@ mod tests {
(4, 5),
],
6,
);
)
}

fn get_issue4_graph() -> PartiallyDirectedGraph {
PartiallyDirectedGraph::from_edge_list(
vec![
(0, 1),
(1, 0),
(0, 2),
(2, 0),
(1, 2),
(2, 1),
(1, 3),
(3, 1),
(1, 4),
(4, 1),
],
5,
)
}

fn get_issue5_graph() -> PartiallyDirectedGraph {
PartiallyDirectedGraph::from_edge_list(
vec![
(9, 10),
(9, 13),
(9, 7),
(10, 9),
(10, 11),
(10, 12),
(13, 9),
(4, 5),
(4, 12),
(5, 4),
(0, 1),
(0, 3),
(1, 0),
(1, 19),
(6, 7),
(6, 14),
(6, 19),
(7, 6),
(7, 9),
(7, 8),
(14, 6),
(14, 15),
(8, 7),
(8, 19),
(16, 15),
(16, 18),
(16, 17),
(15, 16),
(15, 14),
(18, 16),
(18, 19),
(11, 10),
(11, 19),
(3, 17),
(3, 19),
(2, 3),
],
5,
)
}

#[test]
fn sample_amos_basic_check() {
let g = get_paper_graph();
let sample_size = 10_000;
let amos = super::sample_amos(&g, sample_size);
assert_eq!(amos.len(), sample_size);
Expand All @@ -567,4 +633,28 @@ mod tests {
}
assert_eq!(dags.len(), 54);
}

#[test]
fn sample_cpdag_basic_check() {
let g = get_issue4_graph();
let sample_size = 10_000;
let dags = super::sample_cpdag(&g, sample_size);
assert_eq!(dags.len(), sample_size);
let mut count_dags = HashMap::new();
for a in dags.iter() {
count_dags.entry(a).and_modify(|cnt| *cnt += 1).or_insert(1);
}
assert_eq!(count_dags.len(), 10);
let g = get_issue5_graph();
let sample_size = 10_000;
let dags = super::sample_cpdag(&g, sample_size);
assert_eq!(dags.len(), sample_size);
let mut count_dags = HashMap::new();
for a in dags.iter() {
count_dags.entry(a).and_modify(|cnt| *cnt += 1).or_insert(1);
}
assert_eq!(count_dags.len(), 44);
}

// TODO: test orders as well
}

0 comments on commit 2a766d2

Please sign in to comment.