diff --git a/bin/cli/run_signal_analysis.py b/bin/cli/run_signal_analysis.py index f039e630..dc537814 100644 --- a/bin/cli/run_signal_analysis.py +++ b/bin/cli/run_signal_analysis.py @@ -26,8 +26,12 @@ from looptrace.SpotPicker import SpotPicker from looptrace.utilities import find_first_option, get_either, wrap_exception, wrap_error_message +FieldOfViewName: TypeAlias = str +RawTimepoint: TypeAlias = int + _A = TypeVar("_A") _B = TypeVar("_B") +_K = TypeVar("_K") SIGNAL_CHANNEL_COLUMN = "signalChannel" @@ -153,16 +157,21 @@ def workflow( spot_drift_file: Path = H.drift_correction_file__coarse logging.info("Reading nuclei drift file: %s", spot_drift_file) - all_spot_drifts: pd.DataFrame = pd.read_csv(spot_drift_file, index_col=False) + all_spot_drifts: Mapping[tuple[FieldOfViewName, RawTimepoint], DriftRecord] = read_spot_drifts_file(spot_drift_file) nuclei_drift_file: Path = H.nuclei_coarse_drift_correction_file logging.info("Reading nuclei drift file: %s", nuclei_drift_file) - all_nuclei_drifts: pd.DataFrame = pd.read_csv(nuclei_drift_file, index_col=False) + all_nuclei_drifts: Mapping[FieldOfViewName, DriftRecord] = read_signal_drifts_file(nuclei_drift_file) # TODO: discard regional spot timepoints from the bigger collection: https://github.com/gerlichlab/looptrace/issues/376 for spec in analysis_specs: # Get the ROIs of this type. roi_type: RoiType = spec.roi_type + if roi_type == RoiType.LocusSpecific: + # TODO: will need to account for pixels vs. nanometers + # TODO: will need to account for different headers (e.g., z_px and z rather than zc, yc, etc.) + logging.error("Cross-channel analysis for locus-specific spots isn't yet supported, skipping!") + continue logging.info("Analyzing signal for ROI type '%s'", roi_type.name) rois_file: Path = getattr(H, roi_type.file_attribute_on_image_handler) all_rois: pd.DataFrame = pd.read_csv(rois_file, index_col=False) @@ -170,21 +179,24 @@ def workflow( # Build up the records for this ROI type, for all FOVs. by_raw_channel: Mapping[int, list[dict]] = defaultdict # TODO: refactor with https://github.com/gerlichlab/gertils/issues/32 - for fov, img in S.iter_fov_img_pairs(): + for fov, image_stack in S.iter_fov_img_pairs(): logging.info(f"Analysing signal for FOV: {fov}") - nuc_drift_curr_fov: pd.DataFrame = all_nuclei_drifts[all_nuclei_drifts[FIELD_OF_VIEW_COLUMN] == fov] - logging.debug(f"Shape of nuclei drifts ({type(nuc_drift_curr_fov).__name__}): {nuc_drift_curr_fov.shape}") - spot_drifts_curr_fov: pd.DataFrame = all_spot_drifts[all_spot_drifts[FIELD_OF_VIEW_COLUMN] == fov] - logging.debug(f"Shape of spot drifts ({type(spot_drifts_curr_fov).__name__}): {spot_drifts_curr_fov.shape}") + nuc_drift: DriftRecord = all_nuclei_drifts[fov] rois: pd.DataFrame = all_rois[all_rois[FIELD_OF_VIEW_COLUMN] == fov] logging.debug("ROI count: %d", rois.shape[0]) + print("ROI count: {rois.shape[0]}") # DEBUG for _, r in rois.iterrows(): - spot_drift = spot_drifts_curr_fov[spot_drifts_curr_fov[TIMEPOINT_COLUMN] == r[TIMEPOINT_COLUMN]] + timepoint: RawTimepoint = r[TIMEPOINT_COLUMN] + img = image_stack[timepoint] + print(f"img dim: {img.shape}") # DEBUG + spot_drift: DriftRecord = all_spot_drifts[(fov, timepoint)] + print(f"Spot drift ({type(spot_drift).__name__}): {spot_drift}") # DEBUG pt0: ImagePoint3D = get_centroid_from_record(r) + print(f"Point: {pt0}") # DEBUG dc_pt: ImagePoint3D = ImagePoint3D( - z=pt0.z - nuc_drift_curr_fov[Z_PX_COARSE] + spot_drift[Z_PX_COARSE], - y=pt0.y - nuc_drift_curr_fov[Y_PX_COARSE] + spot_drift[Y_PX_COARSE], - x=pt0.x - nuc_drift_curr_fov[X_PX_COARSE] + spot_drift[X_PX_COARSE], + z=pt0.z - nuc_drift.z + spot_drift.z, + y=pt0.y - nuc_drift.y + spot_drift.y, + x=pt0.x - nuc_drift.x + spot_drift.x, ) for stats in compute_pixel_statistics( img=img, @@ -223,6 +235,36 @@ def proc1(acc: State, a: _A) -> State: return Seq.of_iterable(inputs).fold(proc1, Result.Ok(Seq())) +@dataclass(kw_only=True, frozen=True) +class DriftRecord: + x: float + y: float + z: float + + +@curry_flip(1) +def read_drift_file(drift_file: Path, get_key: Callable[[pd.Series], _K]) -> Mapping[_K, "DriftRecord"]: + drifts: pd.DataFrame = pd.read_csv(drift_file, index_col=False) + colnames = {Z_PX_COARSE: "z", Y_PX_COARSE: "y", X_PX_COARSE: "x"} + result: Mapping[_K, DriftRecord] = {} + for key, rec in drifts.rename(columns=colnames).apply( + lambda row: (get_key(row), DriftRecord(**row[colnames.values()].to_dict())), + axis=1, + ): + if key in result: + raise ValueError(f"Repeated key in drift file ({drift_file}): {key}") + result[key] = rec + return result + + +read_signal_drifts_file: Callable[[Path], Mapping[FieldOfViewName, "DriftRecord"]] = \ + read_drift_file(lambda row: row[FIELD_OF_VIEW_COLUMN]) + +read_spot_drifts_file: Callable[[Path], Mapping[tuple[FieldOfViewName, RawTimepoint], "DriftRecord"]] = \ + read_drift_file(lambda row: (row[FIELD_OF_VIEW_COLUMN], row[TIMEPOINT_COLUMN])) + + + def _ensure_unique(items: Iterable[_A]) -> Result[set[_A], str]: try: counts = Counter(items)