Skip to content

Commit

Permalink
bug fix Pareto UB-subtract, expose UB to python, perfect tree example
Browse files Browse the repository at this point in the history
  • Loading branch information
kjgm committed Nov 21, 2023
1 parent c001dc3 commit 9bec94d
Show file tree
Hide file tree
Showing 15 changed files with 141 additions and 25 deletions.
31 changes: 31 additions & 0 deletions examples/perfect_tree_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pystreed import STreeDClassifier
import pandas as pd
from warnings import simplefilter
simplefilter(action='ignore', category=UserWarning)

# Read data

df = pd.read_csv("data/classification/segment.csv", sep=" ", header=None)
X = df[df.columns[1:]].values
y = df[0].values

# Fit the model
for max_depth in range(0, 5):
model = STreeDClassifier(max_depth = max_depth, upper_bound=0, time_limit=100, verbose=False)
model.fit(X,y)
if model.is_fitted(): break
print(f"No perfect tree for d = {max_depth}")


for max_num_nodes in range(0, 2**max_depth):
model = STreeDClassifier(max_depth = max_depth, max_num_nodes=max_num_nodes, upper_bound=0, time_limit=100, verbose=False)
model.fit(X,y)
if model.is_fitted(): break
print(f"No perfect tree for d = {max_depth}, n = {max_num_nodes}")


print(f"\nSmallest perfect tree with d = {max_depth}, n = {max_num_nodes}")

print("")

model.print_tree()
3 changes: 3 additions & 0 deletions include/model/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ namespace STreeD {

// Get the solutions in this container
inline const std::vector<Node<OT>>& GetSolutions() const { return solutions; }

// Get the ith solution in this container
inline const Node<OT>& Get(size_t ix) const { return solutions[ix]; }

// Get a mutable reference to the solution at index ix
inline Node<OT>& GetMutable(size_t ix) { return solutions[ix]; }
Expand Down
4 changes: 4 additions & 0 deletions include/tasks/eq_opp.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ namespace STreeD {
};
}

inline static std::vector<EqOppSol> ExtremePoints() {
return { {INT32_MAX, 0, 0}, {0, 1, 0}, {0, 0, 1} };
}

inline bool SatisfiesConstraint(const Node<EqOpp>& sol, const BranchContext& context) {
double disc = std::max(sol.solution.group0_score, sol.solution.group1_score) - 1;
return disc <= discrimination_limit;
Expand Down
4 changes: 4 additions & 0 deletions include/tasks/f1score.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ namespace STreeD {
inline static void MergeInv(const F1ScoreSol& s1, const F1ScoreSol& s2, F1ScoreSol& out) {
out = { std::max(s1.false_negatives, s2.false_negatives), std::max(s1.false_positives, s2.false_positives) };
}

inline static std::vector<F1ScoreSol> ExtremePoints() {
return { {0, INT32_MAX}, {INT32_MAX, 0} };
}
};

}
Expand Down
4 changes: 4 additions & 0 deletions include/tasks/group_fairness.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ namespace STreeD {
};
}

inline static std::vector<GroupFairnessSol> ExtremePoints() {
return { {INT32_MAX, 0, 0}, {0, 1, 0}, {0, 0, 1} };
}

inline bool SatisfiesConstraint(const Node<GroupFairness>& sol, const BranchContext& context) {
double disc = std::max(sol.solution.group0_score, sol.solution.group1_score) - 1;
return disc <= discrimination_limit;
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ target-version = "py38"

[project]
name = "pystreed"
version = "1.0.0"
version = "1.1.0"
description = "Python Implementation of STreeD: Dynamic Programming Approach for Optimal Decision Trees with Separable objectives and Constraints"
license= {file = "LICENSE"}
authors = [
Expand Down
17 changes: 14 additions & 3 deletions pystreed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BaseSTreeDSolver(BaseEstimator):
"time_limit": [Interval(numbers.Real, 0, None, closed="neither")],
"cost_complexity": [Interval(numbers.Real, 0, 1, closed="both")],
"feature_ordering": [StrOptions({"in-order", "gini"})],
"upper_bound": [Interval(numbers.Real, 0, None, closed="left")],
"random_seed": [Interval(numbers.Integral, -1, None, closed="left")],
"n_thresholds": [Interval(numbers.Integral, 1, None, closed="left")],
"n_categories": [Interval(numbers.Integral, 2, None, closed="left")],
Expand All @@ -45,6 +46,7 @@ def __init__(self,
use_similarity_lower_bound: bool = True,
use_upper_bound: bool = True,
use_lower_bound: bool = True,
upper_bound: float = 2**31-1,
verbose: bool = False,
random_seed: int = 27,
continuous_binarize_strategy: str = 'quantile',
Expand All @@ -68,6 +70,7 @@ def __init__(self,
use_similarity_lower_bound: Enable/Disable the similarity lower bound (Enabled typically results in a large runtime advantage)
use_upper_bound: Enable/Disable the use of upper bounds (Enabled is typically faster)
use_lower_bound: Enable/Disable the use of lower bounds (Enabled is typically faster)
upper_bound: Search for a tree better than the provided upper bound
verbose: Enable/Disable verbose output
random_seed: the random seed used by the solver (for example when creating folds)
continuous_binarization_strategy: the strategy used for binarizing continuous features
Expand All @@ -89,6 +92,7 @@ def __init__(self,
self.use_similarity_lower_bound = use_similarity_lower_bound
self.use_upper_bound = use_upper_bound
self.use_lower_bound = use_lower_bound
self.upper_bound = upper_bound
self.verbose = verbose
self.random_seed = random_seed
self.continuous_binarize_strategy = continuous_binarize_strategy
Expand Down Expand Up @@ -130,6 +134,7 @@ def _initialize_param_handler(self):
self._params.use_similarity_lower_bound = self.use_similarity_lower_bound
self._params.use_upper_bound = self.use_upper_bound
self._params.use_lower_bound = self.use_lower_bound
self._params.upper_bound = self.upper_bound

def get_solver_params(self):
return self._solver._get_parameters()
Expand Down Expand Up @@ -258,17 +263,23 @@ def fit(self, X, y, extra_data=None, categorical=None):

if duration > self.time_limit:
warnings.warn("Fitting exceeds time limit.", stacklevel=2)
if not self.fit_result.is_feasible():
warnings.warn("No feasible tree found.", stacklevel=2)
delattr(self, "fit_result")
else:
self.tree_ = self._solver._get_tree(self.fit_result)

self.tree_ = self._solver._get_tree(self.fit_result)

if self.verbose:
if self.is_fitted() and self.verbose:
print("Training score: ", self.fit_result.score())
print("Tree depth: ", self.fit_result.tree_depth(), " \tBranching nodes: ", self.fit_result.tree_nodes())
if not self.fit_result.is_optimal():
print("No proof of optimality!")

return self

def is_fitted(self):
return hasattr(self, "fit_result")

def predict(self, X, extra_data=None):
"""
Predicts the target variable for the given input feature data.
Expand Down
6 changes: 6 additions & 0 deletions pystreed/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pystreed.utils import _color_brew
import numpy as np
import numbers
import warnings

class STreeDClassifier(BaseSTreeDSolver):
"""
Expand All @@ -26,6 +27,7 @@ def __init__(self,
use_similarity_lower_bound: bool = True,
use_upper_bound: bool = True,
use_lower_bound: bool = True,
upper_bound: float = 2**31 -1,
verbose : bool = False,
random_seed: int = 27,
continuous_binarize_strategy: str = 'quantile',
Expand All @@ -50,6 +52,7 @@ def __init__(self,
use_similarity_lower_bound: Enable/Disable the similarity lower bound (Enabled typically results in a large runtime advantage)
use_upper_bound: Enable/Disable the use of upper bounds (Enabled is typically faster)
use_lower_bound: Enable/Disable the use of lower bounds (Enabled is typically faster)
upper_bound: Search for a tree better than the provided upper bound
verbose: Enable/Disable verbose output
random_seed: the random seed used by the solver (for example when creating folds)
continuous_binarization_strategy: the strategy used for binarizing continuous features
Expand All @@ -72,11 +75,14 @@ def __init__(self,
use_similarity_lower_bound=use_similarity_lower_bound,
use_upper_bound=use_upper_bound,
use_lower_bound=use_lower_bound,
upper_bound=upper_bound,
verbose=verbose,
random_seed=random_seed,
continuous_binarize_strategy=continuous_binarize_strategy,
n_thresholds=n_thresholds,
n_categories=n_categories)
if optimization_task == "f1-score" and upper_bound != 2**31-1:
warnings.warn(f"upper_bound parameter is ignored for f1-score", stacklevel=2)

def _initialize_param_handler(self):
super()._initialize_param_handler()
Expand Down
3 changes: 3 additions & 0 deletions pystreed/cost_sensitive_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self,
use_terminal_solver: bool = True,
use_upper_bound: bool = True,
use_lower_bound: bool = True,
upper_bound: float = 2**31-1,
verbose : bool = False,
random_seed: int = 27):
"""
Expand All @@ -33,6 +34,7 @@ def __init__(self,
use_terminal_solver: Enable/Disable the depth-two solver (Enabled typically results in a large runtime advantage)
use_upper_bound: Enable/Disable the use of upper bounds (Enabled is typically faster)
use_lower_bound: Enable/Disable the use of lower bounds (Enabled is typically faster)
upper_bound: Search for a tree better than the provided upper bound
verbose: Enable/Disable verbose output
random_seed: the random seed used by the solver (for example when creating folds)
"""
Expand All @@ -50,6 +52,7 @@ def __init__(self,
use_similarity_lower_bound=False,
use_upper_bound=use_upper_bound,
use_lower_bound=use_lower_bound,
upper_bound=upper_bound,
verbose=verbose,
random_seed=random_seed)

Expand Down
3 changes: 3 additions & 0 deletions pystreed/prescriptive_policy_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self,
use_similarity_lower_bound: bool = True,
use_upper_bound: bool = True,
use_lower_bound: bool = True,
upper_bound: float = 2**31 - 1,
verbose : bool = False,
random_seed : int = 27,
continuous_binarize_strategy: str = 'quantile',
Expand All @@ -47,6 +48,7 @@ def __init__(self,
use_similarity_lower_bound: Enable/Disable the similarity lower bound (Enabled typically results in a large runtime advantage)
use_upper_bound: Enable/Disable the use of upper bounds (Enabled is typically faster)
use_lower_bound: Enable/Disable the use of lower bounds (Enabled is typically faster)
upper_bound: Search for a tree better than the provided upper bound
verbose: Enable/Disable verbose output
random_seed: the random seed used by the solver (for example when creating folds)
continuous_binarization_strategy: the strategy used for binarizing continuous features
Expand All @@ -67,6 +69,7 @@ def __init__(self,
use_similarity_lower_bound=use_similarity_lower_bound,
use_upper_bound=use_upper_bound,
use_lower_bound=use_lower_bound,
upper_bound=upper_bound,
verbose=verbose,
random_seed=random_seed,
continuous_binarize_strategy=continuous_binarize_strategy,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Define package metadata
package_name = 'pystreed'
extension_name = 'cstreed'
__version__ = "0.0.1"
__version__ = "1.1.0"

ext_modules = [
Pybind11Extension(package_name + '.' + extension_name,
Expand Down
9 changes: 6 additions & 3 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ void NumpyToSTreeDData(const py::array_t<int, py::array::c_style>& _X,
AData& data, ADataView& data_view) {
const bool regression = std::is_same<LT, double>::value;
std::vector<const AInstance*> instances;
auto X = _X.template unchecked<2>(); // Template keyword because of a bug in the clang compiler
auto y = _y.template unchecked<1>(); // Template keyword because of a bug in the clang compiler
auto X = _X.template unchecked<2>();
auto y = _y.template unchecked<1>();
const int num_instances = int(X.shape(0));
const int num_features = int(X.shape(1));

Expand Down Expand Up @@ -74,7 +74,7 @@ void NumpyToSTreeDData(const py::array_t<int, py::array::c_style>& _X,
}

std::vector<bool> NumpyRowToBoolVector(const py::array_t<int, py::array::c_style>& _X) {
auto X = _X.template unchecked<1>(); // Template keyword because of a bug in the clang compiler
auto X = _X.template unchecked<1>();
std::vector<bool> v(X.shape(0));
for (py::size_t j = 0; j < X.shape(0); j++) {
v[j] = X(j);
Expand Down Expand Up @@ -192,6 +192,8 @@ PYBIND11_MODULE(cstreed, m) {
************************************/
py::class_<SolverResult, std::shared_ptr<SolverResult>> solver_result(m, "SolverResult");

solver_result.def("is_feasible", &SolverResult::IsFeasible);

solver_result.def("is_optimal", [](const SolverResult &solver_result) {
py::scoped_ostream_redirect stream(std::cout, py::module_::import("sys").attr("stdout"));
return solver_result.IsProvenOptimal();
Expand Down Expand Up @@ -240,6 +242,7 @@ PYBIND11_MODULE(cstreed, m) {
ExposeBooleanProperty(parameter_handler, "use-similarity-lower-bound", "use_similarity_lower_bound");
ExposeBooleanProperty(parameter_handler, "use-upper-bound", "use_upper_bound");
ExposeBooleanProperty(parameter_handler, "use-lower-bound", "use_lower_bound");
ExposeFloatProperty(parameter_handler, "upper-bound", "upper_bound");
ExposeStringProperty(parameter_handler, "ppg-teacher-method", "ppg_teacher_method");
ExposeFloatProperty(parameter_handler, "discrimination-limit", "discrimination_limit");

Expand Down
2 changes: 1 addition & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ int main(int argc, char* argv[]) {
std::cout << std::setprecision (std::numeric_limits<double>::digits10 + 1) << test_score->score << " \t";
std::cout << test_score->average_path_length << std::endl;

std::cout << "Tree " << i << ": " << result->tree_strings[i];
std::cout << "Tree " << i << ": " << result->tree_strings[i] << std::endl;

}
} else {
Expand Down
9 changes: 9 additions & 0 deletions src/solver/define_parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,15 @@ namespace STreeD {
"Algorithmic Parameters"
);

parameters.DefineFloatParameter(
"upper-bound",
"Search for a tree better than the provided upper bound (numeric).",
INT32_MAX, // default
"Algorithmic Parameters",
0.0, // min
DBL_MAX // max
);

parameters.DefineStringParameter
(
"feature-ordering",
Expand Down
67 changes: 51 additions & 16 deletions src/solver/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ namespace STreeD {
auto result = InitializeSol<OT>();
if (CheckEmptySol<OT>(global_UB)) {
global_UB = InitializeSol<OT>();
// If an upper bound is provided, and the objective is numeric, add it to the UB
if constexpr (std::is_same<Solver<OT>::SolType, double>::value || std::is_same<Solver<OT>::SolType, int>::value) {
AddSol<OT>(global_UB, Node<OT>(parameters.GetFloatParameter("upper-bound")));
}
result = SolveLeafNode(train_data, root_context, global_UB);
}

Expand All @@ -77,12 +81,14 @@ namespace STreeD {
auto solver_result = std::make_shared<SolverTaskResult<OT>>();
solver_result->is_proven_optimal = stopwatch.IsWithinTimeLimit();
if constexpr (OT::total_order) {
clock_t clock_start = clock();
auto tree = ConstructOptimalTree(result, train_data, root_context, int(parameters.GetIntegerParameter("max-depth")), result.NumNodes());
stats.time_reconstructing += double(clock() - clock_start) / CLOCKS_PER_SEC;
auto score = InternalTrainScore<OT>::ComputeTrainPerformance(&data_splitter, task, tree.get(), train_data);
tree->FlipFlippedFeatures(flipped_features);
solver_result->AddSolution(tree, score);
if (result.IsFeasible()) {
clock_t clock_start = clock();
auto tree = ConstructOptimalTree(result, train_data, root_context, int(parameters.GetIntegerParameter("max-depth")), result.NumNodes());
stats.time_reconstructing += double(clock() - clock_start) / CLOCKS_PER_SEC;
auto score = InternalTrainScore<OT>::ComputeTrainPerformance(&data_splitter, task, tree.get(), train_data);
tree->FlipFlippedFeatures(flipped_features);
solver_result->AddSolution(tree, score);
}
} else {
for (auto& s : result->GetSolutions()) {
clock_t clock_start = clock();
Expand Down Expand Up @@ -635,19 +641,48 @@ namespace STreeD {
sols_ptr = &small_sols;
}


// For each solution, substract it from the current UB
for (size_t i = 0; i < (*UB_ptr)->Size(); i++) {
Solver<OT>::SolContainer corner_union = InitializeSol<OT>();
Solver<OT>::SolContainer sub_union = InitializeSol<OT>();
auto extreme_points = OT::ExtremePoints();
for (size_t j = 0; j < (*sols_ptr)->Size(); j++) {
Solver<OT>::SolContainer sub_ub = InitializeSol<OT>();
for (size_t j = 0; j < (*sols_ptr)->Size(); j++) {
OT::Subtract((*UB_ptr)->GetSolutions()[i].solution, (*sols_ptr)->GetSolutions()[j].solution, diffsol.solution);
// For each current UB-solution - other branch solution, keep only the non-dominated solutions
// If exceeding the size, Merge two solutions such that the combined solution is worse than both
sub_ub->AddOrInvMerge(diffsol, MAX_SIZE);
for (size_t i = 0; i < (*UB_ptr)->Size(); i++) {
OT::Subtract((*UB_ptr)->Get(i).solution, (*sols_ptr)->Get(j).solution, diffsol.solution);
//sub_ub->AddOrInvMerge(diffsol, MAX_SIZE);
sub_ub->Add(diffsol);
}
// For each set of solutions, add to the new UB and use reverse-nondom to decided what to keep
updatedUB->AddInvOrInvMerge(*(sub_ub.get()), MAX_SIZE);
for (auto& ep : extreme_points) {
sub_ub->Add(Node<OT>(ep));
}

// Compute 'staircase corners'
Solver<OT>::SolContainer corners = InitializeSol<OT>();
for (size_t i = 0; i < sub_ub->Size(); i++) {
for (size_t k = i + 1; k < sub_ub->Size(); k++) {
OT::MergeInv(sub_ub->Get(i).solution, sub_ub->Get(k).solution, diffsol.solution);
//corners->AddOrInvMerge(diffsol, MAX_SIZE);
corners->Add(diffsol);
}
}
//corner_union->AddInvOrInvMerge(*(corners.get()), MAX_SIZE);
//sub_union->AddInvOrInvMerge(*(sub_ub.get()), MAX_SIZE);

corner_union->AddInv(*(corners.get()));
sub_union->AddInv(*(sub_ub.get()));
}
// Compute 'staircase corners'
for (auto& ep : extreme_points) {
corner_union->Add(Node<OT>(ep));
}
for (size_t i = 0; i < corner_union->Size(); i++) {
for (size_t k = i + 1; k < corner_union->Size(); k++) {
OT::Merge(corner_union->Get(i).solution, corner_union->Get(k).solution, diffsol.solution);
updatedUB->AddInvOrInvMerge(diffsol, MAX_SIZE);
}
}
updatedUB->AddInvOrInvMerge(*(sub_union.get()), MAX_SIZE);
//updatedUB->AddOrInvMerge(*(UB_ptr->get()), MAX_SIZE);

}
// In the root nod of the search, feasible solutions can be relaxed by removing information that is related
// to constraint satisfaction from the solution
Expand Down

0 comments on commit 9bec94d

Please sign in to comment.