From 01f53e59e963c6236b1293b20f496cf1ef4f3eb3 Mon Sep 17 00:00:00 2001 From: Adam Tyson Date: Mon, 3 Jun 2024 10:50:44 +0000 Subject: [PATCH] Update default inference batch size (#432) * Add timings to classification * standardise batch size (except for training) * Expose batch size in the GUI * reorder GUI * Update default batch size to 64 for inference --- cellfinder/core/classify/classify.py | 8 ++++++++ cellfinder/core/classify/cube_generator.py | 4 ++-- cellfinder/core/main.py | 2 +- cellfinder/napari/curation.py | 2 +- cellfinder/napari/detect/detect.py | 10 ++++++++-- cellfinder/napari/detect/detect_containers.py | 2 ++ cellfinder/napari/detect/thread_worker.py | 4 ++-- 7 files changed, 24 insertions(+), 8 deletions(-) diff --git a/cellfinder/core/classify/classify.py b/cellfinder/core/classify/classify.py index 1da4dabe..ec77190f 100644 --- a/cellfinder/core/classify/classify.py +++ b/cellfinder/core/classify/classify.py @@ -1,4 +1,5 @@ import os +from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Tuple import keras @@ -50,6 +51,8 @@ def main( # Too many workers doesn't increase speed, and uses huge amounts of RAM workers = get_num_processes(min_free_cpu_cores=n_free_cpus) + start_time = datetime.now() + logger.debug("Initialising cube generator") inference_generator = CubeGeneratorFromFile( points, @@ -90,6 +93,11 @@ def main( cell.type = predictions[idx] + 1 points_list.append(cell) + time_elapsed = datetime.now() - start_time + print( + "Classfication complete - all points done in : {}".format(time_elapsed) + ) + return points_list diff --git a/cellfinder/core/classify/cube_generator.py b/cellfinder/core/classify/cube_generator.py index 4a24467f..601a5a0b 100644 --- a/cellfinder/core/classify/cube_generator.py +++ b/cellfinder/core/classify/cube_generator.py @@ -40,7 +40,7 @@ def __init__( background_array: types.array, voxel_sizes: Tuple[int, int, int], network_voxel_sizes: Tuple[int, int, int], - batch_size: int = 16, + batch_size: int = 64, cube_width: int = 50, cube_height: int = 50, cube_depth: int = 20, @@ -345,7 +345,7 @@ def __init__( signal_list: List[Union[str, Path]], background_list: List[Union[str, Path]], labels: Optional[List[int]] = None, # only if training or validating - batch_size: int = 16, + batch_size: int = 64, shape: Tuple[int, int, int] = (50, 50, 20), channels: int = 2, classes: int = 2, diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index 23526a94..5aad49f7 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -18,7 +18,7 @@ def main( trained_model: Optional[os.PathLike] = None, model_weights: Optional[os.PathLike] = None, model: model_type = "resnet50_tv", - batch_size: int = 32, + batch_size: int = 64, n_free_cpus: int = 2, network_voxel_sizes: Tuple[int, int, int] = (5, 1, 1), soma_diameter: int = 16, diff --git a/cellfinder/napari/curation.py b/cellfinder/napari/curation.py index e3ce7e7b..a8dcf60b 100644 --- a/cellfinder/napari/curation.py +++ b/cellfinder/napari/curation.py @@ -54,7 +54,7 @@ def __init__( self.save_empty_cubes = save_empty_cubes self.max_ram = max_ram self.voxel_sizes = [5, 2, 2] - self.batch_size = 32 + self.batch_size = 64 self.viewer = viewer self.signal_layer = None diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index cdf36939..5f2de700 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -253,8 +253,9 @@ def widget( max_cluster_size: int, classification_options, skip_classification: bool, - trained_model: Optional[Path], use_pre_trained_weights: bool, + trained_model: Optional[Path], + batch_size: int, misc_options, start_plane: int, end_plane: int, @@ -298,6 +299,8 @@ def widget( should be attempted use_pre_trained_weights : bool Select to use pre-trained model weights + batch_size : int + How many points to classify at one time skip_classification : bool If selected, the classification step is skipped and all cells from the detection stage are added @@ -372,7 +375,10 @@ def widget( if use_pre_trained_weights: trained_model = None classification_inputs = ClassificationInputs( - skip_classification, use_pre_trained_weights, trained_model + skip_classification, + use_pre_trained_weights, + trained_model, + batch_size, ) if analyse_local: diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index 39fda163..953e6248 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -114,6 +114,7 @@ class ClassificationInputs(InputContainer): skip_classification: bool = False use_pre_trained_weights: bool = True trained_model: Optional[Path] = Path.home() + batch_size: int = 64 def as_core_arguments(self) -> dict: args = super().as_core_arguments() @@ -131,6 +132,7 @@ def widget_representation(cls) -> dict: skip_classification=dict( value=cls.defaults()["skip_classification"] ), + batch_size=dict(value=cls.defaults()["batch_size"]), ) diff --git a/cellfinder/napari/detect/thread_worker.py b/cellfinder/napari/detect/thread_worker.py index c4392860..5e01d434 100644 --- a/cellfinder/napari/detect/thread_worker.py +++ b/cellfinder/napari/detect/thread_worker.py @@ -72,10 +72,10 @@ def detect_finished_callback(points: list) -> None: def classify_callback(batch: int) -> None: self.update_progress_bar.emit( "Classifying cells", - # Default cellfinder-core batch size is 32. This seems to give + # Default cellfinder-core batch size is 64. This seems to give # a slight underestimate of the number of batches though, so # allow for batch number to go over this - max(self.npoints_detected // 32 + 1, batch + 1), + max(self.npoints_detected // 64 + 1, batch + 1), batch + 1, )