Skip to content

Commit

Permalink
use the drift correction for the signal analysis, #337
Browse files Browse the repository at this point in the history
  • Loading branch information
vreuter committed Nov 24, 2024
1 parent 37030ba commit 082b2a9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
45 changes: 24 additions & 21 deletions bin/cli/run_signal_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from operator import itemgetter
from pathlib import Path
import sys
from typing import TYPE_CHECKING, Iterable, Mapping, TypeAlias, TypeVar
from typing import Iterable, Mapping, TypeAlias, TypeVar

from expression import Option, Result, compose, snd
from expression.collections import Seq, seq
Expand All @@ -23,6 +23,7 @@
from spotfishing.roi_tools import get_centroid_from_record

from looptrace import FIELD_OF_VIEW_COLUMN
from looptrace.Drifter import TIMEPOINT_COLUMN, X_PX_COARSE, Y_PX_COARSE, Z_PX_COARSE
from looptrace.ImageHandler import ImageHandler
from looptrace.SpotPicker import SpotPicker
from looptrace.utilities import find_first_option, get_either, wrap_exception, wrap_error_message
Expand Down Expand Up @@ -120,24 +121,6 @@ def from_mapping(cls, data: Mapping[str, object]) -> Result["AnalyticalSpecifica
)))


# TODO: refactor with https://github.com/gerlichlab/gertils/issues/32
def run_cross_channel_signal_analysis(
*,
img: npt.ArrayLike,
roi_diameter: int,
channels: Iterable[ImagingChannel],
points: Iterable[ImagePoint3D],
) -> Iterable[dict[str, PixelStatValue]]:
for pt in points:
yield compute_pixel_statistics(
img=img,
pt=pt,
channels=channels,
diameter=roi_diameter,
channel_column=SIGNAL_CHANNEL_COLUMN,
)


def workflow(*, rounds_config: ExtantFile, params_config: ExtantFile, maybe_signal_config: Option[ExtantFile]) -> None:
match maybe_signal_config:
case option.Option(tag="none", none=_):
Expand All @@ -159,6 +142,15 @@ def workflow(*, rounds_config: ExtantFile, params_config: ExtantFile, maybe_sign
H = ImageHandler(rounds_config=rounds_config, params_config=params_config)
S = SpotPicker(H)

spot_drift_file: Path = H.drift_correction_file__coarse
logging.info("Reading nuclei drift file: %s", spot_drift_file)
all_spot_drifts: pd.DataFrame = pd.read_csv(spot_drift_file, index_col=False)

nuclei_drift_file: Path = H.nuclei_coarse_drift_correction_file
logging.info("Reading nuclei drift file: %s", nuclei_drift_file)
all_nuclei_drifts: pd.DataFrame = pd.read_csv(nuclei_drift_file, index_col=False)

# TODO: discard regional spot timepoints from the bigger collection: https://github.com/gerlichlab/looptrace/issues/376
for spec in analysis_specs:
# Get the ROIs of this type.
roi_type: RoiType = spec.roi_type
Expand All @@ -168,15 +160,26 @@ def workflow(*, rounds_config: ExtantFile, params_config: ExtantFile, maybe_sign

# Build up the records for this ROI type, for all FOVs.
by_raw_channel: Mapping[int, list[dict]] = defaultdict
# TODO: refactor with https://github.com/gerlichlab/gertils/issues/32
for fov, img in S.iter_fov_img_pairs():
logging.info(f"Analysing signal for FOV: {fov}")
nuc_drift_curr_fov: pd.DataFrame = all_nuclei_drifts[all_nuclei_drifts[FIELD_OF_VIEW_COLUMN] == fov]
logging.debug(f"Shape of nuclei drifts ({type(nuc_drift_curr_fov).__name__}): {nuc_drift_curr_fov.shape}")
spot_drifts_curr_fov: pd.DataFrame = all_spot_drifts[all_spot_drifts[FIELD_OF_VIEW_COLUMN] == fov]
logging.debug(f"Shape of spot drifts ({type(spot_drifts_curr_fov).__name__}): {spot_drifts_curr_fov.shape}")
rois: pd.DataFrame = all_rois[all_rois[FIELD_OF_VIEW_COLUMN] == fov]
logging.debug("ROI count: %d", rois.shape[0])
for _, r in rois.iterrows():
pt: ImagePoint3D = get_centroid_from_record(r)
spot_drift = spot_drifts_curr_fov[spot_drifts_curr_fov[TIMEPOINT_COLUMN] == r[TIMEPOINT_COLUMN]]
pt0: ImagePoint3D = get_centroid_from_record(r)
dc_pt: ImagePoint3D = ImagePoint3D(
z=pt0.z - nuc_drift_curr_fov[Z_PX_COARSE] + spot_drift[Z_PX_COARSE],
y=pt0.y - nuc_drift_curr_fov[Y_PX_COARSE] + spot_drift[Y_PX_COARSE],
x=pt0.x - nuc_drift_curr_fov[X_PX_COARSE] + spot_drift[X_PX_COARSE],
)
for stats in compute_pixel_statistics(
img=img,
pt=pt,
pt=dc_pt,
channels=spec.channels,
diameter=spec.roi_diameter,
channel_column=SIGNAL_CHANNEL_COLUMN,
Expand Down
4 changes: 4 additions & 0 deletions looptrace/ImageHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def nuclear_masks_visualisation_data_path(self) -> Path:
def nuclei_channel(self) -> int:
return self.config["nuc_channel"]

@property
def nuclei_coarse_drift_correction_file(self) -> Path:
return self.get_dc_filepath(prefix="nuclei", suffix="_coarse.csv")

@property
def nuclei_filtered_spots_file_path(self) -> Path:
return self.proximity_accepted_spots_file_path.with_suffix(".nuclei_filtered.csv")
Expand Down
2 changes: 1 addition & 1 deletion looptrace/NucDetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def do_in_3d(self) -> bool:
@property
def drift_correction_file__coarse(self) -> Path:
"""Path to the file with coarse drift correction information for nuclei"""
return self.image_handler.get_dc_filepath(prefix="nuclei", suffix="_coarse.csv")
return self.image_handler.nuclei_coarse_drift_correction_file

@property
def drift_correction_file__full(self) -> Path:
Expand Down

0 comments on commit 082b2a9

Please sign in to comment.