Skip to content

Commit

Permalink
Reproducible summation.
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Aug 3, 2023
1 parent c876af1 commit b030bfe
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 6 deletions.
11 changes: 11 additions & 0 deletions legateboost/legateboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ class TreeStructure(_PickleCunumericMixin):
gain: cn.ndarray
hessian: cn.ndarray

def __eq__(self, other: object) -> bool:
if not isinstance(other, TreeStructure):
return NotImplemented
eq = [cn.all(self.leaf_value == other.leaf_value)]
eq.append(cn.all(self.feature == other.feature))
eq.append(cn.all(self.split_value == other.split_value))
eq.append(cn.all(self.gain == other.gain))
eq.append(cn.all(self.hessian == other.hessian))
return all(eq)

def is_leaf(self, id: int) -> Any:
return self.feature[id] == -1

Expand Down Expand Up @@ -397,6 +407,7 @@ def _partial_fit(
g, h = self._objective_instance.gradient(
y, self._objective_instance.transform(pred)
)

assert g.ndim == h.ndim == 2
assert g.dtype == h.dtype == cn.float64, "g.dtype={}, h.dtype={}".format(
g.dtype, h.dtype
Expand Down
25 changes: 25 additions & 0 deletions legateboost/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
from .metrics import BaseMetric, ExponentialMetric, LogLossMetric, MSEMetric


def preround(func):
"""Apply this decorator to the gradient method of an objective to ensure
reproducible floating point summation.
Algorithm 5: Reproducible Sequential Sum in 'Fast Reproducible
Floating-Point Summation' by Demmel and Nguyen.
"""

def round(x):
m = cn.max(cn.abs(x))
n = x.size
delta = cn.floor(n * m / (1 - 2 * n * cn.finfo(x.dtype).eps))
M = 2 ** cn.ceil(cn.log2(delta))
return (x + M) - M

def inner(self, y: cn.ndarray, pred: cn.ndarray) -> Tuple[cn.ndarray, cn.ndarray]:
g, h = func(self, y, pred)
return round(g), round(h)

return inner


class BaseObjective(ABC):
"""The base class for objective functions.
Expand Down Expand Up @@ -73,6 +95,7 @@ class SquaredErrorObjective(BaseObjective):
:class:`legateboost.metrics.MSEMetric`
"""

@preround
def gradient(
self, y: cn.ndarray, pred: cn.ndarray
) -> Tuple[cn.ndarray, cn.ndarray]:
Expand Down Expand Up @@ -100,6 +123,7 @@ class LogLossObjective(BaseObjective):
:class:`legateboost.metrics.LogLossMetric`
"""

@preround
def gradient(
self, y: cn.ndarray, pred: cn.ndarray
) -> Tuple[cn.ndarray, cn.ndarray]:
Expand Down Expand Up @@ -153,6 +177,7 @@ class ExponentialObjective(BaseObjective):
[1] Hastie, Trevor, et al. "Multi-class adaboost." Statistics and its Interface 2.3 (2009): 349-360.
""" # noqa: E501

@preround
def gradient(self, y: cn.ndarray, pred: cn.ndarray) -> cn.ndarray:
assert pred.ndim == 2

Expand Down
37 changes: 31 additions & 6 deletions legateboost/test/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,22 @@ def test_regressor_weights(num_outputs):
assert loss[-1] < 1e-5


def test_regressor_determinism():
X = cn.random.random((100, 10))
y = cn.random.random(X.shape[0])
@pytest.mark.parametrize("num_outputs", [1, 5])
def test_regressor_determinism(num_outputs):
X = cn.random.random((10000, 10))
y = cn.random.random((X.shape[0], num_outputs))
preds = []
params = {"max_depth": 12, "random_state": 84, "objective": "squared_error"}
preds = []
models = []
for _ in range(0, 10):
model = lb.LBRegressor(n_estimators=2, random_state=83).fit(X, y)
model = lb.LBRegressor(n_estimators=10, **params).fit(X, y)
models.append(model)
p = model.predict(X)
if preds:
assert cn.all(p == preds[-1])
preds.append(model.predict(X))
assert cn.allclose(p, preds[-1]), cn.max(cn.abs(p - preds[-1]))
if models:
assert cn.all([a == b for a, b in zip(models[0].models_, model.models_)])


def test_regressor_vs_sklearn():
Expand Down Expand Up @@ -168,3 +174,22 @@ def test_classifier_improving_with_depth(num_class, objective):
loss = next(iter(eval_result["train"].values()))
metrics.append(loss[-1])
assert utils.non_increasing(metrics)


@pytest.mark.parametrize("num_class", [2, 5])
@pytest.mark.parametrize("objective", ["log_loss", "exp"])
def test_classifier_determinism(num_class, objective):
np.random.seed(3)
X = cn.random.random((10000, 20))
y = cn.random.randint(0, num_class, X.shape[0])
params = {"max_depth": 12, "random_state": 84, "objective": objective}
preds = []
models = []
for _ in range(0, 10):
model = lb.LBClassifier(n_estimators=10, **params).fit(X, y)
models.append(model)
p = model.predict_proba(X)
if preds:
assert cn.allclose(p, preds[-1]), cn.max(cn.abs(p - preds[-1]))
if models:
assert cn.all([a == b for a, b in zip(models[0].models_, model.models_)])

0 comments on commit b030bfe

Please sign in to comment.