Skip to content

Commit

Permalink
Merge branch 'master' into fix-from-triples
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 authored Jan 24, 2025
2 parents a1a8492 + ed89c94 commit 1945d57
Show file tree
Hide file tree
Showing 16 changed files with 275 additions and 26 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.7.0] - 2024-MM-DD
## [2.7.0] - 2025-MM-DD

### Added

- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
Expand All @@ -25,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
- Add ComplexWebQuestions (CWQ) dataset ([#9950](https://github.com/pyg-team/pytorch_geometric/pull/9950))

### Changed

Expand Down Expand Up @@ -855,6 +857,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug in which `nn.models.GAT` did not produce `out_channels`-many output channels ([#4299](https://github.com/pyg-team/pytorch_geometric/pull/4299))
- Fixed mini-batching with empty lists as attributes ([#4293](https://github.com/pyg-team/pytorch_geometric/pull/4293))
- Fixed a bug in which `GCNConv` could not be combined with `to_hetero` on heterogeneous graphs with one node type ([#4279](https://github.com/pyg-team/pytorch_geometric/pull/4279))
- Added a scheduler to the Graph Sage OGBN Example [#9877](https://github.com/pyg-team/pytorch_geometric/pull/9877)

### Removed

Expand Down
2 changes: 1 addition & 1 deletion docs/source/install/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ For ease of installation of these extensions, we provide :obj:`pip` wheels for t
where :obj:`${TORCH}` and :obj:`${CUDA}` should be replaced by the specific :pytorch:`PyTorch` and CUDA versions, respectively:

* :pytorch:`PyTorch` 2.4: :obj:`${TORCH}=2.5.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124`
* :pytorch:`PyTorch` 2.5: :obj:`${TORCH}=2.5.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124`
* :pytorch:`PyTorch` 2.4: :obj:`${TORCH}=2.4.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124`
* :pytorch:`PyTorch` 2.3: :obj:`${TORCH}=2.3.0` and :obj:`${CUDA}=cpu|cu118|cu121`
* :pytorch:`PyTorch` 2.2: :obj:`${TORCH}=2.2.0` and :obj:`${CUDA}=cpu|cu118|cu121`
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
24 changes: 17 additions & 7 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import WebQSPDataset
from torch_geometric.datasets import CWQDataset, WebQSPDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GAT, GRetriever
from torch_geometric.nn.nlp import LLM
Expand Down Expand Up @@ -89,7 +89,7 @@ def compute_metrics(eval_output):
f1 = sum(all_f1) / len(all_f1)

# Print metrics to console
print(f'Hit: {hit:.4f}')
print(f'Hit@1: {hit:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')
Expand Down Expand Up @@ -193,9 +193,10 @@ def train(
lr, # Initial learning rate
llm_model_name, # `transformers` model name
checkpointing=False, # Whether to checkpoint model
cwq=False, # Whether to train on the CWQ dataset
tiny_llama=False, # Whether to use tiny LLaMA model
):
"""Train a GNN+LLM model on WebQSP dataset.
"""Train a GNN+LLM model on WebQSP or CWQ dataset.
Args:
num_epochs (int): Total number of training epochs.
Expand All @@ -207,6 +208,8 @@ def train(
llm_model_name (str): The name of the LLM to use.
checkpointing (bool, optional): Whether to checkpoint model.
Defaults to False.
cwq (bool, optional): Whether to train on the CWQ dataset
instead of WebQSP.
tiny_llama (bool, optional): Whether to use tiny LLaMA model.
Defaults to False.
Expand Down Expand Up @@ -240,10 +243,16 @@ def adjust_learning_rate(param_group, LR, epoch):

# Load dataset and create data loaders
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', '..', 'data', 'WebQSPDataset')
train_dataset = WebQSPDataset(path, split='train')
val_dataset = WebQSPDataset(path, split='val')
test_dataset = WebQSPDataset(path, split='test')
if not cwq:
path = osp.join(path, '..', '..', 'data', 'WebQSPDataset')
train_dataset = WebQSPDataset(path, split='train')
val_dataset = WebQSPDataset(path, split='val')
test_dataset = WebQSPDataset(path, split='test')
else:
path = osp.join(path, '..', '..', 'data', 'CWQDataset')
train_dataset = CWQDataset(path, split='train')
val_dataset = CWQDataset(path, split='val')
test_dataset = CWQDataset(path, split='test')

seed_everything(42)

Expand Down Expand Up @@ -388,6 +397,7 @@ def adjust_learning_rate(param_group, LR, epoch):
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--eval_batch_size', type=int, default=16)
parser.add_argument('--checkpointing', action='store_true')
parser.add_argument('--cwq', action='store_true')
parser.add_argument('--tiny_llama', action='store_true')
parser.add_argument('--llm_model_name', type=str,
default="meta-llama/Meta-Llama-3.1-8B-Instruct")
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/glem.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main(args):

dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root)
split_idx = dataset.get_idx_split()
data = dataset.data
data = dataset[0]

tag_dataset = TAGDataset(root, dataset, hf_model,
token_on_disk=token_on_disk)
Expand Down
14 changes: 11 additions & 3 deletions examples/llm/molecule_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import MoleculeGPTDataset
from torch_geometric.datasets import InstructMolDataset, MoleculeGPTDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv
from torch_geometric.nn.models import MoleculeGPT
Expand Down Expand Up @@ -44,6 +44,7 @@ def eval(model, data_loader):


def train(
dataset_name: str,
num_epochs: int,
lr: float,
batch_size: int,
Expand All @@ -65,8 +66,11 @@ def adjust_learning_rate(param_group, LR, epoch):
start_time = time.time()
# Load dataset ================================================
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', '..', 'data', 'MoleculeGPT')
dataset = MoleculeGPTDataset(path)
path = osp.join(path, '..', '..', 'data', dataset_name)
if dataset_name == 'MoleculeGPT':
dataset = MoleculeGPTDataset(path)
elif dataset_name == 'InstructMol':
dataset = InstructMolDataset(path)
train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset))
train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:train_size + val_size]
Expand Down Expand Up @@ -177,6 +181,9 @@ def adjust_learning_rate(param_group, LR, epoch):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default='MoleculeGPT',
choices=['MoleculeGPT', 'InstructMol'],
help='Support MoleculeGPT and InstructMol')
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--batch_size', type=int, default=2)
Expand All @@ -185,6 +192,7 @@ def adjust_learning_rate(param_group, LR, epoch):

start_time = time.time()
train(
args.dataset_name,
args.epochs,
args.lr,
args.batch_size,
Expand Down
8 changes: 7 additions & 1 deletion examples/ogbn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
action='store_true',
help='Whether or not to use GAT model',
)
parser.add_argument('-e', '--epochs', type=int, default=10)
parser.add_argument('-e', '--epochs', type=int, default=50)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--num_heads', type=int, default=2,
help='number of heads for GAT model.')
Expand Down Expand Up @@ -179,6 +179,8 @@ def test(loader: NeighborLoader) -> float:
lr=args.lr,
weight_decay=args.wd,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',
patience=5)

print(f'Total time before training begins took '
f'{time.perf_counter() - wall_clock_start:.4f}s')
Expand All @@ -204,6 +206,10 @@ def test(loader: NeighborLoader) -> float:
if val_acc > best_val:
best_val = val_acc
times.append(time.perf_counter() - train_start)
for param_group in optimizer.param_groups:
print('lr:')
print(param_group['lr'])
scheduler.step(val_acc)

print(f'Average Epoch Time on training: '
f'{torch.tensor(train_times).mean():.4f}s')
Expand Down
11 changes: 11 additions & 0 deletions test/datasets/test_instruct_mol_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch_geometric.datasets import InstructMolDataset
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('rdkit')
def test_instruct_mol_dataset():
dataset = InstructMolDataset(root='./data/InstructMol')
assert len(dataset) == 326689
assert dataset.num_edge_features == 4
assert dataset.num_node_features == 6
17 changes: 17 additions & 0 deletions test/metrics/test_link_pred_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,20 @@ def test_link_pred_metric_collection(num_src_nodes, num_dst_nodes, num_edges):
metric_collection.update(pred_index_mat, edge_label_index)
assert metric_collection.compute() == expected
metric_collection.reset()


def test_empty_ground_truth():
pred = torch.rand(10, 5)
pred_index_mat = pred.argsort(dim=1)
edge_label_index = torch.empty(2, 0, dtype=torch.long)
edge_label_weight = torch.empty(0)

metric = LinkPredMAP(k=5)
metric.update(pred_index_mat, edge_label_index)
assert metric.compute() == 0
metric.reset()

metric = LinkPredNDCG(k=5, weighted=True)
metric.update(pred_index_mat, edge_label_index, edge_label_weight)
assert metric.compute() == 0
metric.reset()
5 changes: 4 additions & 1 deletion torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@
from .myket import MyketDataset
from .brca_tgca import BrcaTcga
from .neurograph import NeuroGraphDataset
from .web_qsp_dataset import WebQSPDataset
from .web_qsp_dataset import WebQSPDataset, CWQDataset
from .git_mol_dataset import GitMolDataset
from .molecule_gpt_dataset import MoleculeGPTDataset
from .instruct_mol_dataset import InstructMolDataset
from .tag_dataset import TAGDataset

from .dbp15k import DBP15K
Expand Down Expand Up @@ -193,8 +194,10 @@
'BrcaTcga',
'NeuroGraphDataset',
'WebQSPDataset',
'CWQDataset',
'GitMolDataset',
'MoleculeGPTDataset',
'InstructMolDataset',
'TAGDataset',
]

Expand Down
134 changes: 134 additions & 0 deletions torch_geometric/datasets/instruct_mol_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json
import sys
from typing import Callable, List, Optional

import torch
from tqdm import tqdm

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs
from torch_geometric.utils import one_hot


class InstructMolDataset(InMemoryDataset):
r"""The dataset from the `"InstructMol: Multi-Modal Integration for
Building a Versatile and Reliable Molecular Assistant in Drug Discovery"
<https://arxiv.org/pdf/2311.16208>`_ paper.
Args:
root (str): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/blob/main'

def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
def raw_file_names(self) -> List[str]:
return ['all_clean.json']

@property
def processed_file_names(self) -> List[str]:
return ['data.pt']

def download(self) -> None:
print('downloading dataset...')
fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)

def process(self) -> None:
try:
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
WITH_RDKIT = True

except ImportError:
WITH_RDKIT = False

if not WITH_RDKIT:
print(("Using a pre-processed version of the dataset. Please "
"install 'rdkit' to alternatively process the raw data."),
file=sys.stderr)

data_list = fs.torch_load(self.raw_paths[0])
data_list = [Data(**data_dict) for data_dict in data_list]

if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]

if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]

self.save(data_list, self.processed_paths[0])
return

# types of atom and bond
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

# load data
mols = json.load(open(f'{self.raw_dir}/all_clean.json'))

data_list = []
for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
continue

x: torch.Tensor = torch.tensor([
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
for atom in mol.GetAtoms()
])
x = one_hot(x, num_classes=len(types), dtype=torch.float)

rows, cols, edge_types = [], [], []
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_types += [bonds[bond.GetBondType()]] * 2
rows += [i, j]
cols += [j, i]

edge_index = torch.tensor([rows, cols], dtype=torch.long)
edge_type = torch.tensor(edge_types, dtype=torch.long)
edge_attr = one_hot(edge_type, num_classes=len(bonds))

for question, answer in qa_pairs:
data = Data(
x=x,
edge_index=edge_index,
edge_attr=edge_attr,
smiles=smiles,
instruction=question,
y=answer,
)

if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

self.save(data_list, self.processed_paths[0])
Loading

0 comments on commit 1945d57

Please sign in to comment.