-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy path_model.py
164 lines (138 loc) · 6.76 KB
/
_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import keras as ks
from ._layers import TrafoEdgeNetMessages
from kgcnn.layers.aggr import AggregateLocalEdges
from kgcnn.layers.gather import GatherNodesOutgoing, GatherNodesIngoing
from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, GaussBasisLayer, ShiftPeriodicLattice
from kgcnn.layers.mlp import GraphMLP, MLP
from kgcnn.layers.modules import Embedding
from kgcnn.layers.message import MatMulMessages
from kgcnn.layers.pooling import PoolingNodes
from kgcnn.layers.update import GRUUpdate
from kgcnn.layers.set2set import PoolingSet2SetEncoder
def model_disjoint(inputs,
use_node_embedding: bool = None,
use_edge_embedding: bool = None,
input_node_embedding: dict = None,
input_edge_embedding: dict = None,
geometric_edge: bool = None,
make_distance: bool = None,
expand_distance: bool = None,
gauss_args: dict = None,
set2set_args: dict = None,
pooling_args: dict = None,
edge_mlp: dict = None,
use_set2set: bool = None,
node_dim: int = None,
depth: int = None,
output_embedding: str = None,
output_mlp: dict = None):
n0, ed, disjoint_indices, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs
# embedding, if no feature dimension
if use_node_embedding:
n0 = Embedding(**input_node_embedding)(n0)
if not geometric_edge:
if use_edge_embedding:
ed = Embedding(**input_edge_embedding)(ed)
if make_distance:
pos1, pos2 = NodePosition()([ed, disjoint_indices])
ed = NodeDistanceEuclidean()([pos1, pos2])
if expand_distance:
ed = GaussBasisLayer(**gauss_args)(ed)
# Make hidden dimension
n = ks.layers.Dense(node_dim, activation="linear")(n0)
# Make edge networks.
edge_net_in = GraphMLP(**edge_mlp)([ed, batch_id_edge, count_edges])
edge_net_in = TrafoEdgeNetMessages(target_shape=(node_dim, node_dim))(edge_net_in)
edge_net_out = GraphMLP(**edge_mlp)([ed, batch_id_edge, count_edges])
edge_net_out = TrafoEdgeNetMessages(target_shape=(node_dim, node_dim))(edge_net_out)
# Gru for node updates
gru = GRUUpdate(node_dim)
for i in range(0, depth):
n_in = GatherNodesOutgoing()([n, disjoint_indices])
n_out = GatherNodesIngoing()([n, disjoint_indices])
m_in = MatMulMessages()([edge_net_in, n_in])
m_out = MatMulMessages()([edge_net_out, n_out])
eu = ks.layers.Concatenate(axis=-1)([m_in, m_out])
eu = AggregateLocalEdges(**pooling_args)([n, eu, disjoint_indices]) # Summing for each node connections
n = gru([n, eu])
n = ks.layers.Concatenate(axis=-1)([n0, n])
# Output embedding choice
if output_embedding == 'graph':
if use_set2set:
# output
n = ks.layers.Dense(units=set2set_args['channels'], activation="linear")(n)
out = PoolingSet2SetEncoder(**set2set_args)([count_nodes, n, batch_id_node])
else:
out = PoolingNodes(**pooling_args)([count_nodes, n, batch_id_node])
out = ks.layers.Flatten()(out) # Flatten() required for to Set2Set output.
out = MLP(**output_mlp)(out)
elif output_embedding == 'node':
out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `NMPN` .")
return out
def model_disjoint_crystal(inputs,
use_node_embedding: bool = None,
use_edge_embedding: bool = None,
input_node_embedding: dict = None,
input_edge_embedding: dict = None,
geometric_edge: bool = None,
make_distance: bool = None,
expand_distance: bool = None,
gauss_args: dict = None,
set2set_args: dict = None,
pooling_args: dict = None,
edge_mlp: dict = None,
use_set2set: bool = None,
node_dim: int = None,
depth: int = None,
output_embedding: str = None,
output_mlp: dict = None):
n0, ed, disjoint_indices, edge_image, lattice, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs
# embedding, if no feature dimension
if use_node_embedding:
n0 = Embedding(**input_node_embedding)(n0)
if not geometric_edge:
if use_edge_embedding:
ed = Embedding(**input_edge_embedding)(ed)
# If coordinates are in place of edges
if make_distance:
x = ed
pos1, pos2 = NodePosition()([x, disjoint_indices])
pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice, batch_id_edge])
ed = NodeDistanceEuclidean()([pos1, pos2])
if expand_distance:
ed = GaussBasisLayer(**gauss_args)(ed)
# Make hidden dimension
n = ks.layers.Dense(node_dim, activation="linear")(n0)
# Make edge networks.
edge_net_in = GraphMLP(**edge_mlp)([ed, batch_id_edge, count_edges])
edge_net_in = TrafoEdgeNetMessages(target_shape=(node_dim, node_dim))(edge_net_in)
edge_net_out = GraphMLP(**edge_mlp)([ed, batch_id_edge, count_edges])
edge_net_out = TrafoEdgeNetMessages(target_shape=(node_dim, node_dim))(edge_net_out)
# Gru for node updates
gru = GRUUpdate(node_dim)
for i in range(0, depth):
n_in = GatherNodesOutgoing()([n, disjoint_indices])
n_out = GatherNodesIngoing()([n, disjoint_indices])
m_in = MatMulMessages()([edge_net_in, n_in])
m_out = MatMulMessages()([edge_net_out, n_out])
eu = ks.layers.Concatenate(axis=-1)([m_in, m_out])
eu = AggregateLocalEdges(**pooling_args)([n, eu, disjoint_indices]) # Summing for each node connections
n = gru([n, eu])
n = ks.layers.Concatenate(axis=-1)([n0, n])
# Output embedding choice
if output_embedding == 'graph':
if use_set2set:
# output
n = ks.layers.Dense(units=set2set_args['channels'], activation="linear")(n)
out = PoolingSet2SetEncoder(**set2set_args)([count_nodes, n, batch_id_node])
else:
out = PoolingNodes(**pooling_args)([count_nodes, n, batch_id_node])
out = ks.layers.Flatten()(out) # Flatten() required for to Set2Set output.
out = MLP(**output_mlp)(out)
elif output_embedding == 'node':
out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `NMPN` .")
return out