Skip to content

Commit

Permalink
Add fmt::formatter specializations (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Dec 6, 2023
1 parent fbf4fa3 commit f023c4e
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 16 deletions.
21 changes: 5 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,6 @@ add_custom_command(

# fmt dependency
if(NOT USE_SYSTEM_FMT)
set(BUILD_SHARED_LIBS_SAVE ${BUILD_SHARED_LIBS})
set(BUILD_SHARED_LIBS OFF)
set(CMAKE_POSITION_INDEPENDENT_CODE ${BUILD_SHARED_LIBS_SAVE})

# Since fmt is an internal dependency, only install it for Emscripten compiler
if(${CMAKE_SYSTEM_NAME} MATCHES "Emscripten")
set(FMT_INSTALL ON)
else()
set(FMT_INSTALL OFF)
endif()

fetchcontent_declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
Expand All @@ -145,7 +134,7 @@ else()
find_package(fmt CONFIG REQUIRED)
endif()

target_link_libraries(Sleipnir PRIVATE fmt::fmt)
target_link_libraries(Sleipnir PUBLIC fmt::fmt)

target_include_directories(
Sleipnir
Expand Down Expand Up @@ -218,7 +207,7 @@ if(BUILD_BENCHMARKING)
)
target_link_libraries(
${benchmark}ScalabilityBenchmark
PRIVATE Sleipnir fmt::fmt casadi
PRIVATE Sleipnir casadi
)
endforeach()
endif()
Expand Down Expand Up @@ -268,7 +257,7 @@ if(BUILD_TESTING)
)
target_link_libraries(
SleipnirTest
PRIVATE Sleipnir fmt::fmt GTest::gtest GTest::gtest_main
PRIVATE Sleipnir GTest::gtest GTest::gtest_main
)
if(NOT CMAKE_TOOLCHAIN_FILE)
gtest_discover_tests(SleipnirTest)
Expand All @@ -289,7 +278,7 @@ foreach(example ${EXAMPLES})
${CMAKE_CURRENT_SOURCE_DIR}/examples/${example}/include
${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/units/include
)
target_link_libraries(${example} PRIVATE Sleipnir fmt::fmt)
target_link_libraries(${example} PRIVATE Sleipnir)

# Build example test if files exist for it
if(
Expand All @@ -315,7 +304,7 @@ foreach(example ${EXAMPLES})
)
target_link_libraries(
${example}Test
PRIVATE Sleipnir fmt::fmt GTest::gtest GTest::gtest_main
PRIVATE Sleipnir GTest::gtest GTest::gtest_main
)
if(NOT CMAKE_TOOLCHAIN_FILE)
gtest_discover_tests(${example}Test)
Expand Down
105 changes: 105 additions & 0 deletions include/sleipnir/Formatters.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) Sleipnir contributors

#pragma once

#include <concepts>

#include <Eigen/Core>
#include <Eigen/SparseCore>
#include <fmt/format.h>

#include "sleipnir/autodiff/Variable.hpp"
#include "sleipnir/autodiff/VariableBlock.hpp"
#include "sleipnir/autodiff/VariableMatrix.hpp"

// FIXME: Doxygen gives internal inconsistency errors:
// scope for class sleipnir::fmt::formatter< Derived, CharT > not found!
// scope for class sleipnir::fmt::formatter< sleipnir::Variable > not found!
// scope for class sleipnir::fmt::formatter< T > not found!

//! @cond Doxygen_Suppress

/**
* Formatter for classes derived from Eigen::MatrixBase<Derived> or
* Eigen::SparseCompressedBase<Derived>.
*/
template <typename Derived, typename CharT>
requires std::derived_from<Derived, Eigen::MatrixBase<Derived>> ||
std::derived_from<Derived, Eigen::SparseCompressedBase<Derived>>
struct fmt::formatter<Derived, CharT> {
constexpr auto parse(fmt::format_parse_context& ctx) {
return m_underlying.parse(ctx);
}

auto format(const Derived& mat, fmt::format_context& ctx) const {
auto out = ctx.out();

for (int row = 0; row < mat.rows(); ++row) {
for (int col = 0; col < mat.cols(); ++col) {
out = fmt::format_to(out, " ");
out = m_underlying.format(mat.coeff(row, col), ctx);
}

if (row < mat.rows() - 1) {
out = fmt::format_to(out, "\n");
}
}

return out;
}

private:
fmt::formatter<typename Derived::Scalar, CharT> m_underlying;
};

/**
* Formatter for sleipnir::Variable.
*/
template <>
struct fmt::formatter<sleipnir::Variable> {
constexpr auto parse(fmt::format_parse_context& ctx) {
return m_underlying.parse(ctx);
}

auto format(const sleipnir::Variable& variable,
fmt::format_context& ctx) const {
return m_underlying.format(variable.Value(), ctx);
}

private:
fmt::formatter<double> m_underlying;
};

/**
* Formatter for sleipnir::VariableBlock or sleipnir::VariableMatrix.
*/
template <typename T>
requires std::same_as<T, sleipnir::VariableBlock<sleipnir::VariableMatrix>> ||
std::same_as<T, sleipnir::VariableMatrix>
struct fmt::formatter<T> {
constexpr auto parse(fmt::format_parse_context& ctx) {
return m_underlying.parse(ctx);
}

auto format(const T& mat, fmt::format_context& ctx) const {
auto out = ctx.out();

for (int row = 0; row < mat.Rows(); ++row) {
for (int col = 0; col < mat.Cols(); ++col) {
out = fmt::format_to(out, " ");
out = m_underlying.format(mat(row, col).Value(), ctx);
}

if (row < mat.Rows() - 1) {
out = fmt::format_to(out, "\n");
}
}

return out;
}

private:
fmt::formatter<double> m_underlying;
};

//! @endcond
60 changes: 60 additions & 0 deletions test/src/FormattersTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Sleipnir contributors

#include <vector>

#include <fmt/format.h>
#include <gtest/gtest.h>
#include <sleipnir/Formatters.hpp>

TEST(FormattersTest, Eigen) {
Eigen::Matrix<double, 3, 2> A{{0.0, 1.0}, {2.0, 3.0}, {4.0, 5.0}};
EXPECT_EQ(
" 0.000000 1.000000\n"
" 2.000000 3.000000\n"
" 4.000000 5.000000",
fmt::format("{:f}", A));

Eigen::MatrixXd B{{0.0, 1.0}, {2.0, 3.0}, {4.0, 5.0}};
EXPECT_EQ(
" 0.000000 1.000000\n"
" 2.000000 3.000000\n"
" 4.000000 5.000000",
fmt::format("{:f}", B));

Eigen::SparseMatrix<double> C{3, 2};
std::vector<Eigen::Triplet<double>> triplets;
triplets.emplace_back(0, 1, 1.0);
triplets.emplace_back(1, 0, 2.0);
triplets.emplace_back(1, 1, 3.0);
triplets.emplace_back(2, 0, 4.0);
triplets.emplace_back(2, 1, 5.0);
C.setFromTriplets(triplets.begin(), triplets.end());
EXPECT_EQ(
" 0.000000 1.000000\n"
" 2.000000 3.000000\n"
" 4.000000 5.000000",
fmt::format("{:f}", C));
}

TEST(FormattersTest, Variable) {
EXPECT_EQ("4.000000", fmt::format("{:f}", sleipnir::Variable{4.0}));
}

TEST(FormattersTest, VariableMatrix) {
Eigen::Matrix<double, 3, 2> A{{0.0, 1.0}, {2.0, 3.0}, {4.0, 5.0}};

sleipnir::VariableMatrix B{3, 2};
B = A;
EXPECT_EQ(
" 0.000000 1.000000\n"
" 2.000000 3.000000\n"
" 4.000000 5.000000",
fmt::format("{:f}", B));

sleipnir::VariableBlock<sleipnir::VariableMatrix> C{B};
EXPECT_EQ(
" 0.000000 1.000000\n"
" 2.000000 3.000000\n"
" 4.000000 5.000000",
fmt::format("{:f}", C));
}

0 comments on commit f023c4e

Please sign in to comment.