Skip to content

Commit

Permalink
get program working with drift parsing, #337
Browse files Browse the repository at this point in the history
  • Loading branch information
vreuter committed Nov 24, 2024
1 parent 9bd33a0 commit c6f1277
Showing 1 changed file with 53 additions and 11 deletions.
64 changes: 53 additions & 11 deletions bin/cli/run_signal_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
from looptrace.SpotPicker import SpotPicker
from looptrace.utilities import find_first_option, get_either, wrap_exception, wrap_error_message

FieldOfViewName: TypeAlias = str
RawTimepoint: TypeAlias = int

_A = TypeVar("_A")
_B = TypeVar("_B")
_K = TypeVar("_K")

SIGNAL_CHANNEL_COLUMN = "signalChannel"

Expand Down Expand Up @@ -153,38 +157,46 @@ def workflow(

spot_drift_file: Path = H.drift_correction_file__coarse
logging.info("Reading nuclei drift file: %s", spot_drift_file)
all_spot_drifts: pd.DataFrame = pd.read_csv(spot_drift_file, index_col=False)
all_spot_drifts: Mapping[tuple[FieldOfViewName, RawTimepoint], DriftRecord] = read_spot_drifts_file(spot_drift_file)

nuclei_drift_file: Path = H.nuclei_coarse_drift_correction_file
logging.info("Reading nuclei drift file: %s", nuclei_drift_file)
all_nuclei_drifts: pd.DataFrame = pd.read_csv(nuclei_drift_file, index_col=False)
all_nuclei_drifts: Mapping[FieldOfViewName, DriftRecord] = read_signal_drifts_file(nuclei_drift_file)

# TODO: discard regional spot timepoints from the bigger collection: https://github.com/gerlichlab/looptrace/issues/376
for spec in analysis_specs:
# Get the ROIs of this type.
roi_type: RoiType = spec.roi_type
if roi_type == RoiType.LocusSpecific:
# TODO: will need to account for pixels vs. nanometers
# TODO: will need to account for different headers (e.g., z_px and z rather than zc, yc, etc.)
logging.error("Cross-channel analysis for locus-specific spots isn't yet supported, skipping!")
continue
logging.info("Analyzing signal for ROI type '%s'", roi_type.name)
rois_file: Path = getattr(H, roi_type.file_attribute_on_image_handler)
all_rois: pd.DataFrame = pd.read_csv(rois_file, index_col=False)

# Build up the records for this ROI type, for all FOVs.
by_raw_channel: Mapping[int, list[dict]] = defaultdict
# TODO: refactor with https://github.com/gerlichlab/gertils/issues/32
for fov, img in S.iter_fov_img_pairs():
for fov, image_stack in S.iter_fov_img_pairs():
logging.info(f"Analysing signal for FOV: {fov}")
nuc_drift_curr_fov: pd.DataFrame = all_nuclei_drifts[all_nuclei_drifts[FIELD_OF_VIEW_COLUMN] == fov]
logging.debug(f"Shape of nuclei drifts ({type(nuc_drift_curr_fov).__name__}): {nuc_drift_curr_fov.shape}")
spot_drifts_curr_fov: pd.DataFrame = all_spot_drifts[all_spot_drifts[FIELD_OF_VIEW_COLUMN] == fov]
logging.debug(f"Shape of spot drifts ({type(spot_drifts_curr_fov).__name__}): {spot_drifts_curr_fov.shape}")
nuc_drift: DriftRecord = all_nuclei_drifts[fov]
rois: pd.DataFrame = all_rois[all_rois[FIELD_OF_VIEW_COLUMN] == fov]
logging.debug("ROI count: %d", rois.shape[0])
print("ROI count: {rois.shape[0]}") # DEBUG
for _, r in rois.iterrows():
spot_drift = spot_drifts_curr_fov[spot_drifts_curr_fov[TIMEPOINT_COLUMN] == r[TIMEPOINT_COLUMN]]
timepoint: RawTimepoint = r[TIMEPOINT_COLUMN]
img = image_stack[timepoint]
print(f"img dim: {img.shape}") # DEBUG
spot_drift: DriftRecord = all_spot_drifts[(fov, timepoint)]
print(f"Spot drift ({type(spot_drift).__name__}): {spot_drift}") # DEBUG
pt0: ImagePoint3D = get_centroid_from_record(r)
print(f"Point: {pt0}") # DEBUG
dc_pt: ImagePoint3D = ImagePoint3D(
z=pt0.z - nuc_drift_curr_fov[Z_PX_COARSE] + spot_drift[Z_PX_COARSE],
y=pt0.y - nuc_drift_curr_fov[Y_PX_COARSE] + spot_drift[Y_PX_COARSE],
x=pt0.x - nuc_drift_curr_fov[X_PX_COARSE] + spot_drift[X_PX_COARSE],
z=pt0.z - nuc_drift.z + spot_drift.z,
y=pt0.y - nuc_drift.y + spot_drift.y,
x=pt0.x - nuc_drift.x + spot_drift.x,
)
for stats in compute_pixel_statistics(
img=img,
Expand Down Expand Up @@ -223,6 +235,36 @@ def proc1(acc: State, a: _A) -> State:
return Seq.of_iterable(inputs).fold(proc1, Result.Ok(Seq()))


@dataclass(kw_only=True, frozen=True)
class DriftRecord:
x: float
y: float
z: float


@curry_flip(1)
def read_drift_file(drift_file: Path, get_key: Callable[[pd.Series], _K]) -> Mapping[_K, "DriftRecord"]:
drifts: pd.DataFrame = pd.read_csv(drift_file, index_col=False)
colnames = {Z_PX_COARSE: "z", Y_PX_COARSE: "y", X_PX_COARSE: "x"}
result: Mapping[_K, DriftRecord] = {}
for key, rec in drifts.rename(columns=colnames).apply(
lambda row: (get_key(row), DriftRecord(**row[colnames.values()].to_dict())),
axis=1,
):
if key in result:
raise ValueError(f"Repeated key in drift file ({drift_file}): {key}")
result[key] = rec
return result


read_signal_drifts_file: Callable[[Path], Mapping[FieldOfViewName, "DriftRecord"]] = \
read_drift_file(lambda row: row[FIELD_OF_VIEW_COLUMN])

read_spot_drifts_file: Callable[[Path], Mapping[tuple[FieldOfViewName, RawTimepoint], "DriftRecord"]] = \
read_drift_file(lambda row: (row[FIELD_OF_VIEW_COLUMN], row[TIMEPOINT_COLUMN]))



def _ensure_unique(items: Iterable[_A]) -> Result[set[_A], str]:
try:
counts = Counter(items)
Expand Down

0 comments on commit c6f1277

Please sign in to comment.