Skip to content

Commit

Permalink
Add iterables + more list types
Browse files Browse the repository at this point in the history
  • Loading branch information
ccummingsNV committed Aug 12, 2024
1 parent 1a6b891 commit 772d77f
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 37 deletions.
12 changes: 12 additions & 0 deletions src/sgl/core/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ inline std::string list_to_string(const std::vector<T>& list, std::string_view i
return list_to_string(std::span{list}, indentation);
}

template<typename T>
inline std::string iterable_to_string(const T& iterable)
{
std::string result = "[\n";
for (const auto& item : iterable) {
result += " ";
result += string::indent(item->to_string());
result += ",\n";
}
return result;
}

/**
* Remove leading whitespace.
* \param str Input string.
Expand Down
8 changes: 7 additions & 1 deletion src/sgl/device/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,19 @@ class ShaderProgram;
// reflection.h

class DeclReflection;
class DeclReflectionChildList;
class DeclReflectionIndexedChildList;
class TypeReflection;
class TypeReflectionFieldList;
class TypeLayoutReflection;
class TypeLayoutReflectionFieldList;
class FunctionReflection;
class FunctionReflectionParameterList;
class VariableReflection;
class VariableLayoutReflection;
class ProgramLayout;
class EntryPointLayout;
class EntryPointLayoutParameterList;
class ProgramLayout;

// kernel.h

Expand Down
2 changes: 2 additions & 0 deletions src/sgl/device/python/reflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ SGL_PY_EXPORT(device_reflection)
.def("unwrap_array", &TypeLayoutReflection::unwrap_array)
.def("__repr__", &TypeLayoutReflection::to_string);

build_list_type<TypeLayoutReflectionFieldList>(m, "TypeLayoutReflectionFieldList");

nb::class_<FunctionReflection, BaseReflectionObject>(m, "FunctionReflection")
.def_prop_ro("name", &FunctionReflection::name)
.def_prop_ro("return_type", &FunctionReflection::return_type)
Expand Down
19 changes: 17 additions & 2 deletions src/sgl/device/reflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ std::string TypeReflection::to_string() const
return str;
}

TypeLayoutReflectionFieldList TypeLayoutReflection::fields() const
{
return TypeLayoutReflectionFieldList(ref(this));
}

std::string TypeLayoutReflection::to_string() const
{
switch (kind()) {
Expand All @@ -183,7 +188,7 @@ std::string TypeLayoutReflection::to_string() const
kind(),
size(),
stride(),
string::indent(string::list_to_string(fields()))
string::indent(string::iterable_to_string(fields()))
);
break;
case TypeReflection::Kind::resource:
Expand Down Expand Up @@ -237,6 +242,11 @@ std::string TypeLayoutReflection::to_string() const
}
}

FunctionReflectionParameterList FunctionReflection::parameters() const
{
return FunctionReflectionParameterList(ref(this));
}

std::string VariableLayoutReflection::to_string() const
{
return fmt::format(
Expand All @@ -249,6 +259,11 @@ std::string VariableLayoutReflection::to_string() const
);
}

EntryPointLayoutParameterList EntryPointLayout::parameters() const
{
return EntryPointLayoutParameterList(ref(this));
}

std::string EntryPointLayout::to_string() const
{
return fmt::format(
Expand All @@ -263,7 +278,7 @@ std::string EntryPointLayout::to_string() const
c_str_to_string(name_override()),
stage(),
compute_thread_group_size(),
string::indent(string::list_to_string(parameters()))
string::indent(string::iterable_to_string(parameters()))
);
}

Expand Down
169 changes: 135 additions & 34 deletions src/sgl/device/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,35 @@ template<class ParentType, class ChildType>
class BaseReflectionList {

public:
class Iterator {
public:
explicit Iterator(const BaseReflectionList* list, uint32_t index)
: m_list(list)
, m_index(index)
{
}
Iterator& operator++()
{
if (m_index >= m_list->size())
SGL_THROW("Iterator out of range");
m_index++;
return *this;
}
Iterator operator++(int)
{
Iterator retval = *this;
++(*this);
return retval;
}
bool operator==(Iterator other) const { return m_list == other.m_list && m_index == other.m_index; }
bool operator!=(Iterator other) const { return !(*this == other); }
ref<const ChildType> operator*() const { return (*m_list)[m_index]; }

private:
const BaseReflectionList* m_list;
uint32_t m_index;
};

BaseReflectionList(ref<const ParentType> owner)
: m_owner(std::move(owner))
{
Expand All @@ -96,7 +125,13 @@ class BaseReflectionList {
virtual uint32_t size() const = 0;

/// Index operator
ref<const ChildType> operator[](uint32_t index) { return evaluate(index); }
ref<const ChildType> operator[](uint32_t index) const { return evaluate(index); }

/// Begin iterator
Iterator begin() const { return Iterator(this, 0); }

/// End iterator
Iterator end() const { return Iterator(this, size()); }

protected:
ref<const ParentType> m_owner;
Expand All @@ -112,6 +147,35 @@ class BaseReflectionList {
template<class ParentType, class ChildType>
class BaseReflectionIndexedList {
public:
class Iterator {
public:
explicit Iterator(const BaseReflectionIndexedList* list, uint32_t index)
: m_list(list)
, m_index(index)
{
}
Iterator& operator++()
{
if (m_index >= m_list->size())
SGL_THROW("Iterator out of range");
m_index++;
return *this;
}
Iterator operator++(int)
{
Iterator retval = *this;
++(*this);
return retval;
}
bool operator==(Iterator other) const { return m_list == other.m_list && m_index == other.m_index; }
bool operator!=(Iterator other) const { return !(*this == other); }
ref<const ChildType> operator*() const { return (*m_list)[m_index]; }

private:
const BaseReflectionIndexedList* m_list;
uint32_t m_index;
};

BaseReflectionIndexedList(ref<const ParentType> owner, std::vector<uint32_t> indices)
: m_owner(std::move(owner))
, m_indices(std::move(indices))
Expand All @@ -134,7 +198,13 @@ class BaseReflectionIndexedList {
uint32_t size() const { return static_cast<uint32_t>(m_indices.size()); }

/// Index operator
ref<const ChildType> operator[](uint32_t index) { return evaluate(m_indices[index]); }
ref<const ChildType> operator[](uint32_t index) const { return evaluate(m_indices[index]); }

/// Begin iterator
Iterator begin() const { return Iterator(this, 0); }

/// End iterator
Iterator end() const { return Iterator(this, size()); }

protected:
ref<const ParentType> m_owner;
Expand All @@ -155,9 +225,6 @@ class SGL_API BaseReflectionObject : public Object {
ref<const Object> m_owner;
};

class DeclReflectionChildList;
class DeclReflectionIndexedChildList;

class SGL_API DeclReflection : public BaseReflectionObject {

public:
Expand Down Expand Up @@ -242,7 +309,7 @@ class SGL_API DeclReflection : public BaseReflectionObject {
};
SGL_ENUM_REGISTER(DeclReflection::Kind);

/// DeclReflection lazy child list evaluation implementation
/// DeclReflection lazy child list evaluation.
class SGL_API DeclReflectionChildList : public BaseReflectionList<DeclReflection, DeclReflection> {

public:
Expand All @@ -257,7 +324,7 @@ class SGL_API DeclReflectionChildList : public BaseReflectionList<DeclReflection
ref<const DeclReflection> evaluate(uint32_t index) const override { return m_owner->child(index); }
};

/// DeclReflection lazy search result evaluation implementation
/// DeclReflection lazy search result evaluation.
class SGL_API DeclReflectionIndexedChildList : public BaseReflectionIndexedList<DeclReflection, DeclReflection> {
public:
DeclReflectionIndexedChildList(ref<const DeclReflection> owner, std::vector<uint32_t> results)
Expand All @@ -268,8 +335,6 @@ class SGL_API DeclReflectionIndexedChildList : public BaseReflectionIndexedList<
ref<const DeclReflection> evaluate(uint32_t index) const override { return m_owner->child(index); }
};

class TypeReflectionFieldList;

class SGL_API TypeReflection : public BaseReflectionObject {
public:
enum class Kind {
Expand Down Expand Up @@ -582,7 +647,7 @@ SGL_ENUM_REGISTER(TypeReflection::ResourceShape);
SGL_ENUM_REGISTER(TypeReflection::ResourceAccess);
SGL_ENUM_REGISTER(TypeReflection::ParameterCategory);

/// TypeReflectionChildList lazy field list evaluation implementation
/// TypeReflection lazy field list evaluation.
class SGL_API TypeReflectionFieldList : public BaseReflectionList<TypeReflection, VariableReflection> {

public:
Expand Down Expand Up @@ -646,14 +711,7 @@ class SGL_API TypeLayoutReflection : public BaseReflectionObject {
return nullptr;
}

std::vector<ref<const VariableLayoutReflection>> fields() const
{
std::vector<ref<const VariableLayoutReflection>> result;
for (uint32_t i = 0; i < m_target->getFieldCount(); ++i) {
result.push_back(detail::from_slang(m_owner, m_target->getFieldByIndex(i)));
}
return result;
}
TypeLayoutReflectionFieldList fields() const;

bool is_array() const { return type()->is_array(); }

Expand Down Expand Up @@ -705,6 +763,25 @@ class SGL_API TypeLayoutReflection : public BaseReflectionObject {
slang::TypeLayoutReflection* m_target;
};

/// TypeLayoutReflection lazy field list evaluation.
class SGL_API TypeLayoutReflectionFieldList
: public BaseReflectionList<TypeLayoutReflection, VariableLayoutReflection> {

public:
TypeLayoutReflectionFieldList(ref<const TypeLayoutReflection> owner)
: BaseReflectionList(std::move(owner)){};

/// Number of entries in list.
uint32_t size() const override { return m_owner->field_count(); }

protected:
/// Get a specific child.
ref<const VariableLayoutReflection> evaluate(uint32_t index) const override
{
return m_owner->get_field_by_index(index);
}
};

class SGL_API FunctionReflection : public BaseReflectionObject {
public:
FunctionReflection(ref<const Object> owner, slang::FunctionReflection* target)
Expand All @@ -723,14 +800,7 @@ class SGL_API FunctionReflection : public BaseReflectionObject {
return detail::from_slang(m_owner, m_target->getParameterByIndex(index));
}

std::vector<ref<const VariableReflection>> parameters() const
{
std::vector<ref<const VariableReflection>> result;
for (uint32_t i = 0; i < m_target->getParameterCount(); ++i) {
result.push_back(detail::from_slang(m_owner, m_target->getParameterByIndex(i)));
}
return result;
}
FunctionReflectionParameterList parameters() const;

/// Check if variable has a given modifier (e.g. 'inout').
bool has_modifier(ModifierID modifier) const
Expand All @@ -742,6 +812,24 @@ class SGL_API FunctionReflection : public BaseReflectionObject {
slang::FunctionReflection* m_target;
};

/// FunctionReflection lazy parameter list evaluation.
class SGL_API FunctionReflectionParameterList : public BaseReflectionList<FunctionReflection, VariableReflection> {

public:
FunctionReflectionParameterList(ref<const FunctionReflection> owner)
: BaseReflectionList(std::move(owner)){};

/// Number of entries in list.
uint32_t size() const override { return m_owner->parameter_count(); }

protected:
/// Get a specific child.
ref<const VariableReflection> evaluate(uint32_t index) const override
{
return m_owner->get_parameter_by_index(index);
}
};

class SGL_API VariableReflection : public BaseReflectionObject {
public:
static ref<const VariableReflection>
Expand Down Expand Up @@ -819,14 +907,7 @@ class SGL_API EntryPointLayout : public BaseReflectionObject {
return detail::from_slang(m_owner, m_target->getParameterByIndex(index));
}

std::vector<ref<const VariableLayoutReflection>> parameters() const
{
std::vector<ref<const VariableLayoutReflection>> result;
for (uint32_t i = 0; i < m_target->getParameterCount(); ++i) {
result.push_back(detail::from_slang(m_owner, m_target->getParameterByIndex(i)));
}
return result;
}
EntryPointLayoutParameterList parameters() const;

uint3 compute_thread_group_size() const
{
Expand All @@ -843,6 +924,26 @@ class SGL_API EntryPointLayout : public BaseReflectionObject {
slang::EntryPointLayout* m_target;
};


/// EntryPointLayout lazy parameter list evaluation.
class SGL_API EntryPointLayoutParameterList : public BaseReflectionList<EntryPointLayout, VariableLayoutReflection> {

public:
EntryPointLayoutParameterList(ref<const EntryPointLayout> owner)
: BaseReflectionList(std::move(owner)){};

/// Number of entries in list.
uint32_t size() const override { return m_owner->parameter_count(); }

protected:
/// Get a specific child.
ref<const VariableLayoutReflection> evaluate(uint32_t index) const override
{
return m_owner->get_parameter_by_index(index);
}
};


class SGL_API ProgramLayout : public BaseReflectionObject {
public:
static ref<const ProgramLayout> from_slang(ref<const Object> owner, slang::ProgramLayout* program_layout)
Expand Down

0 comments on commit 772d77f

Please sign in to comment.