Skip to content

Commit

Permalink
Merge pull request #6 from Forbu/main
Browse files Browse the repository at this point in the history
Adding correction for multigraph aspect

@pn51 & @bbartoldson I'm going ahead and merging this after you approvals.  Let me know if that's ok or not and if we should create a new (patch or otherwise) release of this.
  • Loading branch information
ksbeattie authored Aug 16, 2022
2 parents 8e6fe93 + dde5e70 commit 9c7b143
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions GNN/GNNComponents/MultiGNNComponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def reset_parameters(self):
def forward(self, x, edge_indices,
edge_attrs=None,
u=None,
batch=None):
batch=None, dim_size=None):

if self.edge_models is not None:
for i, (em, ei, ea) in enumerate(zip(self.edge_models, edge_indices, edge_attrs)):
edge_attrs[i] = em(x[ei[0]], x[ei[1]], ea, u, batch if batch is None else batch[row])

if self.node_model is not None:
x = self.node_model(x, edge_indices, edge_attrs, u, batch)
x = self.node_model(x, edge_indices, edge_attrs, u, batch, dim_size=dim_size)

return x, edge_attrs, u

Expand All @@ -61,7 +61,7 @@ def __init__(self,
norm_type: normalization type; one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
'''

#super(NodeProcessor, self).__init__()
super(NodeProcessor, self).__init__()
self.node_mlp = MLP(in_dim_node + in_dim_edge,
in_dim_node,
hidden_dim,
Expand All @@ -70,11 +70,11 @@ def __init__(self,

def forward(self, x,
edge_indices, edge_attrs,
u=None, batch=None):
u=None, batch=None, dim_size=None):

out = [x]
for ei, ea in zip(edge_indices, edge_attrs):
out.append(scatter_sum(ea, ei[1], dim=0))
out.append(scatter_sum(ea, ei[1], dim=0, dim_size=dim_size))

out = cat(out, dim=-1)
out = self.node_mlp(out)
Expand All @@ -89,7 +89,7 @@ def build_graph_processor_block(
hidden_layers_node=2, hidden_layers_edge=2,
norm_type='LayerNorm'):

edge_models = [EdgeProcessor(in_dim_node, in_dim_edge, hidden_dim_edge, hidden_layers_edge, norm_type) for _ in range(num_edge_models)]
edge_models = ModuleList([EdgeProcessor(in_dim_node, in_dim_edge, hidden_dim_edge, hidden_layers_edge, norm_type) for _ in range(num_edge_models)])
node_model = NodeProcessor(in_dim_node, in_dim_edge * num_edge_models, hidden_dim_node, hidden_layers_node, norm_type)

return MetaLayerMultigraph(
Expand All @@ -103,7 +103,8 @@ def __init__(self,
num_edge_models=1,
in_dim_node=128, in_dim_edge=128,
hidden_dim_node=128, hidden_dim_edge=128,
hidden_layers_node=2, hidden_layers_edge=2):
hidden_layers_node=2, hidden_layers_edge=2,
norm_type='LayerNorm'):

'''
Graph processor
Expand All @@ -119,7 +120,7 @@ def __init__(self,
'''

# super(GraphProcessor, self).__init__()
super(GraphProcessor, self).__init__()

self.blocks = ModuleList()
for _ in range(mp_iterations):
Expand All @@ -129,9 +130,9 @@ def __init__(self,
hidden_layers_node, hidden_layers_edge,
norm_type))

def forward(self, x, edge_indices, edge_attrs):
def forward(self, x, edge_indices, edge_attrs, dim_size=None):
for block in self.blocks:
x, edge_attrs, _ = block(x, edge_indices, edge_attrs)
x, edge_attrs, _ = block(x, edge_indices, edge_attrs, dim_size=dim_size)

return x, edge_attrs

0 comments on commit 9c7b143

Please sign in to comment.