Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrete Error #66

Merged
merged 12 commits into from
Nov 25, 2023
16 changes: 16 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ namespace gtsam {
return error(values.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> DecisionTreeFactor::error() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);

// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ namespace gtsam {
*/
double error(const HybridValues& values) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> error() const override;

/// @}

private:
Expand Down
15 changes: 10 additions & 5 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

#pragma once

#include <gtsam/base/Testable.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/base/Testable.h>

#include <string>
namespace gtsam {
Expand All @@ -35,7 +36,7 @@ class HybridValues;
*
* @ingroup discrete
*/
class GTSAM_EXPORT DiscreteFactor: public Factor {
class GTSAM_EXPORT DiscreteFactor : public Factor {
public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class
Expand Down Expand Up @@ -103,15 +104,19 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
*/
double error(const HybridValues& c) const override;

/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
/// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> error() const = 0;

/// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// @}
/// @name Wrapper support
/// @{

/// Translation table from values to strings.
using Names = DiscreteValues::Names;

Expand Down Expand Up @@ -175,4 +180,4 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
std::vector<double> expNormalize(const std::vector<double> &logProbs);


}// namespace gtsam
} // namespace gtsam
5 changes: 5 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ double TableFactor::error(const HybridValues& values) const {
return error(values.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> TableFactor::error() const {
return toDecisionTreeFactor().error();
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
double error(const HybridValues& values) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> error() const override;

/// @}
};

Expand Down
18 changes: 18 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, Error) {
// Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2);

// Create factors
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");

auto errors = f.error();
// regression
AlgebraicDecisionTree<Key> expected(
{X, Y, Z},
vector<double>{-0.69314718, -1.6094379, -1.0986123, -1.7917595,
-1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481,
-4.1743873, -3.8066625, -4.3174881});
EXPECT(assert_equal(expected, errors, 1e-6));
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
Expand Down
30 changes: 30 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,36 @@ HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->error(continuousValues);

} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the result tree.
result = result.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete error in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->error(DiscreteValues(assignment));
});
}
}

return result;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
Expand Down
16 changes: 16 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Error method using HybridValues which returns specific error for
* assignment.
*/
using Base::error;

/**
* @brief Compute log probability for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which
* to compute the log probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;

Expand Down
80 changes: 79 additions & 1 deletion gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,85 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
}

/* ************************************************************************ */
void HybridGaussianFactorGraph::printErrors(
const HybridValues &values, const std::string &str,
const KeyFormatter &keyFormatter,
const std::function<bool(const Factor * /*factor*/,
double /*whitenedError*/, size_t /*index*/)>
&printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;

std::stringstream ss;

for (size_t i = 0; i < factors_.size(); i++) {
auto &&factor = factors_[i];
std::cout << "Factor " << i << ": ";

// Clear the stringstream
ss.str(std::string());

if (auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->error(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);

if (hc->isContinuous()) {
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
std::cout << "error = ";
hc->asDiscrete()->error().print("", keyFormatter);
std::cout << "\n";
} else {
// Is hybrid
std::cout << "error = ";
hc->asMixture()->error(values.continuous()).print();
std::cout << "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->error().print("", keyFormatter);
}

} else {
continue;
}

std::cout << "\n";
}
std::cout.flush();
}

/* ************************************************************************ */
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree,
Expand All @@ -96,7 +175,6 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {

GaussianFactorGraphTree result;

for (auto &f : factors_) {
Expand Down
16 changes: 13 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @{

// TODO(dellaert): customize print and equals.
// void print(const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
// override;
// void print(
// const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

void printErrors(
const HybridValues& values,
const std::string& str = "HybridGaussianFactorGraph: ",
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const std::function<bool(const Factor* /*factor*/,
double /*whitenedError*/, size_t /*index*/)>&
printCondition =
[](const Factor*, double, size_t) { return true; }) const;

// bool equals(const This& fg, double tol = 1e-9) const override;

/// @}
Expand Down
Loading
Loading