diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index f49b8efd..cbf6063f 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -50,6 +50,7 @@ namespace librapid { using Scalar = typename TypeInfo::Scalar; using Packet = std::false_type; using Backend = typename TypeInfo::Backend; + using ShapeType = ShapeType_; static constexpr int64_t packetWidth = 1; static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; static constexpr bool supportsLogical = TypeInfo::supportsLogical; diff --git a/librapid/include/librapid/array/assignOps.hpp b/librapid/include/librapid/array/assignOps.hpp index 920bfcad..747f022e 100644 --- a/librapid/include/librapid/array/assignOps.hpp +++ b/librapid/include/librapid/array/assignOps.hpp @@ -124,15 +124,15 @@ namespace librapid { using Function = detail::Function; using Scalar = typename array::ArrayContainer>::Scalar; - constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; + constexpr size_t packetWidth = typetraits::TypeInfo::packetWidth; constexpr bool allowVectorisation = typetraits::TypeInfo< detail::Function>::allowVectorisation && Function::argsAreSameType; - const int64_t size = function.shape().size(); - const int64_t vectorSize = size - (size % packetWidth); + const size_t size = function.shape().size(); + const size_t vectorSize = size - (size % packetWidth); LIBRAPID_ASSUME(vectorSize % packetWidth == 0); diff --git a/librapid/include/librapid/array/function.hpp b/librapid/include/librapid/array/function.hpp index 2bfc3928..bc52202b 100644 --- a/librapid/include/librapid/array/function.hpp +++ b/librapid/include/librapid/array/function.hpp @@ -2,281 +2,282 @@ #define LIBRAPID_ARRAY_FUNCTION_HPP namespace librapid { - namespace typetraits { - // Extract allowVectorisation from the input types - template - constexpr bool checkAllowVectorisation() { - if constexpr (sizeof...(T) == 0) { - return TypeInfo>::allowVectorisation; - } else { - using T1 = typename TypeInfo>::Scalar; - return TypeInfo>::allowVectorisation && - checkAllowVectorisation() && - (std::is_same_v>::Scalar> && ...); - } - } - - template - constexpr auto commonBackend() { - using FirstBackend = typename TypeInfo>::Backend; - if constexpr (sizeof...(Rest) == 0) { - return FirstBackend {}; - } else { - using RestBackend = decltype(commonBackend()); - if constexpr (std::is_same_v || - std::is_same_v) { - return backend::OpenCLIfAvailable {}; - } else if constexpr (std::is_same_v || - std::is_same_v) { - return backend::CUDAIfAvailable {}; - } else { - return backend::CPU {}; - } - } - } - - template - struct TypeInfo<::librapid::detail::Function> { - static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction; - using Scalar = decltype(std::declval()( - std::declval>::Scalar>()...)); - using Backend = decltype(commonBackend()); - - static constexpr bool allowVectorisation = checkAllowVectorisation(); - - static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; - static constexpr bool supportsLogical = TypeInfo::supportsLogical; - static constexpr bool supportsBinary = TypeInfo::supportsBinary; - }; - - LIBRAPID_DEFINE_AS_TYPE(typename desc COMMA typename Functor_ COMMA typename... Args, - ::librapid::detail::Function); - } // namespace typetraits - - namespace detail { - // Descriptor is defined in "forward.hpp" - - template< - typename Packet, typename T, - typename std::enable_if_t< - typetraits::TypeInfo::type != ::librapid::detail::LibRapidType::Scalar, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, - size_t index) { - static_assert(std::is_same_v, - "Packet types do not match"); - return obj.packet(index); - } - - template< - typename Packet, typename T, - typename std::enable_if_t< - typetraits::TypeInfo::type == ::librapid::detail::LibRapidType::Scalar, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, size_t) { - return Packet(obj); - } - - template::type != - ::librapid::detail::LibRapidType::Scalar, - int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t index) { - return obj.scalar(index); - } - - template::type == - ::librapid::detail::LibRapidType::Scalar, - int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t) { - return obj; - } - - template - constexpr auto scalarTypesAreSame() { - if constexpr (sizeof...(Rest) == 0) { - using Scalar = typename typetraits::TypeInfo>::Scalar; - return Scalar {}; - } else { - using RestType = decltype(scalarTypesAreSame()); - if constexpr (std::is_same_v< - typename typetraits::TypeInfo>::Scalar, - RestType>) { - return RestType {}; - } else { - return std::false_type {}; - } - } - } - - template - class Function { - public: - using Type = Function; - using Functor = Functor_; - using ShapeType = Shape; - using StrideType = ShapeType; - using Scalar = typename typetraits::TypeInfo::Scalar; - using Backend = typename typetraits::TypeInfo::Backend; - using Packet = typename typetraits::TypeInfo::Packet; - using Iterator = detail::ArrayIterator; - - using Descriptor = desc; - static constexpr bool argsAreSameType = - !std::is_same_v()), std::false_type>; - - Function() = default; - - /// Constructs a Function from a functor and arguments. - /// \param functor The functor to use. - /// \param args The arguments to use. - LIBRAPID_ALWAYS_INLINE explicit Function(Functor &&functor, Args &&...args); - - /// Constructs a Function from another function. - /// \param other The Function to copy. - LIBRAPID_ALWAYS_INLINE Function(const Function &other) = default; - - /// Construct a Function from a temporary function. - /// \param other The Function to move. - LIBRAPID_ALWAYS_INLINE Function(Function &&other) noexcept = default; - - /// Assigns a Function to this function. - /// \param other The Function to copy. - /// \return A reference to this Function. - LIBRAPID_ALWAYS_INLINE Function &operator=(const Function &other) = default; - - /// Assigns a temporary Function to this Function. - /// \param other The Function to move. - /// \return A reference to this Function. - LIBRAPID_ALWAYS_INLINE Function &operator=(Function &&other) noexcept = default; - - /// Return the shape of the Function's result - /// \return The shape of the Function's result - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const; - - /// Return the arguments in the Function - /// \return The arguments in the Function - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &args() const; - - /// Return an evaluated Array object - /// \return - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const; - - /// Evaluates the function at the given index, returning a Packet result. - /// \param index The index to evaluate at. - /// \return The result of the function (vectorized). - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const; - - /// Evaluates the function at the given index, returning a Scalar result. - /// \param index The index to evaluate at. - /// \return The result of the function (scalar). - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const; - - template - void str(const fmt::formatter &format, char bracket, char separator, - Ctx &ctx) const; - - private: - /// Implementation detail -- evaluates the function at the given index, - /// returning a Packet result. - /// \tparam I The index sequence. - /// \param index The index to evaluate at. - /// \return The result of the function (vectorized). - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetImpl(std::index_sequence, - size_t index) const; - - /// Implementation detail -- evaluates the function at the given index, - /// returning a Scalar result. - /// \tparam I The index sequence. - /// \param index The index to evaluate at. - /// \return The result of the function (scalar). - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalarImpl(std::index_sequence, - size_t index) const; - - Functor m_functor; - std::tuple m_args; - }; - - template - Function::Function(Functor &&functor, Args &&...args) : - m_functor(std::forward(functor)), m_args(std::forward(args)...) {} - - template - auto Function::shape() const { - return typetraits::TypeInfo::getShape(m_args); - } - - template - auto &Function::args() const { - return m_args; - } - - template - auto Function::operator[](int64_t index) const { - return array::GeneralArrayView(*this)[index]; - } - - template - auto Function::eval() const { - auto res = Array(shape()); - res = *this; - return res; - } - - template - typename Function::Packet - Function::packet(size_t index) const { - return packetImpl(std::make_index_sequence(), index); - } - - template - template - auto Function::packetImpl(std::index_sequence, - size_t index) const -> Packet { - return m_functor.packet(packetExtractor(std::get(m_args), index)...); - } - - template - auto Function::scalar(size_t index) const -> Scalar { - return scalarImpl(std::make_index_sequence(), index); - } - - template - template - auto Function::scalarImpl(std::index_sequence, - size_t index) const -> Scalar { - return m_functor(scalarExtractor(std::get(m_args), index)...); - } - - template - auto Function::begin() const -> Iterator { - return Iterator(*this, 0); - } - - template - auto Function::end() const -> Iterator { - return Iterator(*this, shape()[0]); - } - - template - template - void Function::str(const fmt::formatter &format, - char bracket, char separator, Ctx &ctx) const { - array::GeneralArrayView(*this).str(format, bracket, separator, ctx); - } - } // namespace detail + namespace typetraits { + // Extract allowVectorisation from the input types + template + constexpr bool checkAllowVectorisation() { + if constexpr (sizeof...(T) == 0) { + return TypeInfo>::allowVectorisation; + } else { + using T1 = typename TypeInfo>::Scalar; + return TypeInfo>::allowVectorisation && + checkAllowVectorisation() && + (std::is_same_v>::Scalar> && ...); + } + } + + template + constexpr auto commonBackend() { + using FirstBackend = typename TypeInfo>::Backend; + if constexpr (sizeof...(Rest) == 0) { + return FirstBackend {}; + } else { + using RestBackend = decltype(commonBackend()); + if constexpr (std::is_same_v || + std::is_same_v) { + return backend::OpenCLIfAvailable {}; + } else if constexpr (std::is_same_v || + std::is_same_v) { + return backend::CUDAIfAvailable {}; + } else { + return backend::CPU {}; + } + } + } + + template + struct TypeInfo<::librapid::detail::Function> { + static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction; + using Scalar = decltype(std::declval()( + std::declval>::Scalar>()...)); + using Backend = decltype(commonBackend()); + + static constexpr bool allowVectorisation = checkAllowVectorisation(); + + static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; + static constexpr bool supportsLogical = TypeInfo::supportsLogical; + static constexpr bool supportsBinary = TypeInfo::supportsBinary; + }; + + LIBRAPID_DEFINE_AS_TYPE(typename desc COMMA typename Functor_ COMMA typename... Args, + ::librapid::detail::Function); + } // namespace typetraits + + namespace detail { + // Descriptor is defined in "forward.hpp" + + template< + typename Packet, typename T, + typename std::enable_if_t< + typetraits::TypeInfo::type != ::librapid::detail::LibRapidType::Scalar, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, + size_t index) { + static_assert(std::is_same_v, + "Packet types do not match"); + return obj.packet(index); + } + + template< + typename Packet, typename T, + typename std::enable_if_t< + typetraits::TypeInfo::type == ::librapid::detail::LibRapidType::Scalar, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, size_t) { + return Packet(obj); + } + + template::type != + ::librapid::detail::LibRapidType::Scalar, + int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t index) { + return obj.scalar(index); + } + + template::type == + ::librapid::detail::LibRapidType::Scalar, + int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t) { + return obj; + } + + template + constexpr auto scalarTypesAreSame() { + if constexpr (sizeof...(Rest) == 0) { + using Scalar = typename typetraits::TypeInfo>::Scalar; + return Scalar {}; + } else { + using RestType = decltype(scalarTypesAreSame()); + if constexpr (std::is_same_v< + typename typetraits::TypeInfo>::Scalar, + RestType>) { + return RestType {}; + } else { + return std::false_type {}; + } + } + } + + template + class Function { + public: + using Type = Function; + using Functor = Functor_; + using ShapeType = + detail::ShapeTypeHelper::ShapeType...>::Type; + using StrideType = ShapeType; + using Scalar = typename typetraits::TypeInfo::Scalar; + using Backend = typename typetraits::TypeInfo::Backend; + using Packet = typename typetraits::TypeInfo::Packet; + using Iterator = detail::ArrayIterator; + + using Descriptor = desc; + static constexpr bool argsAreSameType = + !std::is_same_v()), std::false_type>; + + Function() = default; + + /// Constructs a Function from a functor and arguments. + /// \param functor The functor to use. + /// \param args The arguments to use. + LIBRAPID_ALWAYS_INLINE explicit Function(Functor &&functor, Args &&...args); + + /// Constructs a Function from another function. + /// \param other The Function to copy. + LIBRAPID_ALWAYS_INLINE Function(const Function &other) = default; + + /// Construct a Function from a temporary function. + /// \param other The Function to move. + LIBRAPID_ALWAYS_INLINE Function(Function &&other) noexcept = default; + + /// Assigns a Function to this function. + /// \param other The Function to copy. + /// \return A reference to this Function. + LIBRAPID_ALWAYS_INLINE Function &operator=(const Function &other) = default; + + /// Assigns a temporary Function to this Function. + /// \param other The Function to move. + /// \return A reference to this Function. + LIBRAPID_ALWAYS_INLINE Function &operator=(Function &&other) noexcept = default; + + /// Return the shape of the Function's result + /// \return The shape of the Function's result + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const; + + /// Return the arguments in the Function + /// \return The arguments in the Function + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &args() const; + + /// Return an evaluated Array object + /// \return + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const; + + /// Evaluates the function at the given index, returning a Packet result. + /// \param index The index to evaluate at. + /// \return The result of the function (vectorized). + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const; + + /// Evaluates the function at the given index, returning a Scalar result. + /// \param index The index to evaluate at. + /// \return The result of the function (scalar). + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const; + + template + void str(const fmt::formatter &format, char bracket, char separator, + Ctx &ctx) const; + + private: + /// Implementation detail -- evaluates the function at the given index, + /// returning a Packet result. + /// \tparam I The index sequence. + /// \param index The index to evaluate at. + /// \return The result of the function (vectorized). + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetImpl(std::index_sequence, + size_t index) const; + + /// Implementation detail -- evaluates the function at the given index, + /// returning a Scalar result. + /// \tparam I The index sequence. + /// \param index The index to evaluate at. + /// \return The result of the function (scalar). + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalarImpl(std::index_sequence, + size_t index) const; + + Functor m_functor; + std::tuple m_args; + }; + + template + Function::Function(Functor &&functor, Args &&...args) : + m_functor(std::forward(functor)), m_args(std::forward(args)...) {} + + template + auto Function::shape() const { + return typetraits::TypeInfo::getShape(m_args); + } + + template + auto &Function::args() const { + return m_args; + } + + template + auto Function::operator[](int64_t index) const { + return array::GeneralArrayView(*this)[index]; + } + + template + auto Function::eval() const { + auto res = Array(shape()); + res = *this; + return res; + } + + template + typename Function::Packet + Function::packet(size_t index) const { + return packetImpl(std::make_index_sequence(), index); + } + + template + template + auto Function::packetImpl(std::index_sequence, + size_t index) const -> Packet { + return m_functor.packet(packetExtractor(std::get(m_args), index)...); + } + + template + auto Function::scalar(size_t index) const -> Scalar { + return scalarImpl(std::make_index_sequence(), index); + } + + template + template + auto Function::scalarImpl(std::index_sequence, + size_t index) const -> Scalar { + return m_functor(scalarExtractor(std::get(m_args), index)...); + } + + template + auto Function::begin() const -> Iterator { + return Iterator(*this, 0); + } + + template + auto Function::end() const -> Iterator { + return Iterator(*this, shape()[0]); + } + + template + template + void Function::str(const fmt::formatter &format, + char bracket, char separator, Ctx &ctx) const { + array::GeneralArrayView(*this).str(format, bracket, separator, ctx); + } + } // namespace detail } // namespace librapid // Support FMT printing #ifdef FMT_API ARRAY_TYPE_FMT_IML(typename desc COMMA typename Functor COMMA typename... Args, - librapid::detail::Function) + librapid::detail::Function) LIBRAPID_SIMPLE_IO_NORANGE(typename desc COMMA typename Functor COMMA typename... Args, - librapid::detail::Function) + librapid::detail::Function) #endif // FMT_API #endif // LIBRAPID_ARRAY_FUNCTION_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/shape.hpp b/librapid/include/librapid/array/shape.hpp index 3c0009d0..39ef4f75 100644 --- a/librapid/include/librapid/array/shape.hpp +++ b/librapid/include/librapid/array/shape.hpp @@ -10,7 +10,7 @@ namespace librapid { namespace typetraits { LIBRAPID_DEFINE_AS_TYPE(size_t N, Shape); LIBRAPID_DEFINE_AS_TYPE_NO_TEMPLATE(MatrixShape); - } + } // namespace typetraits template class Shape { @@ -492,10 +492,10 @@ namespace librapid { /// \param shapes Remaining (optional) inputs /// \return True if all inputs have the same shape, false otherwise template::value && - typetraits::IsSizeType::value && - (typetraits::IsSizeType::value && ...), - int> = 0> + typename std::enable_if_t::value && + typetraits::IsSizeType::value && + (typetraits::IsSizeType::value && ...), + int> = 0> LIBRAPID_NODISCARD LIBRAPID_INLINE bool shapesMatch(const First &first, const Second &second, const Rest &...shapes) { if constexpr (sizeof...(Rest) == 0) { @@ -507,10 +507,10 @@ namespace librapid { /// \sa shapesMatch template::value && - typetraits::IsSizeType::value && - (typetraits::IsSizeType::value && ...), - int> = 0> + typename std::enable_if_t::value && + typetraits::IsSizeType::value && + (typetraits::IsSizeType::value && ...), + int> = 0> LIBRAPID_NODISCARD LIBRAPID_INLINE bool shapesMatch(const std::tuple &shapes) { if constexpr (sizeof...(Rest) == 0) { @@ -521,6 +521,44 @@ namespace librapid { [](auto, auto, auto... rest) { return std::make_tuple(rest...); }, shapes)); } } + + namespace detail { + template + struct ShapeTypeHelperImpl { + using Type = std::false_type; + }; + + template + struct ShapeTypeHelperImpl, Shape> { + using Type = Shape<(N > M ? N : M)>; + }; + + template + struct ShapeTypeHelperImpl, MatrixShape> { + using Type = Shape; + }; + + template + struct ShapeTypeHelperImpl> { + using Type = Shape; + }; + + template<> + struct ShapeTypeHelperImpl { + using Type = MatrixShape; + }; + + template + struct ShapeTypeHelper { + using FirstResult = typename ShapeTypeHelperImpl::Type; + using Type = typename ShapeTypeHelper::Type; + }; + + template + struct ShapeTypeHelper { + using Type = typename ShapeTypeHelperImpl::Type; + }; + } // namespace detail } // namespace librapid // Support FMT printing