diff --git a/CHANGELOG.md b/CHANGELOG.md index 91da66973cef..9240420208bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710)) - Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748)) diff --git a/examples/llm/README.md b/examples/llm/README.md index e0ac02d87f2e..d860232aa56b 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,6 +1,7 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | -| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | +| Example | Description | +| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | diff --git a/examples/llm/molecule_gpt.py b/examples/llm/molecule_gpt.py new file mode 100644 index 000000000000..8f6c6024014d --- /dev/null +++ b/examples/llm/molecule_gpt.py @@ -0,0 +1,193 @@ +"""This example implements the MoleculeGPT model +(https://ai4d3.github.io/papers/34.pdf) using PyG. +""" +import argparse +import math +import os.path as osp +import time + +import torch +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm + +from torch_geometric import seed_everything +from torch_geometric.datasets import MoleculeGPTDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GINEConv +from torch_geometric.nn.models import MoleculeGPT +from torch_geometric.nn.nlp import LLM, SentenceTransformer + + +def save_params_dict(model, save_path): + state_dict = model.state_dict() + param_grad_dict = { + k: v.requires_grad + for (k, v) in model.named_parameters() + } + for k in list(state_dict.keys()): + if k in param_grad_dict.keys() and not param_grad_dict[k]: + del state_dict[k] # Delete parameters that do not require gradient + torch.save(state_dict, save_path) + + +@torch.no_grad() +def eval(model, data_loader): + model.eval() + loss = 0 + + for batch in data_loader: + batch_loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.instruction, + batch.y) + loss += batch_loss.item() / len(data_loader) + return loss + + +def train( + num_epochs: int, + lr: float, + batch_size: int, + checkpointing: bool, +): + def adjust_learning_rate(param_group, LR, epoch): + # Decay the learning rate with half-cycle cosine after warmup + min_lr = 5e-6 + warmup_epochs = 1 + if epoch < warmup_epochs: + lr = LR + else: + lr = min_lr + (LR - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / + (num_epochs - warmup_epochs))) + param_group['lr'] = lr + return lr + + start_time = time.time() + # Load dataset ================================================ + path = osp.dirname(osp.realpath(__file__)) + path = osp.join(path, '..', '..', 'data', 'MoleculeGPT') + dataset = MoleculeGPTDataset(path) + train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset)) + train_dataset = dataset[:train_size] + val_dataset = dataset[train_size:train_size + val_size] + test_dataset = dataset[train_size + val_size:] + + seed_everything(42) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + # Create model =============================================== + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + + graph_encoder = GINEConv( + nn=torch.nn.Sequential( + torch.nn.Linear(6, 768), + torch.nn.ReLU(), + torch.nn.Linear(768, 768), + ), + train_eps=True, + edge_dim=4, + ) + + smiles_encoder = SentenceTransformer( + model_name='DeepChem/ChemBERTa-77M-MTR', + pooling_strategy='last_hidden_state', + ) + + model = MoleculeGPT( + llm=llm, + graph_encoder=graph_encoder, + smiles_encoder=smiles_encoder, + ) + + # Train and eval ============================================ + params = [p for _, p in model.named_parameters() if p.requires_grad] + optimizer = torch.optim.AdamW([ + { + 'params': params, + 'lr': lr, + 'weight_decay': 0.05, + }, + ], betas=(0.9, 0.95)) + grad_steps = 2 + + best_epoch = 0 + best_val_loss = float('inf') + for epoch in range(num_epochs): + # Train + model.train() + epoch_loss = 0 + if epoch == 0: + print(f"Total Preparation Time: {time.time() - start_time:2f}s") + start_time = time.time() + print("Training beginning...") + epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' + loader = tqdm(train_loader, desc=epoch_str) + + for step, batch in enumerate(loader): + optimizer.zero_grad() + loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.instruction, + batch.y) + loss.backward() + clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) + + if (step + 1) % grad_steps == 0: + adjust_learning_rate(optimizer.param_groups[0], lr, + step / len(train_loader) + epoch) + + optimizer.step() + epoch_loss += loss.item() + + if (step + 1) % grad_steps == 0: + lr = optimizer.param_groups[0]['lr'] + train_loss = epoch_loss / len(train_loader) + + # Eval + val_loss = eval(model, val_loader) + print( + f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501 + ) + + if checkpointing and val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + save_params_dict( + model, + f'moleculegpt_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501 + ) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + print(f"Total Training Time: {time.time() - start_time:2f}s") + # Test + test_loss = eval(model, test_loader) + print(f'Test loss: {test_loss:4f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--epochs', type=int, default=3) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--checkpointing', type=bool, default=True) + args = parser.parse_args() + + start_time = time.time() + train( + args.epochs, + args.lr, + args.batch_size, + args.checkpointing, + ) + print(f'Total Time: {time.time() - start_time:2f}s') diff --git a/test/datasets/test_molecule_gpt_dataset.py b/test/datasets/test_molecule_gpt_dataset.py new file mode 100644 index 000000000000..7c00c5efc1b6 --- /dev/null +++ b/test/datasets/test_molecule_gpt_dataset.py @@ -0,0 +1,10 @@ +from torch_geometric.datasets import MoleculeGPTDataset +from torch_geometric.testing import withPackage + + +@withPackage('transformers', 'sentencepiece', 'accelerate', 'rdkit') +def test_molecule_gpt_dataset(): + dataset = MoleculeGPTDataset(root='./data/MoleculeGPT') + assert str(dataset) == f'MoleculeGPTDataset({len(dataset)})' + assert dataset.num_edge_features == 4 + assert dataset.num_node_features == 6 diff --git a/test/nn/attention/test_qformer.py b/test/nn/attention/test_qformer.py new file mode 100644 index 000000000000..0de023708fd8 --- /dev/null +++ b/test/nn/attention/test_qformer.py @@ -0,0 +1,13 @@ +import torch + +from torch_geometric.nn.attention import QFormer + + +def test_qformer(): + x = torch.randn(1, 4, 16) + attn = QFormer(input_dim=16, hidden_dim=16, output_dim=32, num_heads=4, + num_layers=2) + out = attn(x) + + assert out.shape == (1, 4, 32) + assert str(attn) == ('QFormer(num_heads=4, num_layers=2)') diff --git a/test/nn/models/test_molecule_gpt.py b/test/nn/models/test_molecule_gpt.py new file mode 100644 index 000000000000..c9f0a53403ee --- /dev/null +++ b/test/nn/models/test_molecule_gpt.py @@ -0,0 +1,60 @@ +import torch +from torch.nn import Linear as Lin +from torch.nn import ReLU +from torch.nn import Sequential as Seq + +from torch_geometric.nn import GINEConv, MoleculeGPT +from torch_geometric.nn.nlp import LLM, SentenceTransformer +from torch_geometric.testing import onlyFullTest, withPackage + + +@onlyFullTest +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_molecule_gpt() -> None: + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + + graph_encoder = GINEConv(nn=Seq(Lin(16, 16), ReLU(), Lin(16, 16)), + train_eps=True, edge_dim=16) + + smiles_encoder = SentenceTransformer( + model_name='DeepChem/ChemBERTa-77M-MTR', + pooling_strategy='last_hidden_state', + ) + + model = MoleculeGPT( + llm=llm, + graph_encoder=graph_encoder, + smiles_encoder=smiles_encoder, + ) + + assert str(model) == ( + 'MoleculeGPT(\n' + ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' + ' graph=GINEConv,\n' + ' smiles=SentenceTransformer(model_name=DeepChem/ChemBERTa-77M-MTR),\n' # noqa: E501 + ')') + + x = torch.randn(10, 16) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + ]) + edge_attr = torch.randn(edge_index.size(1), 16) + batch = torch.zeros(x.size(0), dtype=torch.long) + smiles = ['CCCCCCCCCC'] + instructions = ['What is ∼ functional related to?'] + label = ['I do not know!'] + + # Test train: + loss = model(x, edge_index, batch, edge_attr, smiles, instructions, label) + assert loss >= 0 + + # Test inference: + pred = model.inference(x, edge_index, batch, edge_attr, smiles, + instructions) + assert len(pred) == 1 diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 0b6569d3f92b..c086a85df779 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -77,6 +77,7 @@ from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset from .web_qsp_dataset import WebQSPDataset +from .molecule_gpt_dataset import MoleculeGPTDataset from .tag_dataset import TAGDataset from .dbp15k import DBP15K @@ -191,6 +192,7 @@ 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', + 'MoleculeGPTDataset', 'TAGDataset', ] diff --git a/torch_geometric/datasets/molecule_gpt_dataset.py b/torch_geometric/datasets/molecule_gpt_dataset.py new file mode 100644 index 000000000000..b1da09f38570 --- /dev/null +++ b/torch_geometric/datasets/molecule_gpt_dataset.py @@ -0,0 +1,480 @@ +import gzip +import json +import multiprocessing +import os +import sys +from collections import defaultdict +from multiprocessing import Pool +from typing import Callable, List, Optional, Tuple + +import numpy as np +import requests +import torch +from tqdm import tqdm + +from torch_geometric.data import Data, InMemoryDataset, download_url +from torch_geometric.io import fs +from torch_geometric.nn.nlp import LLM +from torch_geometric.utils import one_hot + + +def clean_up_description(description: str) -> str: + description = description + " " + + # extra adj Pure + if description.startswith("Pure "): + description = description.replace("Pure ", "") + # fix typo + if description.startswith("Mercurycombines"): + description = description.replace("Mercurycombines", + "Mercury combines") + + # a special case + description = description.replace( + "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ", + "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ") + + # a special case + description = description.replace("5-Thymidylic acid. ", + "5-Thymidylic acid. is ") + + # a special case + description = description.replace( + "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ", + "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ") + + # a special case + description = description.replace( + ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride" + " with phosphorothioic acid. "), + ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride" + " with phosphorothioic acid is ")) + + # a special case + description = description.replace("5'-Uridylic acid. ", + "5'-Uridylic acid is ") + + # a special case + description = description.replace("5'-Adenylic acid, ", + "5'-Adenylic acid is ") + + # a special case + description = description.replace( + "Uridine 5'-(tetrahydrogen triphosphate). ", + "Uridine 5'-(tetrahydrogen triphosphate). is ") + + # a special case + description = description.replace("Inosine 5'-Monophosphate. ", + "Inosine 5'-Monophosphate. is ") + + # a special case + description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ", + "Pivaloyloxymethyl butyrate (AN-9) is ") + + # a special case + description = description.replace( + "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ", + "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ") + + # a special case + description = description.replace( + "Cardamonin (also known as Dihydroxymethoxychalcone), ", + "Cardamonin (also known as Dihydroxymethoxychalcone) is ") + + # a special case + description = description.replace("Lithium has been used to treat ", + "Lithium is ") + + # a special case + description = description.replace("4,4'-Methylenebis ", + "4,4'-Methylenebis is ") + + # a special case + description = description.replace( + "2,3,7,8-Tetrachlorodibenzo-p-dioxin", + "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ") + + # a special case + description = description.replace("Exposure to 2,4,5-trichlorophenol ", + "2,4,5-Trichlorophenol exposure ") + + index = 0 + L = len(description) + if description.startswith('C.I. '): + start_index = len('C.I. ') + elif description.startswith('Nectriapyrone. D '): + start_index = len('Nectriapyrone. D ') + elif description.startswith( + 'Salmonella enterica sv. Minnesota LPS core oligosaccharide'): + start_index = len( + 'Salmonella enterica sv. Minnesota LPS core oligosaccharide') + else: + start_index = 0 + for index in range(start_index, L - 1): + if index < L - 2: + if description[index] == '.' and description[ + index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z': + break + elif index == L - 2: + break + + first_sentence = description[:index + 1] + return first_sentence + + +def extract_name(name_raw: str, description: str) -> Tuple[str, str, str]: + first_sentence = clean_up_description(description) + + splitter = ' -- -- ' + if ' are ' in first_sentence or ' were ' in first_sentence: + replaced_words = 'These molecules' + else: + replaced_words = 'This molecule' + + first_sentence = first_sentence.replace(' is ', splitter) + first_sentence = first_sentence.replace(' are ', splitter) + first_sentence = first_sentence.replace(' was ', splitter) + first_sentence = first_sentence.replace(' were ', splitter) + first_sentence = first_sentence.replace(' appears ', splitter) + first_sentence = first_sentence.replace(' occurs ', splitter) + first_sentence = first_sentence.replace(' stands for ', splitter) + first_sentence = first_sentence.replace(' belongs to ', splitter) + first_sentence = first_sentence.replace(' exists ', + splitter) # only for CID=11443 + first_sentence = first_sentence.replace(' has been used in trials ', + splitter) + first_sentence = first_sentence.replace(' has been investigated ', + splitter) + first_sentence = first_sentence.replace(' has many uses ', splitter) + + if splitter in first_sentence: + extracted_name = first_sentence.split(splitter, 1)[0] + elif first_sentence.startswith(name_raw): + extracted_name = name_raw + elif name_raw in first_sentence: + extracted_name = name_raw + extracted_name = None + print("=====", name_raw) + print("first sentence: ", first_sentence) + else: + extracted_name = None + + if extracted_name is not None: + extracted_description = description.replace(extracted_name, + replaced_words) + else: + extracted_description = description + + return extracted_name, extracted_description, first_sentence + + +class MoleculeGPTDataset(InMemoryDataset): + r"""The dataset from the `"MoleculeGPT: Instruction Following Large + Language Models for Molecular Property Prediction" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + total_page_num (int, optional): The number of pages from PubChem. + (default: :obj:`10`) + total_block_num (int, optional): The blocks of SDF files from PubChem. + (default: :obj:`1`) + """ + description_url = ( + 'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/' + 'heading/json?heading_type=Compound&heading=Record+Description&page={}' + ) + compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/' + 'CURRENT-Full/SDF') + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + total_page_num: int = 10, + total_block_num: int = 1, + ): + self.total_page_num = total_page_num + self.total_block_num = total_block_num + + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return ['pubchem.csv'] + + @property + def processed_file_names(self) -> List[str]: + return ['data.pt'] + + def download(self) -> None: + # Step 01. Extract description + step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description" + if not os.path.exists(step1_folder): + os.makedirs(step1_folder) + valid_CID_set = set() + CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict( + list) + CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict( + list) + + for page_index in tqdm(range(self.total_page_num)): + page_num = page_index + 1 + f_out = open( + f"{step1_folder}/Compound_description_{page_num}.txt", "w") + + description_data = requests.get( + self.description_url.format(page_num)).json() + + description_data = description_data["Annotations"] + assert description_data["Page"] == page_num + + record_list = description_data["Annotation"] + + for record in record_list: + try: + CID = record["LinkedRecords"]["CID"][0] + if "Name" in record: + name_raw = record["Name"] + CID2name_raw[CID].append(name_raw) + else: + name_raw = None + + data_list = record["Data"] + for data in data_list: + description = data["Value"]["StringWithMarkup"][0][ + "String"].strip() + + extracted_name, extracted_description, _ = extract_name( # noqa: E501 + name_raw, description) + if extracted_name is not None: + CID2name_extracted[CID].append(extracted_name) + + CID2text_raw[CID].append(description) + CID2text_extracted[CID].append( + extracted_description) + + valid_CID_set.add(CID) + f_out.write(f"{CID}\n") + f_out.write(f"{extracted_description}\n\n") + except Exception: + continue + + valid_CID_list = sorted(list(valid_CID_set)) + print(f"Total CID (with raw name) {len(CID2name_raw)}") + print(f"Total CID (with extracted name) {len(CID2name_extracted)}") + print(f"Total CID {len(valid_CID_list)}") + + with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f: + json.dump(CID2name_raw, f) + + with open(f"{self.raw_dir}/CID2name.json", "w") as f: + json.dump(CID2name_extracted, f) + + with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f: + json.dump(CID2text_raw, f) + + with open(f"{self.raw_dir}/CID2text.json", "w") as f: + json.dump(CID2text_extracted, f) + + # Step 02. Download SDF Files + step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" + if not os.path.exists(step2_folder): + for block_id in tqdm(range(self.total_block_num)): + block_size = 500000 + l_id = block_id * block_size + 1 + r_id = (block_id + 1) * block_size + + compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" + download_url(f"{self.compound_url}/{compound_file_name}", + step2_folder) + + def process(self, use_mp: bool = False) -> None: + try: + from rdkit import Chem + from rdkit.Chem.rdchem import BondType as BT + WITH_RDKIT = True + + except ImportError: + WITH_RDKIT = False + + if not WITH_RDKIT: + print(("Using a pre-processed version of the dataset. Please " + "install 'rdkit' to alternatively process the raw data."), + file=sys.stderr) + + data_list = fs.torch_load(self.raw_paths[0]) + data_list = [Data(**data_dict) for data_dict in data_list] + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.save(data_list, self.processed_paths[0]) + return + + # Step 03. Filter out SDF + step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" + step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered" + if not os.path.exists(step3_folder): + os.makedirs(step3_folder) + with open(f"{self.raw_dir}/CID2text.json") as f: + CID2text = json.load(f) + target_CID_list = set(CID2text.keys()) + + block_size = 500000 + + def extract_one_SDF_file(block_id: int) -> None: + valid_mol_count = 0 + + writer = Chem.SDWriter( + f'{step3_folder}/filtered_{block_id}.sdf') + l_id = block_id * block_size + 1 + r_id = (block_id + 1) * block_size + + compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" + gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}") + suppl = Chem.ForwardSDMolSupplier(gzip_loader) + + for mol in tqdm(suppl): + if mol is None: + continue + cid = mol.GetProp("PUBCHEM_COMPOUND_CID") + + if cid not in target_CID_list: + continue + + writer.write(mol) + valid_mol_count += 1 + + print(f"block id: {block_id}\nfound {valid_mol_count}\n\n") + sys.stdout.flush() + return + + if use_mp: + num_process = multiprocessing.cpu_count() + print(f"{num_process} CPUs") + num_process = 8 + p = Pool(num_process) + + block_id_list = np.arange(self.total_block_num) + with p: + p.map(extract_one_SDF_file, block_id_list) + else: + for block_id in range(self.total_block_num): + extract_one_SDF_file(block_id) + + # Step 04. Merge SDF + with open(f"{self.raw_dir}/CID2text.json") as f: + CID2text = json.load(f) + target_CID_list = set(CID2text.keys()) + print(f'The length of target_CID_list: {len(target_CID_list)}') + + writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf') + + found_CID_set = set() + for block_id in range(self.total_block_num + 1): + compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf" + try: + suppl = Chem.SDMolSupplier(compound_file_path) + + for mol in tqdm(suppl): + writer.write(mol) + cid = mol.GetProp("PUBCHEM_COMPOUND_CID") + found_CID_set.add(cid) + except Exception: + print(f"block id: {block_id} with 0 valid SDF file") + continue + + print(f"In total: {len(found_CID_set)} molecules") + + # Step 05. Convert to PyG data format + types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} + bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} + + data_list = [] + # Real data + CID2text_file = f'{self.raw_dir}/CID2text.json' + + with open(CID2text_file) as f: + CID2text_data = json.load(f) + + suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf') + + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + prompt = ("Propose a question regarding the molecule '∼' " + "whose answer is: {}:") + for mol in tqdm(suppl): + if mol.HasProp('PUBCHEM_COMPOUND_CID'): + CID = mol.GetProp("PUBCHEM_COMPOUND_CID") + CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES") + + m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES) + if m is None: + continue + RDKit_CAN_SMILES = Chem.MolToSmiles(m) + + ground_truth = CID2text_data[CID][0] + + instruction = llm.inference([prompt.format(ground_truth)])[0] + + x: torch.Tensor = torch.tensor([ + types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 + for atom in m.GetAtoms() # type: ignore + ]) + x = one_hot(x, num_classes=len(types), dtype=torch.float) + + rows, cols, edge_types = [], [], [] + for bond in m.GetBonds(): # type: ignore + i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_types += [bonds[bond.GetBondType()]] * 2 + rows += [i, j] + cols += [j, i] + + edge_index = torch.tensor([rows, cols], dtype=torch.long) + edge_type = torch.tensor(edge_types, dtype=torch.long) + edge_attr = one_hot(edge_type, num_classes=len(bonds)) + + data = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + smiles=RDKit_CAN_SMILES, + instruction=instruction, + y=ground_truth, + ) + + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save(data_list, self.processed_paths[0]) diff --git a/torch_geometric/nn/attention/__init__.py b/torch_geometric/nn/attention/__init__.py index 947d5850173b..6b4064cd34b9 100644 --- a/torch_geometric/nn/attention/__init__.py +++ b/torch_geometric/nn/attention/__init__.py @@ -1,3 +1,7 @@ from .performer import PerformerAttention +from .qformer import QFormer -__all__ = ['PerformerAttention'] +__all__ = [ + 'PerformerAttention', + 'QFormer', +] diff --git a/torch_geometric/nn/attention/qformer.py b/torch_geometric/nn/attention/qformer.py new file mode 100644 index 000000000000..3a8f512d3f83 --- /dev/null +++ b/torch_geometric/nn/attention/qformer.py @@ -0,0 +1,71 @@ +from typing import Callable + +import torch + + +class QFormer(torch.nn.Module): + r"""The Querying Transformer (Q-Former) from + `"BLIP-2: Bootstrapping Language-Image Pre-training + with Frozen Image Encoders and Large Language Models" + `_ paper. + + Args: + input_dim (int): The number of features in the input. + hidden_dim (int): The dimension of the fnn in the encoder layer. + output_dim (int): The final output dimension. + num_heads (int): The number of multi-attention-heads. + num_layers (int): The number of sub-encoder-layers in the encoder. + dropout (int): The dropout value in each encoder layer. + + + .. note:: + This is a simplified version of the original Q-Former implementation. + """ + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_heads: int, + num_layers: int, + dropout: float = 0.0, + activation: Callable = torch.nn.ReLU(), + ) -> None: + + super().__init__() + self.num_layers = num_layers + self.num_heads = num_heads + + self.layer_norm = torch.nn.LayerNorm(input_dim) + self.encoder_layer = torch.nn.TransformerEncoderLayer( + d_model=input_dim, + nhead=num_heads, + dim_feedforward=hidden_dim, + dropout=dropout, + activation=activation, + batch_first=True, + ) + self.encoder = torch.nn.TransformerEncoder( + self.encoder_layer, + num_layers=num_layers, + ) + self.project = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""Forward pass. + + Args: + x (torch.Tensor): Input sequence to the encoder layer. + :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with + batch-size :math:`B`, sequence length :math:`N`, + and feature dimension :math:`F`. + """ + x = self.layer_norm(x) + x = self.encoder(x) + out = self.project(x) + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'num_heads={self.num_heads}, ' + f'num_layers={self.num_layers})') diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 5860db311ac3..9aeac020264a 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -29,6 +29,7 @@ from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet from .g_retriever import GRetriever +from .molecule_gpt import MoleculeGPT from .glem import GLEM # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, @@ -77,5 +78,6 @@ 'NeuralFingerprint', 'ViSNet', 'GRetriever', + 'MoleculeGPT', 'GLEM', ] diff --git a/torch_geometric/nn/models/molecule_gpt.py b/torch_geometric/nn/models/molecule_gpt.py new file mode 100644 index 000000000000..a0ac73ad9abb --- /dev/null +++ b/torch_geometric/nn/models/molecule_gpt.py @@ -0,0 +1,222 @@ +from typing import List, Optional + +import torch +from torch import Tensor + +from torch_geometric.nn.attention import QFormer +from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS +from torch_geometric.utils import to_dense_batch + + +def pad_or_truncate(embeddings: Tensor, max_seq_len: int, + padding_value: int = 0) -> Tensor: + batch_size, current_seq_len, d = embeddings.size() + + if current_seq_len > max_seq_len: + return embeddings[:, :max_seq_len, :] + elif current_seq_len < max_seq_len: + pad_tensor = torch.full((batch_size, max_seq_len - current_seq_len, d), + padding_value, dtype=embeddings.dtype, + device=embeddings.device) + return torch.cat([embeddings, pad_tensor], dim=1) + else: + return embeddings + + +class MoleculeGPT(torch.nn.Module): + r"""The MoleculeGPT model from the `"MoleculeGPT: Instruction + Following Large Language Models for Molecular Property Prediction" + `_ paper. + + Args: + llm (LLM): The LLM to use. + graph_encoder (torch.nn.Module): Encode 2D molecule graph. + smiles_encoder (torch.nn.Module): Encode 1D SMILES. + mlp_out_channels (int, optional): The size of each embedding + after qformer encoding. (default: :obj:`32`) + max_tokens (int, optional): Max output tokens of 1D/2D encoder. + (default: :obj:`20`) + + .. warning:: + This module has been tested with the following HuggingFace models + + * :obj:`llm_to_use="lmsys/vicuna-7b-v1.5"` + + and may not work with other models. See other models at `HuggingFace + Models `_ and let us know if you + encounter any issues. + + .. note:: + For an example of using :class:`MoleculeGPT`, see + `examples/llm/molecule_gpt.py `_. + """ + def __init__( + self, + llm: LLM, + graph_encoder: torch.nn.Module, + smiles_encoder: torch.nn.Module, + mlp_out_channels: int = 32, + max_tokens: Optional[int] = 20, + ) -> None: + super().__init__() + self.llm = llm + self.graph_encoder = graph_encoder.to(self.llm.device) + self.smiles_encoder = smiles_encoder.to(self.llm.device) + + self.graph_qformer = QFormer( + input_dim=self.graph_encoder.nn[-1].out_features, + hidden_dim=mlp_out_channels, + output_dim=mlp_out_channels, + num_heads=4, + num_layers=2, + ).to(self.llm.device) + + self.smiles_qformer = QFormer( + input_dim=self.smiles_encoder.model.pooler.dense.out_features, + hidden_dim=mlp_out_channels, + output_dim=mlp_out_channels, + num_heads=4, + num_layers=2, + ).to(self.llm.device) + + self.max_tokens = max_tokens + + self.word_embedding = self.llm.word_embedding + self.llm_generator = self.llm.llm + + # LLMs + in_dim = 2 * mlp_out_channels * max_tokens + out_dim = self.llm.llm.model.embed_tokens.embedding_dim + self.projector = torch.nn.Sequential( + torch.nn.Linear(in_dim, in_dim), + torch.nn.Sigmoid(), + torch.nn.Linear(in_dim, out_dim), + ).to(self.llm.device) + + def encode( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + ) -> Tensor: + batch_size = len(smiles) + # 2D Graph Branch: [bs, node_len, d] + x = x.to(self.llm.device) + edge_index = edge_index.to(self.llm.device) + if edge_attr is not None: + edge_attr = edge_attr.to(self.llm.device) + batch = batch.to(self.llm.device) + + x_graph = self.graph_encoder(x, edge_index, edge_attr=edge_attr) + x_graph = to_dense_batch(x_graph, batch)[0] + out_graph = self.graph_qformer(x_graph) + out_graph = pad_or_truncate(out_graph, max_seq_len=self.max_tokens, + padding_value=0) + out_graph = out_graph.view(batch_size, -1) + + # 1D SMILES Branch: [bs, seq_len, d] + x_smiles = self.smiles_encoder.encode(smiles, + output_device=self.llm.device) + out_smiles = self.smiles_qformer(x_smiles) + out_smiles = pad_or_truncate(out_smiles, max_seq_len=self.max_tokens, + padding_value=0) + out_smiles = out_smiles.view(batch_size, -1) + + # Merge into LLMs + x_cat = torch.cat([out_graph, out_smiles], dim=1) + return x_cat + + def forward( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + instructions: List[str], + label: List[str], + additional_text_context: Optional[List[str]] = None, + ): + x = self.encode(x, edge_index, batch, edge_attr, smiles) + x = self.projector(x) + xs = x.split(1, dim=0) + + batch_unique = batch.unique() + batch_size = len(instructions) + if len(batch_unique) < batch_size: + xs = [ + xs[i] if i in batch_unique else None for i in range(batch_size) + ] + + ( + inputs_embeds, + attention_mask, + label_input_ids, + ) = self.llm._get_embeds(instructions, additional_text_context, xs, + label) + + with self.llm.autocast_context: + outputs = self.llm_generator( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=label_input_ids, + ) + + return outputs.loss + + @torch.no_grad() + def inference( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + instructions: List[str], + additional_text_context: Optional[List[str]] = None, + max_out_tokens: Optional[int] = MAX_NEW_TOKENS, + ): + x = self.encode(x, edge_index, batch, edge_attr, smiles) + x = self.projector(x) + xs = x.split(1, dim=0) + + # Handle questions without node features: + batch_unique = batch.unique() + batch_size = len(instructions) + if len(batch_unique) < batch_size: + xs = [ + xs[i] if i in batch_unique else None for i in range(batch_size) + ] + + inputs_embeds, attention_mask, _ = self.llm._get_embeds( + instructions, additional_text_context, xs) + + bos_token = self.llm.tokenizer( + BOS, + add_special_tokens=False, + ).input_ids[0] + + with self.llm.autocast_context: + outputs = self.llm_generator.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=max_out_tokens, + attention_mask=attention_mask, + bos_token_id=bos_token, + use_cache=True # Important to set! + ) + + return self.llm.tokenizer.batch_decode( + outputs, + skip_special_tokens=True, + ) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(\n' + f' llm={self.llm},\n' + f' graph={self.graph_encoder.__class__.__name__},\n' + f' smiles={self.smiles_encoder},\n' + f')') diff --git a/torch_geometric/nn/nlp/llm.py b/torch_geometric/nn/nlp/llm.py index b58059f8e098..d18aa42382f7 100644 --- a/torch_geometric/nn/nlp/llm.py +++ b/torch_geometric/nn/nlp/llm.py @@ -56,7 +56,7 @@ class LLM(torch.nn.Module): allocate the correct number of GPUs needed, given the available GPU memory of your GPUs. dtype (torch.dtype, optional): The data type to use for the LLM. - (default :obj: `torch.bloat16`) + (default :obj: `torch.bfloat16`) """ def __init__( self, diff --git a/torch_geometric/nn/nlp/sentence_transformer.py b/torch_geometric/nn/nlp/sentence_transformer.py index c66677e8fa24..715f343bfc19 100644 --- a/torch_geometric/nn/nlp/sentence_transformer.py +++ b/torch_geometric/nn/nlp/sentence_transformer.py @@ -10,6 +10,7 @@ class PoolingStrategy(Enum): MEAN = 'mean' LAST = 'last' CLS = 'cls' + LAST_HIDDEN_STATE = 'last_hidden_state' class SentenceTransformer(torch.nn.Module): @@ -38,6 +39,8 @@ def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: emb = mean_pooling(emb, attention_mask) elif self.pooling_strategy == PoolingStrategy.LAST: emb = last_pooling(emb, attention_mask) + elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE: + emb = out.last_hidden_state else: assert self.pooling_strategy == PoolingStrategy.CLS emb = emb[:, 0, :]