1.0.3rc1 版本预发布
gCastle
是华为诺亚方舟实验室自研的因果结构学习工具链,主要的功能和愿景包括:
- 数据生成及处理: 包含各种模拟数据生成算子,数据读取算子,数据处理算子(如先验灌入,变量选择,CRAM)。
- 因果图构建: 提供了一个因果结构学习python算法库,包含了主流的因果学习算法以及最近兴起的基于梯度的因果结构学习算法。
- 因果评价: 提供了常用的因果结构学习性能评价指标,包括F1, SHD, FDR, TPR, FDR, NNZ等。
算法 | 分类 | 说明 | 状态 |
---|---|---|---|
PC | IID/Constraint-based | 一种基于独立性检验的经典因果发现算法 | v1.0.1 |
ANM | IID/Function-based | 一种非线性的加性噪声因果模型 | v1.0.3rc1 |
DirectLiNGAM | IID/Function-based | 一种线性非高斯无环模型的直接学习方法 | v1.0.1 |
ICALiNGAM | IID/Function-based | 一种线性非高斯无环模型的因果学习算法 | v1.0.1 |
NOTEARS | IID/Gradient-based | 一种基于梯度、针对线性数据模型的因果结构学习算法 | v1.0.1 |
NOTEARS-MLP | IID/Gradient-based | 一种深度可微分、基于神经网络建模的因果结构学习算法 | v1.0.1 |
NOTEARS-SOB | IID/Gradient-based | 一种深度可微分、基于Sobolev空间建模的因果结构学习算法 | v1.0.1 |
NOTEARS-lOW-RANK | IID/Gradient-based | 基于low rank假定、针对线性数据模型的因果结构学习算法 | v1.0.1 |
GOLEM | IID/Gradient-based | 一种基于NOTEARS、通过减少优化循环次数提升训练效率的因果结构学习算法 | v1.0.1 |
GraNDAG | IID/Gradient-based | 一种深度可微分、针对非线性加性噪声数据模型的因果结构学习算法 | v1.0.1 |
MCSL | IID/Gradient-based | 一种基于掩码梯度的因果结构学习算法 | v1.0.1 |
GAE | IID/Gradient-based | 一种基于图自编码器的因果发现算法 | v1.0.1 |
RL | IID/Gradient-based | 一种基于强化学习的因果发现算法 | v1.0.3rc1 |
CORL1 | IID/Gradient-based | 一种基于强化学习搜索因果序的因果发现方法 | v1.0.3rc1 |
CORL2 | IID/Gradient-based | 一种基于强化学习搜索因果序的因果发现方法 | v1.0.3rc1 |
TTPM | EventSequence/Function-based | 一种针对时空事件序列的基于时空Hawkes Process的因果结构学习算法 | v1.0.1 |
HPCI | EventSequence/Hybrid | 一种针对时序事件序列的基于Hawkes Process和CI tests的因果结构学习算法 | 开发中 |
- python (>= 3.6)
- tqdm (>= 4.48.2)
- numpy (>= 1.19.2)
- pandas (>= 0.22.0)
- scipy (>= 1.4.1)
- scikit-learn (>= 0.21.1)
- matplotlib (>=2.1.2)
- networkx (>= 2.5)
- torch (>= 1.4.0)
- tensorflow (>=2.6.0)
- tensorflow-probability (>=0.13.0)
pip install gcastle==1.0.3rc1
以PC算法为例:
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import IIDSimulation, DAG
from castle.algorithms import PC
# data simulation, simulate true causal dag and train_data.
weighted_random_dag = DAG.erdos_renyi(n_nodes=10, n_edges=10,
weight_range=(0.5, 2.0), seed=1)
dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='linear',
sem_type='gauss')
true_causal_matrix, X = dataset.B, dataset.X
# structure learning
pc = PC()
pc.learn(X)
# plot predict_dag and true_dag
GraphDAG(pc.causal_matrix, true_causal_matrix, 'result')
# calculate metrics
mt = MetricsDAG(pc.causal_matrix, true_causal_matrix)
print(mt.metrics)
大家可访问 examples 获取更多的示例.
欢迎大家使用gCastle
. 该项目尚处于起步阶段,欢迎各个经验等级的贡献者,近期我们将公布具体的代码贡献规范和要求。当前有任何疑问及建议,包括修改bug、贡献算法、完善文档等,请在社区提交issue,我们会及时回复交流。