Skip to content

Commit

Permalink
Start batch size search at 1/4 of GPU cores
Browse files Browse the repository at this point in the history
  • Loading branch information
StuartAbercrombie committed Dec 4, 2023
1 parent 7a745e3 commit 293e4e6
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions dorado/nn/MetalCRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ class MetalCaller {
spdlog::debug("Physical memory available {} GB", physical_memory / (size_t{1} << 30));

// Constrain the maximum batch size to use about half physical memory for decode buffers,
// with neural network GPU buffers and CPU buffers are assumed to occupy a subset of the
// with neural network GPU buffers and CPU buffers assumed to occupy a subset of the
// remaining memory. This generally constrains the batch size to use fewer than
// the maximum GPU cores when running sup models on systems with a large GPU core
// to system memory ratio.
Expand All @@ -645,25 +645,29 @@ class MetalCaller {
static_cast<size_t>(m_states) * sizeof(int16_t) + // Posts
static_cast<size_t>(m_states) * sizeof(float)); // Back guides.
spdlog::debug("decode_buffer_size_per_elem {}", decode_buffer_size_per_elem);
const int max_batch_size = std::min(
const int max_batch_size = std::clamp(
utils::pad_to(physical_memory / (2 * decode_buffer_size_per_elem),
static_cast<size_t>(MTL_CORE_BATCH_SIZE)),
static_cast<size_t>(MTL_CORE_BATCH_SIZE),
static_cast<size_t>(MTL_CORE_BATCH_SIZE * get_mtl_device_core_count()));
spdlog::debug("max_batch_size {}", max_batch_size);

// Always try natural batch sizes for 1 GPU core and maximum we think is viable,
// which absent memory limits will be the full GPU core count.
const int min_batch_size = MTL_CORE_BATCH_SIZE;
std::set<int> test_batch_sizes{min_batch_size, max_batch_size};
// Subject to the above memory constraint, impose a minimum batch size
// that will use 1/4 of GPU cores for LSTM execution.
const int min_batch_size =
std::min(MTL_CORE_BATCH_SIZE * get_mtl_device_core_count() / 4, max_batch_size);
spdlog::debug("min_batch_size {}", min_batch_size);

std::set<int> test_batch_sizes{max_batch_size};

// Add some batch sizes evenly distributed in between.
const int kNumIntermediateSizes = 16;
const int kNumSmallerSizes = 16;
const float test_size_increment = static_cast<float>(max_batch_size - min_batch_size) /
static_cast<float>(kNumIntermediateSizes + 1);
for (int i = 0; i < kNumIntermediateSizes; ++i) {
const int test_batch_size = utils::pad_to(
static_cast<int>(static_cast<float>(i + 1) * test_size_increment),
MTL_CORE_BATCH_SIZE);
static_cast<float>(kNumSmallerSizes);
for (int i = 0; i < kNumSmallerSizes; ++i) {
const int test_batch_size =
utils::pad_to(min_batch_size + static_cast<size_t>(i * test_size_increment),
static_cast<size_t>(MTL_CORE_BATCH_SIZE));
test_batch_sizes.insert(test_batch_size);
}

Expand Down

0 comments on commit 293e4e6

Please sign in to comment.