Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Library handler #23

Merged
merged 12 commits into from
Nov 2, 2023
175 changes: 151 additions & 24 deletions library_spectra_validation/library_handler.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,163 @@
from matchms.importing import load_spectra
from matchms.filtering.SpectrumProcessor import SpectrumProcessor
from filters import PRIMARY_FILTERS
from validation_pipeline import Modification, SpectrumRepairer, SpectrumValidator

class LibraryHandler:
"""Stores the 3 different types of spectra. Correct, repaired, wrong.
Has internal organization using spectrum ids"""

def __init__(self, f, pipeline):
#todo modify default pipeline
def __init__(self, f):
metadata_field_harmonization = SpectrumProcessor(predefined_pipeline=None,
additional_filters=PRIMARY_FILTERS)
self.spectra = metadata_field_harmonization.process_spectrums(load_spectra(f))
self.pipeline = pipeline
self.spectra_dictionary = {
'valid': None, #[id1, id2,...]
'repaired': None, #[id1:[modifications],..]
'invalid': None #also a dictionary
}
self.modifications = {} #todo change to Modifications class

def clean_and_validate_spectrum(self, spectrum_id):
spectrum = self.spectra[spectrum_id]
modifications = self.pipeline.run(spectrum)
spectrum_id.update_spectra_dictionary(spectrum_id, modifications)
self.modifications.append(modifications)

def update_spectra_dictionary(self, spectrum_id, modifications):
self.spectra_dictionary[modifications["spectra_quality"]["updated"]].append(spectrum_id) #valid, repaired,...
if ((modifications["spectra_quality"]["updated"] != None) &
(modifications["spectra_quality"]["updated"] != modifications["spectra_quality"]["previous"])):
self.spectra_dictionary[modifications["spectra_quality"]["previous"]].remove(spectrum_id)

def run(self):
self.spectrum_repairer = SpectrumRepairer()
self.spectrum_validator = SpectrumValidator()
self.validated_spectra = []
self.nonvalidated_spectra = []
self.modifications = {}
self.failed_requirements = {}

self.initial_run()

def initial_run(self):
for spectrum_id in range(len(self.spectra)):
self.clean_and_validate_spectrum(spectrum_id)
spectrum = self.spectra[spectrum_id]
modifications, spectrum = self.spectrum_repairer.process_spectrum_store_modifications(spectrum)
self.modifications[spectrum_id] = modifications

self.failed_requirements[spectrum_id] = self.spectrum_validator.process_spectrum_store_failed_filters(
spectrum)
self.update_spectra_quality_lists(spectrum_id)
self.spectra[spectrum_id] = spectrum

# iterate over all failed requirements
# it's almost streamlit
# for the dashboard run should use spectrum id
# for spectrum_id in range(len(self.spectra)):
# if len(self.failed_requirements[spectrum_id]) != 0:
# self.pass_user_validation_info(spectrum_id)
# #todo should we grab here state variable from streamlit - accept or change
# # self.user_approve_repair(spectrum_id)
# # self.user_metadat_change(spectrum_id)

def update_spectra_quality_lists(self, spectrum_id):
"""Will update validated_spectra and nonvalidated_spectra list for this spectrum_id"""
valid_spectrum = True
if len(self.failed_requirements[spectrum_id]) != 0:
valid_spectrum = False
for modification in self.modifications[spectrum_id]:
if modification.validated_by_user is False:
valid_spectrum = False

if valid_spectrum is True:
if spectrum_id not in self.validated_spectra:
self.validated_spectra.append(spectrum_id)
if spectrum_id in self.nonvalidated_spectra:
self.nonvalidated_spectra.remove(spectrum_id)
else:
if spectrum_id not in self.nonvalidated_spectra:
self.nonvalidated_spectra.append(spectrum_id)
if spectrum_id in self.validated_spectra:
self.validated_spectra.remove(spectrum_id)

def return_user_validation_info(self, spectrum_id):
"""
Returns all info related to spectrum_id
"""
assert spectrum_id in self.nonvalidated_spectra

modifications = self.modifications[spectrum_id]
failed_requirements = self.failed_requirements[spectrum_id]

return modifications, failed_requirements, self.spectra[spectrum_id]

def approve_repair(self, spectrum_id, field_name):
"""Accepts every modification done to a field_name"""
# Accepts every modification so far.
for modification in self.modifications[spectrum_id]:
if modification.metadata_field == field_name:
modification.validated_by_user = True
self.update_spectra_quality_lists(spectrum_id)

def approve_all_repairs(self, spectrum_id):
"""Accepts all modifications done for a spectrum"""
for modification in self.modifications[spectrum_id]:
modification.validated_by_user = True
self.update_spectra_quality_lists(spectrum_id)

def decline_last_repair(self, spectrum_id, field_name):
"""Undo the last modification made to a field"""
for mod_idx, modification in enumerate(self.modifications[spectrum_id]):
# Checks if it is the correct metadata field and if it was the last changed made
if modification.metadata_field == field_name and modification.after == self.spectra[spectrum_id].get(field_name):
# undo change
spectrum = self.spectra[spectrum_id]
spectrum.set(field_name, modification.before)
self.spectra[spectrum_id] = spectrum
# remove the modification from the list of modifications
del self.modifications[spectrum_id][mod_idx]
# todo run validation after.

def decline_all_repairs_on_a_field(self, spectrum_id, field_name):
"""Undoes all the repairs for a specific field.

This is achieved by iteratively removing the last added repair"""
nr_of_modifications_to_field = len([modification for modification in self.modifications[spectrum_id]
if modification.metadata_field == field_name])
# Removes all the modifications until the last one was removed.
for _ in range(nr_of_modifications_to_field):
self.decline_last_repair(spectrum_id, field_name)
# todo run validation after.

def decline_all_repairs_spectrum(self, spectrum_id):
"""Undoes all modifications made to a spectrum"""
while len(self.modifications[spectrum_id]) > 0:
for mod_idx, modification in enumerate(self.modifications[spectrum_id]):
field_name = modification.metadata_field
# Checks if it was the last changed made
if modification.after == self.spectra[spectrum_id].get(field_name):
# undo change
spectrum = self.spectra[spectrum_id]
spectrum.set(field_name, modification.before)
self.spectra[spectrum_id] = spectrum
# remove the modification from the list of modifications
del self.modifications[spectrum_id][mod_idx]

def decline_wrapper(self, spectrum_id, field_name, only_last_repair: bool):
if field_name is None:
self.decline_all_repairs_spectrum(spectrum_id)
elif only_last_repair:
self.decline_last_repair(spectrum_id, field_name)
else:
self.decline_all_repairs_on_a_field(spectrum_id, field_name)

self.failed_requirements[spectrum_id] = self.spectrum_validator.process_spectrum_store_failed_filters(self.spectra[spectrum_id])
self.update_spectra_quality_lists(spectrum_id)

def user_metadata_change(self, field_name, user_input, spectrum_id):
"""This function takes user defined metadata and rewrites the required field in spectra
The info on user-defined modifications is added to modifications dictionary and mandatory
validation is rerun.
"""
# Add a user defined modification
self.modifications[spectrum_id].append(
Modification(metadata_field=field_name, before=self.spectra[spectrum_id].get(field_name),
after=user_input, logging_message="Manual change", validated_by_user=True))
self.spectra[spectrum_id].set(field_name, user_input)
self.failed_requirements[spectrum_id] = self.spectrum_validator.process_spectrum_store_failed_filters(self.spectra[spectrum_id])
self.update_spectra_quality_lists(spectrum_id)


def user_rerun_repair(self, spectrum_id, rerun: bool):
'''
The function behind user's choice to rerun the repairment and validation
Should be linked to a button in a dashboard
'''
if rerun: #todo do we even need it??
self.modifications[spectrum_id] = self.spectrum_repairer.process_spectrum_store_modifications(self.spectra[spectrum_id])
self.failed_requirements[spectrum_id] = self.spectrum_validator.process_spectrum_store_failed_filters(self.spectra[spectrum_id])




63 changes: 63 additions & 0 deletions library_spectra_validation/tests/test_library_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from library_handler import LibraryHandler


def test_init_library_handler():
LibraryHandler("./examples/test_case_correct.mgf")


def test_approve_repairs():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.approve_repair(spectrum_id=spectrum_id, field_name="inchi")
assert library_handler.modifications[spectrum_id][0].validated_by_user is True


def test_approve_all_repairs():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.approve_all_repairs(spectrum_id=spectrum_id)
for modification in library_handler.modifications[spectrum_id]:
assert modification.validated_by_user is True


def test_decline_last_repairs():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.decline_last_repair(spectrum_id=spectrum_id, field_name="inchi")
assert len(library_handler.modifications[0]) == 1


def test_decline_all_repairs_on_a_field():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
# todo add test that actually has multiple repairs for one field
library_handler.decline_all_repairs_on_a_field(spectrum_id=spectrum_id, field_name="inchi")
assert len(library_handler.modifications[0]) == 1


def test_decline_all_repairs_spectrum():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.decline_all_repairs_spectrum(spectrum_id=spectrum_id)
assert len(library_handler.modifications[0]) == 0

# todo check that change is undone


def test_decline_wrapper():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
original_spectrum = library_handler.spectra[0]
spectrum_id = 0
library_handler.decline_wrapper(spectrum_id=spectrum_id, field_name=None, only_last_repair=False)
assert len(library_handler.modifications[0]) == 0
# check that changes were undone
assert original_spectrum == library_handler.spectra[spectrum_id]
assert len(library_handler.failed_requirements[spectrum_id]) == 3
assert spectrum_id in library_handler.nonvalidated_spectra


def test_user_metadata_change():
library_handler = LibraryHandler("./examples/test_case_correct.mgf")
spectrum_id = 0
library_handler.user_metadata_change(spectrum_id=spectrum_id, field_name="smiles", user_input="CCC")
assert library_handler.spectra[spectrum_id].get("smiles") == "CCC"
4 changes: 0 additions & 4 deletions library_spectra_validation/tests/test_spectra_loading.py

This file was deleted.

42 changes: 31 additions & 11 deletions library_spectra_validation/validation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,28 @@
"""

import logging
from typing import Iterable, List, Optional, Union
from typing import Iterable, List, Optional, Union, Tuple
from matchms.filtering.SpectrumProcessor import SpectrumProcessor
from matchms import Spectrum

logger = logging.getLogger("matchms")


class Modification:
def __init__(self, metadata_field, before, after, logging_message, validated_by_user):
self.metadata_field = metadata_field
self.before = before
self.after = after
# self.original =
self.logging_message = logging_message
self.validated_by_user = validated_by_user


class RequirementFailure:
def __init__(self, metadata_field, logging_message):
self.metadata_field = metadata_field
self.logging_message = logging_message


def find_modifications(spectrum_old, spectrum_new, logging_message: str):
"""Checks which modifications have been made in a filter step"""
modifications = []
Expand All @@ -34,7 +41,7 @@ def find_modifications(spectrum_old, spectrum_new, logging_message: str):
modifications.append(
Modification(metadata_field=metadata_field,
before=spectrum_old.get(metadata_field),
after=spectrum_new(metadata_field),
after=spectrum_new.get(metadata_field),
logging_message=logging_message,
validated_by_user=False))
return modifications
Expand All @@ -50,7 +57,7 @@ def process_spectrum(self, spectrum,
processing_report=None):
raise AttributeError("process spectrum is not a valid method of SpectrumValidator")

def process_spectrum_store_modifications(self, spectrum) -> List[Modification]:
def process_spectrum_store_modifications(self, spectrum) -> Tuple[List[Modification], Spectrum]:
if not self.filters:
raise TypeError("No filters to process")
modifications = []
Expand All @@ -64,21 +71,31 @@ def process_spectrum_store_modifications(self, spectrum) -> List[Modification]:
if spectrum_out is None:
raise AttributeError("SpectrumRepairer is only expected to repair spectra, not set to None")
spectrum = spectrum_out
return modifications
return modifications, spectrum


class SpectrumValidator(SpectrumProcessor):
def __init__(self):
# todo add the fields each requirement checks.
fields_checked_by_filter = {filter_name: [fields_checked]}
self.fields_checked_by_filter = {
"require_precursor_mz": ["precursor_mz"],
"require_valid_annotation": ["smiles", "inchi", "inchikey"],
"require_correct_ionmode": ["ionmode", "adduct", "charge"],
# "require_parent_mass_match_smiles": ["smiles", "parent_mass"]
}
# todo require adduct, precursor mz and parent mass match.
# todo add all the checks for formatting. That everything is filled and of the expected format.
super().__init__(predefined_pipeline=None,
additional_filters=list(fields_checked_by_filter.keys()))

additional_filters=("require_precursor_mz",
"require_valid_annotation",
("require_correct_ionmode", {"ion_mode_to_keep": "both"}),
# ("require_parent_mass_match_smiles", {'mass_tolerance': 0.1}),
))
# todo add require parent mass match smiles after matchms release.
def process_spectrum(self, spectrum,
processing_report=None):
raise AttributeError("process spectrum is not a valid method of SpectrumValidator")

def process_spectrum_store_failed_filters(self, spectrum) -> List[Modification]:
def process_spectrum_store_failed_filters(self, spectrum) -> List[RequirementFailure]:
if not self.filters:
raise TypeError("No filters to process")
failed_requirements = []
Expand All @@ -87,5 +104,8 @@ def process_spectrum_store_failed_filters(self, spectrum) -> List[Modification]:
logging_message = ""
spectrum_out = filter_func(spectrum)
if spectrum_out is None:
failed_requirements += logging_message
fields_changed = self.fields_checked_by_filter[filter_func.__name__]
for field_changed in fields_changed:
failed_requirements.append(RequirementFailure(field_changed,
logging_message))
return failed_requirements
Loading