-
Notifications
You must be signed in to change notification settings - Fork 938
[DRAFT] PyWhy Proposal of Causal Graphs and Their Operations
In https://github.com/py-why/dowhy/wiki/Networkx-Proposals-For-MixedEdgeGraph-Class-To-Enable-Causal-Graphs, we propose a new generic class addition into networkx that provides support for arbitrary mixed edge graphs. PyWhy should support fundamental causal graph classes though because i) it provides an easy API interface to users/downstream-causal-packages, ii) it helps error-checking using "types", and iii) it reduces the error surface.
While it may be tempting to augment the causal graph classes with "every possible" causal operation known for that graph, that makes maintenance and implementation clunky. In line with networkx, we want a similar approach:
Most of the NetworkX API is provided by functions which take a graph object as an argument. Methods of the graph object are limited to basic manipulation and reporting. This provides modularity of code and documentation. It also makes it easier for newcomers to learn about the package in stages.
See this thread for a discussion on the benefits of adding these explicit subclasses: https://github.com/py-why/dowhy/discussions/525#discussioncomment-3132244
If an algorithm, such as a structure learning algorithm relies on an implementation of a graph, then instead of asserting that the passed in object is a "causal graph object", one should instead assert the object follows a Protocol
(https://peps.python.org/pep-0544/) and specify the basic operations any implementation of the causal graph would require.
For example:
def run_pc_algorithm(graph: CPDAGProtocol, ...):
...
# the CPDAG protocol should be able to query basic operations
# and perform operations to orient edges
class CPDAGProtocol(Protocol):
def adjacencies(self, node):
# query all adjacencies of a node (any edge)
def has_adjacency(self, u, v):
# has any type of edge between u and v
def has_undirected_edge(self, u, v):
def remove_undirected_edge(self, u, v):
def add_edge(self, u, v):
In practice, one can pass a class that implements the CPDAGProtocol
, but also much more that enables querying of the CPDAG learned.
Although all graph classes could be represented using just a MixedEdgeGraph
instantiation, it is not desirable. For one, having a subclass implementation of common causal graphs would enable i) error checking of edge operations and graph initialization, ii) light-weight convenience API for common causal-graph operations (i.e. access to specific types of edges and causality-specific notions such as "c-components"), and iii) provide additional causality-specific graph operations (i.e. do-operation).
class ADMG(MixedEdgeGraph):
def __init__(self, directed_data, bidirected_data):
super().__init__([directed_data, bidirected_data], edge_types=['directed', 'bidirected'])
def do(self, u):
# apply do intervention and intervene on the graph
def soft_do(self, nodes):
# apply soft intervention on a set of nodes by adding a F-node for example
@property
def c_components(self):
return nx.connected_components(self.bidirected_edge_graph)
...
We could also subclass the MixedEdgeGraph
with equivalence classes, where we have again different types of edges.
class CPDAG(MixedEdgeGraph):
def __init__(self, directed_data, undirected_data):
super().__init__([directed_data, undirected_data], edge_types=['directed', 'undirected'])
# API layer for users to work with the specific types of edges
def add_undirected_edge(self, u, v):
super().add_edge(u, v, edge_type='undirected')
...
# We can also use the ADMG defined functions in the PAG as well
class PAG(MixedEdgeGraph, ADMG):
def __init__(self, directed_data, undirected_data, circular_data):
super().__init__([directed_data, undirected_data, circular_data], edge_types=['directed', 'undirected', 'circle'])
# API layer for users to work with the specific types of edges
def add_circle_endpoint(self, u, v):
super().add_edge(u, v, edge_type='circle')
def potential_parents(self, u):
# return the directed edge parents of u and circle edges pointing to u
...
# instantiate the PAG
pag = PAG(nx.DiGraph, nx.Graph, nx.DiGraph)