diff --git a/include/luisa/dsl/fmt_impl.h b/include/luisa/dsl/fmt_impl.h index 584d16192..3b4f0740b 100644 --- a/include/luisa/dsl/fmt_impl.h +++ b/include/luisa/dsl/fmt_impl.h @@ -3,23 +3,23 @@ #include -LUISA_DERIVE_FMT(luisa::compute::float2, x, y); -LUISA_DERIVE_FMT(luisa::compute::float3, x, y, z); -LUISA_DERIVE_FMT(luisa::compute::float4, x, y, z, w); +LUISA_DERIVE_FMT(luisa::compute::float2, float2,x, y); +LUISA_DERIVE_FMT(luisa::compute::float3, float3,x, y, z); +LUISA_DERIVE_FMT(luisa::compute::float4, float4,x, y, z, w); -LUISA_DERIVE_FMT(luisa::compute::int2, x, y); -LUISA_DERIVE_FMT(luisa::compute::int3, x, y, z); -LUISA_DERIVE_FMT(luisa::compute::int4, x, y, z, w); +LUISA_DERIVE_FMT(luisa::compute::int2, int2,x, y); +LUISA_DERIVE_FMT(luisa::compute::int3, int3,x, y, z); +LUISA_DERIVE_FMT(luisa::compute::int4, int4,x, y, z, w); -LUISA_DERIVE_FMT(luisa::compute::uint2, x, y); -LUISA_DERIVE_FMT(luisa::compute::uint3, x, y, z); -LUISA_DERIVE_FMT(luisa::compute::uint4, x, y, z, w); +LUISA_DERIVE_FMT(luisa::compute::uint2, uint2, x, y); +LUISA_DERIVE_FMT(luisa::compute::uint3, uint3, x, y, z); +LUISA_DERIVE_FMT(luisa::compute::uint4, uint4, x, y, z, w); -LUISA_DERIVE_FMT(luisa::compute::bool2, x, y); -LUISA_DERIVE_FMT(luisa::compute::bool3, x, y, z); -LUISA_DERIVE_FMT(luisa::compute::bool4, x, y, z, w); +LUISA_DERIVE_FMT(luisa::compute::bool2, bool2, x, y); +LUISA_DERIVE_FMT(luisa::compute::bool3, bool3, x, y, z); +LUISA_DERIVE_FMT(luisa::compute::bool4, bool4, x, y, z, w); -LUISA_DERIVE_FMT(luisa::compute::float2x2, cols[0], cols[1]); -LUISA_DERIVE_FMT(luisa::compute::float3x3, cols[0], cols[1], cols[2]); -LUISA_DERIVE_FMT(luisa::compute::float4x4, cols[0], cols[1], cols[2], cols[3]); \ No newline at end of file +LUISA_DERIVE_FMT(luisa::compute::float2x2, float2x2, cols[0], cols[1]); +LUISA_DERIVE_FMT(luisa::compute::float3x3, float3x3, cols[0], cols[1], cols[2]); +LUISA_DERIVE_FMT(luisa::compute::float4x4, float4x4, cols[0], cols[1], cols[2], cols[3]); \ No newline at end of file diff --git a/include/luisa/dsl/printer.h b/include/luisa/dsl/printer.h index 9faf6c723..71c025a64 100644 --- a/include/luisa/dsl/printer.h +++ b/include/luisa/dsl/printer.h @@ -50,14 +50,6 @@ class LC_DSL_API Printer { _buffer->write(offset + index + i, data[i]); } index += N; - // if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - // _buffer->write(offset + index, cast(curr)); - // } else if constexpr (std::is_same_v) { - // _buffer->write(offset + index, as(curr)); - // } else { - - // // static_assert(always_false_v, "unsupported type for printing in kernel."); - // } } _log_to_buffer(offset, index, other...); } @@ -133,14 +125,7 @@ class LC_DSL_API Printer { template void Printer::_log(luisa::log_level level, luisa::string fmt, const Args &...args) noexcept { std::array count_per_arg{}; - // for (constexpr int i = 0; i < sizeof...(Args); i++) { - // if constexpr (is_dsl_v>>) { - // using T = expr_value_t>>; - // count_per_arg[i] = (sizeof(T) + sizeof(uint) - 1) / sizeof(uint); - // } else { - // count_per_arg[i] = 0; - // } - // } + auto do_count = [&](std::index_sequence) noexcept { auto impl = [&]() noexcept { if constexpr (is_dsl_v>>) { @@ -158,7 +143,7 @@ void Printer::_log(luisa::log_level level, luisa::string fmt, const Args &...arg if constexpr (sizeof...(Args) > 0) { count_by_arg[0] = 0; for (int i = 1; i < sizeof...(Args); i++) { - count_by_arg[i] = count_by_arg[i - 1] + count_per_arg[i]; + count_by_arg[i] = count_by_arg[i - 1] + count_per_arg[i - 1]; } } @@ -193,12 +178,6 @@ void Printer::_log(luisa::log_level level, luisa::string fmt, const Args &...arg T raw{}; std::memcpy(&raw, arg_data.data(), sizeof(T)); return raw; - // if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - // return static_cast(raw); - // } else { - // return luisa::bit_cast(raw); - // } - // return luisa::bit_cast(raw); } else { return std::get(args); } diff --git a/include/luisa/dsl/struct.h b/include/luisa/dsl/struct.h index d5a01246e..dcd1ef553 100644 --- a/include/luisa/dsl/struct.h +++ b/include/luisa/dsl/struct.h @@ -53,7 +53,7 @@ using c_array_to_std_array_t = typename c_array_to_std_array::type; #define LUISA_DERIVE_FMT_STRUCT_FIELD_FMT(x) #x "={} " #define LUISA_DERIVE_FMT_MAP_STRUCT_FIELD(x) input.x -#define LUISA_DERIVE_FMT(Struct, ...) \ +#define LUISA_DERIVE_FMT(Struct, DisplayName, ...) \ template<> \ struct fmt::formatter { \ constexpr auto parse(format_parse_context &ctx) -> decltype(ctx.begin()) { \ @@ -62,13 +62,13 @@ using c_array_to_std_array_t = typename c_array_to_std_array::type; template \ auto format(const Struct &input, FormatContext &ctx) -> decltype(ctx.out()) { \ return fmt::format_to(ctx.out(), \ - #Struct " {{ " LUISA_MAP(LUISA_DERIVE_FMT_STRUCT_FIELD_FMT, __VA_ARGS__) "}}", \ + #DisplayName " {{ " LUISA_MAP(LUISA_DERIVE_FMT_STRUCT_FIELD_FMT, __VA_ARGS__) "}}", \ LUISA_MAP_LIST(LUISA_DERIVE_FMT_MAP_STRUCT_FIELD, __VA_ARGS__)); \ } \ }; #define LUISA_STRUCT(S, ...) \ - LUISA_DERIVE_FMT(S, __VA_ARGS__) \ + LUISA_DERIVE_FMT(S, S, __VA_ARGS__) \ LUISA_STRUCT_REFLECT(S, __VA_ARGS__) \ template<> \ struct luisa_compute_extension; \