Skip to content

Commit

Permalink
Improve prediction widget input validation and add halo to iterative …
Browse files Browse the repository at this point in the history
…prediction
  • Loading branch information
qin-yu committed Mar 11, 2024
1 parent 39e6b47 commit b65ae75
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 52 deletions.
43 changes: 21 additions & 22 deletions plantseg/viewer/widget/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws
from plantseg.viewer.widget.utils import return_value_if_widget
from plantseg.viewer.widget.utils import start_threading_process, start_prediction_process, create_layer_name, layer_properties
from plantseg.viewer.widget.validation import change_handler, get_image_volume_from_layer, widgets_inactive
from plantseg.viewer.widget.validation import _on_prediction_input_image_change, widgets_inactive

ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
MPS = ['mps'] if torch.backends.mps.is_available() else []
Expand Down Expand Up @@ -109,24 +109,9 @@ def widget_unet_predictions(viewer: Viewer,
)


@change_handler(widget_unet_predictions.image, init=False)
def _image_change(image: Image):
shape = get_image_volume_from_layer(image).shape
ndim = len(shape)
widget_unet_predictions.image.tooltip = f"Shape: {shape}"

size_z = widget_unet_predictions.patch_size[0]
halo_z = widget_unet_predictions.patch_halo[0]
if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget
size_z.value = 0
halo_z.value = 0
widgets_inactive(size_z, halo_z, active=False)
elif ndim == 3 and shape[0] > 1: # 3D
size_z.value = min(64, shape[0]) # TODO: fetch model default
halo_z.value = 8
widgets_inactive(size_z, halo_z, active=True)
else:
raise ValueError(f"Unsupported number of dimensions: {ndim}")
@widget_unet_predictions.image.changed.connect
def _on_widget_unet_predictions_image_change(image: Image):
_on_prediction_input_image_change(widget_unet_predictions, image)


def _on_any_metadata_changed(dimensionality, modality, output_type):
Expand Down Expand Up @@ -231,9 +216,14 @@ def on_done(result):
return future


def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, single_batch_mode, device):
func = partial(unet_predictions, model_name=model_name, patch=patch_size, single_batch_mode=single_batch_mode,
device=device)
@widget_test_all_unet_predictions.image.changed.connect
def _on_widget_test_all_unet_predictions_image_change(image: Image):
_on_prediction_input_image_change(widget_test_all_unet_predictions, image)


def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patch_size, patch_halo, single_batch_mode, device):
func = partial(unet_predictions, model_name=model_name, patch=patch_size, patch_halo=patch_halo,
single_batch_mode=single_batch_mode, device=device)
for i in range(num_iterations - 1):
pmap = func(pmap)
pmap = image_gaussian_smoothing(image=pmap, sigma=sigma)
Expand All @@ -258,6 +248,8 @@ def _compute_iterative_predictions(pmap, model_name, num_iterations, sigma, patc
'min': 0.},
patch_size={'label': 'Patch size',
'tooltip': 'Patch size use to processed the data.'},
patch_halo={'label': 'Patch halo',
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
single_patch={'label': 'Single Patch',
'tooltip': 'If True, a single patch will be processed at a time to save memory.'},
device={'label': 'Device',
Expand All @@ -268,6 +260,7 @@ def widget_iterative_unet_predictions(image: Image,
num_iterations: int = 2,
sigma: float = 1.0,
patch_size: Tuple[int, int, int] = (80, 170, 170),
patch_halo: Tuple[int, int, int] = (8, 16, 16),
single_patch: bool = True,
device: str = ALL_DEVICES[0]) -> Future[LayerDataTuple]:
out_name = create_layer_name(image.name, f'iterative-{model_name}-x{num_iterations}')
Expand All @@ -281,6 +274,7 @@ def widget_iterative_unet_predictions(image: Image,
num_iterations=num_iterations,
sigma=sigma,
patch_size=patch_size,
patch_halo=patch_halo,
single_batch_mode=single_patch,
device=device)

Expand All @@ -303,6 +297,11 @@ def _on_model_name_changed_iterative(model_name: str):
widget_iterative_unet_predictions.patch_size.value = tuple(patch_size)


@widget_iterative_unet_predictions.image.changed.connect
def _on_widget_iterative_unet_predictions_image_change(image: Image):
_on_prediction_input_image_change(widget_iterative_unet_predictions, image)


@magicgui(call_button='Add Custom Model',
new_model_name={'label': 'New model name'},
model_location={'label': 'Model location',
Expand Down
59 changes: 29 additions & 30 deletions plantseg/viewer/widget/validation.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,7 @@
"""Widget input validation"""

from psygnal import Signal
from functools import wraps

def change_handler(*widgets, init=True, debug=False):
def decorator_change_handler(handler):
@wraps(handler)
def wrapper(*args):
source = Signal.sender()
emitter = Signal.current_emitter()
if debug:
# print(f"{emitter}: {source} = {args!r}")
print(f"EVENT '{str(emitter.name)}': {source.name:>20} = {args!r}")
# print(f" {source.name:>14}.value = {source.value}")
return handler(*args)

for widget in widgets:
widget.changed.connect(wrapper)
if init:
widget.changed(widget.value)
return wrapper

return decorator_change_handler


def get_image_volume_from_layer(image):
"""Used for widget parameter validation in `change_handler`s."""
image = image.data[0] if image.multiscale else image.data
if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")):
image = np.asanyarray(image)
return image
from napari.layers import Image
from magicgui.widgets import Widget


def widgets_inactive(*widgets, active):
Expand All @@ -44,3 +16,30 @@ def widgets_valid(*widgets, valid):
widget.native.setStyleSheet("" if valid else "background-color: lightcoral")


def get_image_volume_from_layer(image):
"""Used for widget parameter validation in change-handlers."""
image = image.data[0] if image.multiscale else image.data
if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")):
from numpy import asanyarray

image = asanyarray(image)
return image


def _on_prediction_input_image_change(widget: Widget, image: Image):
shape = get_image_volume_from_layer(image).shape
ndim = len(shape)
widget.image.tooltip = f"Shape: {shape}"

size_z = widget.patch_size[0]
halo_z = widget.patch_halo[0]
if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget
size_z.value = 0
halo_z.value = 0
widgets_inactive(size_z, halo_z, active=False)
elif ndim == 3 and shape[0] > 1: # 3D
size_z.value = min(64, shape[0]) # TODO: fetch model default
halo_z.value = 8
widgets_inactive(size_z, halo_z, active=True)
else:
raise ValueError(f"Unsupported number of dimensions: {ndim}")

0 comments on commit b65ae75

Please sign in to comment.