Skip to content

Commit

Permalink
Fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 4, 2023
1 parent 863f5f6 commit 874e8da
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion legateboost/test/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_gamma_deviance() -> None:
rng = cn.random.default_rng(0)

X = rng.normal(size=(100, 10))
y = rng.gamma(1.0, 1.0, size=100)
y = rng.gamma(3.0, 1.0, size=100)
w = rng.uniform(0.0, 1.0, size=y.shape[0])

reg = lb.LBRegressor()
Expand Down
2 changes: 1 addition & 1 deletion legateboost/test/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_normal() -> None:
def test_gamma_deviance() -> None:
obj = lb.GammaDevianceObjective()
n_samples = 8196
with pytest.raises(ValueError, match="greater"):
with pytest.raises(ValueError, match="positive"):
y = cn.empty(shape=(n_samples,))
y[:] = -1
obj.initialise_prediction(y, None, True)
Expand Down
7 changes: 4 additions & 3 deletions legateboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ def sample_average(
Returns 0 if sum weight is zero or if the input is empty.
"""
if y.ndim > 2:
raise ValueError("Expecting a 1-dim or 2-dim input.")
if y.shape[0] == 0:
return cn.zeros(shape=(1,))
if sample_weight is None:
return cn.sum(y, axis=0) / np.full(
shape=y.shape[1], fill_value=float(y.shape[0])
)
n_columns = y.shape[1:] if y.ndim > 1 else 1
return cn.sum(y, axis=0) / cn.full(shape=n_columns, value=float(y.shape[0]))
if sample_weight.ndim > 1:
raise ValueError("Expecting 1-dim sample weight")
sum_w = sample_weight.sum()
Expand Down

0 comments on commit 874e8da

Please sign in to comment.