Skip to content

Commit

Permalink
Add widget input validation/adaptation and add halo to Napari GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
qin-yu committed Mar 7, 2024
1 parent a009df8 commit 2072f01
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 5 deletions.
6 changes: 4 additions & 2 deletions plantseg/predictions/functional/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int]
Defaults to 'cuda'.
model_update (bool, optional): if True will update the model to the latest version. Defaults to False.
disable_tqdm (bool, optional): if True will disable tqdm progress bar. Defaults to False.
output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is
output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is
multi-channel 3D pmap. Now `4` only used in `widget_unet_predictions()`.
Returns:
Expand All @@ -45,7 +45,9 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int]
state = state['model_state_dict']
model.load_state_dict(state)

patch_halo = get_patch_halo(model_name)
patch_halo = kwargs.get('patch_halo', None)
if patch_halo is None:
patch_halo = get_patch_halo(model_name)
predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'],
out_channels=model_config['out_channels'], device=device, patch=patch,
patch_halo=patch_halo, single_batch_mode=single_batch_mode, headless=False,
Expand Down
34 changes: 31 additions & 3 deletions plantseg/viewer/widget/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws
from plantseg.viewer.widget.utils import return_value_if_widget
from plantseg.viewer.widget.utils import start_threading_process, start_prediction_process, create_layer_name, layer_properties
from plantseg.viewer.widget.validation import change_handler, get_image_volume_from_layer, widgets_inactive

ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
MPS = ['mps'] if torch.backends.mps.is_available() else []
Expand Down Expand Up @@ -61,6 +62,8 @@ def unet_predictions_wrapper(raw, device, **kwargs):
'choices': LIST_ALL_MODELS},
patch_size={'label': 'Patch size',
'tooltip': 'Patch size use to processed the data.'},
patch_halo={'label': 'Patch halo',
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
single_patch={'label': 'Single Patch',
'tooltip': 'If True, a single patch will be processed at a time to save memory.'},
device={'label': 'Device',
Expand All @@ -73,6 +76,7 @@ def widget_unet_predictions(viewer: Viewer,
modality: str = 'All',
output_type: str = 'All',
patch_size: Tuple[int, int, int] = (80, 170, 170),
patch_halo: Tuple[int, int, int] = (8, 16, 16),
single_patch: bool = True,
device: str = ALL_DEVICES[0], ) -> Future[LayerDataTuple]:
out_name = create_layer_name(image.name, model_name)
Expand All @@ -85,7 +89,7 @@ def widget_unet_predictions(viewer: Viewer,
layer_kwargs['metadata']['pmap'] = True # this is used to warn the user that the layer is a pmap

layer_type = 'image'
step_kwargs = dict(model_name=model_name, patch=patch_size, single_batch_mode=single_patch)
step_kwargs = dict(model_name=model_name, patch=patch_size, patch_halo=patch_halo, single_batch_mode=single_patch)

return start_prediction_process(unet_predictions_wrapper,
runtime_kwargs={'raw': image.data,
Expand All @@ -105,6 +109,26 @@ def widget_unet_predictions(viewer: Viewer,
)


@change_handler(widget_unet_predictions.image, init=False)
def _image_change(image: Image):
shape = get_image_volume_from_layer(image).shape
ndim = len(shape)
widget_unet_predictions.image.tooltip = f"Shape: {shape}"

size_z = widget_unet_predictions.patch_size[0]
halo_z = widget_unet_predictions.patch_halo[0]
if ndim == 2 or (ndim == 3 and shape[0] == 1): # 2D image imported by Napari thus no Z, or by PlantSeg widget
size_z.value = 0
halo_z.value = 0
widgets_inactive(size_z, halo_z, active=False)
elif ndim == 3 and shape[0] > 1: # 3D
size_z.value = min(64, shape[0]) # TODO: fetch model default
halo_z.value = 8
widgets_inactive(size_z, halo_z, active=True)
else:
raise ValueError(f"Unsupported number of dimensions: {ndim}")


def _on_any_metadata_changed(dimensionality, modality, output_type):
dimensionality = [dimensionality] if dimensionality != 'All' else None
modality = [modality] if modality != 'All' else None
Expand Down Expand Up @@ -152,7 +176,7 @@ def _on_model_name_changed(model_name: str):
widget_unet_predictions.model_name.tooltip = f'Select a pretrained model. Current model description: {description}'


def _compute_multiple_predictions(image, patch_size, device):
def _compute_multiple_predictions(image, patch_size, patch_halo, device):
out_layers = []
for i, model_name in enumerate(list_models()):

Expand All @@ -167,7 +191,7 @@ def _compute_multiple_predictions(image, patch_size, device):
layer_type = 'image'
try:
pmap = unet_predictions(raw=image.data, model_name=model_name, patch=patch_size, single_batch_mode=True,
device=device)
device=device, patch_halo=patch_halo)
out_layers.append((pmap, layer_kwargs, layer_type))

except Exception as e:
Expand All @@ -181,15 +205,19 @@ def _compute_multiple_predictions(image, patch_size, device):
'tooltip': 'Raw image to be processed with a neural network.'},
patch_size={'label': 'Patch size',
'tooltip': 'Patch size use to processed the data.'},
patch_halo={'label': 'Patch halo',
'tooltip': 'Patch halo is extra padding for correct prediction on image boarder.'},
device={'label': 'Device',
'choices': ALL_DEVICES}
)
def widget_test_all_unet_predictions(image: Image,
patch_size: Tuple[int, int, int] = (80, 170, 170),
patch_halo: Tuple[int, int, int] = (2, 4, 4),
device: str = ALL_DEVICES[0]) -> Future[List[LayerDataTuple]]:
func = thread_worker(partial(_compute_multiple_predictions,
image=image,
patch_size=patch_size,
patch_halo=patch_halo,
device=device))

future = Future()
Expand Down
46 changes: 46 additions & 0 deletions plantseg/viewer/widget/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Widget input validation"""

from psygnal import Signal
from functools import wraps

def change_handler(*widgets, init=True, debug=False):
def decorator_change_handler(handler):
@wraps(handler)
def wrapper(*args):
source = Signal.sender()
emitter = Signal.current_emitter()
if debug:
# print(f"{emitter}: {source} = {args!r}")
print(f"EVENT '{str(emitter.name)}': {source.name:>20} = {args!r}")
# print(f" {source.name:>14}.value = {source.value}")
return handler(*args)

for widget in widgets:
widget.changed.connect(wrapper)
if init:
widget.changed(widget.value)
return wrapper

return decorator_change_handler


def get_image_volume_from_layer(image):
"""Used for widget parameter validation in `change_handler`s."""
image = image.data[0] if image.multiscale else image.data
if not all(hasattr(image, attr) for attr in ("shape", "ndim", "__getitem__")):
image = np.asanyarray(image)
return image


def widgets_inactive(*widgets, active):
"""Toggle visibility of widgets."""
for widget in widgets:
widget.visible = active


def widgets_valid(*widgets, valid):
"""Toggle background warning color of widgets."""
for widget in widgets:
widget.native.setStyleSheet("" if valid else "background-color: lightcoral")


0 comments on commit 2072f01

Please sign in to comment.