diff --git a/alphapulldown/create_individual_features_with_templates.py b/alphapulldown/create_individual_features_with_templates.py index f2f4d261..12b9d32e 100755 --- a/alphapulldown/create_individual_features_with_templates.py +++ b/alphapulldown/create_individual_features_with_templates.py @@ -4,7 +4,7 @@ # from alphapulldown.objects import MonomericObject -from alphapulldown.utils import create_uniprot_runner, get_flags_from_af +from alphapulldown.utils import create_uniprot_runner, get_flags_from_af, convert_fasta_description_to_protein_name from alphapulldown.create_custom_template_db import create_db from alphafold.data.pipeline import DataPipeline from alphafold.data.tools import hmmsearch @@ -24,8 +24,6 @@ flags.DEFINE_string("description_file", None, "Path to the text file with descriptions") -flags.DEFINE_string("path_to_fasta", None, "Path to directory with fasta files") - flags.DEFINE_string("path_to_mmt", None, "Path to directory with multimeric template mmCIF files") flags.DEFINE_float("threshold_clashes", 1000, "Threshold for VDW overlap to identify clashes " @@ -45,7 +43,7 @@ def create_arguments(flags_dict, feat, temp_dir=None): """Create arguments for alphafold.run()""" global use_small_bfd - fasta = Path(feat["fasta"]).stem + protein = feat["protein"] templates, chains = feat["templates"], feat["chains"] # Path to the Uniref30 database for use by HHblits. @@ -89,7 +87,7 @@ def create_arguments(flags_dict, feat, temp_dir=None): hb_allowance = FLAGS.hb_allowance plddt_threshold = FLAGS.plddt_threshold #local_path_to_custom_template_db = Path(".") / "custom_template_db" / fasta # DEBUG - local_path_to_custom_template_db = Path(temp_dir.name) / "custom_template_db" / fasta + local_path_to_custom_template_db = Path(temp_dir.name) / "custom_template_db" / protein logging.info(f"Path to local database: {local_path_to_custom_template_db}") create_db(local_path_to_custom_template_db, templates, chains, threashold_clashes, hb_allowance, plddt_threshold) FLAGS.pdb_seqres_database_path = os.path.join(local_path_to_custom_template_db, "pdb_seqres", "pdb_seqres.txt") @@ -105,17 +103,29 @@ def create_arguments(flags_dict, feat, temp_dir=None): flags_dict.update({"use_small_bfd": use_small_bfd}) -def parse_txt_file(csv_path, fasta_dir, mmt_dir): +def parse_csv_file(csv_path, fasta_paths, mmt_dir): """ o csv_path: Path to the text file with descriptions - features.csv: A coma-separated file with three columns: FASTA file, PDB file, chain ID. - o fasta_dir: Path to directory with fasta files + features.csv: A coma-separated file with three columns: PROTEIN name, PDB/CIF template, chain ID. + o fasta_paths: path to fasta file(s) o mmt_dir: Path to directory with multimeric template mmCIF files Returns: a list of dictionaries with the following structure: - [{"fasta": fasta_file, "templates": [pdb_files], "chains": [chain_id]}, ...] + [{"protein": protein_name, "templates": [pdb_files], "chains": [chain_id]}, ...] """ + protein_names = [] + # Check that fasta files exist + for fasta_path in fasta_paths: + logging.info(f"Parsing {fasta_path}...") + if not os.path.isfile(fasta_path): + raise FileNotFoundError(f"Fasta file {fasta_path} does not exist. Please check your input file.") + # Parse all protein names from fasta files + for curr_seq, curr_desc in iter_seqs(fasta_paths): + protein_names.append(curr_desc) + + protein_names = set(protein_names) + # Parse csv file parsed_dict = {} with open(csv_path, newline="") as csvfile: csvreader = csv.reader(csvfile) @@ -124,15 +134,20 @@ def parse_txt_file(csv_path, fasta_dir, mmt_dir): if not row: continue if len(row) == 3: - fasta, template, chain = [item.strip() for item in row] - if fasta not in parsed_dict: - parsed_dict[fasta] = { - "fasta": os.path.join(fasta_dir, fasta), + protein, template, chain = [item.strip() for item in row] + # Remove special symbols from protein name + protein = convert_fasta_description_to_protein_name(protein) + if protein not in protein_names: + raise Exception(f"Protein {protein} from description.csv is not found in the fasta file(s)." + f"List of proteins in fasta file(s): {protein_names}") + if protein not in parsed_dict: + parsed_dict[protein] = { + "protein": protein, "templates": [], "chains": [], } - parsed_dict[fasta]["templates"].append(os.path.join(mmt_dir, template)) - parsed_dict[fasta]["chains"].append(chain) + parsed_dict[protein]["templates"].append(os.path.join(mmt_dir, template)) + parsed_dict[protein]["chains"].append(chain) else: logging.error(f"Invalid line found in the file {csv_path}: {row}") sys.exit() @@ -175,22 +190,20 @@ def main(argv): logging.info("Multiple processes are trying to create the same folder now.") flags_dict = FLAGS.flag_values_dict() - feats = parse_txt_file(FLAGS.description_file, FLAGS.path_to_fasta, FLAGS.path_to_mmt) + fasta_paths = flags_dict["fasta_paths"] + feats = parse_csv_file(FLAGS.description_file, fasta_paths, FLAGS.path_to_mmt) logging.info(f"job_index: {FLAGS.job_index} feats: {feats}") for idx, feat in enumerate(feats, 1): temp_dir = (tempfile.TemporaryDirectory()) # for each fasta file, create a temp dir if (FLAGS.job_index is None) or (FLAGS.job_index == idx): - if not os.path.isfile(feat["fasta"]): - logging.error(f"Fasta file {feat['fasta']} does not exist. Please check your input file.") - sys.exit() for temp in feat["templates"]: if not os.path.isfile: logging.error(f"Template file {temp} does not exist. Please check your input file.") sys.exit() - logging.info(f"Processing {feat['fasta']}: templates: {feat['templates']} chains: {feat['chains']}") + logging.info(f"Processing {feat['protein']}: templates: {feat['templates']} chains: {feat['chains']}") create_arguments(flags_dict, feat, temp_dir) # Update flags_dict to store data about templates - flags_dict.update({f"fasta_path_{idx}": feat['fasta']}) + flags_dict.update({f"protein_{idx}": feat['protein']}) flags_dict.update({f"multimeric_templates_{idx}": feat['templates']}) flags_dict.update({f"multimeric_chains_{idx}": feat['chains']}) @@ -210,12 +223,11 @@ def main(argv): "Please make sure your data_dir has been configured correctly." ) sys.exit() - # If we are using mmseqs2, we don't need to create a pipeline else: pipeline = create_pipeline() uniprot_runner = None flags_dict = FLAGS.flag_values_dict() - for curr_seq, curr_desc in iter_seqs([feat["fasta"]]): + for curr_seq, curr_desc in iter_seqs(FLAGS.fasta_paths): if curr_desc and not curr_desc.isspace(): curr_monomer = MonomericObject(curr_desc, curr_seq) curr_monomer.uniprot_runner = uniprot_runner @@ -232,7 +244,7 @@ def main(argv): flags.mark_flags_as_required( [ "description_file", - "path_to_fasta", + "fasta_paths", "path_to_mmt", "output_dir", "max_template_date", diff --git a/alphapulldown/objects.py b/alphapulldown/objects.py index 48b2954a..34594f7d 100644 --- a/alphapulldown/objects.py +++ b/alphapulldown/objects.py @@ -222,11 +222,12 @@ def make_mmseq_features( query_seqs_cardinality, template_features, ) = get_msa_and_templates( - self.description, - self.sequence, - plPath(result_dir), - msa_mode, - use_templates, + jobname=self.description, + query_sequences=self.sequence, + a3m_lines=None, + result_dir=plPath(result_dir), + msa_mode=msa_mode, + use_templates=use_templates, custom_template_path=None, pair_mode="none", host_url=DEFAULT_API_SERVER, diff --git a/alphapulldown/utils.py b/alphapulldown/utils.py index ecd456c8..048816bc 100644 --- a/alphapulldown/utils.py +++ b/alphapulldown/utils.py @@ -510,6 +510,15 @@ def save_meta_data(flag_dict, outfile): json.dump(metadata, f, indent=2) +def convert_fasta_description_to_protein_name(line): + line = line.replace(" ", "_") + unwanted_symbols = ["|", "=", "&", "*", "@", "#", "`", ":", ";", "$", "?"] + for symbol in unwanted_symbols: + if symbol in line: + line = line.replace(symbol, "_")[1:] + return line[1:] # Remove the '>' at the beginning. + + def parse_fasta(fasta_string: str): """Parses FASTA string and returns list of strings with amino-acid sequences. @@ -534,12 +543,7 @@ def parse_fasta(fasta_string: str): line = line.strip() if line.startswith(">"): index += 1 - line = line.replace(" ", "_") - unwanted_symbols = ["|", "=", "&", "*", "@", "#", "`", ":", ";", "$", "?"] - for symbol in unwanted_symbols: - if symbol in line: - line = line.replace(symbol, "_") - descriptions.append(line[1:]) # Remove the '>' at the beginning. + descriptions.append(convert_fasta_description_to_protein_name(line)) sequences.append("") continue elif not line: diff --git a/test/test_features_with_templates.py b/test/test_features_with_templates.py index faa0e296..b660ef89 100644 --- a/test/test_features_with_templates.py +++ b/test/test_features_with_templates.py @@ -40,7 +40,9 @@ def run_features_generation(self, file_name, chain_id, file_extension): # Generate description.csv with open(f"{self.TEST_DATA_DIR}/description.csv", 'w') as desc_file: - desc_file.write(f"{file_name}_{chain_id}.fasta, {file_name}.{file_extension}, {chain_id}\n") + desc_file.write(f">{file_name}_{chain_id}, {file_name}.{file_extension}, {chain_id}\n") + + assert Path(f"{self.TEST_DATA_DIR}/fastas/{file_name}_{chain_id}.fasta").exists() # Prepare the command and arguments cmd = [ @@ -54,7 +56,7 @@ def run_features_generation(self, file_name, chain_id, file_extension): '--threshold_clashes', '1000', '--hb_allowance', '0.4', '--plddt_threshold', '0', - '--path_to_fasta', f"{self.TEST_DATA_DIR}/fastas", + '--fasta_paths', f"{self.TEST_DATA_DIR}/fastas/{file_name}_{chain_id}.fasta", '--path_to_mmt', f"{self.TEST_DATA_DIR}/templates", '--description_file', f"{self.TEST_DATA_DIR}/description.csv", '--output_dir', f"{self.TEST_DATA_DIR}/features", @@ -123,5 +125,47 @@ def test_4c_bizarre_filename(self): def test_5b_gappy_pdb(self): self.run_features_generation('GAPPY_PDB', 'B', 'pdb') + def test_6a_mmseqs2(self): + file_name = '3L4Q' + chain_id = 'A' + file_extension = 'cif' + # Ensure directories exist + (self.TEST_DATA_DIR / 'features').mkdir(parents=True, exist_ok=True) + (self.TEST_DATA_DIR / 'templates').mkdir(parents=True, exist_ok=True) + # Remove existing files (should be done by tearDown, but just in case) + pkl_path = self.TEST_DATA_DIR / 'features' / f'{file_name}_{chain_id}.pkl' + a3m_path = self.TEST_DATA_DIR / 'features' / f'{file_name}_{chain_id}.a3m' + template_path = self.TEST_DATA_DIR / 'templates' / f'{file_name}.{file_extension}' + if pkl_path.exists(): + pkl_path.unlink() + if a3m_path.exists(): + a3m_path.unlink() + + # Generate description.csv + with open(f"{self.TEST_DATA_DIR}/description.csv", 'w') as desc_file: + desc_file.write(f">{file_name}_{chain_id}, {file_name}.{file_extension}, {chain_id}\n") + + # Prepare the command and arguments + cmd = [ + 'python', + run_features_generation.__file__, + '--skip_existing', 'False', + '--data_dir', '/scratch/AlphaFold_DBs/2.3.2', + '--max_template_date', '3021-01-01', + '--threshold_clashes', '1000', + '--hb_allowance', '0.4', + '--plddt_threshold', '0', + '--fasta_paths', f"{self.TEST_DATA_DIR}/fastas/{file_name}_{chain_id}.fasta", + '--path_to_mmt', f"{self.TEST_DATA_DIR}/templates", + '--description_file', f"{self.TEST_DATA_DIR}/description.csv", + '--output_dir', f"{self.TEST_DATA_DIR}/features", + '--use_mmseqs2', 'True', + ] + print(" ".join(cmd)) + # Check the output + subprocess.run(cmd, check=True) + assert pkl_path.exists() + assert a3m_path.exists() + if __name__ == '__main__': absltest.main()