Skip to content

Commit

Permalink
Merge pull request #220 from compomics/spectrum-output
Browse files Browse the repository at this point in the history
Reimplement spectrum_output module for v4
  • Loading branch information
RalfG authored May 8, 2024
2 parents 7374689 + 8a91a57 commit bd66087
Show file tree
Hide file tree
Showing 13 changed files with 776 additions and 901 deletions.
5 changes: 0 additions & 5 deletions docs/source/api/ms2pip.constants.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
ms2pip.constants
****************

.. py:data:: ms2pip.constants.SUPPORTED_OUTPUT_FORMATS
:type: list

Supported file formats for spectrum output

.. py:data:: ms2pip.constants.MODELS
:type: dict

Expand Down
5 changes: 5 additions & 0 deletions docs/source/api/ms2pip.spectrum-output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ ms2pip.spectrum_output

.. automodule:: ms2pip.spectrum_output
:members:

.. py:data:: ms2pip.spectrum_output.SUPPORTED_FORMATS
:type: dict

Supported file formats and respective :py:class:`_Writer` class for spectrum output.
76 changes: 33 additions & 43 deletions ms2pip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,12 @@
from werkzeug.utils import secure_filename

import ms2pip.core
from ms2pip import __version__
from ms2pip import __version__, exceptions
from ms2pip._utils.cli import build_credits, build_prediction_table
from ms2pip.constants import MODELS, SUPPORTED_OUTPUT_FORMATS
from ms2pip.exceptions import (
InvalidXGBoostModelError,
UnknownModelError,
UnknownOutputFormatError,
UnresolvableModificationError,
)
from ms2pip.result import correlations_to_csv, results_to_csv
from ms2pip.spectrum_output import write_single_spectrum_csv, write_single_spectrum_png
from ms2pip.constants import MODELS
from ms2pip.plot import spectrum_to_png
from ms2pip.result import write_correlations
from ms2pip.spectrum_output import SUPPORTED_FORMATS, write_spectra

console = Console()
logger = logging.getLogger(__name__)
Expand All @@ -44,7 +39,8 @@ def _infer_output_name(
if output_name:
return Path(output_name)
else:
return Path(input_filename).with_suffix("")
input__filename = Path(input_filename)
return input__filename.with_name(input__filename.stem + "_predictions").with_suffix("")


@click.group()
Expand All @@ -65,49 +61,47 @@ def cli(*args, **kwargs):
@cli.command(help=ms2pip.core.predict_single.__doc__)
@click.argument("peptidoform", required=True)
@click.option("--output-name", "-o", type=str)
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_FORMATS), default="tsv")
@click.option("--model", type=click.Choice(MODELS), default="HCD")
@click.option("--model-dir")
@click.option("--plot", "-p", is_flag=True)
def predict_single(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_format = kwargs.pop("output_format")
plot = kwargs.pop("plot")
if not output_name:
output_name = "ms2pip_prediction_" + secure_filename(kwargs["peptidoform"]) + ".csv"
output_name = "ms2pip_prediction_" + secure_filename(kwargs["peptidoform"])

# Predict spectrum
result = ms2pip.core.predict_single(*args, **kwargs)
predicted_spectrum, _ = result.as_spectra()

# Write output
console.print(build_prediction_table(predicted_spectrum))
write_single_spectrum_csv(predicted_spectrum, output_name)
write_spectra(output_name, [result], output_format)
if plot:
write_single_spectrum_png(predicted_spectrum, output_name)
spectrum_to_png(predicted_spectrum, output_name)


@cli.command(help=ms2pip.core.predict_batch.__doc__)
@click.argument("psms", required=True)
@click.option("--output-name", "-o", type=str)
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_OUTPUT_FORMATS))
@click.option("--output-format", "-f", type=click.Choice(SUPPORTED_FORMATS), default="tsv")
@click.option("--add-retention-time", "-r", is_flag=True)
@click.option("--model", type=click.Choice(MODELS), default="HCD")
@click.option("--model-dir")
@click.option("--processes", "-n", type=int)
def predict_batch(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_format = kwargs.pop("output_format") # noqa F841 TODO
output_name = _infer_output_name(kwargs["psms"], output_name)
output_format = kwargs.pop("output_format")
output_name = _infer_output_name(kwargs["psms"], kwargs.pop("output_name"))

# Run
predictions = ms2pip.core.predict_batch(*args, **kwargs)

# Write output
output_name_csv = output_name.with_name(output_name.stem + "_predictions").with_suffix(".csv")
logger.info(f"Writing output to {output_name_csv}")
results_to_csv(predictions, output_name_csv)
# TODO: add support for other output formats
write_spectra(output_name, predictions, output_format)


@cli.command(help=ms2pip.core.predict_library.__doc__)
Expand All @@ -129,24 +123,22 @@ def predict_library(*args, **kwargs):
@click.option("--processes", "-n", type=int)
def correlate(*args, **kwargs):
# Parse arguments
output_name = kwargs.pop("output_name")
output_name = _infer_output_name(kwargs["psms"], output_name)
output_name = _infer_output_name(kwargs["psms"], kwargs.pop("output_name"))

# Run
results = ms2pip.core.correlate(*args, **kwargs)

# Write output
output_name_int = output_name.with_name(output_name.stem + "_predictions").with_suffix(".csv")
logger.info(f"Writing intensities to {output_name_int}")
results_to_csv(results, output_name_int)
# TODO: add support for other output formats
# Write intensities
logger.info(f"Writing intensities to {output_name.with_suffix('.tsv')}")
write_spectra(output_name, results, "tsv")

# Write correlations
if kwargs["compute_correlations"]:
output_name_corr = output_name.with_name(output_name.stem + "_correlations")
output_name_corr = output_name_corr.with_suffix(".csv")
output_name_corr = output_name.with_name(output_name.stem + "_correlations").with_suffix(
".tsv"
)
logger.info(f"Writing correlations to {output_name_corr}")
correlations_to_csv(results, output_name_corr)
write_correlations(results, output_name_corr)


@cli.command(help=ms2pip.core.get_training_data.__doc__)
Expand Down Expand Up @@ -188,32 +180,30 @@ def annotate_spectra(*args, **kwargs):
# Run
results = ms2pip.core.annotate_spectra(*args, **kwargs)

# Write output
output_name_int = output_name.with_name(output_name.stem + "_observations").with_suffix(".csv")
logger.info(f"Writing intensities to {output_name_int}")
results_to_csv(results, output_name_int)
# Write intensities
output_name_int = output_name.with_name(output_name.stem + "_observations").with_suffix()
logger.info(f"Writing intensities to {output_name_int.with_suffix('.tsv')}")
write_spectra(output_name, results, "tsv")


def main():
try:
cli()
except UnresolvableModificationError as e:
except exceptions.UnresolvableModificationError as e:
logger.critical(
"Unresolvable modification: `%s`. See "
"https://ms2pip.readthedocs.io/en/stable/usage/#amino-acid-modifications "
"for more info.",
e,
)
sys.exit(1)
except UnknownOutputFormatError as o:
logger.critical(
f"Unknown output format: `{o}` (supported formats: `{SUPPORTED_OUTPUT_FORMATS}`)"
)
except exceptions.UnknownOutputFormatError as o:
logger.critical(f"Unknown output format: `{o}` (supported formats: `{SUPPORTED_FORMATS}`)")
sys.exit(1)
except UnknownModelError as f:
except exceptions.UnknownModelError as f:
logger.critical(f"Unknown model: `{f}` (supported models: {set(MODELS.keys())})")
sys.exit(1)
except InvalidXGBoostModelError:
except exceptions.InvalidXGBoostModelError:
logger.critical("Could not correctly download XGBoost model\nTry a manual download.")
sys.exit(1)
except Exception:
Expand Down
4 changes: 3 additions & 1 deletion ms2pip/_utils/dlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Database configuration for EncyclopeDIA DLIB SQLite format."""

import zlib
from pathlib import Path
from typing import Union

import numpy
import sqlalchemy
Expand Down Expand Up @@ -91,7 +93,7 @@ def copy(self):
)


def open_sqlite(filename):
def open_sqlite(filename: Union[str, Path]) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(f"sqlite:///{filename}")
metadata.bind = engine
return engine.connect()
6 changes: 0 additions & 6 deletions ms2pip/_utils/psm_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import psm_utils.io.peptide_record
from psm_utils import PSMList

from ms2pip import exceptions

logger = logging.getLogger(__name__)


Expand All @@ -23,10 +21,6 @@ def read_psms(psms: Union[str, Path, PSMList], filetype: Union[str, None]) -> PS
else:
raise TypeError("Invalid type for psms. Should be str, Path, or PSMList.")

# Validate runs and collections
if not len(psm_list.collections) == 1 or not len(psm_list.runs) == 1:
raise exceptions.InvalidInputError("PSMs should be for a single run and collection.")

# Apply fixed modifications if any
psm_list.apply_fixed_modifications()

Expand Down
1 change: 1 addition & 0 deletions ms2pip/_utils/xgb_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_predictions_xgb(features, num_ions, model_params, model_dir, processes=1
for ion_type, xgb_model in xgboost_models.items():
# Get predictions from XGBoost model
preds = xgb_model.predict(features)
preds = preds.clip(min=np.log2(0.001)) # Clip negative intensities
xgb_model.__del__()

# Reshape into arrays for each peptide
Expand Down
3 changes: 0 additions & 3 deletions ms2pip/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""Constants and fixed configurations for MS²PIP."""

# Supported output formats
SUPPORTED_OUTPUT_FORMATS = ["csv", "mgf", "msp", "bibliospec", "spectronaut", "dlib"]

# Models and their properties
# id is passed to get_predictions to select model
# ion_types is required to write the ion types in the headers of the result files
Expand Down
14 changes: 11 additions & 3 deletions ms2pip/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from ms2pip._utils.psm_input import read_psms
from ms2pip._utils.retention_time import RetentionTime
from ms2pip._utils.xgb_models import get_predictions_xgb, validate_requested_xgb_model
from ms2pip.constants import MODELS, SUPPORTED_OUTPUT_FORMATS
from ms2pip.constants import MODELS
from ms2pip.result import ProcessingResult, calculate_correlations
from ms2pip.spectrum_input import read_spectrum_file
from ms2pip.spectrum_output import SUPPORTED_FORMATS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -424,7 +425,7 @@ def _validate_output_formats(self, output_formats: List[str]) -> List[str]:
self.output_formats = ["csv"]
else:
for output_format in output_formats:
if output_format not in SUPPORTED_OUTPUT_FORMATS:
if output_format not in SUPPORTED_FORMATS:
raise exceptions.UnknownOutputFormatError(output_format)
self.output_formats = output_formats

Expand Down Expand Up @@ -544,6 +545,10 @@ def process_spectra(
If only peak annotations should be extracted from the spectrum file
"""
# Validate runs and collections
if not len(psm_list.collections) == 1 or not len(psm_list.runs) == 1:
raise exceptions.InvalidInputError("PSMs should be for a single run and collection.")

args = (
spectrum_file,
vector_file,
Expand Down Expand Up @@ -672,7 +677,10 @@ def _process_peptidoform(
MODELS[model]["peaks_version"],
30.0, # TODO: Remove CE feature
)
predictions = {i: np.array(p, dtype=np.float32) for i, p in zip(ion_types, predictions)}
predictions = {
i: np.array(p, dtype=np.float32).clip(min=np.log2(0.001)) # Clip negative intensities
for i, p in zip(ion_types, predictions)
}
feature_vectors = None

return ProcessingResult(
Expand Down
23 changes: 23 additions & 0 deletions ms2pip/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path
from typing import Union

from ms2pip.spectrum import Spectrum

try:
import matplotlib.pyplot as plt
import spectrum_utils.plot as sup

_can_plot = True
except ImportError:
_can_plot = False


def spectrum_to_png(spectrum: Spectrum, filepath: Union[str, Path]):
"""Plot a single spectrum and write to a PNG file."""
if not _can_plot:
raise ImportError("Matplotlib and spectrum_utils are required to plot spectra.")
ax = plt.gca()
ax.set_title("MS²PIP prediction for " + str(spectrum.peptidoform))
sup.spectrum(spectrum.to_spectrum_utils(), ax=ax)
plt.savefig(Path(filepath).with_suffix(".png"))
plt.close()
40 changes: 4 additions & 36 deletions ms2pip/result.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Definition and handling of MS²PIP results."""

from __future__ import annotations

import csv
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from psm_utils import PSM
from pydantic import ConfigDict, BaseModel
from pydantic import BaseModel, ConfigDict

try:
import spectrum_utils.plot as sup
Expand Down Expand Up @@ -115,44 +116,11 @@ def calculate_correlations(results: List[ProcessingResult]) -> None:
result.correlation = np.corrcoef(pred_int, obs_int)[0][1]


def results_to_csv(results: List["ProcessingResult"], output_file: str) -> None:
"""Write processing results to CSV file."""
with open(output_file, "wt") as f:
fieldnames = [
"psm_index",
"ion_type",
"ion_number",
"mz",
"predicted",
"observed",
]
writer = csv.DictWriter(f, fieldnames=fieldnames, lineterminator="\n")
writer.writeheader()
for result in results:
if result.theoretical_mz is not None:
for ion_type in result.theoretical_mz:
for i in range(len(result.theoretical_mz[ion_type])):
writer.writerow(
{
"psm_index": result.psm_index,
"ion_type": ion_type,
"ion_number": i + 1,
"mz": "{:.6g}".format(result.theoretical_mz[ion_type][i]),
"predicted": "{:.6g}".format(
result.predicted_intensity[ion_type][i]
) if result.predicted_intensity else None,
"observed": "{:.6g}".format(result.observed_intensity[ion_type][i])
if result.observed_intensity
else None,
}
)


def correlations_to_csv(results: List["ProcessingResult"], output_file: str) -> None:
def write_correlations(results: List["ProcessingResult"], output_file: str) -> None:
"""Write correlations to CSV file."""
with open(output_file, "wt") as f:
fieldnames = ["psm_index", "correlation"]
writer = csv.DictWriter(f, fieldnames=fieldnames, lineterminator="\n")
writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter="\t", lineterminator="\n")
writer.writeheader()
for result in results:
writer.writerow({"psm_index": result.psm_index, "correlation": result.correlation})
6 changes: 3 additions & 3 deletions ms2pip/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __repr__(self) -> str:
@model_validator(mode="after")
@classmethod
def check_array_lengths(cls, data: dict):
if len(data["mz"]) != len(data["intensity"]):
if len(data.mz) != len(data.intensity):
raise ValueError("Array lengths do not match.")
if data["annotations"] is not None:
if len(data["annotations"]) != len(data["intensity"]):
if data.annotations is not None:
if len(data.annotations) != len(data.intensity):
raise ValueError("Array lengths do not match.")
return data

Expand Down
Loading

0 comments on commit bd66087

Please sign in to comment.