diff --git a/dorado/nn/MetalCRFModel.cpp b/dorado/nn/MetalCRFModel.cpp index 9232c46f5..a18913d24 100644 --- a/dorado/nn/MetalCRFModel.cpp +++ b/dorado/nn/MetalCRFModel.cpp @@ -634,7 +634,7 @@ class MetalCaller { spdlog::debug("Physical memory available {} GB", physical_memory / (size_t{1} << 30)); // Constrain the maximum batch size to use about half physical memory for decode buffers, - // with neural network GPU buffers and CPU buffers are assumed to occupy a subset of the + // with neural network GPU buffers and CPU buffers assumed to occupy a subset of the // remaining memory. This generally constrains the batch size to use fewer than // the maximum GPU cores when running sup models on systems with a large GPU core // to system memory ratio. @@ -645,25 +645,29 @@ class MetalCaller { static_cast(m_states) * sizeof(int16_t) + // Posts static_cast(m_states) * sizeof(float)); // Back guides. spdlog::debug("decode_buffer_size_per_elem {}", decode_buffer_size_per_elem); - const int max_batch_size = std::min( + const int max_batch_size = std::clamp( utils::pad_to(physical_memory / (2 * decode_buffer_size_per_elem), static_cast(MTL_CORE_BATCH_SIZE)), + static_cast(MTL_CORE_BATCH_SIZE), static_cast(MTL_CORE_BATCH_SIZE * get_mtl_device_core_count())); spdlog::debug("max_batch_size {}", max_batch_size); - // Always try natural batch sizes for 1 GPU core and maximum we think is viable, - // which absent memory limits will be the full GPU core count. - const int min_batch_size = MTL_CORE_BATCH_SIZE; - std::set test_batch_sizes{min_batch_size, max_batch_size}; + // Subject to the above memory constraint, impose a minimum batch size + // that will use 1/4 of GPU cores for LSTM execution. + const int min_batch_size = + std::min(MTL_CORE_BATCH_SIZE * get_mtl_device_core_count() / 4, max_batch_size); + spdlog::debug("min_batch_size {}", min_batch_size); + + std::set test_batch_sizes{max_batch_size}; // Add some batch sizes evenly distributed in between. - const int kNumIntermediateSizes = 16; + const int kNumSmallerSizes = 16; const float test_size_increment = static_cast(max_batch_size - min_batch_size) / - static_cast(kNumIntermediateSizes + 1); - for (int i = 0; i < kNumIntermediateSizes; ++i) { - const int test_batch_size = utils::pad_to( - static_cast(static_cast(i + 1) * test_size_increment), - MTL_CORE_BATCH_SIZE); + static_cast(kNumSmallerSizes); + for (int i = 0; i < kNumSmallerSizes; ++i) { + const int test_batch_size = + utils::pad_to(min_batch_size + static_cast(i * test_size_increment), + static_cast(MTL_CORE_BATCH_SIZE)); test_batch_sizes.insert(test_batch_size); }