Skip to content

Commit

Permalink
better xir constant data filling
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Nov 14, 2024
1 parent a336239 commit 974dc37
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 15 deletions.
10 changes: 8 additions & 2 deletions include/luisa/ast/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,13 @@ class LC_AST_API Type {
[[nodiscard]] bool is_uint32() const noexcept;
[[nodiscard]] bool is_int64() const noexcept;
[[nodiscard]] bool is_uint64() const noexcept;
[[nodiscard]] bool is_float16() const noexcept;
[[nodiscard]] bool is_float32() const noexcept;
[[nodiscard]] bool is_float64() const noexcept;
[[nodiscard]] bool is_int8() const noexcept;
[[nodiscard]] bool is_uint8() const noexcept;
[[nodiscard]] bool is_int16() const noexcept;
[[nodiscard]] bool is_uint16() const noexcept;
[[nodiscard]] bool is_float16() const noexcept;
/// Arithmetic = float || int || uint
[[nodiscard]] bool is_arithmetic() const noexcept;

Expand All @@ -439,12 +442,15 @@ class LC_AST_API Type {
[[nodiscard]] bool is_bool_vector() const noexcept;
[[nodiscard]] bool is_int32_vector() const noexcept;
[[nodiscard]] bool is_uint32_vector() const noexcept;
[[nodiscard]] bool is_float16_vector() const noexcept;
[[nodiscard]] bool is_float32_vector() const noexcept;
[[nodiscard]] bool is_float64_vector() const noexcept;
[[nodiscard]] bool is_int8_vector() const noexcept;
[[nodiscard]] bool is_uint8_vector() const noexcept;
[[nodiscard]] bool is_int16_vector() const noexcept;
[[nodiscard]] bool is_uint16_vector() const noexcept;
[[nodiscard]] bool is_int64_vector() const noexcept;
[[nodiscard]] bool is_uint64_vector() const noexcept;
[[nodiscard]] bool is_float16_vector() const noexcept;
[[nodiscard]] bool is_matrix() const noexcept;
[[nodiscard]] bool is_structure() const noexcept;
[[nodiscard]] bool is_buffer() const noexcept;
Expand Down
2 changes: 2 additions & 0 deletions include/luisa/core/stl/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,7 @@ class LC_CORE_API Hash128 {
[[nodiscard]] luisa::string to_string() const noexcept;
[[nodiscard]] auto operator==(const Hash128 &rhs) const noexcept { return _data == rhs._data; }
};

LC_CORE_API Hash128 hash128(const void *ptr, size_t size, uint64_t seed) noexcept;

}// namespace luisa
1 change: 1 addition & 0 deletions include/luisa/xir/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class LC_XIR_API Constant : public IntrusiveForwardNode<Constant, DerivedValue<D
std::byte _small[sizeof(void *)] = {};
void *_large;
};
uint64_t _hash = {};

[[nodiscard]] bool _is_small() const noexcept;
[[noreturn]] void _error_cannot_change_type() const noexcept;
Expand Down
25 changes: 16 additions & 9 deletions src/ast/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,16 +542,16 @@ uint Type::dimension() const noexcept {
bool Type::is_scalar() const noexcept {
switch (tag()) {
case Tag::BOOL:
case Tag::FLOAT32:
case Tag::INT32:
case Tag::UINT32:
case Tag::INT64:
case Tag::UINT64:
case Tag::FLOAT16:
case Tag::INT16:
case Tag::UINT16:
case Tag::INT8:
case Tag::UINT8:
case Tag::FLOAT16:
case Tag::FLOAT32:
case Tag::FLOAT64:
return true;
default:
Expand All @@ -561,16 +561,17 @@ bool Type::is_scalar() const noexcept {

bool Type::is_arithmetic() const noexcept {
switch (tag()) {
case Tag::FLOAT16:
case Tag::FLOAT32:
case Tag::FLOAT64:
case Tag::INT8:
case Tag::UINT8:
case Tag::INT16:
case Tag::UINT16:
case Tag::INT32:
case Tag::UINT32:
case Tag::INT64:
case Tag::UINT64:
case Tag::FLOAT16:
case Tag::INT16:
case Tag::UINT16:
case Tag::INT8:
case Tag::UINT8:
return true;
default:
return false;
Expand Down Expand Up @@ -700,20 +701,26 @@ const Type *Type::custom(luisa::string_view name) noexcept {
bool Type::is_bool() const noexcept { return tag() == Tag::BOOL; }
bool Type::is_int32() const noexcept { return tag() == Tag::INT32; }
bool Type::is_uint32() const noexcept { return tag() == Tag::UINT32; }
bool Type::is_float16() const noexcept { return tag() == Tag::FLOAT16; }
bool Type::is_float32() const noexcept { return tag() == Tag::FLOAT32; }
bool Type::is_float64() const noexcept { return tag() == Tag::FLOAT64; }
bool Type::is_int8() const noexcept { return tag() == Tag::INT8; }
bool Type::is_uint8() const noexcept { return tag() == Tag::UINT8; }
bool Type::is_int16() const noexcept { return tag() == Tag::INT16; }
bool Type::is_uint16() const noexcept { return tag() == Tag::UINT16; }
bool Type::is_int64() const noexcept { return tag() == Tag::INT64; }
bool Type::is_uint64() const noexcept { return tag() == Tag::UINT64; }
bool Type::is_float16() const noexcept { return tag() == Tag::FLOAT16; }

bool Type::is_bool_vector() const noexcept { return is_vector() && element()->is_bool(); }
bool Type::is_int32_vector() const noexcept { return is_vector() && element()->is_int32(); }
bool Type::is_uint32_vector() const noexcept { return is_vector() && element()->is_uint32(); }
bool Type::is_float16_vector() const noexcept { return is_vector() && element()->is_float16(); }
bool Type::is_float32_vector() const noexcept { return is_vector() && element()->is_float32(); }
bool Type::is_float64_vector() const noexcept { return is_vector() && element()->is_float64(); }
bool Type::is_int8_vector() const noexcept { return is_vector() && element()->is_int8(); }
bool Type::is_uint8_vector() const noexcept { return is_vector() && element()->is_uint8(); }
bool Type::is_int16_vector() const noexcept { return is_vector() && element()->is_int16(); }
bool Type::is_uint16_vector() const noexcept { return is_vector() && element()->is_uint16(); }
bool Type::is_float16_vector() const noexcept { return is_vector() && element()->is_float16(); }
bool Type::is_int64_vector() const noexcept { return is_vector() && element()->is_int64(); }
bool Type::is_uint64_vector() const noexcept { return is_vector() && element()->is_uint64(); }

Expand Down
3 changes: 3 additions & 0 deletions src/xir/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ void Builder::_check_valid_insertion_point() const noexcept {
Builder::Builder() noexcept = default;

IfInst *Builder::if_(Value *cond) noexcept {
LUISA_ASSERT(cond != nullptr && cond->type() == Type::of<bool>(), "Invalid condition.");
return _create_and_append_instruction<IfInst>(cond);
}

SwitchInst *Builder::switch_(Value *value) noexcept {
LUISA_ASSERT(value != nullptr, "Switch value cannot be null.");
return _create_and_append_instruction<SwitchInst>(value);
}

Expand All @@ -33,6 +35,7 @@ BranchInst *Builder::br(BasicBlock *target) noexcept {
}

ConditionalBranchInst *Builder::cond_br(Value *cond, BasicBlock *true_target, BasicBlock *false_target) noexcept {
LUISA_ASSERT(cond != nullptr && cond->type() == Type::of<bool>(), "Invalid condition.");
auto inst = _create_and_append_instruction<ConditionalBranchInst>(cond);
inst->set_true_target(true_target);
inst->set_false_target(false_target);
Expand Down
72 changes: 68 additions & 4 deletions src/xir/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,75 @@ void Constant::_error_cannot_change_type() const noexcept {
LUISA_ERROR_WITH_LOCATION("Constant type cannot be changed.");
}

void Constant::set_data(const void *data) noexcept {
if (data == nullptr) {
memset(this->data(), 0, type()->size());
namespace detail {

static void xir_constant_fill_data(const Type *t, const void *raw, void *data) noexcept {
LUISA_DEBUG_ASSERT(t != nullptr && raw != nullptr && data != nullptr, "Invalid arguments.");
if (t->is_bool()) {
auto value = static_cast<uint8_t>(*static_cast<const bool *>(raw) ? 1u : 0u);
memmove(data, &value, 1);
} else if (t->is_scalar()) {
memmove(data, raw, t->size());
} else {
memmove(this->data(), data, type()->size());
switch (t->tag()) {
case Type::Tag::VECTOR: {
auto elem_type = t->element();
auto dim = t->dimension();
for (auto i = 0u; i < dim; i++) {
auto offset = i * elem_type->size();
auto raw_elem = static_cast<const std::byte *>(raw) + offset;
auto data_elem = static_cast<std::byte *>(data) + offset;
xir_constant_fill_data(elem_type, raw_elem, data_elem);
}
break;
}
case Type::Tag::MATRIX: {
auto elem_type = t->element();
auto dim = t->dimension();
auto col_type = Type::vector(elem_type, dim);
for (auto i = 0u; i < dim; i++) {
auto offset = i * col_type->size();
auto raw_col = static_cast<const std::byte *>(raw) + offset;
auto data_col = static_cast<std::byte *>(data) + offset;
xir_constant_fill_data(col_type, raw_col, data_col);
}
break;
}
case Type::Tag::ARRAY: {
auto elem_type = t->element();
auto dim = t->dimension();
for (auto i = 0u; i < dim; i++) {
auto offset = i * elem_type->size();
auto raw_elem = static_cast<const std::byte *>(raw) + offset;
auto data_elem = static_cast<std::byte *>(data) + offset;
xir_constant_fill_data(elem_type, raw_elem, data_elem);
}
break;
}
case Type::Tag::STRUCTURE: {
size_t offset = 0u;
for (auto m : t->members()) {
offset = luisa::align(offset, m->alignment());
auto raw_member = static_cast<const std::byte *>(raw) + offset;
auto data_member = static_cast<std::byte *>(data) + offset;
xir_constant_fill_data(m, raw_member, data_member);
offset += m->size();
}
break;
}
default: LUISA_ERROR_WITH_LOCATION("Unsupported constant type.");
}
}
}

}// namespace detail

void Constant::set_data(const void *data) noexcept {
memset(this->data(), 0, type()->size());
_hash = 0u;
if (data != nullptr) {
detail::xir_constant_fill_data(type(), data, this->data());
_hash = luisa::hash64(this->data(), type()->size(), luisa::hash64_default_seed);
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/xir/translators/ast2xir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ class AST2XIRContext {

void _translate_if_stmt(Builder &b, const IfStmt *ast_if, luisa::span<const Statement *const> cdr) noexcept {
auto cond = _translate_expression(b, ast_if->condition(), true);
cond = b.static_cast_if_necessary(Type::of<bool>(), cond);
auto inst = _commented(b.if_(cond));
auto merge_block = inst->create_merge_block();
// true branch
Expand Down Expand Up @@ -788,6 +789,7 @@ class AST2XIRContext {
{
b.set_insertion_point(prepare_block);
auto cond = _translate_expression(b, ast_for->condition(), true);
cond = b.static_cast_if_necessary(Type::of<bool>(), cond);
b.cond_br(cond, body_block, merge_block);
}
// body block
Expand Down

0 comments on commit 974dc37

Please sign in to comment.