Skip to content

Commit

Permalink
[xla:cpu:xnn] Use persistent workers to execute pthreadpool parallel …
Browse files Browse the repository at this point in the history
…loops

PiperOrigin-RevId: 716750569
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 17, 2025
1 parent 4864a98 commit 1161eb1
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 112 deletions.
3 changes: 2 additions & 1 deletion xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ cc_library(
"//xla/tsl/lib/math:math_util",
"//xla/tsl/platform:logging",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array",
"@eigen_archive//:eigen3",
],
)
Expand All @@ -61,11 +62,11 @@ xla_cc_test(
"//xla/tsl/platform:env",
"//xla/tsl/platform:test",
"//xla/tsl/platform:test_benchmark",
"//xla/tsl/platform:test_main",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@eigen_archive//:eigen3",
],
)
Expand Down
237 changes: 147 additions & 90 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ limitations under the License.
#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h"

#include <algorithm>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <optional>
#include <utility>

#include "absl/base/optimization.h"
Expand Down Expand Up @@ -66,73 +68,158 @@ tsl::AsyncValueRef<tsl::Chain> ParallelLoopRunner::TakeDoneEvent(
return std::move(runner.done_event_);
}

ParallelLoopRunner::ParallelTaskConfig
ParallelLoopRunner::ComputeParallelTaskConfig(size_t num_tasks) const {
// We limit the number of parallel tasks per thread to avoid excessive task
// scheduling overheads at run time.
static constexpr size_t kMaxTasksPerThread = 4;
void ParallelLoopRunner::WorkQueue::Partition::Initialize(size_t begin,
size_t end) {
index.store(begin, std::memory_order_relaxed);
this->begin = begin;
this->end = end;
}

ParallelLoopRunner::WorkQueue::WorkQueue(size_t num_tasks,
size_t num_partitions)
: partitions_(num_partitions), empty_(num_tasks == 0) {
size_t partition_size = tsl::MathUtil::CeilOfRatio(num_tasks, num_partitions);
for (size_t i = 0, begin = 0, end = partition_size; i < num_partitions;
++i, begin = end, end += partition_size) {
partitions_[i].Initialize(std::min(num_tasks, begin),
std::min(num_tasks, end));
}
}

std::optional<size_t> ParallelLoopRunner::WorkQueue::Pop(
size_t partition_index) {
DCHECK(partition_index < partitions_.size()) << "Invalid partition index";
Partition& partition = partitions_[partition_index];

size_t index = partition.index.fetch_add(1, std::memory_order_relaxed);
return index >= partition.end ? std::nullopt : std::make_optional(index);
}

ParallelLoopRunner::Worker::Worker(size_t worker_index, WorkQueue* queue)
: worker_index_(worker_index),
partition_index_(worker_index),
queue_(queue) {}

std::optional<size_t> ParallelLoopRunner::Worker::Pop() {
std::optional<size_t> task = queue_->Pop(partition_index_);
if (task) return task;

while (!task.has_value()) {
// Wrap around to the first partition.
if (ABSL_PREDICT_FALSE(++partition_index_ >= queue_->num_partitions())) {
partition_index_ = 0;
}

size_t parallel_task_size =
tsl::MathUtil::CeilOfRatio(num_tasks, kMaxTasksPerThread * num_threads());
size_t num_parallel_tasks =
tsl::MathUtil::CeilOfRatio(num_tasks, parallel_task_size);
// We checked all partitions and got back to the partition we started from.
if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
queue_->empty_.store(true, std::memory_order_relaxed);
break;
}

return {num_tasks, parallel_task_size, num_parallel_tasks};
task = queue_->Pop(partition_index_);
}

return task;
}

template <typename Index, typename ParallelizeContext>
static void Parallelize(ParallelizeContext* ctx, Index start_index,
Index end_index) {
CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK
CHECK_LT(start_index, end_index) << "Invalid worker index range";

auto count_down = [&](size_t count) {
// If count down is completed, delete the context.
if (ctx->count_down.CountDown(count)) delete ctx;
};

// Recursively split the task into two halves and schedule the right half into
// the thread pool.
// Recursively split assigned workers into two halves and schedule the
// right half into the thread pool.
while (end_index - start_index > 1) {
Index mid_index = (start_index + end_index) / 2;
ctx->device->enqueueNoNotification([ctx, mid_index, end_index] {
Parallelize(ctx, mid_index, end_index);
// If work queue is empty, we don't need to keep enqueuing more workers and
// can simply count down for the remaining workers.
if (ABSL_PREDICT_FALSE(ctx->work_queue.empty())) {
count_down(end_index - start_index);
return;
}

Index mid_partition = (start_index + end_index) / 2;
ctx->device->enqueueNoNotification([ctx, mid_partition, end_index] {
Parallelize(ctx, mid_partition, end_index);
});
end_index = mid_index;
end_index = mid_partition;
}

// Execute the `start_index` task in the caller thread.
ctx->parallel_task(start_index);
// Execute the `start_index` worker in the caller thread.
size_t worker_partition = ctx->coprime * start_index % ctx->num_workers;
ParallelLoopRunner::Worker worker(worker_partition, &ctx->work_queue);
while (std::optional<size_t> task = worker.Pop()) {
ctx->parallel_task(*task);
}

// If count down is completed, delete the context.
if (ctx->count_down.CountDown()) {
delete ctx;
// Count down for the one executed worker.
count_down(1);
}

// We use "random" walk to assign work queue partitions to workers, to minimize
// the conflicts between workers that run out of tasks in the assigned partition
// and start stealing tasks from the next partition.
//
// https://lemire.me/blog/2017/09/18/visiting-all-values-in-an-array-exactly-once-in-random-order/
static size_t CoPrime(size_t num_workers) {
for (size_t i = num_workers; i > 1; --i) {
if (tsl::MathUtil::GCD(i, num_workers) == 1) return i;
}
return 1;
}

template <typename ParallelTask>
void ParallelLoopRunner::Parallelize(
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t start_index,
size_t end_index, ParallelTask&& parallel_task) {
CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t num_workers,
size_t num_tasks, ParallelTask&& parallel_task) {
DCHECK_EQ(count_down.count(), num_workers)
<< "Number of workers must match the count down counter";

// Short-circuit single-threaded execution.
if (ABSL_PREDICT_FALSE(num_workers == 1)) {
for (size_t i = 0; i < num_tasks; ++i) {
parallel_task(i);
}
count_down.CountDown();
return;
}

struct ParallelizeContext {
ParallelizeContext(tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
const Eigen::ThreadPoolDevice* device,
ParallelizeContext(const Eigen::ThreadPoolDevice* device,
tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
size_t num_workers, size_t num_tasks,
ParallelTask&& parallel_task)
: count_down(std::move(count_down)),
device(device),
: device(device),
num_workers(num_workers),
coprime(CoPrime(num_workers)),
work_queue(num_tasks, /*num_partitions=*/num_workers),
count_down(std::move(count_down)),
parallel_task(std::forward<ParallelTask>(parallel_task)) {}

tsl::CountDownAsyncValueRef<tsl::Chain> count_down;
const Eigen::ThreadPoolDevice* device;

size_t num_workers;
size_t coprime;
WorkQueue work_queue;

tsl::CountDownAsyncValueRef<tsl::Chain> count_down;
ParallelTask parallel_task;
};

auto ctx = std::make_unique<ParallelizeContext>(
std::move(count_down), device_,
device_, std::move(count_down), num_workers, num_tasks,
std::forward<ParallelTask>(parallel_task));

// We try to use uint16_t for index type because it enables small buffer
// optimization in the constructed `std::function` tasks.
if (ABSL_PREDICT_TRUE(end_index <= std::numeric_limits<uint16_t>::max())) {
xla::cpu::Parallelize<uint16_t>(ctx.release(), start_index, end_index);
if (ABSL_PREDICT_TRUE(num_tasks <= std::numeric_limits<uint16_t>::max())) {
xla::cpu::Parallelize<uint16_t>(ctx.release(), 0, num_workers);
} else {
xla::cpu::Parallelize<size_t>(ctx.release(), start_index, end_index);
xla::cpu::Parallelize<size_t>(ctx.release(), 0, num_workers);
}
}

Expand All @@ -149,14 +236,19 @@ void ParallelLoopRunner::ScheduleOne(Task&& task) {
template <typename ParallelTask>
void ParallelLoopRunner::ScheduleAll(size_t num_tasks,
ParallelTask&& parallel_task) {
tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_tasks);
// We use at most `num_threads()` workers as we can't run more parallel
// workers than the number of threads in the thread pool.
size_t num_workers = std::min(num_tasks, num_threads());

tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_workers);
auto count_down_done = count_down.AsRef();

done_event_.AndThen([this, num_tasks, count_down = std::move(count_down),
parallel_task =
std::forward<ParallelTask>(parallel_task)] {
Parallelize(std::move(count_down), 0, num_tasks, std::move(parallel_task));
});
done_event_.AndThen(
[this, num_workers, num_tasks, count_down = std::move(count_down),
parallel_task = std::forward<ParallelTask>(parallel_task)] {
Parallelize(std::move(count_down), num_workers, num_tasks,
std::move(parallel_task));
});
done_event_ = std::move(count_down_done);
}

Expand Down Expand Up @@ -187,13 +279,6 @@ struct Task3DTile2DIndex {

} // namespace

auto ParallelLoopRunner::ParallelTaskConfig::ParallelTaskRange(
size_t parallel_task_index) const -> TaskRange {
size_t begin = parallel_task_index * parallel_task_size;
size_t end = std::min(num_tasks, begin + parallel_task_size);
return {begin, end};
}

static Task1DTile1DIndex Delinearize(size_t task_index, size_t range,
size_t tile) {
size_t offset = task_index * tile;
Expand Down Expand Up @@ -289,16 +374,7 @@ void ParallelLoopRunner::Parallelize(size_t range, Task1D task) {
return;
}

// Schedule `parallel_config.num_parallel_tasks` into the underlying thread
// pool when done event becomes available.
auto parallel_config = ComputeParallelTaskConfig(range);
auto parallel_task = [parallel_config,
task = std::move(task)](size_t parallel_task_index) {
auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index);
for (size_t i = begin; i < end; ++i) task(i);
};

ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task));
ScheduleAll(range, std::move(task));
}

void ParallelLoopRunner::Parallelize(size_t range, size_t tile,
Expand All @@ -323,19 +399,13 @@ void ParallelLoopRunner::Parallelize(size_t range, size_t tile,
return;
}

// Schedule `parallel_config.num_parallel_tasks` into the underlying thread
// pool when done event becomes available.
auto parallel_config = ComputeParallelTaskConfig(num_tasks);
auto parallel_task = [range, tile, parallel_config,
task = std::move(task)](size_t parallel_task_index) {
auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index);
for (size_t i = begin; i < end; ++i) {
auto x = Delinearize(i, range, tile);
task(x.offset, x.extent);
}
auto parallel_task = [range, tile,
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range, tile);
task(x.offset, x.extent);
};

ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task));
ScheduleAll(num_tasks, std::move(parallel_task));
}

void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
Expand All @@ -358,19 +428,13 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
return;
}

// Schedule `parallel_config.num_parallel_tasks` into the underlying thread
// pool when done event becomes available.
auto parallel_config = ComputeParallelTaskConfig(num_tasks);
auto parallel_task = [range_i, range_j, tile_j, parallel_config,
task = std::move(task)](size_t parallel_task_index) {
auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index);
for (size_t i = begin; i < end; ++i) {
auto x = Delinearize(i, range_i, range_j, tile_j);
task(x.i, x.offset_j, x.extent_j);
}
auto parallel_task = [range_i, range_j, tile_j,
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range_i, range_j, tile_j);
task(x.i, x.offset_j, x.extent_j);
};

ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task));
ScheduleAll(num_tasks, std::move(parallel_task));
}

void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
Expand All @@ -397,20 +461,13 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
return;
}

// Schedule `parallel_config.num_parallel_tasks` into the underlying thread
// pool when done event becomes available.
auto parallel_config = ComputeParallelTaskConfig(num_tasks);
auto parallel_task = [range_i, range_j, range_k, tile_j, tile_k,
parallel_config,
task = std::move(task)](size_t parallel_task_index) {
auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index);
for (size_t i = begin; i < end; ++i) {
auto x = Delinearize(i, range_i, range_j, range_k, tile_j, tile_k);
task(x.i, x.offset_j, x.offset_k, x.extent_j, x.extent_k);
}
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range_i, range_j, range_k, tile_j, tile_k);
task(x.i, x.offset_j, x.offset_k, x.extent_j, x.extent_k);
};

ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task));
ScheduleAll(num_tasks, std::move(parallel_task));
}

} // namespace xla::cpu
Loading

0 comments on commit 1161eb1

Please sign in to comment.