Skip to content

Commit

Permalink
Make ExpressionGraph use Variables (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Jan 18, 2025
1 parent e9d3941 commit 2fbe99a
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 128 deletions.
10 changes: 5 additions & 5 deletions cart-pole-scalability-results-casadi.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Samples,Setup time (ms),Solve time (ms)
100,37.949,1807.38
150,61.714,3580.86
200,93.581,4774.61
250,133.32,7050.54
300,170.672,9341.64
100,38.637,1826.95
150,63.45,3635.56
200,94.95,4734.67
250,131.586,7043.58
300,178.675,9238.07
10 changes: 5 additions & 5 deletions cart-pole-scalability-results-sleipnir.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Samples,Setup time (ms),Solve time (ms)
100,2.058,485.048
150,1.234,1519.07
200,1.737,2329.58
250,2.154,2652.35
300,2.717,3986.88
100,1.69,469.773
150,1.297,1432.54
200,1.813,2161.16
250,2.29,2545.94
300,2.815,3945
Binary file modified cart-pole-scalability-results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 14 additions & 14 deletions flywheel-scalability-results-casadi.csv
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
Samples,Setup time (ms),Solve time (ms)
100,1.966,24.513
200,3.544,35.982
300,5.327,56.478
400,7.167,79.479
500,8.857,97.752
600,10.694,121.422
700,12.15,140.138
800,15.563,169.252
900,16.313,183.834
1000,19.386,218.392
2000,37.511,543.019
3000,54.425,1002.59
4000,73.979,1537.35
5000,91.42,2217.48
100,2.057,21.727
200,3.646,36.476
300,5.402,59.129
400,7.351,78.306
500,8.968,102.043
600,10.933,128.042
700,12.609,146.622
800,14.693,169.145
900,16.269,189.67
1000,19.966,215.395
2000,36.407,553.81
3000,56.344,1018.4
4000,75.164,1546.43
5000,93.18,2225.34
28 changes: 14 additions & 14 deletions flywheel-scalability-results-sleipnir.csv
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
Samples,Setup time (ms),Solve time (ms)
100,0.381,4.529
200,0.111,4.367
300,0.161,6.514
400,0.219,9.043
500,0.28,11.576
600,0.346,16.241
700,0.4,17.232
800,0.488,20.081
900,0.501,21.872
1000,0.567,25.952
2000,1.153,57.788
3000,1.673,105.301
4000,2.475,141.008
5000,2.834,211.878
100,0.353,1.989
200,1.093,5.297
300,0.178,6.186
400,0.237,8.388
500,0.296,10.557
600,0.355,15.161
700,0.423,15.865
800,0.533,20.869
900,0.504,19.665
1000,0.545,23.423
2000,1.096,50.603
3000,1.644,87.56
4000,2.279,113.649
5000,2.838,156.224
Binary file modified flywheel-scalability-results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 18 additions & 34 deletions include/sleipnir/autodiff/ExpressionGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
#pragma once

#include <ranges>
#include <span>
#include <utility>

#include "sleipnir/autodiff/Expression.hpp"
#include "sleipnir/autodiff/Variable.hpp"
#include "sleipnir/autodiff/VariableMatrix.hpp"
#include "sleipnir/util/FunctionRef.hpp"
#include "sleipnir/util/small_vector.hpp"

Expand All @@ -22,10 +23,10 @@ class ExpressionGraph {
*
* @param root The root node of the expression.
*/
explicit ExpressionGraph(ExpressionPtr& root) {
explicit ExpressionGraph(Variable& root) {
// If the root type is a constant, Update() is a no-op, so there's no work
// to do
if (root == nullptr || root->Type() == ExpressionType::kConstant) {
if (root.expr == nullptr || root.Type() == ExpressionType::kConstant) {
return;
}

Expand All @@ -38,7 +39,7 @@ class ExpressionGraph {
// BFS list sorted from parent to child.
small_vector<Expression*> stack;

stack.emplace_back(root.Get());
stack.emplace_back(root.expr.Get());

// Initialize the number of instances of each node in the tree
// (Expression::duplications)
Expand All @@ -60,7 +61,7 @@ class ExpressionGraph {
}
}

stack.emplace_back(root.Get());
stack.emplace_back(root.expr.Get());

while (!stack.empty()) {
auto node = stack.back();
Expand Down Expand Up @@ -117,27 +118,13 @@ class ExpressionGraph {
*
* @param wrt Variables with respect to which to compute the gradient.
*/
small_vector<ExpressionPtr> GenerateGradientTree(
std::span<const ExpressionPtr> wrt) const {
VariableMatrix GenerateGradientTree(const VariableMatrix& wrt) const {
// Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
// for background on reverse accumulation automatic differentiation.

for (size_t row = 0; row < wrt.size(); ++row) {
wrt[row]->row = row;
}

small_vector<ExpressionPtr> grad;
grad.reserve(wrt.size());
for (size_t row = 0; row < wrt.size(); ++row) {
grad.emplace_back(MakeExpressionPtr<ConstExpression>());
}

// Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
if (m_adjointList.size() > 0) {
m_adjointList[0]->adjointExpr = MakeExpressionPtr<ConstExpression>(1.0);
for (auto& node : m_adjointList | std::views::drop(1)) {
node->adjointExpr = MakeExpressionPtr<ConstExpression>();
}
}

// df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
Expand All @@ -148,19 +135,19 @@ class ExpressionGraph {
auto& lhs = node->args[0];
auto& rhs = node->args[1];

if (lhs != nullptr && !lhs->IsConstant(0.0)) {
if (lhs != nullptr) {
lhs->adjointExpr =
lhs->adjointExpr + node->GradientLhs(lhs, rhs, node->adjointExpr);
if (rhs != nullptr) {
rhs->adjointExpr =
rhs->adjointExpr + node->GradientRhs(lhs, rhs, node->adjointExpr);
}
}
if (rhs != nullptr && !rhs->IsConstant(0.0)) {
rhs->adjointExpr =
rhs->adjointExpr + node->GradientRhs(lhs, rhs, node->adjointExpr);
}
}

// If variable is a leaf node, assign its adjoint to the gradient.
if (node->row != -1) {
grad[node->row] = node->adjointExpr;
}
VariableMatrix grad(VariableMatrix::empty, wrt.size(), 1);
for (int row = 0; row < grad.Rows(); ++row) {
grad(row) = Variable{std::move(wrt(row).expr->adjointExpr)};
}

// Unlink adjoints to avoid circular references between them and their
Expand All @@ -172,10 +159,7 @@ class ExpressionGraph {
arg->adjointExpr = nullptr;
}
}
}

for (size_t row = 0; row < wrt.size(); ++row) {
wrt[row]->row = -1;
node->adjointExpr = nullptr;
}

return grad;
Expand Down
21 changes: 2 additions & 19 deletions include/sleipnir/autodiff/Hessian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

#pragma once

#include <utility>

#include <Eigen/Core>
#include <Eigen/SparseCore>

Expand All @@ -13,7 +11,6 @@
#include "sleipnir/autodiff/Variable.hpp"
#include "sleipnir/autodiff/VariableMatrix.hpp"
#include "sleipnir/util/SymbolExports.hpp"
#include "sleipnir/util/small_vector.hpp"

namespace sleipnir {

Expand All @@ -37,22 +34,8 @@ class SLEIPNIR_DLLEXPORT Hessian {
: m_jacobian{
[&] {
m_profiler.StartSetup();

small_vector<detail::ExpressionPtr> wrtVec;
wrtVec.reserve(wrt.size());
for (auto& elem : wrt) {
wrtVec.emplace_back(elem.expr);
}

auto grad =
detail::ExpressionGraph{variable.expr}.GenerateGradientTree(
wrtVec);

VariableMatrix ret{wrt.Rows()};
for (int row = 0; row < ret.Rows(); ++row) {
ret(row) = Variable{std::move(grad[row])};
}
return ret;
return detail::ExpressionGraph{variable}.GenerateGradientTree(
wrt);
}(),
wrt} {
m_profiler.StopSetup();
Expand Down
18 changes: 9 additions & 9 deletions include/sleipnir/autodiff/Jacobian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SLEIPNIR_DLLEXPORT Jacobian {
}

for (auto& variable : m_variables) {
m_graphs.emplace_back(variable.expr);
m_graphs.emplace_back(variable);
}

for (int row = 0; row < m_variables.Rows(); ++row) {
Expand Down Expand Up @@ -78,16 +78,16 @@ class SLEIPNIR_DLLEXPORT Jacobian {
VariableMatrix Get() const {
VariableMatrix result{m_variables.Rows(), m_wrt.Rows()};

small_vector<detail::ExpressionPtr> wrtVec;
wrtVec.reserve(m_wrt.size());
for (auto& elem : m_wrt) {
wrtVec.emplace_back(elem.expr);
}

for (int row = 0; row < m_variables.Rows(); ++row) {
auto grad = m_graphs[row].GenerateGradientTree(wrtVec);
for (auto& node : m_wrt) {
node.expr->adjointExpr = nullptr;
}

auto grad = m_graphs[row].GenerateGradientTree(m_wrt);
for (int col = 0; col < m_wrt.Rows(); ++col) {
result(row, col) = Variable{std::move(grad[col])};
if (grad(col).expr != nullptr) {
result(row, col) = Variable{std::move(grad(col))};
}
}
}

Expand Down
13 changes: 5 additions & 8 deletions include/sleipnir/autodiff/Variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <Eigen/Core>

#include "sleipnir/autodiff/Expression.hpp"
#include "sleipnir/autodiff/ExpressionGraph.hpp"
#include "sleipnir/util/Assert.hpp"
#include "sleipnir/util/Concepts.hpp"
#include "sleipnir/util/Print.hpp"
Expand All @@ -22,6 +21,9 @@
namespace sleipnir {

// Forward declarations for friend declarations in Variable
namespace detail {
class ExpressionGraph;
} // namespace detail
class SLEIPNIR_DLLEXPORT Hessian;
class SLEIPNIR_DLLEXPORT Jacobian;

Expand Down Expand Up @@ -208,13 +210,7 @@ class SLEIPNIR_DLLEXPORT Variable {
/**
* Returns the value of this variable.
*/
double Value() {
// Updates the value of this variable based on the values of its dependent
// variables
detail::ExpressionGraph{expr}.Update();

return expr->value;
}
double Value();

/**
* Returns the type of this expression (constant, linear, quadratic, or
Expand Down Expand Up @@ -252,6 +248,7 @@ class SLEIPNIR_DLLEXPORT Variable {
friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, const Variable& y,
const Variable& z);

friend class detail::ExpressionGraph;
friend class SLEIPNIR_DLLEXPORT Hessian;
friend class SLEIPNIR_DLLEXPORT Jacobian;
};
Expand Down
18 changes: 17 additions & 1 deletion include/sleipnir/autodiff/VariableMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ namespace sleipnir {
*/
class SLEIPNIR_DLLEXPORT VariableMatrix {
public:
struct empty_t {};
static constexpr empty_t empty{};

/**
* Constructs an empty VariableMatrix.
*/
Expand All @@ -45,7 +48,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
}

/**
* Constructs a VariableMatrix with the given dimensions.
* Constructs a zero-initialized VariableMatrix with the given dimensions.
*
* @param rows The number of matrix rows.
* @param cols The number of matrix columns.
Expand All @@ -57,6 +60,19 @@ class SLEIPNIR_DLLEXPORT VariableMatrix {
}
}

/**
* Constructs an empty VariableMatrix with the given dimensions.
*
* @param rows The number of matrix rows.
* @param cols The number of matrix columns.
*/
VariableMatrix(empty_t, int rows, int cols) : m_rows{rows}, m_cols{cols} {
m_storage.reserve(Rows() * Cols());
for (int index = 0; index < Rows() * Cols(); ++index) {
m_storage.emplace_back(nullptr);
}
}

/**
* Constructs a scalar VariableMatrix from a nested list of Variables.
*
Expand Down
Loading

0 comments on commit 2fbe99a

Please sign in to comment.