Skip to content

Commit

Permalink
Refactor adjoint calculation (#702)
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Jan 20, 2025
1 parent 6439a20 commit c88b3fa
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 58 deletions.
10 changes: 5 additions & 5 deletions cart-pole-scalability-results-sleipnir.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Samples,Setup time (ms),Solve time (ms)
100,1.69,469.773
150,1.297,1432.54
200,1.813,2161.16
250,2.29,2545.94
300,2.815,3945
100,1.515,411.42
150,0.999,1235.13
200,1.346,1967.54
250,1.713,2342.6
300,2.139,3613.15
Binary file modified cart-pole-scalability-results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 14 additions & 14 deletions flywheel-scalability-results-sleipnir.csv
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
Samples,Setup time (ms),Solve time (ms)
100,0.353,1.989
200,1.093,5.297
300,0.178,6.186
400,0.237,8.388
500,0.296,10.557
600,0.355,15.161
700,0.423,15.865
800,0.533,20.869
900,0.504,19.665
1000,0.545,23.423
2000,1.096,50.603
3000,1.644,87.56
4000,2.279,113.649
5000,2.838,156.224
100,0.224,1.927
200,0.082,6.263
300,0.129,5.932
400,0.168,8.273
500,0.217,9.847
600,0.256,14.863
700,0.305,17.475
800,0.404,18.325
900,0.43,21.534
1000,0.439,23.686
2000,0.919,50.931
3000,1.375,85.562
4000,1.814,116.722
5000,2.329,157.785
Binary file modified flywheel-scalability-results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
54 changes: 27 additions & 27 deletions include/sleipnir/autodiff/ExpressionGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,17 @@ class ExpressionGraph {
//
// https://en.wikipedia.org/wiki/Breadth-first_search

// BFS list sorted from parent to child.
small_vector<Expression*> stack;

stack.emplace_back(root.expr.Get());

// Initialize the number of instances of each node in the tree
// (Expression::duplications)
stack.emplace_back(root.expr.Get());
while (!stack.empty()) {
auto node = stack.back();
stack.pop_back();

for (auto& arg : node->args) {
// Only continue if the node is not a constant and hasn't already been
// explored.
if (arg != nullptr && arg->Type() != ExpressionType::kConstant) {
if (arg != nullptr) {
// If this is the first instance of the node encountered (it hasn't
// been explored yet), add it to stack so it's recursed upon
if (arg->duplications == 0) {
Expand All @@ -61,13 +57,12 @@ class ExpressionGraph {
}
}

// Generate BFS lists sorted from parent to child
stack.emplace_back(root.expr.Get());

while (!stack.empty()) {
auto node = stack.back();
stack.pop_back();

// BFS lists sorted from parent to child.
m_rowList.emplace_back(node->row);
m_adjointList.emplace_back(node);
if (node->args[0] != nullptr) {
Expand All @@ -77,9 +72,7 @@ class ExpressionGraph {
}

for (auto& arg : node->args) {
// Only add node if it's not a constant and doesn't already exist in the
// tape.
if (arg != nullptr && arg->Type() != ExpressionType::kConstant) {
if (arg != nullptr) {
// Once the number of node visitations equals the number of
// duplications (the counter hits zero), add it to the stack. Note
// that this means the node is only enqueued once.
Expand Down Expand Up @@ -122,11 +115,13 @@ class ExpressionGraph {
// Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
// for background on reverse accumulation automatic differentiation.

// Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
if (m_adjointList.size() > 0) {
m_adjointList[0]->adjointExpr = MakeExpressionPtr<ConstExpression>(1.0);
if (m_adjointList.empty()) {
return VariableMatrix(wrt.size(), 1);
}

// Set root node's adjoint to 1 since df/df is 1
m_adjointList[0]->adjointExpr = MakeExpressionPtr<ConstExpression>(1.0);

// df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
// multiplied by dy/dx. If there are multiple "paths" from the root node to
// variable; the variable's adjoint is the sum of each path's adjoint
Expand All @@ -145,6 +140,7 @@ class ExpressionGraph {
}
}

// Move gradient tree to return value
VariableMatrix grad(VariableMatrix::empty, wrt.size(), 1);
for (int row = 0; row < grad.Rows(); ++row) {
grad(row) = Variable{std::move(wrt(row).expr->adjointExpr)};
Expand All @@ -154,31 +150,30 @@ class ExpressionGraph {
// parent expressions. This ensures all expressions are returned to the free
// list.
for (auto& node : m_adjointList) {
for (auto& arg : node->args) {
if (arg != nullptr) {
arg->adjointExpr = nullptr;
}
}
node->adjointExpr = nullptr;
}

return grad;
}

/**
* Updates the adjoints in the expression graph, effectively computing the
* gradient.
* Updates the adjoints in the expression graph (computes the gradient) then
* appends the adjoints of wrt to the sparse matrix triplets via a callback.
*
* @param func A function that takes two arguments: an int for the gradient
* row, and a double for the adjoint (gradient value).
*/
void ComputeAdjoints(function_ref<void(int row, double adjoint)> func) {
// Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
m_adjointList[0]->adjoint = 1.0;
for (auto& node : m_adjointList | std::views::drop(1)) {
node->adjoint = 0.0;
void AppendAdjointTriplets(function_ref<void(int row, double adjoint)> func) {
// Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
// for background on reverse accumulation automatic differentiation.

if (m_adjointList.empty()) {
return;
}

// Set root node's adjoint to 1 since df/df is 1
m_adjointList[0]->adjoint = 1.0;

// df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
// multiplied by dy/dx. If there are multiple "paths" from the root node to
// variable; the variable's adjoint is the sum of each path's adjoint
Expand All @@ -200,12 +195,17 @@ class ExpressionGraph {
}
}

// If variable is a leaf node, assign its adjoint to the gradient.
// Append adjoints of wrt to sparse matrix triplets
int row = m_rowList[col];
if (row != -1) {
func(row, node->adjoint);
}
}

// Zero adjoints for next run
for (auto& node : m_adjointList) {
node->adjoint = 0.0;
}
}

private:
Expand Down
1 change: 0 additions & 1 deletion include/sleipnir/autodiff/Hessian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#pragma once

#include <Eigen/Core>
#include <Eigen/SparseCore>

#include "sleipnir/autodiff/ExpressionGraph.hpp"
Expand Down
21 changes: 13 additions & 8 deletions include/sleipnir/autodiff/Jacobian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SLEIPNIR_DLLEXPORT Jacobian {
// If the row is linear, compute its gradient once here and cache its
// triplets. Constant rows are ignored because their gradients have no
// nonzero triplets.
m_graphs[row].ComputeAdjoints([&](int col, double adjoint) {
m_graphs[row].AppendAdjointTriplets([&](int col, double adjoint) {
m_cachedTriplets.emplace_back(row, col, adjoint);
});
} else if (m_variables(row).Type() > ExpressionType::kLinear) {
Expand Down Expand Up @@ -76,17 +76,16 @@ class SLEIPNIR_DLLEXPORT Jacobian {
* them.
*/
VariableMatrix Get() const {
VariableMatrix result{m_variables.Rows(), m_wrt.Rows()};
VariableMatrix result{VariableMatrix::empty, m_variables.Rows(),
m_wrt.Rows()};

for (int row = 0; row < m_variables.Rows(); ++row) {
for (auto& node : m_wrt) {
node.expr->adjointExpr = nullptr;
}

auto grad = m_graphs[row].GenerateGradientTree(m_wrt);
for (int col = 0; col < m_wrt.Rows(); ++col) {
if (grad(col).expr != nullptr) {
result(row, col) = std::move(grad(col));
} else {
result(row, col) = Variable{0.0};
}
}
}
Expand Down Expand Up @@ -115,12 +114,18 @@ class SLEIPNIR_DLLEXPORT Jacobian {

// Compute each nonlinear row of the Jacobian
for (int row : m_nonlinearRows) {
m_graphs[row].ComputeAdjoints([&](int col, double adjoint) {
m_graphs[row].AppendAdjointTriplets([&](int col, double adjoint) {
triplets.emplace_back(row, col, adjoint);
});
}

m_J.setFromTriplets(triplets.begin(), triplets.end());
if (triplets.size() > 0) {
m_J.setFromTriplets(triplets.begin(), triplets.end());
} else {
// setFromTriplets() is a no-op on empty triplets, so explicitly zero out
// the storage
m_J.setZero();
}

m_profiler.StopSolve();

Expand Down
7 changes: 4 additions & 3 deletions jormungandr/cpp/Docstrings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2184,9 +2184,10 @@ expression's computational graph in a way that skips duplicates.)doc";

static const char *__doc_sleipnir_detail_ExpressionGraph_2 = R"doc()doc";

static const char *__doc_sleipnir_detail_ExpressionGraph_ComputeAdjoints =
R"doc(Updates the adjoints in the expression graph, effectively computing
the gradient.
static const char *__doc_sleipnir_detail_ExpressionGraph_AppendAdjointTriplets =
R"doc(Updates the adjoints in the expression graph (computes the gradient)
then appends the adjoints of wrt to the sparse matrix triplets via a
callback.
Parameter ``func``:
A function that takes two arguments: an int for the gradient row,
Expand Down

0 comments on commit c88b3fa

Please sign in to comment.