Skip to content

Commit

Permalink
Continue updating to C++23
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Aug 13, 2023
1 parent 0cb3f69 commit 2f9f8c7
Show file tree
Hide file tree
Showing 18 changed files with 367 additions and 371 deletions.
57 changes: 2 additions & 55 deletions librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,61 +845,8 @@ namespace librapid {

// Support FMT printing
#ifdef FMT_API
template<typename ShapeType_, typename StorageType_>
struct fmt::formatter<librapid::array::ArrayContainer<ShapeType_, StorageType_>> {
using Type = librapid::array::ArrayContainer<ShapeType_, StorageType_>;
using Scalar = typename librapid::typetraits::TypeInfo<Type>::Scalar;
using Formatter = fmt::formatter<Scalar>;
Formatter m_formatter;
char m_bracket = 's';
char m_separator = ' ';

template<typename ParseContext>
FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * {
// Custom format options:
// - "~r" for round brackets
// - "~s" for square brackets
// - "~c" for curly brackets
// - "~a" for angle brackets
// - "~p" for pipe brackets
// - "-," for comma separator
// - "-;" for semicolon separator
// - "-:" for colon separator
// - "-|" for pipe separator
// - "-_" for underscore separator

auto it = ctx.begin(), end = ctx.end();
if (it != end && *it == '~') {
++it;
if (it != end && (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) {
m_bracket = *it++;
}
}

if (it != end && *it == '-') {
++it;
if (it != end) { m_separator = *it++; }
}

ctx.advance_to(it);

return m_formatter.parse(ctx);
}

template<typename FormatContext>
FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) {
val.str(m_formatter, m_bracket, m_separator, ctx);
return ctx.out();
}
};

template<typename ShapeType_, typename StorageType_>
auto operator<<(std::ostream &os,
const librapid::array::ArrayContainer<ShapeType_, StorageType_> &object)
-> std::ostream & {
os << fmt::format("{}", object);
return os;
}
ARRAY_TYPE_FMT_IML(typename ShapeType_ COMMA typename StorageType_,
librapid::array::ArrayContainer<ShapeType_ COMMA StorageType_>)

LIBRAPID_SIMPLE_IO_NORANGE(typename ShapeType_ COMMA typename StorageType_,
librapid::array::ArrayContainer<ShapeType_ COMMA StorageType_>)
Expand Down
22 changes: 11 additions & 11 deletions librapid/include/librapid/array/arrayTypeDef.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,17 @@ namespace librapid {
\
template<typename ParseContext> \
FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { \
/* Custom format options: */ \
/* - "~r" for round brackets */ \
/* - "~s" for square brackets */ \
/* - "~c" for curly brackets */ \
/* - "~a" for angle brackets */ \
/* - "~p" for pipe brackets */ \
/* - "-," for comma separator */ \
/* - "-;" for semicolon separator */ \
/* - "-:" for colon separator */ \
/* - "-|" for pipe separator */ \
/* - "-_" for underscore separator */ \
/* Custom format options: */ \
/* - "~r" for round brackets */ \
/* - "~s" for square brackets */ \
/* - "~c" for curly brackets */ \
/* - "~a" for angle brackets */ \
/* - "~p" for pipe brackets */ \
/* - "-," for comma separator */ \
/* - "-;" for semicolon separator */ \
/* - "-:" for colon separator */ \
/* - "-|" for pipe separator */ \
/* - "-_" for underscore separator */ \
\
auto it = ctx.begin(), end = ctx.end(); \
if (it != end && *it == '~') { \
Expand Down
54 changes: 7 additions & 47 deletions librapid/include/librapid/array/arrayView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ namespace librapid {

/// Copy an ArrayView object
/// \param array The array to copy
explicit ArrayView(ArrayViewType &array);
ArrayView(ArrayViewType &array);

/// Copy an ArrayView object (not const)
/// \param array The array to copy
explicit ArrayView(ArrayViewType &&array) = delete;
ArrayView(ArrayViewType &&array);

/// Copy an ArrayView object (const)
/// \param other The array to copy
Expand Down Expand Up @@ -148,6 +148,10 @@ namespace librapid {
ArrayView<ArrayViewType>::ArrayView(ArrayViewType &array) :
m_ref(array), m_shape(array.shape()), m_stride(array.shape()) {}

template<typename ArrayViewType>
ArrayView<ArrayViewType>::ArrayView(ArrayViewType &&array) :
m_ref(array), m_shape(array.shape()), m_stride(array.shape()) {}

template<typename T>
ArrayView<T> &ArrayView<T>::operator=(const Scalar &scalar) {
LIBRAPID_ASSERT(m_shape.ndim() == 0, "Cannot assign to a non-scalar ArrayView.");
Expand Down Expand Up @@ -322,51 +326,7 @@ namespace librapid {

// Support FMT printing
#ifdef FMT_API
template<typename ArrayViewType>
struct fmt::formatter<librapid::array::ArrayView<ArrayViewType>> {
using Type = librapid::array::ArrayView<ArrayViewType>;
using Scalar = typename librapid::typetraits::TypeInfo<Type>::Scalar;
using Formatter = fmt::formatter<Scalar>;
Formatter m_formatter;
char m_bracket = 's';
char m_separator = ' ';

template<typename ParseContext>
FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * {
// Same formatting options as for the ArrayContainer type

auto it = ctx.begin(), end = ctx.end();
if (it != end && *it == '~') {
++it;
if (it != end && (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) {
m_bracket = *it++;
}
}

if (it != end && *it == '-') {
++it;
if (it != end) { m_separator = *it++; }
}

ctx.advance_to(it);

return m_formatter.parse(ctx);
}

template<typename FormatContext>
FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) {
val.str(m_formatter, m_bracket, m_separator, ctx);
return ctx.out();
}
};

template<typename ArrayViewType>
auto operator<<(std::ostream &os, const librapid::array::ArrayView<ArrayViewType> &object)
-> std::ostream & {
os << fmt::format("{}", object);
return os;
}

ARRAY_TYPE_FMT_IML(typename T, librapid::array::ArrayView<T>)
LIBRAPID_SIMPLE_IO_NORANGE(typename T, librapid::array::ArrayView<T>)
#endif // FMT_API

Expand Down
6 changes: 5 additions & 1 deletion librapid/include/librapid/array/arrayViewString.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ namespace librapid {
if (i > 0) fmt::format_to(ctx.out(), "{}", std::string(indent + 1, ' '));
arrayViewToString(view[i], formatter, bracket, separator, indent + 1, ctx);
if (i != view.shape()[0] - 1) {
fmt::format_to(ctx.out(), "{}\n", separator);
if (separator == ' ') {
fmt::format_to(ctx.out(), "\n");
} else {
fmt::format_to(ctx.out(), "{}\n", separator);
}
if (view.ndim() > 2) { fmt::format_to(ctx.out(), "\n"); }
}
}
Expand Down
18 changes: 9 additions & 9 deletions librapid/include/librapid/array/function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,9 @@ namespace librapid {
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const;

/// Return a string representation of the Function
/// \param format The format to use.
/// \return A string representation of the Function
LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const;
template<typename T, typename Char, typename Ctx>
void str(const fmt::formatter<T, Char> &format, char bracket, char separator,
Ctx &ctx) const;

private:
/// Implementation detail -- evaluates the function at the given index,
Expand Down Expand Up @@ -264,17 +263,18 @@ namespace librapid {
}

template<typename desc, typename Functor, typename... Args>
std::string Function<desc, Functor, Args...>::str(const std::string &format) const {
return eval().str(format);
template<typename T, typename Char, typename Ctx>
void Function<desc, Functor, Args...>::str(const fmt::formatter<T, Char> &format,
char bracket, char separator, Ctx &ctx) const {
array::ArrayView(*this).str(format, bracket, separator, ctx);
}
} // namespace detail
} // namespace librapid

// Support FMT printing
#ifdef FMT_API
LIBRAPID_SIMPLE_IO_IMPL(typename desc COMMA typename Functor COMMA typename... Args,
librapid::detail::Function<desc COMMA Functor COMMA Args...>)

ARRAY_TYPE_FMT_IML(typename desc COMMA typename Functor COMMA typename... Args,
librapid::detail::Function<desc COMMA Functor COMMA Args...>)
LIBRAPID_SIMPLE_IO_NORANGE(typename desc COMMA typename Functor COMMA typename... Args,
librapid::detail::Function<desc COMMA Functor COMMA Args...>)
#endif // FMT_API
Expand Down
24 changes: 11 additions & 13 deletions librapid/include/librapid/array/linalg/arrayMultiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,9 @@ namespace librapid {
template<typename StorageType>
void applyTo(array::ArrayContainer<ShapeType, StorageType> &out) const;

/// \brief String representation of the array multiplication
/// \param format Format string for each element
/// \return String representation of the array multiplication
LIBRAPID_NODISCARD std::string str(const std::string &format) const;
template<typename T, typename Char, typename Ctx>
void str(const fmt::formatter<T, Char> &format, char bracket, char separator,
Ctx &ctx) const;

private:
bool m_transA; // Transpose state of A
Expand Down Expand Up @@ -497,10 +496,10 @@ namespace librapid {

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
std::string
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::str(
const std::string &format) const {
return eval().str(format);
template<typename T, typename Char, typename Ctx>
void ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::str(
const fmt::formatter<T, Char> &format, char bracket, char separator, Ctx &ctx) const {
eval().str(format, bracket, separator, ctx);
}
} // namespace linalg

Expand Down Expand Up @@ -693,11 +692,10 @@ namespace librapid {
} // namespace typetraits
} // namespace librapid

LIBRAPID_SIMPLE_IO_IMPL(
typename ShapeTypeA COMMA typename StorageTypeA COMMA typename ShapeTypeB COMMA
typename StorageTypeB COMMA typename Alpha COMMA typename Beta,
librapid::linalg::ArrayMultiply<
ShapeTypeA COMMA StorageTypeA COMMA ShapeTypeB COMMA StorageTypeB COMMA Alpha COMMA Beta>)
ARRAY_TYPE_FMT_IML(typename ShapeTypeA COMMA typename StorageTypeA COMMA typename ShapeTypeB COMMA
typename StorageTypeB COMMA typename Alpha COMMA typename Beta,
librapid::linalg::ArrayMultiply<ShapeTypeA COMMA StorageTypeA COMMA ShapeTypeB
COMMA StorageTypeB COMMA Alpha COMMA Beta>)

LIBRAPID_SIMPLE_IO_NORANGE(
typename ShapeTypeA COMMA typename StorageTypeA COMMA typename ShapeTypeB COMMA
Expand Down
3 changes: 2 additions & 1 deletion librapid/include/librapid/array/linalg/linalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ namespace librapid::typetraits {

#include "transpose.hpp"

#include "level3/gemm.hpp" // Included before gemv, since gemm is used in some gemv implementations

#include "level2/gemv.hpp"

#include "level3/geam.hpp"
#include "level3/gemm.hpp"

#include "arrayMultiply.hpp"

Expand Down
30 changes: 20 additions & 10 deletions librapid/include/librapid/array/linalg/transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,11 @@ namespace librapid {
} // namespace detail

namespace array {
template<typename T>
template<typename TransposeType>
class Transpose {
public:
using ArrayType = T;
using BaseType = typename std::decay_t<T>;
using ArrayType = TransposeType;
using BaseType = typename std::decay_t<TransposeType>;
using Scalar = typename typetraits::TypeInfo<BaseType>::Scalar;
using Reference = BaseType &;
using ConstReference = const BaseType &;
Expand All @@ -446,7 +446,7 @@ namespace librapid {
/// Create a Transpose object from an array/operation
/// \param array The array to copy
/// \param axes The transposition axes
Transpose(const T &array, const ShapeType &axes, Scalar alpha = Scalar(1.0));
Transpose(const TransposeType &array, const ShapeType &axes, Scalar alpha = Scalar(1.0));

/// Copy a Transpose object
Transpose(const Transpose &other) = default;
Expand All @@ -457,7 +457,7 @@ namespace librapid {
/// Assign another Transpose object to this one
/// \param other The Transpose to assign
/// \return *this;
Transpose &operator=(const Transpose &other) = default;
auto operator=(const Transpose &other) -> Transpose & = default;

/// Access sub-array of this Transpose object
/// \param index Array index
Expand Down Expand Up @@ -505,7 +505,9 @@ namespace librapid {
/// the given format string
/// \param format Format string
/// \return Stringified object
LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const;
template<typename T, typename Char, typename Ctx>
LIBRAPID_ALWAYS_INLINE void str(const fmt::formatter<T, Char> &format, char bracket,
char separator, Ctx &ctx) const;

private:
ArrayType m_array;
Expand Down Expand Up @@ -537,6 +539,12 @@ namespace librapid {
return m_outputShape.ndim();
}

template<typename T>
auto Transpose<T>::scalar(int64_t index) const -> auto {
// TODO: This is a heinously inefficient way of doing this. Fix it.
return eval().scalar(index);
}

template<typename T>
auto Transpose<T>::axes() const -> const ShapeType & {
return m_axes;
Expand Down Expand Up @@ -617,9 +625,11 @@ namespace librapid {
return res;
}

template<typename T>
std::string Transpose<T>::str(const std::string &format) const {
return eval().str(format);
template<typename TransposeType>
template<typename T, typename Char, typename Ctx>
void Transpose<TransposeType>::str(const fmt::formatter<T, Char> &format, char bracket, char separator,
Ctx &ctx) const {
eval().str(format, bracket, separator, ctx);
}
}; // namespace array

Expand Down Expand Up @@ -716,7 +726,7 @@ namespace librapid {

// Support FMT printing
#ifdef FMT_API
LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::array::Transpose<T>)
ARRAY_TYPE_FMT_IML(typename T, librapid::array::Transpose<T>)
LIBRAPID_SIMPLE_IO_NORANGE(typename T, librapid::array::Transpose<T>)
#endif // FMT_API

Expand Down
Loading

0 comments on commit 2f9f8c7

Please sign in to comment.