diff --git a/CHANGELOG.md b/CHANGELOG.md index ecbf35985ccb..1c09dc09d1cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added distributed `GAT + ogbn-products` example targeting XPU device ([#8032](https://github.com/pyg-team/pytorch_geometric/pull/8032)) + ### Changed ### Deprecated ### Fixed +- Fixed sparse-sparse matrix multiplication support on Windows in `TwoHop` and `AddRandomWalkPE` transformations ([#8197](https://github.com/pyg-team/pytorch_geometric/pull/8197)) + ### Removed ## [2.4.0] - 2023-10-12 diff --git a/docs/source/advanced/compile.rst b/docs/source/advanced/compile.rst index 47693bf672d7..96fe4644d758 100644 --- a/docs/source/advanced/compile.rst +++ b/docs/source/advanced/compile.rst @@ -68,9 +68,9 @@ The :meth:`torch.compile`/:meth:`torch_geometric.compile` method provides two im torch_geometric.compile(model, dynamic=True) With this, :pytorch:`PyTorch` will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches changes. - Note that when :obj:`dynamic` is set to :obj:`False`, :pytorch:`PyTorch` will *never* generate dynamic kernels, leading to significant slowdowns in model execution on dynamic mini-batches. - As such, you should only ever not specify :obj:`dynamic=True` when graph sizes are guaranteed to never change. - Note that :obj:`dynamic=True` requires :pytorch:`PyTorch` :obj:`>= 2.1.0` to be installed. + Note that when :obj:`dynamic` is set to :obj:`False`, :pytorch:`PyTorch` will *never* generate dynamic kernels, and thus only works when graph sizes are guaranteed to never change (*e.g.*, in full-batch training on small graphs). + By default, :obj:`dynamic` is set to :obj:`None` in :pytorch:`PyTorch` :obj:`>= 2.1.0`, and :pytorch:`PyTorch` will automatically detect if dynamism has occured. + Note that support for dynamic shape tracing requires :pytorch:`PyTorch` :obj:`>= 2.1.0` to be installed. * In order to maximize speedup, graphs breaks in the compiled model should be limited. We can force compilation to raise an error upon the first graph break encountered by using the :obj:`fullgraph=True` argument: diff --git a/examples/multi_gpu/distributed_sampling_xpu.py b/examples/multi_gpu/distributed_sampling_xpu.py new file mode 100644 index 000000000000..ebd3078abaa5 --- /dev/null +++ b/examples/multi_gpu/distributed_sampling_xpu.py @@ -0,0 +1,203 @@ +""" +Distributed GAT training, targeting XPU devices. +PVC has 2 tiles, each reports itself as a separate +device. DDP approach allows us to employ both tiles. + +Additional requirements: + IPEX (intel_extension_for_pytorch) + oneCCL (oneccl_bindings_for_pytorch) + + We need to import both these modules, as they extend + torch module with XPU/oneCCL related functionality. + +Run with: + mpirun -np 2 python distributed_sampling_xpu.py +""" + +import copy +import os +import os.path as osp +from typing import Tuple, Union + +import intel_extension_for_pytorch # noqa +import oneccl_bindings_for_pytorch # noqa +import torch +import torch.distributed as dist +import torch.nn.functional as F +from ogb.nodeproppred import Evaluator, PygNodePropPredDataset +from torch import Tensor +from torch.nn import Linear as Lin +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm + +from torch_geometric.loader import NeighborLoader +from torch_geometric.nn import GATConv + + +class GAT(torch.nn.Module): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + num_layers: int, + heads: int, + ): + super().__init__() + + self.num_layers = num_layers + + self.convs = torch.nn.ModuleList() + self.convs.append(GATConv(dataset.num_features, hidden_channels, + heads)) + for _ in range(num_layers - 2): + self.convs.append( + GATConv(heads * hidden_channels, hidden_channels, heads)) + self.convs.append( + GATConv(heads * hidden_channels, out_channels, heads, + concat=False)) + + self.skips = torch.nn.ModuleList() + self.skips.append(Lin(dataset.num_features, hidden_channels * heads)) + for _ in range(num_layers - 2): + self.skips.append( + Lin(hidden_channels * heads, hidden_channels * heads)) + self.skips.append(Lin(hidden_channels * heads, out_channels)) + + def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: + for i, (conv, skip) in enumerate(zip(self.convs, self.skips)): + x = conv(x, edge_index) + skip(x) + if i != self.num_layers - 1: + x = F.elu(x) + x = F.dropout(x, p=0.5, training=self.training) + return x + + def inference( + self, + x_all: Tensor, + device: Union[str, torch.device], + subgraph_loader: NeighborLoader, + ) -> Tensor: + pbar = tqdm(total=x_all.size(0) * self.num_layers) + pbar.set_description("Evaluating") + + # Compute representations of nodes layer by layer, using *all* + # available edges. This leads to faster computation in contrast to + # immediately computing the final representations of each batch. + for i in range(self.num_layers): + xs = [] + for batch in subgraph_loader: + x = x_all[batch.n_id].to(device) + edge_index = batch.edge_index.to(device) + x = self.convs[i](x, edge_index) + self.skips[i](x) + x = x[:batch.batch_size] + if i != self.num_layers - 1: + x = F.elu(x) + xs.append(x.cpu()) + + pbar.update(batch.batch_size) + + x_all = torch.cat(xs, dim=0) + + pbar.close() + + return x_all + + +def run(rank: int, world_size: int, dataset: PygNodePropPredDataset): + device = f"xpu:{rank}" + + split_idx = dataset.get_idx_split() + split_idx["train"] = (split_idx["train"].split( + split_idx["train"].size(0) // world_size, dim=0)[rank].clone()) + data = dataset[0].to(device, "x", "y") + + kwargs = dict(batch_size=1024, num_workers=0, pin_memory=True) + train_loader = NeighborLoader(data, input_nodes=split_idx["train"], + num_neighbors=[10, 10, 5], **kwargs) + + if rank == 0: + subgraph_loader = NeighborLoader(copy.copy(data), num_neighbors=[-1], + **kwargs) + evaluator = Evaluator(name="ogbn-products") + + torch.manual_seed(12345) + model = GAT(dataset.num_features, 128, dataset.num_classes, num_layers=3, + heads=4).to(device) + model = DDP(model, device_ids=[device]) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + for epoch in range(1, 21): + model.train() + for batch in train_loader: + optimizer.zero_grad() + out = model(batch.x, + batch.edge_index.to(device))[:batch.batch_size] + loss = F.cross_entropy(out, batch.y[:batch.batch_size].squeeze()) + loss.backward() + optimizer.step() + + dist.barrier() + + if rank == 0: + print(f"Epoch: {epoch:02d}, Loss: {loss:.4f}") + + if rank == 0 and epoch % 5 == 0: # Evaluation on a single GPU + model.eval() + with torch.no_grad(): + out = model.module.inference(data.x, device, subgraph_loader) + + y_true = data.y.to(out.device) + y_pred = out.argmax(dim=-1, keepdim=True) + + train_acc = evaluator.eval({ + "y_true": y_true[split_idx["train"]], + "y_pred": y_pred[split_idx["train"]], + })["acc"] + val_acc = evaluator.eval({ + "y_true": y_true[split_idx["valid"]], + "y_pred": y_pred[split_idx["valid"]], + })["acc"] + test_acc = evaluator.eval({ + "y_true": y_true[split_idx["test"]], + "y_pred": y_pred[split_idx["test"]], + })["acc"] + + print(f"Train: {train_acc:.4f}, Val: {val_acc:.4f}, " + f"Test: {test_acc:.4f}") + + dist.barrier() + + dist.destroy_process_group() + + +def get_dist_params() -> Tuple[int, int, str]: + master_addr = "127.0.0.1" + master_port = "29500" + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + mpi_rank = int(os.environ.get("PMI_RANK", -1)) + mpi_world_size = int(os.environ.get("PMI_SIZE", -1)) + rank = mpi_rank if mpi_world_size > 0 else os.environ.get("RANK", 0) + world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get( + "WORLD_SIZE", 1)) + + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + init_method = f"tcp://{master_addr}:{master_port}" + + return rank, world_size, init_method + + +if __name__ == "__main__": + rank, world_size, init_method = get_dist_params() + dist.init_process_group(backend="ccl", init_method=init_method, + world_size=world_size, rank=rank) + + path = osp.join(osp.dirname(osp.realpath(__file__)), "../../data", + "ogbn-products") + dataset = PygNodePropPredDataset("ogbn-products", path) + + run(rank, world_size, dataset) diff --git a/test/transforms/test_add_positional_encoding.py b/test/transforms/test_add_positional_encoding.py index c635f9f3d8f4..d61c781e126a 100644 --- a/test/transforms/test_add_positional_encoding.py +++ b/test/transforms/test_add_positional_encoding.py @@ -1,7 +1,6 @@ import torch from torch_geometric.data import Data -from torch_geometric.testing import onlyLinux from torch_geometric.transforms import ( AddLaplacianEigenvectorPE, AddRandomWalkPE, @@ -74,7 +73,6 @@ def test_eigenvector_permutation_invariance(): assert torch.allclose(out1.x[perm].abs(), out2.x.abs(), atol=1e-6) -@onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows. def test_add_random_walk_pe(): x = torch.randn(6, 4) edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], diff --git a/test/transforms/test_two_hop.py b/test/transforms/test_two_hop.py index 934a10056513..ac63e6e69e8c 100644 --- a/test/transforms/test_two_hop.py +++ b/test/transforms/test_two_hop.py @@ -1,11 +1,9 @@ import torch from torch_geometric.data import Data -from torch_geometric.testing import onlyLinux from torch_geometric.transforms import TwoHop -@onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows. def test_two_hop(): transform = TwoHop() assert str(transform) == 'TwoHop()' diff --git a/torch_geometric/transforms/add_positional_encoding.py b/torch_geometric/transforms/add_positional_encoding.py index 1524562a3da4..6739479840a4 100644 --- a/torch_geometric/transforms/add_positional_encoding.py +++ b/torch_geometric/transforms/add_positional_encoding.py @@ -3,6 +3,7 @@ import numpy as np import torch +import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -12,6 +13,7 @@ scatter, to_edge_index, to_scipy_sparse_matrix, + to_torch_coo_tensor, to_torch_csr_tensor, ) @@ -136,7 +138,10 @@ def forward(self, data: Data) -> Data: value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row] value = 1.0 / value - adj = to_torch_csr_tensor(data.edge_index, value, size=data.size()) + if torch_geometric.typing.WITH_WINDOWS: + adj = to_torch_coo_tensor(data.edge_index, value, size=data.size()) + else: + adj = to_torch_csr_tensor(data.edge_index, value, size=data.size()) out = adj pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)] diff --git a/torch_geometric/transforms/two_hop.py b/torch_geometric/transforms/two_hop.py index ec53d99167ea..4ef6315753a8 100644 --- a/torch_geometric/transforms/two_hop.py +++ b/torch_geometric/transforms/two_hop.py @@ -1,5 +1,6 @@ import torch +import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -7,6 +8,7 @@ coalesce, remove_self_loops, to_edge_index, + to_torch_coo_tensor, to_torch_csr_tensor, ) @@ -19,8 +21,14 @@ def forward(self, data: Data) -> Data: edge_index, edge_attr = data.edge_index, data.edge_attr N = data.num_nodes - adj = to_torch_csr_tensor(edge_index, size=(N, N)) - edge_index2, _ = to_edge_index(adj @ adj) + if torch_geometric.typing.WITH_WINDOWS: + adj = to_torch_coo_tensor(edge_index, size=(N, N)) + else: + adj = to_torch_csr_tensor(edge_index, size=(N, N)) + + adj = adj @ adj + + edge_index2, _ = to_edge_index(adj) edge_index2, _ = remove_self_loops(edge_index2) edge_index = torch.cat([edge_index, edge_index2], dim=1) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 735e81c58287..2b3d9f2ab753 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -1,4 +1,5 @@ import inspect +import os import platform import sys import warnings @@ -14,6 +15,7 @@ WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12 WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13 +WITH_WINDOWS = os.name == 'nt' WITH_ARM = platform.machine() != 'x86_64' if not hasattr(torch, 'sparse_csc'):