Skip to content

Commit

Permalink
Fix pass of graph to graph layer
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 7, 2024
1 parent 094808e commit d39a30a
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions diploma_thesis/agents/utils/nn/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ def forward(self, state):
if isinstance(state, GraphState) and self.graph_encoder is not None:
data = state.graph.data

if not self.is_configured:
if isinstance(data, HeteroData):
self.graph_encoder = to_hetero(self.graph_encoder, data.metadata(), aggr='sum')

encoded_graph = self.graph_encoder(data.x_dict, data.edge_index_dict)
encoded_graph = self.graph_encoder(data)

hidden = self.merge(encoded_state, encoded_graph)
output = self.output(hidden)
Expand Down

0 comments on commit d39a30a

Please sign in to comment.