From b254d4654d3864ee959afccd564677575871d889 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Tue, 23 Aug 2022 16:56:56 +0800 Subject: [PATCH] merge data workflow to main (#48) Added data workflow for fastfold --- README.md | 30 +++- docker/Dockerfile | 2 +- fastfold/workflow/__init__.py | 1 + fastfold/workflow/factory/__init__.py | 5 + fastfold/workflow/factory/hhblits.py | 29 ++++ fastfold/workflow/factory/hhfilter.py | 33 +++++ fastfold/workflow/factory/hhsearch.py | 38 +++++ fastfold/workflow/factory/jackhmmer.py | 34 +++++ fastfold/workflow/factory/task_factory.py | 50 +++++++ fastfold/workflow/template/__init__.py | 1 + .../template/fastfold_data_workflow.py | 140 ++++++++++++++++++ fastfold/workflow/workflow_run.py | 25 ++++ inference.py | 48 ++++-- inference.sh | 13 ++ 14 files changed, 432 insertions(+), 17 deletions(-) create mode 100644 fastfold/workflow/__init__.py create mode 100644 fastfold/workflow/factory/__init__.py create mode 100644 fastfold/workflow/factory/hhblits.py create mode 100644 fastfold/workflow/factory/hhfilter.py create mode 100644 fastfold/workflow/factory/hhsearch.py create mode 100644 fastfold/workflow/factory/jackhmmer.py create mode 100644 fastfold/workflow/factory/task_factory.py create mode 100644 fastfold/workflow/template/__init__.py create mode 100644 fastfold/workflow/template/fastfold_data_workflow.py create mode 100644 fastfold/workflow/workflow_run.py create mode 100755 inference.sh diff --git a/README.md b/README.md index 5cf3623b..6e720b2d 100644 --- a/README.md +++ b/README.md @@ -18,14 +18,16 @@ FastFold provides a **high-performance implementation of Evoformer** with the fo 3. Ease of use * Huge performance gains with a few lines changes * You don't need to care about how the parallel part is implemented +4. Faster data processing, about 3x times faster than the original way ## Installation To install and use FastFold, you will need: -+ Python 3.8 or later ++ Python 3.8 or 3.9. + [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.1 or above + PyTorch 1.10 or above + For now, You can install FastFold: ### Using Conda (Recommended) @@ -116,6 +118,32 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ --hhsearch_binary_path `which hhsearch` \ --kalign_binary_path `which kalign` ``` +or run the script `./inference.sh`, you can change the parameter in the script, especisally those data path. +```shell +./inference.sh +``` + +#### inference with data workflow +Alphafold's data pre-processing takes a lot of time, so we speed up the data pre-process by [ray](https://docs.ray.io/en/latest/workflows/concepts.html) workflow, which achieves a 3x times faster speed. To run the intference with ray workflow, you should install the package and add parameter `--enable_workflow` to cmdline or shell script `./inference.sh` +```shell +pip install ray==1.13.0 pyarrow +``` +```shell +python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ + --output_dir ./ \ + --gpus 2 \ + --uniref90_database_path data/uniref90/uniref90.fasta \ + --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ + --pdb70_database_path data/pdb70/pdb70 \ + --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ + --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ + --jackhmmer_binary_path `which jackhmmer` \ + --hhblits_binary_path `which hhblits` \ + --hhsearch_binary_path `which hhsearch` \ + --kalign_binary_path `which kalign` \ + --enable_workflow +``` + ## Performance Benchmark diff --git a/docker/Dockerfile b/docker/Dockerfile index 7e143199..4f8d9c92 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -9,7 +9,7 @@ RUN conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pyt && conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 numpy==1.21.2 \ - PyYAML==5.4.1 requests==2.26.0 scipy==1.7.1 tqdm==4.62.2 typing-extensions==3.10.0.2 einops + PyYAML==5.4.1 requests==2.26.0 scipy==1.7.1 tqdm==4.62.2 typing-extensions==3.10.0.2 einops ray==1.13.0 pyarrow RUN pip install colossalai==0.1.8+torch1.10cu11.3 -f https://release.colossalai.org diff --git a/fastfold/workflow/__init__.py b/fastfold/workflow/__init__.py new file mode 100644 index 00000000..8a0ce50c --- /dev/null +++ b/fastfold/workflow/__init__.py @@ -0,0 +1 @@ +from .workflow_run import batch_run \ No newline at end of file diff --git a/fastfold/workflow/factory/__init__.py b/fastfold/workflow/factory/__init__.py new file mode 100644 index 00000000..6e70de4d --- /dev/null +++ b/fastfold/workflow/factory/__init__.py @@ -0,0 +1,5 @@ +from .task_factory import TaskFactory +from .hhblits import HHBlitsFactory +from .hhsearch import HHSearchFactory +from .jackhmmer import JackHmmerFactory +from .hhfilter import HHfilterFactory \ No newline at end of file diff --git a/fastfold/workflow/factory/hhblits.py b/fastfold/workflow/factory/hhblits.py new file mode 100644 index 00000000..aecfcc55 --- /dev/null +++ b/fastfold/workflow/factory/hhblits.py @@ -0,0 +1,29 @@ +from ray import workflow +from typing import List +from fastfold.workflow.factory import TaskFactory +from ray.workflow.common import Workflow +import fastfold.data.tools.hhblits as ffHHBlits + +class HHBlitsFactory(TaskFactory): + + keywords = ['binary_path', 'databases', 'n_cpu'] + + def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # setup runner + runner = ffHHBlits.HHBlits( + binary_path=self.config['binary_path'], + databases=self.config['databases'], + n_cpu=self.config['n_cpu'] + ) + + # generate step function + @workflow.step + def hhblits_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: + result = runner.query(fasta_path) + with open(output_path, "w") as f: + f.write(result["a3m"]) + + return hhblits_step.step(fasta_path, output_path, after) diff --git a/fastfold/workflow/factory/hhfilter.py b/fastfold/workflow/factory/hhfilter.py new file mode 100644 index 00000000..de680610 --- /dev/null +++ b/fastfold/workflow/factory/hhfilter.py @@ -0,0 +1,33 @@ +import subprocess +import logging +from ray import workflow +from typing import List +from fastfold.workflow.factory import TaskFactory +from ray.workflow.common import Workflow + +class HHfilterFactory(TaskFactory): + + keywords = ['binary_path'] + + def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # generate step function + @workflow.step + def hhfilter_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: + + cmd = [ + self.config.get('binary_path'), + ] + if 'id' in self.config: + cmd += ['-id', str(self.config.get('id'))] + if 'cov' in self.config: + cmd += ['-cov', str(self.config.get('cov'))] + cmd += ['-i', fasta_path, '-o', output_path] + + logging.info(f"HHfilter start: {' '.join(cmd)}") + + subprocess.run(cmd) + + return hhfilter_step.step(fasta_path, output_path, after) diff --git a/fastfold/workflow/factory/hhsearch.py b/fastfold/workflow/factory/hhsearch.py new file mode 100644 index 00000000..d315de07 --- /dev/null +++ b/fastfold/workflow/factory/hhsearch.py @@ -0,0 +1,38 @@ +from fastfold.workflow.factory import TaskFactory +from ray import workflow +from ray.workflow.common import Workflow +import fastfold.data.tools.hhsearch as ffHHSearch +from typing import List + +class HHSearchFactory(TaskFactory): + + keywords = ['binary_path', 'databases', 'n_cpu'] + + def gen_task(self, a3m_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # setup runner + runner = ffHHSearch.HHSearch( + binary_path=self.config['binary_path'], + databases=self.config['databases'], + n_cpu=self.config['n_cpu'] + ) + + # generate step function + @workflow.step + def hhsearch_step(a3m_path: str, output_path: str, after: List[Workflow], atab_path: str = None) -> None: + + with open(a3m_path, "r") as f: + a3m = f.read() + if atab_path: + hhsearch_result, atab = runner.query(a3m, gen_atab=True) + else: + hhsearch_result = runner.query(a3m) + with open(output_path, "w") as f: + f.write(hhsearch_result) + if atab_path: + with open(atab_path, "w") as f: + f.write(atab) + + return hhsearch_step.step(a3m_path, output_path, after) diff --git a/fastfold/workflow/factory/jackhmmer.py b/fastfold/workflow/factory/jackhmmer.py new file mode 100644 index 00000000..ebba4ba9 --- /dev/null +++ b/fastfold/workflow/factory/jackhmmer.py @@ -0,0 +1,34 @@ +from fastfold.workflow.factory import TaskFactory +from ray import workflow +from ray.workflow.common import Workflow +import fastfold.data.tools.jackhmmer as ffJackHmmer +from fastfold.data import parsers +from typing import List + +class JackHmmerFactory(TaskFactory): + + keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits'] + + def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # setup runner + runner = ffJackHmmer.Jackhmmer( + binary_path=self.config['binary_path'], + database_path=self.config['database_path'], + n_cpu=self.config['n_cpu'] + ) + + # generate step function + @workflow.step + def jackhmmer_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: + result = runner.query(fasta_path)[0] + uniref90_msa_a3m = parsers.convert_stockholm_to_a3m( + result['sto'], + max_sequences=self.config['uniref_max_hits'] + ) + with open(output_path, "w") as f: + f.write(uniref90_msa_a3m) + + return jackhmmer_step.step(fasta_path, output_path, after) diff --git a/fastfold/workflow/factory/task_factory.py b/fastfold/workflow/factory/task_factory.py new file mode 100644 index 00000000..dd8c739e --- /dev/null +++ b/fastfold/workflow/factory/task_factory.py @@ -0,0 +1,50 @@ +from ast import keyword +import json +from ray.workflow.common import Workflow +from os import path +from typing import List + +class TaskFactory: + + keywords = [] + + def __init__(self, config: dict = None, config_path: str = None) -> None: + + # skip if no keyword required from config file + if not self.__class__.keywords: + return + + # setting config for factory + if config is not None: + self.config = config + elif config_path is not None: + self.loadConfig(config_path) + else: + self.loadConfig() + + def configure(self, config: dict, purge=False) -> None: + if purge: + self.config = config + else: + self.config.update(config) + + def configure(self, keyword: str, value: any) -> None: + self.config[keyword] = value + + def gen_task(self, after: List[Workflow]=None, *args, **kwargs) -> Workflow: + raise NotImplementedError + + def isReady(self): + for key in self.__class__.keywords: + if key not in self.config: + raise KeyError(f"{self.__class__.__name__} not ready: \"{key}\" not specified") + + def loadConfig(self, config_path='./config.json'): + with open(config_path) as configFile: + globalConfig = json.load(configFile) + if 'tools' not in globalConfig: + raise KeyError("\"tools\" not found in global config file") + factoryName = self.__class__.__name__[:-7] + if factoryName not in globalConfig['tools']: + raise KeyError(f"\"{factoryName}\" not found in the \"tools\" section in config") + self.config = globalConfig['tools'][factoryName] \ No newline at end of file diff --git a/fastfold/workflow/template/__init__.py b/fastfold/workflow/template/__init__.py new file mode 100644 index 00000000..f9c45c56 --- /dev/null +++ b/fastfold/workflow/template/__init__.py @@ -0,0 +1 @@ +from .fastfold_data_workflow import FastFoldDataWorkFlow \ No newline at end of file diff --git a/fastfold/workflow/template/fastfold_data_workflow.py b/fastfold/workflow/template/fastfold_data_workflow.py new file mode 100644 index 00000000..449c243d --- /dev/null +++ b/fastfold/workflow/template/fastfold_data_workflow.py @@ -0,0 +1,140 @@ +import os +import time +from multiprocessing import cpu_count +from ray import workflow +from fastfold.workflow.factory import JackHmmerFactory, HHSearchFactory, HHBlitsFactory +from fastfold.workflow import batch_run +from typing import Optional + +class FastFoldDataWorkFlow: + def __init__( + self, + jackhmmer_binary_path: Optional[str] = None, + hhblits_binary_path: Optional[str] = None, + hhsearch_binary_path: Optional[str] = None, + uniref90_database_path: Optional[str] = None, + mgnify_database_path: Optional[str] = None, + bfd_database_path: Optional[str] = None, + uniclust30_database_path: Optional[str] = None, + pdb70_database_path: Optional[str] = None, + use_small_bfd: Optional[bool] = None, + no_cpus: Optional[int] = None, + uniref_max_hits: int = 10000, + mgnify_max_hits: int = 5000, + ): + self.db_map = { + "jackhmmer": { + "binary": jackhmmer_binary_path, + "dbs": [ + uniref90_database_path, + mgnify_database_path, + bfd_database_path if use_small_bfd else None, + ], + }, + "hhblits": { + "binary": hhblits_binary_path, + "dbs": [ + bfd_database_path if not use_small_bfd else None, + ], + }, + "hhsearch": { + "binary": hhsearch_binary_path, + "dbs": [ + pdb70_database_path, + ], + }, + } + + for name, dic in self.db_map.items(): + binary, dbs = dic["binary"], dic["dbs"] + if(binary is None and not all([x is None for x in dbs])): + raise ValueError( + f"{name} DBs provided but {name} binary is None" + ) + + if(not all([x is None for x in self.db_map["hhsearch"]["dbs"]]) + and uniref90_database_path is None): + raise ValueError( + """uniref90_database_path must be specified in order to perform + template search""" + ) + + self.use_small_bfd = use_small_bfd + self.uniref_max_hits = uniref_max_hits + self.mgnify_max_hits = mgnify_max_hits + + if(no_cpus is None): + self.no_cpus = cpu_count() + else: + self.no_cpus = no_cpus + + def run(self, fasta_path: str, output_dir: str, alignment_dir: str=None) -> None: + + localtime = time.asctime(time.localtime(time.time())) + workflow_id = 'fastfold_data_workflow ' + str(localtime) + # clearing remaining ray workflow data + try: + workflow.cancel(workflow_id) + workflow.delete(workflow_id) + except: + print("Workflow not found. Clean. Skipping") + pass + + # prepare alignment directory for alignment outputs + if alignment_dir is None: + alignment_dir = os.path.join(output_dir, "alignment") + if not os.path.exists(alignment_dir): + os.makedirs(alignment_dir) + + # Run JackHmmer on UNIREF90 + # create JackHmmer workflow generator + jh_config = { + "binary_path": self.db_map["jackhmmer"]["binary"], + "database_path": self.db_map["jackhmmer"]["dbs"][0], + "n_cpu": self.no_cpus, + "uniref_max_hits": self.uniref_max_hits, + } + jh_fac = JackHmmerFactory(config = jh_config) + # set jackhmmer output path + uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m") + # generate the workflow with i/o path + wf1 = jh_fac.gen_task(fasta_path, uniref90_out_path) + + #Run HHSearch on STEP1's result with PDB70""" + # create HHSearch workflow generator + hhs_config = { + "binary_path": self.db_map["hhsearch"]["binary"], + "databases": self.db_map["hhsearch"]["dbs"], + "n_cpu": self.no_cpus, + } + hhs_fac = HHSearchFactory(config=hhs_config) + # set HHSearch output path + pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr") + # generate the workflow (STEP2 depend on STEP1) + wf2 = hhs_fac.gen_task(uniref90_out_path, pdb70_out_path, after=[wf1]) + + # Run JackHmmer on MGNIFY + # reconfigure jackhmmer factory to use MGNIFY DB instead + jh_fac.configure('database_path', self.db_map["jackhmmer"]["dbs"][1]) + # set jackhmmer output path + mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m") + # generate workflow for STEP3 + wf3 = jh_fac.gen_task(fasta_path, mgnify_out_path) + + # Run HHBlits on BFD + # create HHBlits workflow generator + hhb_config = { + "binary_path": self.db_map["hhblits"]["binary"], + "databases": self.db_map["hhblits"]["dbs"], + "n_cpu": self.no_cpus, + } + hhb_fac = HHBlitsFactory(config=hhb_config) + # set HHBlits output path + bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m") + # generate workflow for STEP4 + wf4 = hhb_fac.gen_task(fasta_path, bfd_out_path) + + # run workflow + batch_run(wfs=[wf2, wf3, wf4], workflow_id=workflow_id) + + return \ No newline at end of file diff --git a/fastfold/workflow/workflow_run.py b/fastfold/workflow/workflow_run.py new file mode 100644 index 00000000..196dccfa --- /dev/null +++ b/fastfold/workflow/workflow_run.py @@ -0,0 +1,25 @@ +from ast import Call +from typing import Callable, List +from ray.workflow.common import Workflow +from ray import workflow + +def batch_run(wfs: List[Workflow], workflow_id: str) -> None: + + @workflow.step + def batch_step(wfs) -> None: + return + + batch_wf = batch_step.step(wfs) + + batch_wf.run(workflow_id=workflow_id) + +def wf(after: List[Workflow]=None): + def decorator(f: Callable): + + @workflow.step + def step_func(after: List[Workflow]) -> None: + f() + + return step_func.step(after) + + return decorator diff --git a/inference.py b/inference.py index 56025c96..5ad36dd7 100644 --- a/inference.py +++ b/inference.py @@ -31,6 +31,7 @@ from fastfold.config import model_config from fastfold.model.fastnn import set_chunk_size from fastfold.data import data_pipeline, feature_pipeline, templates +from fastfold.workflow.template import FastFoldDataWorkFlow from fastfold.utils import inject_fastnn from fastfold.data.parsers import parse_fasta from fastfold.utils.import_weights import import_jax_weights_ @@ -74,7 +75,7 @@ def add_data_args(parser: argparse.ArgumentParser): ) parser.add_argument('--obsolete_pdbs_path', type=str, default=None) parser.add_argument('--release_dates_path', type=str, default=None) - + parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not') def inference_model(rank, world_size, result_q, batch, args): os.environ['RANK'] = str(rank) @@ -157,20 +158,37 @@ def main(args): if (args.use_precomputed_alignments is None): if not os.path.exists(local_alignment_dir): os.makedirs(local_alignment_dir) - - alignment_runner = data_pipeline.AlignmentRunner( - jackhmmer_binary_path=args.jackhmmer_binary_path, - hhblits_binary_path=args.hhblits_binary_path, - hhsearch_binary_path=args.hhsearch_binary_path, - uniref90_database_path=args.uniref90_database_path, - mgnify_database_path=args.mgnify_database_path, - bfd_database_path=args.bfd_database_path, - uniclust30_database_path=args.uniclust30_database_path, - pdb70_database_path=args.pdb70_database_path, - use_small_bfd=use_small_bfd, - no_cpus=args.cpus, - ) - alignment_runner.run(fasta_path, local_alignment_dir) + if args.enable_workflow: + print("Running alignment with ray workflow...") + alignment_data_workflow_runner = FastFoldDataWorkFlow( + jackhmmer_binary_path=args.jackhmmer_binary_path, + hhblits_binary_path=args.hhblits_binary_path, + hhsearch_binary_path=args.hhsearch_binary_path, + uniref90_database_path=args.uniref90_database_path, + mgnify_database_path=args.mgnify_database_path, + bfd_database_path=args.bfd_database_path, + uniclust30_database_path=args.uniclust30_database_path, + pdb70_database_path=args.pdb70_database_path, + use_small_bfd=use_small_bfd, + no_cpus=args.cpus, + ) + t = time.perf_counter() + alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir) + print(f"Alignment data workflow time: {time.perf_counter() - t}") + else: + alignment_runner = data_pipeline.AlignmentRunner( + jackhmmer_binary_path=args.jackhmmer_binary_path, + hhblits_binary_path=args.hhblits_binary_path, + hhsearch_binary_path=args.hhsearch_binary_path, + uniref90_database_path=args.uniref90_database_path, + mgnify_database_path=args.mgnify_database_path, + bfd_database_path=args.bfd_database_path, + uniclust30_database_path=args.uniclust30_database_path, + pdb70_database_path=args.pdb70_database_path, + use_small_bfd=use_small_bfd, + no_cpus=args.cpus, + ) + alignment_runner.run(fasta_path, local_alignment_dir) feature_dict = data_processor.process_fasta(fasta_path=fasta_path, alignment_dir=local_alignment_dir) diff --git a/inference.sh b/inference.sh new file mode 100755 index 00000000..d35c18e5 --- /dev/null +++ b/inference.sh @@ -0,0 +1,13 @@ +python inference.py target.fasta /data/pdb_mmcif/mmcif_files \ + --output_dir ./ \ + --gpus 2 \ + --uniref90_database_path data/uniref90/uniref90.fasta \ + --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ + --pdb70_database_path data/pdb70/pdb70 \ + --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ + --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ + --jackhmmer_binary_path `which jackhmmer` \ + --hhblits_binary_path `which hhblits` \ + --hhsearch_binary_path `which hhsearch` \ + --kalign_binary_path `which kalign` \ + # --enable_workflow \ No newline at end of file