From 29f62dbb4eb465ce2f0474f6ea7ff892a3f79f8e Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 10 Oct 2023 13:58:13 +0000 Subject: [PATCH] Basic KRR implementation. --- legateboost/models/__init__.py | 1 + legateboost/models/krr.py | 78 ++++++++++++++++++++++++ legateboost/models/linear.py | 51 +--------------- legateboost/objectives.py | 1 - legateboost/test/models/test_krr.py | 49 +++++++++++++++ legateboost/test/test_estimator.py | 8 ++- legateboost/test/test_with_hypothesis.py | 15 ++++- legateboost/test/utils.py | 12 +++- legateboost/utils.py | 48 ++++++++++++++- 9 files changed, 205 insertions(+), 58 deletions(-) create mode 100644 legateboost/models/krr.py create mode 100644 legateboost/test/models/test_krr.py diff --git a/legateboost/models/__init__.py b/legateboost/models/__init__.py index 9462d9aa..fc5369d3 100644 --- a/legateboost/models/__init__.py +++ b/legateboost/models/__init__.py @@ -1,3 +1,4 @@ from .tree import Tree from .linear import Linear +from .krr import KRR from .base_model import BaseModel diff --git a/legateboost/models/krr.py b/legateboost/models/krr.py new file mode 100644 index 00000000..c55a2f57 --- /dev/null +++ b/legateboost/models/krr.py @@ -0,0 +1,78 @@ +import cunumeric as cn + +from ..utils import solve_singular +from .base_model import BaseModel + + +def l2(X, Y): + XX = cn.einsum("ij,ij->i", X, X)[:, cn.newaxis] + YY = cn.einsum("ij,ij->i", Y, Y) + XY = 2 * cn.dot(X, Y.T) + return XX + YY - XY + + +def rbf_kernel(X, Y, sigma=1.0): + K = l2(X, Y) + return cn.exp(-K / (2 * sigma**2)) + + +class KRR(BaseModel): + def __init__(self, n_components=10, alpha=1.0): + self.num_components = n_components + self.alpha = alpha + + def _fit_components(self, X, g, h) -> "KRR": + # fit with fixed set of components + K = rbf_kernel(X, self.X_train) + num_outputs = g.shape[1] + self.bias_ = cn.zeros(num_outputs) + self.betas_ = cn.zeros((self.X_train.shape[0], num_outputs)) + + for k in range(num_outputs): + W = cn.sqrt(h[:, k]) + Kw = K * W[:, cn.newaxis] + diag = cn.eye(Kw.shape[1]) * self.alpha + KtK = cn.dot(Kw.T, Kw) + diag + yw = W * (-g[:, k] / h[:, k]) + self.betas_[:, k] = solve_singular(KtK, cn.dot(Kw.T, yw)) + return self + + def fit( + self, + X: cn.ndarray, + g: cn.ndarray, + h: cn.ndarray, + ) -> "KRR": + usable_num_components = min(X.shape[0], self.num_components) + self.indices = self.random_state.permutation(X.shape[0])[:usable_num_components] + self.X_train = X[self.indices] + return self._fit_components(X, g, h) + + def predict(self, X): + K = rbf_kernel(X, self.X_train) + return K.dot(self.betas_) + + def clear(self) -> None: + self.betas_.fill(0) + + def update( + self, + X: cn.ndarray, + g: cn.ndarray, + h: cn.ndarray, + ) -> "KRR": + return self._fit_components(X, g, h) + + def __str__(self) -> str: + return ( + "Components: " + + str(self.X_train) + + "\nCoefficients: " + + str(self.betas_) + + "\n" + ) + + def __eq__(self, other: object) -> bool: + return (other.betas_ == self.betas_).all() and ( + other.X_train == self.X_train + ).all() diff --git a/legateboost/models/linear.py b/legateboost/models/linear.py index ef742b52..0e45889b 100644 --- a/legateboost/models/linear.py +++ b/legateboost/models/linear.py @@ -1,8 +1,6 @@ -import numpy as np - import cunumeric as cn -from legate.core import get_legate_runtime +from ..utils import solve_singular from .base_model import BaseModel @@ -31,51 +29,6 @@ class Linear(BaseModel): def __init__(self, alpha: float = 0.0) -> None: self.alpha = alpha - def solve_singular(self, a, b): - """Solve a singular linear system Ax = b for x. - The same as np.linalg.solve, but if A is singular, - then we use Algorithm 3.3 from: - - Nocedal, Jorge, and Stephen J. Wright, eds. - Numerical optimization. New York, NY: Springer New York, 1999. - - This progressively adds to the diagonal of the matrix until it is non-singular. - """ - # ensure we are doing all calculations in float 64 for stability - a = a.astype(np.float64) - b = b.astype(np.float64) - # try first without modification - try: - res = cn.linalg.solve(a, b) - get_legate_runtime().raise_exceptions() - if np.isnan(res).any(): - raise np.linalg.LinAlgError - return res - except (np.linalg.LinAlgError, cn.linalg.LinAlgError): - pass - - # if that fails, try adding to the diagonal - eps = 1e-3 - min_diag = a[::].min() - if min_diag > 0: - tau = eps - else: - tau = -min_diag + eps - while True: - try: - res = cn.linalg.solve(a + cn.eye(a.shape[0]) * tau, b) - get_legate_runtime().raise_exceptions() - if np.isnan(res).any(): - raise np.linalg.LinAlgError - return res - except (np.linalg.LinAlgError, cn.linalg.LinAlgError): - tau = max(tau * 2, eps) - if tau > 1e10: - raise ValueError( - "Numerical instability in linear model solve. " - "Consider normalising your data." - ) - def fit( self, X: cn.ndarray, @@ -95,7 +48,7 @@ def fit( diag[0, 0] = 0 XtX = cn.dot(Xw.T, Xw) + diag yw = W * (-g[:, k] / h[:, k]) - result = self.solve_singular(XtX, cn.dot(Xw.T, yw)) + result = solve_singular(XtX, cn.dot(Xw.T, yw)) self.bias_[k] = result[0] self.betas_[:, k] = result[1:] diff --git a/legateboost/objectives.py b/legateboost/objectives.py index 86d03397..3f87e175 100644 --- a/legateboost/objectives.py +++ b/legateboost/objectives.py @@ -253,7 +253,6 @@ def gradient( # multi-class case label = y.astype(cn.int32).squeeze() h = pred * (1.0 - pred) - print(pred.min(), pred.max()) g = pred.copy() mod_col_by_idx(g, label, -1.0) # g[cn.arange(y.size), label] -= 1.0 diff --git a/legateboost/test/models/test_krr.py b/legateboost/test/models/test_krr.py new file mode 100644 index 00000000..11774637 --- /dev/null +++ b/legateboost/test/models/test_krr.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest + +import cunumeric as cn +import legateboost as lb + +from ..utils import non_increasing + + +@pytest.mark.parametrize("num_outputs", [1, 5]) +def test_improving_with_components(num_outputs): + rs = cn.random.RandomState(0) + X = rs.random((100, 10)) + g = rs.normal(size=(X.shape[0], num_outputs)) + h = rs.random(g.shape) + 0.1 + X, g, h = cn.array(X), cn.array(g), cn.array(h) + y = -g / h + metrics = [] + for n_components in range(1, 15): + model = ( + lb.models.KRR(n_components=n_components) + .set_random_state(np.random.RandomState(2)) + .fit(X, g, h) + ) + predict = model.predict(X) + loss = ((predict - y) ** 2 * h).sum(axis=0) / h.sum(axis=0) + loss = loss.mean() + metrics.append(loss) + + assert non_increasing(metrics) + + +@pytest.mark.parametrize("num_outputs", [1, 5]) +def test_alpha(num_outputs): + # higher alpha hyperparameter should lead to smaller coefficients + rs = cn.random.RandomState(0) + X = rs.random((100, 10)) + g = rs.normal(size=(X.shape[0], num_outputs)) + h = rs.random(g.shape) + 0.1 + X, g, h = cn.array(X), cn.array(g), cn.array(h) + norms = [] + for alpha in np.linspace(0.0, 2.5, 5): + model = ( + lb.models.KRR(alpha=alpha) + .set_random_state(np.random.RandomState(2)) + .fit(X, g, h) + ) + norms.append(np.linalg.norm(model.betas_)) + assert non_increasing(norms) diff --git a/legateboost/test/test_estimator.py b/legateboost/test/test_estimator.py index afaf2830..c8dab99d 100644 --- a/legateboost/test/test_estimator.py +++ b/legateboost/test/test_estimator.py @@ -5,7 +5,7 @@ import cunumeric as cn import legateboost as lb -from .utils import non_increasing, sanity_check_tree_stats +from .utils import non_increasing, sanity_check_models def test_init(): @@ -59,6 +59,7 @@ def test_update(): (lb.models.Tree(max_depth=5),), (lb.models.Linear(),), (lb.models.Tree(max_depth=1), lb.models.Linear()), + (lb.models.KRR(),), ], ) def test_regressor(num_outputs, objective, base_models): @@ -80,7 +81,7 @@ def test_regressor(num_outputs, objective, base_models): loss = next(iter(eval_result["train"].values())) assert np.isclose(loss[-1], loss_recomputed) assert non_increasing(loss) - sanity_check_tree_stats(model) + sanity_check_models(model) @pytest.fixture @@ -103,6 +104,7 @@ def test_sklearn_compatible_estimator(estimator, check, test_name): (lb.models.Tree(max_depth=5),), (lb.models.Linear(),), (lb.models.Tree(max_depth=1), lb.models.Linear()), + (lb.models.KRR(),), ], ) def test_classifier(num_class, objective, base_models): @@ -124,7 +126,7 @@ def test_classifier(num_class, objective, base_models): assert non_increasing(train_loss) # better than random guessing accuracy assert model.score(X, y) > 1 / num_class - sanity_check_tree_stats(model) + sanity_check_models(model) def test_normal_distribution(): diff --git a/legateboost/test/test_with_hypothesis.py b/legateboost/test/test_with_hypothesis.py index 133e72fb..248555f3 100644 --- a/legateboost/test/test_with_hypothesis.py +++ b/legateboost/test/test_with_hypothesis.py @@ -3,7 +3,7 @@ import legateboost as lb -from .utils import non_increasing, sanity_check_tree_stats +from .utils import non_increasing, sanity_check_models np.set_printoptions(threshold=10, edgeitems=1) @@ -32,12 +32,21 @@ def linear_strategy(draw): return lb.models.Linear(alpha=alpha) +@st.composite +def krr_strategy(draw): + alpha = draw(st.floats(0.0, 1.0)) + components = draw(st.integers(1, 10)) + return lb.models.KRR(n_components=components, alpha=alpha) + + @st.composite def base_model_strategy(draw): n = draw(st.integers(1, 5)) base_models = () for _ in range(n): - base_models += (draw(st.one_of([tree_strategy(), linear_strategy()])),) + base_models += ( + draw(st.one_of([tree_strategy(), linear_strategy(), krr_strategy()])), + ) return base_models @@ -120,7 +129,7 @@ def test_regressor(model_params, regression_params, regression_dataset): model.predict(X) loss = next(iter(eval_result["train"].values())) assert non_increasing(loss) - sanity_check_tree_stats(model) + sanity_check_models(model) classification_param_strategy = st.fixed_dictionaries( diff --git a/legateboost/test/utils.py b/legateboost/test/utils.py index fcb2e2af..a1f4fa2c 100644 --- a/legateboost/test/utils.py +++ b/legateboost/test/utils.py @@ -12,8 +12,11 @@ def non_decreasing(x): return all(x <= y for x, y in zip(x, x[1:])) -def sanity_check_tree_stats(model): +def sanity_check_models(model): trees = [m for m in model.models_ if isinstance(m, lb.models.Tree)] + linear_models = [m for m in model.models_ if isinstance(m, lb.models.Linear)] + krr_models = [m for m in model.models_ if isinstance(m, lb.models.KRR)] + for m in trees: # Check that we have no 0 hessian splits split_nodes = m.feature != -1 @@ -26,3 +29,10 @@ def sanity_check_tree_stats(model): leaves = (m.feature == -1) & (m.hessian[:, 0] > 0.0) leaf_sum = m.hessian[leaves].sum(axis=0) assert np.isclose(leaf_sum, m.hessian[0]).all() + + for m in linear_models: + assert cn.isfinite(m.betas_).all() + assert cn.isfinite(m.bias_).all() + + for m in krr_models: + assert cn.isfinite(m.betas_).all() diff --git a/legateboost/utils.py b/legateboost/utils.py index bbebcd16..5614a6c9 100644 --- a/legateboost/utils.py +++ b/legateboost/utils.py @@ -3,7 +3,7 @@ import numpy as np import cunumeric as cn -from legate.core import Store +from legate.core import Store, get_legate_runtime class PickleCunumericMixin: @@ -99,3 +99,49 @@ def get_store(input: Any) -> Store: array = data[field] _, store = array.stores() return store + + +def solve_singular(a, b): + """Solve a singular linear system Ax = b for x. + The same as np.linalg.solve, but if A is singular, + then we use Algorithm 3.3 from: + + Nocedal, Jorge, and Stephen J. Wright, eds. + Numerical optimization. New York, NY: Springer New York, 1999. + + This progressively adds to the diagonal of the matrix until it is non-singular. + """ + # ensure we are doing all calculations in float 64 for stability + a = a.astype(np.float64) + b = b.astype(np.float64) + # try first without modification + try: + res = cn.linalg.solve(a, b) + get_legate_runtime().raise_exceptions() + if np.isnan(res).any(): + raise np.linalg.LinAlgError + return res + except (np.linalg.LinAlgError, cn.linalg.LinAlgError): + pass + + # if that fails, try adding to the diagonal + eps = 1e-3 + min_diag = a[::].min() + if min_diag > 0: + tau = eps + else: + tau = -min_diag + eps + while True: + try: + res = cn.linalg.solve(a + cn.eye(a.shape[0]) * tau, b) + get_legate_runtime().raise_exceptions() + if np.isnan(res).any(): + raise np.linalg.LinAlgError + return res + except (np.linalg.LinAlgError, cn.linalg.LinAlgError): + tau = max(tau * 2, eps) + if tau > 1e10: + raise ValueError( + "Numerical instability in linear model solve. " + "Consider normalising your data." + )