Skip to content

Commit

Permalink
Add MoleculeGPT (#9710)
Browse files Browse the repository at this point in the history
### Issue
- #9694 
- #9698

### Feature Summary

- Add `MoleculeGPTDataset`
- Add `MoleculeGPT` as GNN & LLM Co-training model to PyG
- Add an example for training and testing
- Split the PR into 3 sub-PRs (#9723, #9724, #9725)
- Limited hardware resources, can't load `lmsys/vicuna-7b-v1.5`, use
`TinyLlama/TinyLlama-1.1B-Chat-v0.1` instead, and the full training
pipeline was not tested

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Giovanni Gatti <[email protected]>
Co-authored-by: Rishi Puri <[email protected]>
  • Loading branch information
4 people authored Nov 20, 2024
1 parent e1a925b commit 529237c
Show file tree
Hide file tree
Showing 14 changed files with 1,068 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710))
- Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662))
- Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662))
- Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748))
Expand Down
9 changes: 5 additions & 4 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| Example | Description |
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
193 changes: 193 additions & 0 deletions examples/llm/molecule_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""This example implements the MoleculeGPT model
(https://ai4d3.github.io/papers/34.pdf) using PyG.
"""
import argparse
import math
import os.path as osp
import time

import torch
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import MoleculeGPTDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv
from torch_geometric.nn.models import MoleculeGPT
from torch_geometric.nn.nlp import LLM, SentenceTransformer


def save_params_dict(model, save_path):
state_dict = model.state_dict()
param_grad_dict = {
k: v.requires_grad
for (k, v) in model.named_parameters()
}
for k in list(state_dict.keys()):
if k in param_grad_dict.keys() and not param_grad_dict[k]:
del state_dict[k] # Delete parameters that do not require gradient
torch.save(state_dict, save_path)


@torch.no_grad()
def eval(model, data_loader):
model.eval()
loss = 0

for batch in data_loader:
batch_loss = model(batch.x, batch.edge_index, batch.batch,
batch.edge_attr, batch.smiles, batch.instruction,
batch.y)
loss += batch_loss.item() / len(data_loader)
return loss


def train(
num_epochs: int,
lr: float,
batch_size: int,
checkpointing: bool,
):
def adjust_learning_rate(param_group, LR, epoch):
# Decay the learning rate with half-cycle cosine after warmup
min_lr = 5e-6
warmup_epochs = 1
if epoch < warmup_epochs:
lr = LR
else:
lr = min_lr + (LR - min_lr) * 0.5 * (
1.0 + math.cos(math.pi * (epoch - warmup_epochs) /
(num_epochs - warmup_epochs)))
param_group['lr'] = lr
return lr

start_time = time.time()
# Load dataset ================================================
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', '..', 'data', 'MoleculeGPT')
dataset = MoleculeGPTDataset(path)
train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset))
train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:train_size + val_size]
test_dataset = dataset[train_size + val_size:]

seed_everything(42)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
drop_last=True, pin_memory=True, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
drop_last=False, pin_memory=True, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
drop_last=False, pin_memory=True, shuffle=False)

# Create model ===============================================
llm = LLM(
# model_name='lmsys/vicuna-7b-v1.5',
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
num_params=1,
dtype=torch.bfloat16,
)

graph_encoder = GINEConv(
nn=torch.nn.Sequential(
torch.nn.Linear(6, 768),
torch.nn.ReLU(),
torch.nn.Linear(768, 768),
),
train_eps=True,
edge_dim=4,
)

smiles_encoder = SentenceTransformer(
model_name='DeepChem/ChemBERTa-77M-MTR',
pooling_strategy='last_hidden_state',
)

model = MoleculeGPT(
llm=llm,
graph_encoder=graph_encoder,
smiles_encoder=smiles_encoder,
)

# Train and eval ============================================
params = [p for _, p in model.named_parameters() if p.requires_grad]
optimizer = torch.optim.AdamW([
{
'params': params,
'lr': lr,
'weight_decay': 0.05,
},
], betas=(0.9, 0.95))
grad_steps = 2

best_epoch = 0
best_val_loss = float('inf')
for epoch in range(num_epochs):
# Train
model.train()
epoch_loss = 0
if epoch == 0:
print(f"Total Preparation Time: {time.time() - start_time:2f}s")
start_time = time.time()
print("Training beginning...")
epoch_str = f'Epoch: {epoch + 1}|{num_epochs}'
loader = tqdm(train_loader, desc=epoch_str)

for step, batch in enumerate(loader):
optimizer.zero_grad()
loss = model(batch.x, batch.edge_index, batch.batch,
batch.edge_attr, batch.smiles, batch.instruction,
batch.y)
loss.backward()
clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)

if (step + 1) % grad_steps == 0:
adjust_learning_rate(optimizer.param_groups[0], lr,
step / len(train_loader) + epoch)

optimizer.step()
epoch_loss += loss.item()

if (step + 1) % grad_steps == 0:
lr = optimizer.param_groups[0]['lr']
train_loss = epoch_loss / len(train_loader)

# Eval
val_loss = eval(model, val_loader)
print(
f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501
)

if checkpointing and val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch
save_params_dict(
model,
f'moleculegpt_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501
)
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

print(f"Total Training Time: {time.time() - start_time:2f}s")
# Test
test_loss = eval(model, test_loader)
print(f'Test loss: {test_loss:4f}')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--checkpointing', type=bool, default=True)
args = parser.parse_args()

start_time = time.time()
train(
args.epochs,
args.lr,
args.batch_size,
args.checkpointing,
)
print(f'Total Time: {time.time() - start_time:2f}s')
10 changes: 10 additions & 0 deletions test/datasets/test_molecule_gpt_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torch_geometric.datasets import MoleculeGPTDataset
from torch_geometric.testing import withPackage


@withPackage('transformers', 'sentencepiece', 'accelerate', 'rdkit')
def test_molecule_gpt_dataset():
dataset = MoleculeGPTDataset(root='./data/MoleculeGPT')
assert str(dataset) == f'MoleculeGPTDataset({len(dataset)})'
assert dataset.num_edge_features == 4
assert dataset.num_node_features == 6
13 changes: 13 additions & 0 deletions test/nn/attention/test_qformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch

from torch_geometric.nn.attention import QFormer


def test_qformer():
x = torch.randn(1, 4, 16)
attn = QFormer(input_dim=16, hidden_dim=16, output_dim=32, num_heads=4,
num_layers=2)
out = attn(x)

assert out.shape == (1, 4, 32)
assert str(attn) == ('QFormer(num_heads=4, num_layers=2)')
60 changes: 60 additions & 0 deletions test/nn/models/test_molecule_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
from torch.nn import Linear as Lin
from torch.nn import ReLU
from torch.nn import Sequential as Seq

from torch_geometric.nn import GINEConv, MoleculeGPT
from torch_geometric.nn.nlp import LLM, SentenceTransformer
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('transformers', 'sentencepiece', 'accelerate')
def test_molecule_gpt() -> None:
llm = LLM(
# model_name='lmsys/vicuna-7b-v1.5',
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
num_params=1,
dtype=torch.bfloat16,
)

graph_encoder = GINEConv(nn=Seq(Lin(16, 16), ReLU(), Lin(16, 16)),
train_eps=True, edge_dim=16)

smiles_encoder = SentenceTransformer(
model_name='DeepChem/ChemBERTa-77M-MTR',
pooling_strategy='last_hidden_state',
)

model = MoleculeGPT(
llm=llm,
graph_encoder=graph_encoder,
smiles_encoder=smiles_encoder,
)

assert str(model) == (
'MoleculeGPT(\n'
' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n'
' graph=GINEConv,\n'
' smiles=SentenceTransformer(model_name=DeepChem/ChemBERTa-77M-MTR),\n' # noqa: E501
')')

x = torch.randn(10, 16)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
])
edge_attr = torch.randn(edge_index.size(1), 16)
batch = torch.zeros(x.size(0), dtype=torch.long)
smiles = ['CCCCCCCCCC']
instructions = ['What is ∼ functional related to?']
label = ['I do not know!']

# Test train:
loss = model(x, edge_index, batch, edge_attr, smiles, instructions, label)
assert loss >= 0

# Test inference:
pred = model.inference(x, edge_index, batch, edge_attr, smiles,
instructions)
assert len(pred) == 1
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from .brca_tgca import BrcaTcga
from .neurograph import NeuroGraphDataset
from .web_qsp_dataset import WebQSPDataset
from .molecule_gpt_dataset import MoleculeGPTDataset
from .tag_dataset import TAGDataset

from .dbp15k import DBP15K
Expand Down Expand Up @@ -191,6 +192,7 @@
'BrcaTcga',
'NeuroGraphDataset',
'WebQSPDataset',
'MoleculeGPTDataset',
'TAGDataset',
]

Expand Down
Loading

0 comments on commit 529237c

Please sign in to comment.