Skip to content

Commit

Permalink
[cpp] Start metal compute sim3d engine
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiashienzsch committed Sep 27, 2024
1 parent 3f88bb8 commit 3aabd93
Show file tree
Hide file tree
Showing 8 changed files with 481 additions and 214 deletions.
1 change: 1 addition & 0 deletions src/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ endif()

if(PFFDTD_HAS_METAL)
target_compile_definitions(pffdtd PUBLIC PFFDTD_HAS_METAL=1)
target_sources(pffdtd PRIVATE pffdtd/engine_metal_2d.hpp pffdtd/engine_metal_2d.mm)
target_sources(pffdtd PRIVATE pffdtd/engine_metal_3d.hpp pffdtd/engine_metal_3d.mm)

set(SHADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/pffdtd/engine_metal.metal")
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#endif

#if defined(PFFDTD_HAS_METAL)
#include "pffdtd/engine_metal_2d.hpp"
#include "pffdtd/engine_metal_3d.hpp"
#endif

Expand All @@ -37,6 +38,9 @@ namespace {
using Callback = std::function<stdex::mdarray<double, stdex::dextents<size_t, 2>>(Simulation2D const&)>;
auto engines = std::map<std::string, Callback>{};
engines["native"] = pffdtd::EngineCPU2D{};
#if defined(PFFDTD_HAS_METAL)
engines["metal"] = pffdtd::EngineMETAL2D{};
#endif
#if defined(PFFDTD_HAS_SYCL)
engines["sycl"] = pffdtd::EngineSYCL2D{};
#endif
Expand Down
14 changes: 8 additions & 6 deletions src/cpp/pffdtd/engine_metal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@ namespace pffdtd {

template<typename Real>
struct Constants2D {
long Nx;
long Ny;
long Nt;
long in_ixy;
Real lossFactor;
};

template<typename Real>
struct Constants3D {
long n;
Real l;
Real lo2;
Real sl2;
Real a1;
Real a2;
long Nx;
long Ny;
long Nz;
Expand All @@ -28,11 +35,6 @@ struct Constants3D {
long Ns;
long Nr;
long Nt;
Real l;
Real lo2;
Real sl2;
Real a1;
Real a2;
};

template<typename T>
Expand Down
97 changes: 54 additions & 43 deletions src/cpp/pffdtd/engine_metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace sim2d {
device float const* u1 [[buffer(1)]],
device float const* u2 [[buffer(2)]],
device uint8_t const* in_mask [[buffer(3)]],
constant Constants2D<float> const& constants [[buffer(4)]],
constant Constants2D<float>& constants [[buffer(4)]],
uint2 id [[thread_position_in_grid]]
) {
int64_t const x = id.x + 1;
Expand All @@ -39,7 +39,7 @@ namespace sim2d {
device float const* u2 [[buffer(2)]],
device int64_t const* bn_ixy [[buffer(3)]],
device int64_t const* adj_bn [[buffer(4)]],
constant Constants2D<float> const& constants [[buffer(5)]],
constant Constants2D<float>& constants [[buffer(5)]],
uint id [[thread_position_in_grid]]
) {
int64_t const ib = bn_ixy[id];
Expand All @@ -62,7 +62,7 @@ namespace sim2d {
device float const* u2 [[buffer(1)]],
device int64_t const* bn_ixy [[buffer(2)]],
device int64_t const* adj_bn [[buffer(3)]],
constant Constants2D<float> const& constants [[buffer(4)]],
constant Constants2D<float>& constants [[buffer(4)]],
uint id [[thread_position_in_grid]]
) {
int64_t const ib = bn_ixy[id];
Expand All @@ -76,6 +76,27 @@ namespace sim2d {
u0[ib] = (current + lossFactor * K4 * prev) / (1.0 + lossFactor * K4);
}

[[kernel]] void addSource(
device float* u0 [[buffer(0)]],
device float const* src_sig [[buffer(1)]],
constant Constants2D<float>& constants [[buffer(2)]],
constant int64_t& timestep [[buffer(3)]],
uint id [[thread_position_in_grid]]
) {
u0[constants.in_ixy] += src_sig[timestep];
}

[[kernel]] void readOutput(
device float* out [[buffer(0)]],
device float const* u0 [[buffer(1)]],
device int64_t const* out_ixy [[buffer(2)]],
constant Constants2D<float>& constants [[buffer(3)]],
constant int64_t& timestep [[buffer(4)]],
uint id [[thread_position_in_grid]]
) {
out[id * constants.Nt + timestep] = u0[out_ixy[id]];
}

} // namespace sim2d

namespace sim3d {
Expand All @@ -84,46 +105,45 @@ namespace sim3d {
device float* u0 [[buffer(0)]],
device float const* u1 [[buffer(1)]],
device uint8_t const* bn_mask [[buffer(2)]],
constant Constants3D<float> const& constants [[buffer(3)]],
constant Constants3D<float>& constants [[buffer(3)]],
uint3 id [[thread_position_in_grid]]
) {
float const a1 = constants.a1;
float const a2 = constants.a2;

int64_t const Nz = constants.Nz;
int64_t const NzNy = constants.NzNy;
float const a1 = constants.a1;
float const a2 = constants.a2;

int64_t const ix = id.x + 1;
int64_t const iy = id.y + 1;
int64_t const iz = id.z + 1;
int64_t const ii = ix * NzNy + iy * Nz + iz;
int64_t const x = id.x + 1;
int64_t const y = id.y + 1;
int64_t const z = id.z + 1;
int64_t const i = x * NzNy + y * Nz + z;

if (get_bit<int>(bn_mask[ii >> 3], ii % 8) != 0) {
if (get_bit<int>(bn_mask[i / 8], i % 8) != 0) {
return;
}

float partial = a1 * u1[ii] - u0[ii];
partial += a2 * u1[ii + NzNy];
partial += a2 * u1[ii - NzNy];
partial += a2 * u1[ii + Nz];
partial += a2 * u1[ii - Nz];
partial += a2 * u1[ii + 1];
partial += a2 * u1[ii - 1];
u0[ii] = partial;

// u0[ii] = a1 + a2;
float partial = a1 * u1[i] - u0[i];
partial += a2 * u1[i + NzNy];
partial += a2 * u1[i - NzNy];
partial += a2 * u1[i + Nz];
partial += a2 * u1[i - Nz];
partial += a2 * u1[i + 1];
partial += a2 * u1[i - 1];
u0[i] = partial;
}

[[kernel]] void rigidUpdateCart(
device float* u0 [[buffer(0)]],
device float const* u1 [[buffer(1)]],
device int64_t const* bn_ixyz [[buffer(2)]],
device uint16_t const* adj_bn [[buffer(3)]],
constant Constants3D<float> const& constants [[buffer(4)]],
constant Constants3D<float>& constants [[buffer(4)]],
uint id [[thread_position_in_grid]]
) {
int64_t const ii = bn_ixyz[id];
auto const adj = adj_bn[id];
auto const Kint = metal::popcount(adj);
int64_t const ii = bn_ixyz[id];

float const _2 = 2.0;
float const K = Kint;
Expand All @@ -150,7 +170,7 @@ namespace sim3d {
device float const* mat_beta [[buffer(6)]],
device MatQuad<float> const* mat_quads [[buffer(7)]],
device uint8_t const* Mb [[buffer(8)]],
constant Constants3D<float> const& constants [[buffer(9)]],
constant Constants3D<float>& constants [[buffer(9)]],
uint id [[thread_position_in_grid]]
) {
auto nb = static_cast<int64_t>(id);
Expand Down Expand Up @@ -194,7 +214,7 @@ namespace sim3d {
device float const* u2ba [[buffer(1)]],
device float const* Q_bna [[buffer(2)]],
device int64_t const* bna_ixyz [[buffer(3)]],
constant Constants3D<float> const& constants [[buffer(4)]],
constant Constants3D<float>& constants [[buffer(4)]],
uint id [[thread_position_in_grid]]
) {
auto const lQ = constants.l * Q_bna[id];
Expand All @@ -204,7 +224,7 @@ namespace sim3d {

[[kernel]] void flipHaloXY(
device float* u1 [[buffer(0)]],
constant Constants3D<float> const& constants [[buffer(1)]],
constant Constants3D<float>& constants [[buffer(1)]],
uint2 id [[thread_position_in_grid]]
) {
auto const Nz = constants.Nz;
Expand All @@ -220,7 +240,7 @@ namespace sim3d {

[[kernel]] void flipHaloXZ(
device float* u1 [[buffer(0)]],
constant Constants3D<float> const& constants [[buffer(1)]],
constant Constants3D<float>& constants [[buffer(1)]],
uint2 id [[thread_position_in_grid]]
) {
auto const Ny = constants.Ny;
Expand All @@ -236,7 +256,7 @@ namespace sim3d {

[[kernel]] void flipHaloYZ(
device float* u1 [[buffer(0)]],
constant Constants3D<float> const& constants [[buffer(1)]],
constant Constants3D<float>& constants [[buffer(1)]],
uint2 id [[thread_position_in_grid]]
) {
auto const Nx = constants.Nx;
Expand Down Expand Up @@ -272,33 +292,24 @@ namespace sim3d {
device float* u0 [[buffer(0)]],
device float const* in_sigs [[buffer(1)]],
device int64_t const* in_ixyz [[buffer(2)]],
constant Constants3D<float> const& constants [[buffer(3)]],
constant Constants3D<float>& constants [[buffer(3)]],
constant int64_t& timestep [[buffer(4)]],
uint id [[thread_position_in_grid]]
) {
auto const ii = in_ixyz[id];
u0[ii] += in_sigs[id * constants.Nt + constants.n];
u0[in_ixyz[id]] += in_sigs[id * constants.Nt + timestep];
}

[[kernel]] void readOutput(
device float* u_out [[buffer(0)]],
device float const* u1 [[buffer(1)]],
device int64_t const* out_ixyz [[buffer(2)]],
constant Constants3D<float> const& constants [[buffer(3)]],
constant Constants3D<float>& constants [[buffer(3)]],
constant int64_t& timestep [[buffer(4)]],
uint id [[thread_position_in_grid]]
) {
auto const ii = out_ixyz[id];
u_out[id * constants.Nt + constants.n] = u1[ii];
u_out[id * constants.Nt + timestep] = u1[out_ixyz[id]];
}

} // namespace sim3d

[[kernel]] void
foo(device float const* in [[buffer(0)]], device float* out [[buffer(1)]], uint id [[thread_position_in_grid]]) {
out[id] = in[id] + 42.0;
}

[[kernel]] void
bar(device float const* in [[buffer(0)]], device float* out [[buffer(1)]], uint id [[thread_position_in_grid]]) {
out[id] = in[id] - 1.0;
}
} // namespace pffdtd
19 changes: 19 additions & 0 deletions src/cpp/pffdtd/engine_metal_2d.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2024 Tobias Hienzsch

#pragma once

#if not defined(PFFDTD_HAS_METAL)
#error "METAL must be enabled in CMake via PFFDTD_ENABLE_METAL"
#endif

#include "pffdtd/mdspan.hpp"
#include "pffdtd/simulation_2d.hpp"

namespace pffdtd {

struct EngineMETAL2D {
[[nodiscard]] auto operator()(Simulation2D const& sim) const -> stdex::mdarray<double, stdex::dextents<size_t, 2>>;
};

} // namespace pffdtd
Loading

0 comments on commit 3aabd93

Please sign in to comment.