From e12e8c2f298e444964f57841a00581740a47a8a4 Mon Sep 17 00:00:00 2001 From: Amine Ghozlane Date: Tue, 1 Oct 2024 17:49:53 +0200 Subject: [PATCH] Testing advanced strain analysis --- meteor/phylogeny.py | 86 ++------- meteor/tests/test_phylogeny.py | 1 - meteor/variantcalling.py | 321 +++++++++++++++++++++------------ 3 files changed, 222 insertions(+), 186 deletions(-) diff --git a/meteor/phylogeny.py b/meteor/phylogeny.py index cb384bd..22ce0b5 100644 --- a/meteor/phylogeny.py +++ b/meteor/phylogeny.py @@ -28,10 +28,12 @@ from collections import OrderedDict from datetime import datetime from typing import Iterable, Tuple -from cogent3 import load_aligned_seqs -from cogent3.evolve.distance import EstimateDistances -from cogent3.evolve.models import GTR -from cogent3.cluster.UPGMA import upgma +from cogent3 import load_unaligned_seqs # , load_aligned_seqs + +# from cogent3.evolve.distance import EstimateDistances +# from cogent3.evolve.models import GTR +# from cogent3.cluster.UPGMA import upgma +from cogent3.align.progressive import tree_align @dataclass @@ -116,29 +118,6 @@ def remove_edge_labels(self, newick: str) -> str: def execute(self) -> None: logging.info("Launch phylogeny analysis") - # raxml_ng_exec = run(["raxml-ng", "--version"], check=False, capture_output=True) - # if raxml_ng_exec.returncode != 0: - # logging.error( - # "Checking raxml-ng failed:\n%s", raxml_ng_exec.stderr.decode("utf-8") - # ) - # sys.exit(1) - # raxml_ng_help = raxml_ng_exec.stdout.decode("utf-8") - # Define the regex pattern to match the version number - # version_pattern = re.compile(r"RAxML-NG v\. (\d+\.\d+\.\d+)") - # match = version_pattern.search(raxml_ng_help) - # Check if a match is found - # if not match: - # logging.error("Failed to determine the raxml-ng version.") - # sys.exit(1) - # raxml_ng_version = match.group(1) - # if parse(raxml_ng_version) < self.meteor.MIN_RAXML_NG_VERSION: - # logging.error( - # "The raxml-ng version %s is outdated for meteor. Please update raxml-ng to >= %s.", - # raxml_ng_version, - # self.meteor.MIN_RAXML_NG_VERSION, - # ) - # sys.exit(1) - # Start phylogenies start = perf_counter() self.tree_files: list[Path] = [] @@ -165,68 +144,33 @@ def execute(self) -> None: info_sites, self.min_info_sites, ) - # elif len(cleaned_seqs) >= 4: - # # Compute trees - # logging.info("Run raxml-ng") - # result = check_call( - # [ - # "raxml-ng", - # "--threads", - # "auto{{{}}}".format(self.meteor.threads), - # "--workers", - # "auto", - # "--search1", - # "--msa", - # temp_clean.name, - # "--model", - # "GTR+G", - # "--redo", - # "--force", - # "perf,msa", # not working with raxml-ng-mpi - # "--prefix", - # str(tree_file.resolve()), - # ], - # stdout = DEVNULL - # ) - # if result != 0: - # logging.error("raxml-ng failed with return code %d", result) else: - # logging.info( - # "Less than 4 sequences, run cogent3" - # ) - aligned_seqs = load_aligned_seqs( - temp_clean.name, - moltype="dna", + seqs = load_unaligned_seqs(temp_clean.name, moltype="dna") + # params = {"kappa": 4.0} + _, tree = tree_align( + "GTR", + seqs, + # param_vals=params, + show_progress=False, ) - d = EstimateDistances(aligned_seqs, submodel=GTR()) - d.run(show_progress=False) - mycluster = upgma(d.get_pairwise_distances()) - mycluster = mycluster.unrooted_deepcopy() + # print(aln) with tree_file.with_suffix(".tree").open("w") as f: f.write( self.remove_edge_labels( - mycluster.get_newick(with_distances=True) + tree.get_newick(with_distances=True) ) ) - # Edges get a name which is not supported by ete3 - # mycluster.write( - # tree_file.with_suffix(".tree"), - # ) if tree_file.with_suffix(".tree").exists(): self.tree_files.append(tree_file.with_suffix(".tree")) logging.info( "Completed MSP tree for MSP %s", msp_file.name.replace(".fasta", ""), ) - # elif tree_file.with_suffix(".raxml.bestTree").exists(): - # self.tree_files.append(tree_file.with_suffix(".raxml.bestTree")) - # logging.info("Completed MSP tree with raxml") else: logging.info("No tree file generated") logging.info("Completed phylogeny in %f seconds", perf_counter() - start) logging.info( "Trees were generated for %d/%d MSPs", len(self.tree_files), msp_count ) - # config = self.set_tree_config(raxml_ng_version) config = self.set_tree_config() self.save_config(config, self.meteor.tree_dir / "census_stage_4.json") diff --git a/meteor/tests/test_phylogeny.py b/meteor/tests/test_phylogeny.py index 6928916..91acbc1 100644 --- a/meteor/tests/test_phylogeny.py +++ b/meteor/tests/test_phylogeny.py @@ -50,7 +50,6 @@ def test_compute_site_info(phylogeny_builder: Phylogeny): def test_clean_sites(phylogeny_builder: Phylogeny, datadir: Path, tmpdir: Path): msp = tmpdir / "msp_0864_clean.fasta" msp_expected_file = datadir / "msp_0864_dict.pck" - print(msp_expected_file) with open(msp_expected_file, "rb") as msp_file: msp_expected = pickle.load(msp_file) with msp.open("w") as f: diff --git a/meteor/variantcalling.py b/meteor/variantcalling.py index 9ab5b1b..fdde991 100644 --- a/meteor/variantcalling.py +++ b/meteor/variantcalling.py @@ -24,14 +24,72 @@ from time import perf_counter from tempfile import NamedTemporaryFile from packaging.version import parse - -# import psutil from pysam import AlignmentFile, FastaFile, VariantFile, faidx, tabix_index from collections import defaultdict import pandas as pd from typing import ClassVar import numpy as np import os +import pickle + + +# Helper function for multiprocessing +def process_msp_file( + meteor_tmp_dir, meteor_tree_dir, min_info_sites, idx, msp_file, msp_count, max_gap +): + """ + Process single MSP file and return the path to the generated tree file or None if unsuccessful. + This runs in a separate process. + """ + pattern = r"\b(edge\.\d+):\b" + + logging.info( + "%d/%d %s: Start analysis", + idx, + msp_count, + msp_file.name.replace(".fasta", ""), + ) + + with NamedTemporaryFile( + mode="wt", dir=meteor_tmp_dir, suffix=".fasta" + ) as temp_clean: + tree_file = Path(meteor_tree_dir) / f"{msp_file.name}".replace(".fasta", "") + + # Clean sites (use a dummy clean sites function for illustration) + logging.info("Clean sites") + # Replace the dummy version below with the actual logic to clean sites: + info_sites = 20 # Fake number of informative sites for demo + + if info_sites < min_info_sites: + logging.info( + "Only %d informative sites (<%d threshold) left after cleaning, skip.", + info_sites, + min_info_sites, + ) + return None # Skip this file, return None + + # Perform sequence alignment and tree generation + seqs = load_unaligned_seqs(msp_file, moltype="dna") + params = {"kappa": 4.0} + aln, tree = tree_align( + "HKY85", + seqs, + param_vals=params, + show_progress=False, + ) + print(aln) + with tree_file.with_suffix(".tree").open("w") as f: + f.write(re.sub(pattern, ":", tree.get_newick(with_distances=True))) + + if tree_file.with_suffix(".tree").exists(): + logging.info( + "Completed MSP tree for MSP %s", + msp_file.name.replace(".fasta", ""), + ) + return tree_file.with_suffix(".tree") + else: + logging.info("No tree file generated") + return None @dataclass @@ -245,33 +303,8 @@ def filter_low_cov_sites( gene_ignore = gene_interest[ gene_interest["coverage"] < self.min_depth ].set_index("gene_id") - # gene_ignore = {row["gene_id"]: row["gene_length"] for _, row in gene_interest[gene_interest["coverage"] < self.min_depth].iterrows()} - # dfs.append( - # pd.DataFrame( - # { - # "gene_id": gene_ignore["gene_id"], - # "startpos": 0, - # "endpos": gene_ignore["gene_length"], - # "coverage": gene_ignore["coverage"], - # } - # ) - # ) - # for _, row in gene_interest[gene_interest["coverage"] < self.min_depth].iterrows(): - # gene_id = row["gene_id"] - # # If the gene_id is not already a key in the dict, create a new list - # if gene_id not in all_genes_dict: - # all_genes_dict[gene_id] = [] - # # Append the new information to the list corresponding to the gene_id - # all_genes_dict[gene_id].append({ - # "startpos": 0, - # "endpos": row["gene_length"], - # "coverage": row["coverage"] - # }) sum_cov_bed = pd.concat(dfs, ignore_index=True).set_index("gene_id") - # .query(f"coverage < {self.min_depth}") - # sum_cov_bed.to_csv(temp_low_cov_sites, sep="\t", header=False, index=False) return sum_cov_bed, gene_ignore - # return all_genes_dict # @memory_profiler.profile def create_consensus( @@ -306,8 +339,11 @@ def create_consensus( consensus = [ self.meteor.DEFAULT_GAP_CHAR ] * gene_ignore.loc[gene_id]["gene_length"] + consensus_f.write(f">{gene_id}\n") + consensus_f.write("".join(consensus) + "\n") else: - consensus = np.array(list(Fasta.fetch(ref)), dtype=" ##INFO= - reference_frequency = record.info["RO"] / ( - record.info["RO"] + np.sum(record.info["AO"]) - ) - if reference_frequency >= self.min_frequency: - keep_alts = tuple(sorted(list(record.alleles))) - else: - keep_alts = tuple(sorted(list(record.alts))) - max_len = max(map(len, keep_alts)) - # MNV vase - if max_len > 1: - for i in range(max_len): - mnv = tuple( - sorted( - set( - keep_alts[k][i] - for k in range(len(keep_alts)) + if record.info["TYPE"][0] == "snp": + reference_frequency = record.info["RO"] / ( + record.info["RO"] + np.sum(record.info["AO"]) + ) + if reference_frequency >= self.min_frequency: + keep_alts = tuple(sorted(list(record.alleles))) + else: + keep_alts = tuple(sorted(list(record.alts))) + max_len = max(map(len, keep_alts)) + # MNV vase + if max_len > 1: + for i in range(max_len): + mnv = tuple( + sorted( + set( + keep_alts[k][i] + for k in range(len(keep_alts)) + ) ) ) - ) - consensus[record.start + i] = self.IUPAC[mnv] + consensus[record.start + i] = self.IUPAC[ + mnv + ] + else: + consensus[record.start] = self.IUPAC[keep_alts] else: - consensus[record.start] = self.IUPAC[keep_alts] + # we had a nested sequence + consensus[record.start] = [ + record.alts[0], + record.start, + record.stop, + ] # Update consensus array for each matching range - # print(low_cov_sites.index) - # print(gene_id) if ref in low_cov_sites.index: - # print(gene_id) selection = low_cov_sites.loc[ref] - # print(selection) if isinstance(selection, pd.Series): - print(selection) - consensus[ - selection["startpos"] : selection["endpos"] - ] = self.meteor.DEFAULT_GAP_CHAR + for i in range( + selection["startpos"], selection["endpos"] + ): + consensus[i] = self.meteor.DEFAULT_GAP_CHAR + # consensus[ + # selection["startpos"] : selection["endpos"] + # ] = self.meteor.DEFAULT_GAP_CHAR else: - print(selection) for _, row in selection.iterrows(): + for i in range(row["startpos"], row["endpos"]): + consensus[i] = self.meteor.DEFAULT_GAP_CHAR # Mark as uncertain - consensus[row["startpos"] : row["endpos"]] = ( - self.meteor.DEFAULT_GAP_CHAR - ) - consensus_f.write(f">{gene_id}\n") - consensus_f.write("".join(consensus) + "\n") + # consensus[row["startpos"] : row["endpos"]] = ( + # self.meteor.DEFAULT_GAP_CHAR + # ) + consensus_res = "" + l = 0 + while l < len(consensus): + if type(consensus[l]) is str: + consensus_res += consensus[l] + l += 1 + else: + consensus_res += consensus[l][0] + l = consensus[l][2] + consensus_f.write(f">{gene_id}\n") + # consensus_f.write("".join(consensus) + "\n") + consensus_f.write(consensus_res + "\n") del consensus def execute(self) -> None: @@ -377,6 +433,10 @@ def execute(self) -> None: self.census["directory"] / f"{self.census['census']['sample_info']['sample_name']}.vcf.gz" ) + low_cov_sites_file = ( + self.census["directory"] + / f"{self.census['census']['sample_info']['sample_name']}.pickle" + ) consensus_file = ( self.census["directory"] / f"{self.census['census']['sample_info']['sample_name']}_consensus.fasta.xz" @@ -447,67 +507,86 @@ def execute(self) -> None: sys.exit(1) start = perf_counter() - logging.info("Run freebayes") startfreebayes = perf_counter() - with reference_file.open("rb") as ref_fh: - with bgzip.BGZipReader(ref_fh, num_threads=self.meteor.threads) as reader: - decompressed_reference = reader.read() - with NamedTemporaryFile( - suffix=".fasta", dir=self.meteor.tmp_dir, delete=False - ) as temp_ref_file: - temp_ref_file.write(decompressed_reference) - temp_ref_file_path = temp_ref_file.name - # index on the fly - faidx(temp_ref_file.name) + temp_ref_file_path = None + if vcf_file.exists(): + logging.info("Vcf already exist, skipping freebayes..") + else: + logging.info("Run freebayes") + with reference_file.open("rb") as ref_fh: + with bgzip.BGZipReader( + ref_fh, num_threads=self.meteor.threads + ) as reader: + decompressed_reference = reader.read() + with NamedTemporaryFile( + suffix=".fasta", dir=self.meteor.tmp_dir, delete=False + ) as temp_ref_file: + temp_ref_file.write(decompressed_reference) + temp_ref_file_path = temp_ref_file.name + # index on the fly + faidx(temp_ref_file.name) try: - with Popen( - [ - "freebayes", - "-i", # no indel - # "-X", - "-u", # no complex observation that may include ins - "--min-alternate-count", - str(self.min_snp_depth), - "--min-alternate-fraction", - str(self.min_frequency), - "-t", - str(bed_file), - "-p", - str(self.ploidy), - "-f", - temp_ref_file_path, - "-b", - str(cram_file.resolve()), - ], - stdin=PIPE, - stdout=PIPE, - ) as freebayes_process: - # capture output of bcftools_process - freebayes_output = freebayes_process.communicate( - input=decompressed_reference - )[0] - # print(freebayes_output) - # compress output using bgzip - with open(str(vcf_file.resolve()), "wb") as raw: - with bgzip.BGZipWriter(raw) as fh: - fh.write(freebayes_output) + if not vcf_file.exists(): + with Popen( + [ + "freebayes", + # "-i", # no indel + # "-X", + # "-u", # no complex observation that may include ins + "--pooled-continuous", + "--min-alternate-count", + str(1), + "--min-coverage", + str(self.min_snp_depth), + "--min-alternate-fraction", + str(self.min_frequency), + "--min-mapping-quality", + str(0), + "--use-duplicate-reads", + "-t", + str(bed_file), + "-p", + str(self.ploidy), + "-f", + temp_ref_file_path, + "-b", + str(cram_file.resolve()), + ], + stdin=PIPE, + stdout=PIPE, + ) as freebayes_process: + # capture output of bcftools_process + freebayes_output = freebayes_process.communicate( + input=decompressed_reference + )[0] + # print(freebayes_output) + # compress output using bgzip + with vcf_file.open("wb") as raw: + with bgzip.BGZipWriter(raw) as fh: + fh.write(freebayes_output) except CalledProcessError as e: logging.error("Freebayes failed with return code %d", e.returncode) logging.error("Output: %s", e.output) sys.exit() finally: - # Ensure the temporary file is removed after use - if os.path.isfile(temp_ref_file_path): - os.remove(temp_ref_file_path) - if os.path.isfile(temp_ref_file_path + ".fai"): - os.remove(temp_ref_file_path + ".fai") + if temp_ref_file_path is not None: + temp_ref_file_path = Path(temp_ref_file_path) + # Ensure the temporary file is removed after use + if temp_ref_file_path.exists(): + temp_ref_file_path.unlink(missing_ok=True) + if Path(f"{temp_ref_file_path}.fai").exists(): + Path(f"{temp_ref_file_path}.fai").unlink(missing_ok=True) logging.info( "Completed freebayes step in %f seconds", perf_counter() - startfreebayes ) # Index the vcf file - logging.info("Indexing") + startindexing = perf_counter() - tabix_index(str(vcf_file.resolve()), preset="vcf", force=True) + if not Path(f"{vcf_file}.tbi").exists(): + logging.info("Indexing") + tabix_index(str(vcf_file.resolve()), preset="vcf", force=True) + else: + logging.info("Index already exist, skipping...") logging.info( "Completed indexing step in %f seconds", perf_counter() - startindexing ) @@ -517,10 +596,24 @@ def execute(self) -> None: # rather than the 1-based tab-delimited file, the file must have # the ".bed" or ".bed.gz" suffix (case-insensitive). startlowcovpython = perf_counter() - logging.info("Detecting low coverage regions") - low_cov_sites, gene_ignore = self.filter_low_cov_sites( - cram_file, reference_file - ) + if low_cov_sites_file.exists(): + logging.info("Loading low coverage regions") + with low_cov_sites_file.open("rb") as file: + # Load the data from the file + data = pickle.load(file) + low_cov_sites = data["low_cov_sites"] + gene_ignore = data["gene_ignore"] + else: + logging.info("Detecting low coverage regions") + low_cov_sites, gene_ignore = self.filter_low_cov_sites( + cram_file, reference_file + ) + # Open a file for writing the pickle data (binary write mode) + with low_cov_sites_file.open("wb") as file: + # Dump the data into the file + pickle.dump( + {"low_cov_sites": low_cov_sites, "gene_ignore": gene_ignore}, file + ) logging.info( "Completed low coverage regions filtering step in %f seconds", perf_counter() - startlowcovpython,