Skip to content

Commit

Permalink
Update default inference batch size (#432)
Browse files Browse the repository at this point in the history
* Add timings to classification

* standardise batch size (except for training)

* Expose batch size in the GUI

* reorder GUI

* Update default batch size to 64 for inference
  • Loading branch information
adamltyson authored Jun 3, 2024
1 parent 70ed0cb commit 01f53e5
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 8 deletions.
8 changes: 8 additions & 0 deletions cellfinder/core/classify/classify.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple

import keras
Expand Down Expand Up @@ -50,6 +51,8 @@ def main(
# Too many workers doesn't increase speed, and uses huge amounts of RAM
workers = get_num_processes(min_free_cpu_cores=n_free_cpus)

start_time = datetime.now()

logger.debug("Initialising cube generator")
inference_generator = CubeGeneratorFromFile(
points,
Expand Down Expand Up @@ -90,6 +93,11 @@ def main(
cell.type = predictions[idx] + 1
points_list.append(cell)

time_elapsed = datetime.now() - start_time
print(
"Classfication complete - all points done in : {}".format(time_elapsed)
)

return points_list


Expand Down
4 changes: 2 additions & 2 deletions cellfinder/core/classify/cube_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
background_array: types.array,
voxel_sizes: Tuple[int, int, int],
network_voxel_sizes: Tuple[int, int, int],
batch_size: int = 16,
batch_size: int = 64,
cube_width: int = 50,
cube_height: int = 50,
cube_depth: int = 20,
Expand Down Expand Up @@ -345,7 +345,7 @@ def __init__(
signal_list: List[Union[str, Path]],
background_list: List[Union[str, Path]],
labels: Optional[List[int]] = None, # only if training or validating
batch_size: int = 16,
batch_size: int = 64,
shape: Tuple[int, int, int] = (50, 50, 20),
channels: int = 2,
classes: int = 2,
Expand Down
2 changes: 1 addition & 1 deletion cellfinder/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main(
trained_model: Optional[os.PathLike] = None,
model_weights: Optional[os.PathLike] = None,
model: model_type = "resnet50_tv",
batch_size: int = 32,
batch_size: int = 64,
n_free_cpus: int = 2,
network_voxel_sizes: Tuple[int, int, int] = (5, 1, 1),
soma_diameter: int = 16,
Expand Down
2 changes: 1 addition & 1 deletion cellfinder/napari/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self.save_empty_cubes = save_empty_cubes
self.max_ram = max_ram
self.voxel_sizes = [5, 2, 2]
self.batch_size = 32
self.batch_size = 64
self.viewer = viewer

self.signal_layer = None
Expand Down
10 changes: 8 additions & 2 deletions cellfinder/napari/detect/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,9 @@ def widget(
max_cluster_size: int,
classification_options,
skip_classification: bool,
trained_model: Optional[Path],
use_pre_trained_weights: bool,
trained_model: Optional[Path],
batch_size: int,
misc_options,
start_plane: int,
end_plane: int,
Expand Down Expand Up @@ -298,6 +299,8 @@ def widget(
should be attempted
use_pre_trained_weights : bool
Select to use pre-trained model weights
batch_size : int
How many points to classify at one time
skip_classification : bool
If selected, the classification step is skipped and all cells from
the detection stage are added
Expand Down Expand Up @@ -372,7 +375,10 @@ def widget(
if use_pre_trained_weights:
trained_model = None
classification_inputs = ClassificationInputs(
skip_classification, use_pre_trained_weights, trained_model
skip_classification,
use_pre_trained_weights,
trained_model,
batch_size,
)

if analyse_local:
Expand Down
2 changes: 2 additions & 0 deletions cellfinder/napari/detect/detect_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class ClassificationInputs(InputContainer):
skip_classification: bool = False
use_pre_trained_weights: bool = True
trained_model: Optional[Path] = Path.home()
batch_size: int = 64

def as_core_arguments(self) -> dict:
args = super().as_core_arguments()
Expand All @@ -131,6 +132,7 @@ def widget_representation(cls) -> dict:
skip_classification=dict(
value=cls.defaults()["skip_classification"]
),
batch_size=dict(value=cls.defaults()["batch_size"]),
)


Expand Down
4 changes: 2 additions & 2 deletions cellfinder/napari/detect/thread_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def detect_finished_callback(points: list) -> None:
def classify_callback(batch: int) -> None:
self.update_progress_bar.emit(
"Classifying cells",
# Default cellfinder-core batch size is 32. This seems to give
# Default cellfinder-core batch size is 64. This seems to give
# a slight underestimate of the number of batches though, so
# allow for batch number to go over this
max(self.npoints_detected // 32 + 1, batch + 1),
max(self.npoints_detected // 64 + 1, batch + 1),
batch + 1,
)

Expand Down

0 comments on commit 01f53e5

Please sign in to comment.