Skip to content

Commit

Permalink
Basic KRR implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Oct 10, 2023
1 parent 7d1d541 commit 29f62db
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 58 deletions.
1 change: 1 addition & 0 deletions legateboost/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .tree import Tree
from .linear import Linear
from .krr import KRR
from .base_model import BaseModel
78 changes: 78 additions & 0 deletions legateboost/models/krr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import cunumeric as cn

from ..utils import solve_singular
from .base_model import BaseModel


def l2(X, Y):
XX = cn.einsum("ij,ij->i", X, X)[:, cn.newaxis]
YY = cn.einsum("ij,ij->i", Y, Y)
XY = 2 * cn.dot(X, Y.T)
return XX + YY - XY


def rbf_kernel(X, Y, sigma=1.0):
K = l2(X, Y)
return cn.exp(-K / (2 * sigma**2))


class KRR(BaseModel):
def __init__(self, n_components=10, alpha=1.0):
self.num_components = n_components
self.alpha = alpha

def _fit_components(self, X, g, h) -> "KRR":
# fit with fixed set of components
K = rbf_kernel(X, self.X_train)
num_outputs = g.shape[1]
self.bias_ = cn.zeros(num_outputs)
self.betas_ = cn.zeros((self.X_train.shape[0], num_outputs))

for k in range(num_outputs):
W = cn.sqrt(h[:, k])
Kw = K * W[:, cn.newaxis]
diag = cn.eye(Kw.shape[1]) * self.alpha
KtK = cn.dot(Kw.T, Kw) + diag
yw = W * (-g[:, k] / h[:, k])
self.betas_[:, k] = solve_singular(KtK, cn.dot(Kw.T, yw))
return self

def fit(
self,
X: cn.ndarray,
g: cn.ndarray,
h: cn.ndarray,
) -> "KRR":
usable_num_components = min(X.shape[0], self.num_components)
self.indices = self.random_state.permutation(X.shape[0])[:usable_num_components]
self.X_train = X[self.indices]
return self._fit_components(X, g, h)

def predict(self, X):
K = rbf_kernel(X, self.X_train)
return K.dot(self.betas_)

def clear(self) -> None:
self.betas_.fill(0)

def update(
self,
X: cn.ndarray,
g: cn.ndarray,
h: cn.ndarray,
) -> "KRR":
return self._fit_components(X, g, h)

def __str__(self) -> str:
return (
"Components: "
+ str(self.X_train)
+ "\nCoefficients: "
+ str(self.betas_)
+ "\n"
)

def __eq__(self, other: object) -> bool:
return (other.betas_ == self.betas_).all() and (
other.X_train == self.X_train
).all()
51 changes: 2 additions & 49 deletions legateboost/models/linear.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np

import cunumeric as cn
from legate.core import get_legate_runtime

from ..utils import solve_singular
from .base_model import BaseModel


Expand Down Expand Up @@ -31,51 +29,6 @@ class Linear(BaseModel):
def __init__(self, alpha: float = 0.0) -> None:
self.alpha = alpha

def solve_singular(self, a, b):
"""Solve a singular linear system Ax = b for x.
The same as np.linalg.solve, but if A is singular,
then we use Algorithm 3.3 from:
Nocedal, Jorge, and Stephen J. Wright, eds.
Numerical optimization. New York, NY: Springer New York, 1999.
This progressively adds to the diagonal of the matrix until it is non-singular.
"""
# ensure we are doing all calculations in float 64 for stability
a = a.astype(np.float64)
b = b.astype(np.float64)
# try first without modification
try:
res = cn.linalg.solve(a, b)
get_legate_runtime().raise_exceptions()
if np.isnan(res).any():
raise np.linalg.LinAlgError
return res
except (np.linalg.LinAlgError, cn.linalg.LinAlgError):
pass

# if that fails, try adding to the diagonal
eps = 1e-3
min_diag = a[::].min()
if min_diag > 0:
tau = eps
else:
tau = -min_diag + eps
while True:
try:
res = cn.linalg.solve(a + cn.eye(a.shape[0]) * tau, b)
get_legate_runtime().raise_exceptions()
if np.isnan(res).any():
raise np.linalg.LinAlgError
return res
except (np.linalg.LinAlgError, cn.linalg.LinAlgError):
tau = max(tau * 2, eps)
if tau > 1e10:
raise ValueError(
"Numerical instability in linear model solve. "
"Consider normalising your data."
)

def fit(
self,
X: cn.ndarray,
Expand All @@ -95,7 +48,7 @@ def fit(
diag[0, 0] = 0
XtX = cn.dot(Xw.T, Xw) + diag
yw = W * (-g[:, k] / h[:, k])
result = self.solve_singular(XtX, cn.dot(Xw.T, yw))
result = solve_singular(XtX, cn.dot(Xw.T, yw))
self.bias_[k] = result[0]
self.betas_[:, k] = result[1:]

Expand Down
1 change: 0 additions & 1 deletion legateboost/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def gradient(
# multi-class case
label = y.astype(cn.int32).squeeze()
h = pred * (1.0 - pred)
print(pred.min(), pred.max())
g = pred.copy()
mod_col_by_idx(g, label, -1.0)
# g[cn.arange(y.size), label] -= 1.0
Expand Down
49 changes: 49 additions & 0 deletions legateboost/test/models/test_krr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import pytest

import cunumeric as cn
import legateboost as lb

from ..utils import non_increasing


@pytest.mark.parametrize("num_outputs", [1, 5])
def test_improving_with_components(num_outputs):
rs = cn.random.RandomState(0)
X = rs.random((100, 10))
g = rs.normal(size=(X.shape[0], num_outputs))
h = rs.random(g.shape) + 0.1
X, g, h = cn.array(X), cn.array(g), cn.array(h)
y = -g / h
metrics = []
for n_components in range(1, 15):
model = (
lb.models.KRR(n_components=n_components)
.set_random_state(np.random.RandomState(2))
.fit(X, g, h)
)
predict = model.predict(X)
loss = ((predict - y) ** 2 * h).sum(axis=0) / h.sum(axis=0)
loss = loss.mean()
metrics.append(loss)

assert non_increasing(metrics)


@pytest.mark.parametrize("num_outputs", [1, 5])
def test_alpha(num_outputs):
# higher alpha hyperparameter should lead to smaller coefficients
rs = cn.random.RandomState(0)
X = rs.random((100, 10))
g = rs.normal(size=(X.shape[0], num_outputs))
h = rs.random(g.shape) + 0.1
X, g, h = cn.array(X), cn.array(g), cn.array(h)
norms = []
for alpha in np.linspace(0.0, 2.5, 5):
model = (
lb.models.KRR(alpha=alpha)
.set_random_state(np.random.RandomState(2))
.fit(X, g, h)
)
norms.append(np.linalg.norm(model.betas_))
assert non_increasing(norms)
8 changes: 5 additions & 3 deletions legateboost/test/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import cunumeric as cn
import legateboost as lb

from .utils import non_increasing, sanity_check_tree_stats
from .utils import non_increasing, sanity_check_models


def test_init():
Expand Down Expand Up @@ -59,6 +59,7 @@ def test_update():
(lb.models.Tree(max_depth=5),),
(lb.models.Linear(),),
(lb.models.Tree(max_depth=1), lb.models.Linear()),
(lb.models.KRR(),),
],
)
def test_regressor(num_outputs, objective, base_models):
Expand All @@ -80,7 +81,7 @@ def test_regressor(num_outputs, objective, base_models):
loss = next(iter(eval_result["train"].values()))
assert np.isclose(loss[-1], loss_recomputed)
assert non_increasing(loss)
sanity_check_tree_stats(model)
sanity_check_models(model)


@pytest.fixture
Expand All @@ -103,6 +104,7 @@ def test_sklearn_compatible_estimator(estimator, check, test_name):
(lb.models.Tree(max_depth=5),),
(lb.models.Linear(),),
(lb.models.Tree(max_depth=1), lb.models.Linear()),
(lb.models.KRR(),),
],
)
def test_classifier(num_class, objective, base_models):
Expand All @@ -124,7 +126,7 @@ def test_classifier(num_class, objective, base_models):
assert non_increasing(train_loss)
# better than random guessing accuracy
assert model.score(X, y) > 1 / num_class
sanity_check_tree_stats(model)
sanity_check_models(model)


def test_normal_distribution():
Expand Down
15 changes: 12 additions & 3 deletions legateboost/test/test_with_hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import legateboost as lb

from .utils import non_increasing, sanity_check_tree_stats
from .utils import non_increasing, sanity_check_models

np.set_printoptions(threshold=10, edgeitems=1)

Expand Down Expand Up @@ -32,12 +32,21 @@ def linear_strategy(draw):
return lb.models.Linear(alpha=alpha)


@st.composite
def krr_strategy(draw):
alpha = draw(st.floats(0.0, 1.0))
components = draw(st.integers(1, 10))
return lb.models.KRR(n_components=components, alpha=alpha)


@st.composite
def base_model_strategy(draw):
n = draw(st.integers(1, 5))
base_models = ()
for _ in range(n):
base_models += (draw(st.one_of([tree_strategy(), linear_strategy()])),)
base_models += (
draw(st.one_of([tree_strategy(), linear_strategy(), krr_strategy()])),
)
return base_models


Expand Down Expand Up @@ -120,7 +129,7 @@ def test_regressor(model_params, regression_params, regression_dataset):
model.predict(X)
loss = next(iter(eval_result["train"].values()))
assert non_increasing(loss)
sanity_check_tree_stats(model)
sanity_check_models(model)


classification_param_strategy = st.fixed_dictionaries(
Expand Down
12 changes: 11 additions & 1 deletion legateboost/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ def non_decreasing(x):
return all(x <= y for x, y in zip(x, x[1:]))


def sanity_check_tree_stats(model):
def sanity_check_models(model):
trees = [m for m in model.models_ if isinstance(m, lb.models.Tree)]
linear_models = [m for m in model.models_ if isinstance(m, lb.models.Linear)]
krr_models = [m for m in model.models_ if isinstance(m, lb.models.KRR)]

for m in trees:
# Check that we have no 0 hessian splits
split_nodes = m.feature != -1
Expand All @@ -26,3 +29,10 @@ def sanity_check_tree_stats(model):
leaves = (m.feature == -1) & (m.hessian[:, 0] > 0.0)
leaf_sum = m.hessian[leaves].sum(axis=0)
assert np.isclose(leaf_sum, m.hessian[0]).all()

for m in linear_models:
assert cn.isfinite(m.betas_).all()
assert cn.isfinite(m.bias_).all()

for m in krr_models:
assert cn.isfinite(m.betas_).all()
48 changes: 47 additions & 1 deletion legateboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import cunumeric as cn
from legate.core import Store
from legate.core import Store, get_legate_runtime


class PickleCunumericMixin:
Expand Down Expand Up @@ -99,3 +99,49 @@ def get_store(input: Any) -> Store:
array = data[field]
_, store = array.stores()
return store


def solve_singular(a, b):
"""Solve a singular linear system Ax = b for x.
The same as np.linalg.solve, but if A is singular,
then we use Algorithm 3.3 from:
Nocedal, Jorge, and Stephen J. Wright, eds.
Numerical optimization. New York, NY: Springer New York, 1999.
This progressively adds to the diagonal of the matrix until it is non-singular.
"""
# ensure we are doing all calculations in float 64 for stability
a = a.astype(np.float64)
b = b.astype(np.float64)
# try first without modification
try:
res = cn.linalg.solve(a, b)
get_legate_runtime().raise_exceptions()
if np.isnan(res).any():
raise np.linalg.LinAlgError
return res
except (np.linalg.LinAlgError, cn.linalg.LinAlgError):
pass

# if that fails, try adding to the diagonal
eps = 1e-3
min_diag = a[::].min()
if min_diag > 0:
tau = eps
else:
tau = -min_diag + eps
while True:
try:
res = cn.linalg.solve(a + cn.eye(a.shape[0]) * tau, b)
get_legate_runtime().raise_exceptions()
if np.isnan(res).any():
raise np.linalg.LinAlgError
return res
except (np.linalg.LinAlgError, cn.linalg.LinAlgError):
tau = max(tau * 2, eps)
if tau > 1e10:
raise ValueError(
"Numerical instability in linear model solve. "
"Consider normalising your data."
)

0 comments on commit 29f62db

Please sign in to comment.