Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InstructMol dataset #9975

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
14 changes: 11 additions & 3 deletions examples/llm/molecule_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import MoleculeGPTDataset
from torch_geometric.datasets import InstructMolDataset, MoleculeGPTDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv
from torch_geometric.nn.models import MoleculeGPT
Expand Down Expand Up @@ -44,6 +44,7 @@ def eval(model, data_loader):


def train(
dataset_name: str,
num_epochs: int,
lr: float,
batch_size: int,
Expand All @@ -65,8 +66,11 @@ def adjust_learning_rate(param_group, LR, epoch):
start_time = time.time()
# Load dataset ================================================
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', '..', 'data', 'MoleculeGPT')
dataset = MoleculeGPTDataset(path)
path = osp.join(path, '..', '..', 'data', dataset_name)
if dataset_name == 'MoleculeGPT':
dataset = MoleculeGPTDataset(path)
elif dataset_name == 'InstructMol':
dataset = InstructMolDataset(path)
train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset))
train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:train_size + val_size]
Expand Down Expand Up @@ -177,6 +181,9 @@ def adjust_learning_rate(param_group, LR, epoch):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default='MoleculeGPT',
choices=['MoleculeGPT', 'InstructMol'],
help='Support MoleculeGPT and InstructMol')
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--batch_size', type=int, default=2)
Expand All @@ -185,6 +192,7 @@ def adjust_learning_rate(param_group, LR, epoch):

start_time = time.time()
train(
args.dataset_name,
args.epochs,
args.lr,
args.batch_size,
Expand Down
11 changes: 11 additions & 0 deletions test/datasets/test_instruct_mol_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch_geometric.datasets import InstructMolDataset
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('rdkit')
def test_instruct_mol_dataset():
dataset = InstructMolDataset(root='./data/InstructMol')
assert len(dataset) == 326689
assert dataset.num_edge_features == 4
assert dataset.num_node_features == 6
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .web_qsp_dataset import WebQSPDataset, CWQDataset
from .git_mol_dataset import GitMolDataset
from .molecule_gpt_dataset import MoleculeGPTDataset
from .instruct_mol_dataset import InstructMolDataset
from .tag_dataset import TAGDataset

from .dbp15k import DBP15K
Expand Down Expand Up @@ -196,6 +197,7 @@
'CWQDataset',
'GitMolDataset',
'MoleculeGPTDataset',
'InstructMolDataset',
'TAGDataset',
]

Expand Down
134 changes: 134 additions & 0 deletions torch_geometric/datasets/instruct_mol_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json
import sys
from typing import Callable, List, Optional

import torch
from tqdm import tqdm

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs
from torch_geometric.utils import one_hot


class InstructMolDataset(InMemoryDataset):
r"""The dataset from the `"InstructMol: Multi-Modal Integration for
Building a Versatile and Reliable Molecular Assistant in Drug Discovery"
<https://arxiv.org/pdf/2311.16208>`_ paper.
Args:
root (str): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/blob/main'

def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
def raw_file_names(self) -> List[str]:
return ['all_clean.json']

@property
def processed_file_names(self) -> List[str]:
return ['data.pt']

def download(self) -> None:
print('downloading dataset...')
fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)

def process(self) -> None:
try:
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
WITH_RDKIT = True

except ImportError:
WITH_RDKIT = False

if not WITH_RDKIT:
print(("Using a pre-processed version of the dataset. Please "
"install 'rdkit' to alternatively process the raw data."),
file=sys.stderr)

data_list = fs.torch_load(self.raw_paths[0])
data_list = [Data(**data_dict) for data_dict in data_list]

if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]

if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]

self.save(data_list, self.processed_paths[0])
return

# types of atom and bond
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

# load data
mols = json.load(open(f'{self.raw_dir}/all_clean.json'))

data_list = []
for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
continue

x: torch.Tensor = torch.tensor([
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
for atom in mol.GetAtoms()
])
x = one_hot(x, num_classes=len(types), dtype=torch.float)

rows, cols, edge_types = [], [], []
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_types += [bonds[bond.GetBondType()]] * 2
rows += [i, j]
cols += [j, i]

edge_index = torch.tensor([rows, cols], dtype=torch.long)
edge_type = torch.tensor(edge_types, dtype=torch.long)
edge_attr = one_hot(edge_type, num_classes=len(bonds))

for question, answer in qa_pairs:
data = Data(
x=x,
edge_index=edge_index,
edge_attr=edge_attr,
smiles=smiles,
instruction=question,
y=answer,
)

if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

self.save(data_list, self.processed_paths[0])
Loading