Skip to content

Commit

Permalink
Enable JIT mode by default (#8772)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 15, 2024
1 parent a2eda56 commit 570f08e
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 154 deletions.
26 changes: 15 additions & 11 deletions test/nn/conv/test_gcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import copy

import pytest
import torch

Expand Down Expand Up @@ -36,7 +34,7 @@ def test_gcn_conv():
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)

if is_full_test():
jit = torch.jit.script(conv.jittable())
jit = torch.jit.script(conv)
assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)
assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)

Expand All @@ -60,18 +58,24 @@ def test_gcn_conv_with_decomposed_layers():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])

conv = GCNConv(16, 32)

decomposed_conv = copy.deepcopy(conv)
decomposed_conv.decomposed_layers = 2
def hook(module, inputs):
assert inputs[0]['x_j'].size() == (10, 32 // module.decomposed_layers)

conv = GCNConv(16, 32)
conv.register_message_forward_pre_hook(hook)
out1 = conv(x, edge_index)
out2 = decomposed_conv(x, edge_index)

conv.decomposed_layers = 2
assert conv.propagate.__module__.endswith('message_passing')
out2 = conv(x, edge_index)
assert torch.allclose(out1, out2)

if is_full_test():
jit = torch.jit.script(decomposed_conv.jittable())
assert torch.allclose(jit(x, edge_index), out1)
# TorchScript should still work since it relies on class methods
# (but without decomposition).
torch.jit.script(conv)

conv.decomposed_layers = 1
assert conv.propagate.__module__.endswith('GCNConv_propagate')


def test_gcn_conv_with_sparse_input_feature():
Expand Down
18 changes: 15 additions & 3 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,17 +551,29 @@ def test_explain_message():
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

conv = MyExplainConv()
assert conv(x, edge_index).abs().sum() != 0.

conv.explain = True
assert conv.propagate.__module__.endswith('message_passing')

with pytest.raises(ValueError, match="pre-defined 'edge_mask'"):
conv(x, edge_index)

conv._edge_mask = torch.tensor([0, 0, 0, 0], dtype=torch.float)
conv._edge_mask = torch.tensor([0.0, 0.0, 0.0, 0.0])
conv._apply_sigmoid = False
assert conv(x, edge_index).abs().sum() == 0.

conv._edge_mask = torch.tensor([1.0, 1.0, 1.0, 1.0])
conv._apply_sigmoid = False
out1 = conv(x, edge_index)

# TorchScript should still work since it relies on class methods
# (but without explainability).
torch.jit.script(conv)

conv.explain = False
assert conv.propagate.__module__.endswith('MyExplainConv_propagate')
out2 = conv(x, edge_index)
assert torch.allclose(out1, out2)


class MyAggregatorConv(MessagePassing):
def __init__(self, **kwargs):
Expand Down
59 changes: 0 additions & 59 deletions test/nn/conv/test_propagate.py

This file was deleted.

8 changes: 4 additions & 4 deletions test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_sage_conv(project, aggr):
out = assert_module(conv, x, edge_index, expected_size=(4, 32))

if is_full_test():
jit = torch.jit.script(conv.jittable())
jit = torch.jit.script(conv)
assert torch.allclose(jit(x, edge_index), out, atol=1e-6)
assert torch.allclose(jit(x, edge_index, size=(4, 4)), out, atol=1e-6)

Expand All @@ -45,7 +45,7 @@ def test_sage_conv(project, aggr):
expected_size=(2, 32))

if is_full_test():
jit = torch.jit.script(conv.jittable())
jit = torch.jit.script(conv)
assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1)
assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2)
Expand Down Expand Up @@ -137,12 +137,12 @@ def test_compile_multi_aggr_sage_conv(device):
in_channels=8,
out_channels=32,
aggr=['mean', 'sum', 'min', 'max', 'std'],
).jittable().to(device)
).to(device)

explanation = dynamo.explain(conv)(x, edge_index)
assert explanation.graph_break_count == 0

compiled_conv = torch_geometric.compile(conv)
compiled_conv = torch.compile(conv)

expected = conv(x, edge_index)
out = compiled_conv(x, edge_index)
Expand Down
100 changes: 54 additions & 46 deletions torch_geometric/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Inspector:
def __init__(self, cls: Type):
self._cls = cls
self._signature_dict: Dict[str, Signature] = {}
self._source: Optional[str] = None
self._source_dict: Dict[str, str] = {}

@property
def _globals(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -344,13 +344,17 @@ def collect_param_data(

# Inspecting Method Bodies ################################################

@property
def source(self) -> str:
def get_source(self, cls: Optional[Type] = None) -> str:
r"""Returns the source code of :obj:`cls`."""
if self._source is not None:
return self._source
self._source = inspect.getsource(self._cls)
return self._source
cls = cls or self._cls
if cls.__name__ in self._source_dict:
return self._source_dict[cls.__name__]
try:
source = inspect.getsource(cls)
except Exception:
source = ''
self._source_dict[cls.__name__] = source
return source

def get_params_from_method_call(
self,
Expand Down Expand Up @@ -394,47 +398,51 @@ def get_params_from_method_call(
return param_dict

# (2) Find type annotation:
match = find_parenthesis_content(self.source, f'{func_name}_type:')
if match is not None:
for arg in split(match, sep=','):
name_and_type_repr = re.split(r'\s*:\s*', arg)
if len(name_and_type_repr) != 2:
raise ValueError(f"Could not parse the argument '{arg}' "
f"of the '{func_name}_type' annoitation")

name, type_repr = name_and_type_repr
param_dict[name] = Parameter(
name=name,
type=self.eval_type(type_repr),
type_repr=type_repr,
default=inspect._empty,
)
return param_dict
for cls in self._cls.__mro__:
source = self.get_source(cls)
match = find_parenthesis_content(source, f'{func_name}_type:')
if match is not None:
for arg in split(match, sep=','):
name_and_type_repr = re.split(r'\s*:\s*', arg)
if len(name_and_type_repr) != 2:
raise ValueError(f"Could not parse argument '{arg}' "
f"of '{func_name}_type' annotation")

name, type_repr = name_and_type_repr
param_dict[name] = Parameter(
name=name,
type=self.eval_type(type_repr),
type_repr=type_repr,
default=inspect._empty,
)
return param_dict

# (3) Parse the function call:
match = find_parenthesis_content(self.source, f'self.{func_name}')
if match is not None:
for i, kwarg in enumerate(split(match, sep=',')):
if exclude is not None and i in exclude:
continue

name_and_content = re.split(r'\s*=\s*', kwarg)
if len(name_and_content) != 2:
raise ValueError(f"Could not parse the keyword argument "
f"'{kwarg}' in 'self.{func_name}(...)'")

name, _ = name_and_content

if exclude is not None and name in exclude:
continue

param_dict[name] = Parameter(
name=name,
type=Tensor,
type_repr=self.type_repr(Tensor),
default=inspect._empty,
)
return param_dict
for cls in self._cls.__mro__:
source = self.get_source(cls)
match = find_parenthesis_content(source, f'self.{func_name}')
if match is not None:
for i, kwarg in enumerate(split(match, sep=',')):
if exclude is not None and i in exclude:
continue

name_and_content = re.split(r'\s*=\s*', kwarg)
if len(name_and_content) != 2:
raise ValueError(f"Could not parse keyword argument "
f"'{kwarg}' in 'self.{func_name}()'")

name, _ = name_and_content

if exclude is not None and name in exclude:
continue

param_dict[name] = Parameter(
name=name,
type=Tensor,
type_repr=self.type_repr(Tensor),
default=inspect._empty,
)
return param_dict

return {} # (4) No function call found:

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/collect.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def {{collect_name}}(
{%- if name not in signature.param_dict and
not name.endswith('_i') and
not name.endswith('_j') and
name not in ['edge_index', 'adj_t', 'size', 'ptr', 'index', 'dim_size'] %}
name not in ['edge_index', 'adj_t', 'size', 'ptr', 'index', 'dim_size'] and
'_empty' not in param.default.__name__ %}
{{name}} = {{param.default}}
{%- endif %}
{%- endfor %}
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/conv/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,5 +270,5 @@ def forward(self, x: Tensor, edge_index: Adj,
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
return spmm(adj_t, x, reduce=self.aggr)
Loading

0 comments on commit 570f08e

Please sign in to comment.