Skip to content

Commit

Permalink
printer good
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Jul 26, 2023
1 parent 1b9b756 commit 1508a20
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 41 deletions.
30 changes: 15 additions & 15 deletions include/luisa/dsl/fmt_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
#include <luisa/dsl/struct.h>


LUISA_DERIVE_FMT(luisa::compute::float2, x, y);
LUISA_DERIVE_FMT(luisa::compute::float3, x, y, z);
LUISA_DERIVE_FMT(luisa::compute::float4, x, y, z, w);
LUISA_DERIVE_FMT(luisa::compute::float2, float2,x, y);
LUISA_DERIVE_FMT(luisa::compute::float3, float3,x, y, z);
LUISA_DERIVE_FMT(luisa::compute::float4, float4,x, y, z, w);

LUISA_DERIVE_FMT(luisa::compute::int2, x, y);
LUISA_DERIVE_FMT(luisa::compute::int3, x, y, z);
LUISA_DERIVE_FMT(luisa::compute::int4, x, y, z, w);
LUISA_DERIVE_FMT(luisa::compute::int2, int2,x, y);
LUISA_DERIVE_FMT(luisa::compute::int3, int3,x, y, z);
LUISA_DERIVE_FMT(luisa::compute::int4, int4,x, y, z, w);

LUISA_DERIVE_FMT(luisa::compute::uint2, x, y);
LUISA_DERIVE_FMT(luisa::compute::uint3, x, y, z);
LUISA_DERIVE_FMT(luisa::compute::uint4, x, y, z, w);
LUISA_DERIVE_FMT(luisa::compute::uint2, uint2, x, y);
LUISA_DERIVE_FMT(luisa::compute::uint3, uint3, x, y, z);
LUISA_DERIVE_FMT(luisa::compute::uint4, uint4, x, y, z, w);

LUISA_DERIVE_FMT(luisa::compute::bool2, x, y);
LUISA_DERIVE_FMT(luisa::compute::bool3, x, y, z);
LUISA_DERIVE_FMT(luisa::compute::bool4, x, y, z, w);
LUISA_DERIVE_FMT(luisa::compute::bool2, bool2, x, y);
LUISA_DERIVE_FMT(luisa::compute::bool3, bool3, x, y, z);
LUISA_DERIVE_FMT(luisa::compute::bool4, bool4, x, y, z, w);


LUISA_DERIVE_FMT(luisa::compute::float2x2, cols[0], cols[1]);
LUISA_DERIVE_FMT(luisa::compute::float3x3, cols[0], cols[1], cols[2]);
LUISA_DERIVE_FMT(luisa::compute::float4x4, cols[0], cols[1], cols[2], cols[3]);
LUISA_DERIVE_FMT(luisa::compute::float2x2, float2x2, cols[0], cols[1]);
LUISA_DERIVE_FMT(luisa::compute::float3x3, float3x3, cols[0], cols[1], cols[2]);
LUISA_DERIVE_FMT(luisa::compute::float4x4, float4x4, cols[0], cols[1], cols[2], cols[3]);
25 changes: 2 additions & 23 deletions include/luisa/dsl/printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ class LC_DSL_API Printer {
_buffer->write(offset + index + i, data[i]);
}
index += N;
// if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, int> || std::is_same_v<T, uint>) {
// _buffer->write(offset + index, cast<uint>(curr));
// } else if constexpr (std::is_same_v<T, float>) {
// _buffer->write(offset + index, as<uint>(curr));
// } else {

// // static_assert(always_false_v<T>, "unsupported type for printing in kernel.");
// }
}
_log_to_buffer(offset, index, other...);
}
Expand Down Expand Up @@ -133,14 +125,7 @@ class LC_DSL_API Printer {
template<typename... Args>
void Printer::_log(luisa::log_level level, luisa::string fmt, const Args &...args) noexcept {
std::array<uint, sizeof...(Args)> count_per_arg{};
// for (constexpr int i = 0; i < sizeof...(Args); i++) {
// if constexpr (is_dsl_v<std::tuple_element_t<i, std::tuple<Args...>>>) {
// using T = expr_value_t<std::tuple_element_t<i, std::tuple<Args...>>>;
// count_per_arg[i] = (sizeof(T) + sizeof(uint) - 1) / sizeof(uint);
// } else {
// count_per_arg[i] = 0;
// }
// }

auto do_count = [&]<size_t... i>(std::index_sequence<i...>) noexcept {
auto impl = [&]<size_t j>() noexcept {
if constexpr (is_dsl_v<std::tuple_element_t<j, std::tuple<Args...>>>) {
Expand All @@ -158,7 +143,7 @@ void Printer::_log(luisa::log_level level, luisa::string fmt, const Args &...arg
if constexpr (sizeof...(Args) > 0) {
count_by_arg[0] = 0;
for (int i = 1; i < sizeof...(Args); i++) {
count_by_arg[i] = count_by_arg[i - 1] + count_per_arg[i];
count_by_arg[i] = count_by_arg[i - 1] + count_per_arg[i - 1];
}
}

Expand Down Expand Up @@ -193,12 +178,6 @@ void Printer::_log(luisa::log_level level, luisa::string fmt, const Args &...arg
T raw{};
std::memcpy(&raw, arg_data.data(), sizeof(T));
return raw;
// if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, int> || std::is_same_v<T, uint>) {
// return static_cast<T>(raw);
// } else {
// return luisa::bit_cast<T>(raw);
// }
// return luisa::bit_cast<T>(raw);
} else {
return std::get<i>(args);
}
Expand Down
6 changes: 3 additions & 3 deletions include/luisa/dsl/struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ using c_array_to_std_array_t = typename c_array_to_std_array<T>::type;

#define LUISA_DERIVE_FMT_STRUCT_FIELD_FMT(x) #x "={} "
#define LUISA_DERIVE_FMT_MAP_STRUCT_FIELD(x) input.x
#define LUISA_DERIVE_FMT(Struct, ...) \
#define LUISA_DERIVE_FMT(Struct, DisplayName, ...) \
template<> \
struct fmt::formatter<Struct> { \
constexpr auto parse(format_parse_context &ctx) -> decltype(ctx.begin()) { \
Expand All @@ -62,13 +62,13 @@ using c_array_to_std_array_t = typename c_array_to_std_array<T>::type;
template<typename FormatContext> \
auto format(const Struct &input, FormatContext &ctx) -> decltype(ctx.out()) { \
return fmt::format_to(ctx.out(), \
#Struct " {{ " LUISA_MAP(LUISA_DERIVE_FMT_STRUCT_FIELD_FMT, __VA_ARGS__) "}}", \
#DisplayName " {{ " LUISA_MAP(LUISA_DERIVE_FMT_STRUCT_FIELD_FMT, __VA_ARGS__) "}}", \
LUISA_MAP_LIST(LUISA_DERIVE_FMT_MAP_STRUCT_FIELD, __VA_ARGS__)); \
} \
};

#define LUISA_STRUCT(S, ...) \
LUISA_DERIVE_FMT(S, __VA_ARGS__) \
LUISA_DERIVE_FMT(S, S, __VA_ARGS__) \
LUISA_STRUCT_REFLECT(S, __VA_ARGS__) \
template<> \
struct luisa_compute_extension<S>; \
Expand Down

0 comments on commit 1508a20

Please sign in to comment.