Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Expose ObjectiveFunction class #6586

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
43 changes: 43 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ typedef void* DatasetHandle; /*!< \brief Handle of dataset. */
typedef void* BoosterHandle; /*!< \brief Handle of booster. */
typedef void* FastConfigHandle; /*!< \brief Handle of FastConfig. */
typedef void* ByteBufferHandle; /*!< \brief Handle of ByteBuffer. */
typedef void* ObjectiveFunctionHandle; /*!< \brief Handle of ObjectiveFunction. */

#define C_API_DTYPE_FLOAT32 (0) /*!< \brief float32 (single precision float). */
#define C_API_DTYPE_FLOAT64 (1) /*!< \brief float64 (double precision float). */
Expand Down Expand Up @@ -1563,6 +1564,48 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetUpperBoundValue(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLowerBoundValue(BoosterHandle handle,
double* out_results);

/*!
* \brief Create an objective function.
* \param typ Type of the objective function
* \param parameter Parameters for the objective function
* \param[out] out Handle pointing to the created objective function
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionCreate(const char *typ,
const char *parameter,
ObjectiveFunctionHandle *out);

/*!
* \brief Initialize an objective function with the dataset.
* \param handle Handle of the objective function
* \param dataset Handle of the dataset used for initialization
* \param[out] num_data Number of data points; this may be modified within the function
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionInit(ObjectiveFunctionHandle handle,
DatasetHandle dataset,
int *num_data);

/*!
* \brief Evaluate the objective function given model scores.
* \param handle Handle of the objective function
* \param score Array of scores predicted by the model
* \param[out] grad Gradient result array
* \param[out] hess Hessian result array
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionEval(ObjectiveFunctionHandle handle,
const double* score,
float* grad,
float* hess);

/*!
* \brief Free the memory allocated for an objective function.
* \param handle Handle of the objective function
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionFree(ObjectiveFunctionHandle handle);

/*!
* \brief Initialize the network.
* \param machines List of machines in format 'ip1:port1,ip2:port2'
Expand Down
3 changes: 2 additions & 1 deletion python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pathlib import Path

from .basic import Booster, Dataset, Sequence, register_logger
from .basic import Booster, Dataset, ObjectiveFunction, Sequence, register_logger
from .callback import EarlyStopException, early_stopping, log_evaluation, record_evaluation, reset_parameter
from .engine import CVBooster, cv, train

Expand All @@ -31,6 +31,7 @@
__all__ = [
"Dataset",
"Booster",
"ObjectiveFunction",
"CVBooster",
"Sequence",
"register_logger",
Expand Down
124 changes: 124 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5281,3 +5281,127 @@ def __get_eval_info(self) -> None:
self.__higher_better_inner_eval = [
name.startswith(("auc", "ndcg@", "map@", "average_precision")) for name in self.__name_inner_eval
]


class ObjectiveFunction:
"""
ObjectiveFunction in LightGBM.

This class exposes the builtin objective functions for evaluating gradients and hessians
on external datasets. LightGBM does not use this wrapper during its training as it is
using the underlying C++ class.
"""
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, name: str, params: Dict[str, Any]):
"""
Initialize the ObjectiveFunction.

Parameters
----------
name : str
The name of the objective function.
params : dict
Dictionary of parameters for the objective function.
These are the parameters that would have been passed to ``booster.train``.
The ``name`` should be consistent with the ``params["objective"]`` field.
"""
self.name = name
self.params = params
self.num_data = None
self.num_class = params.get("num_class", 1)

if "objective" in params and params["objective"] != self.name:
raise ValueError('The name should be consistent with the params["objective"] field.')

self.__create()

def init(self, dataset: Dataset) -> "ObjectiveFunction":
"""
Initialize the objective function using the provided dataset.

Parameters
----------
dataset : Dataset
The dataset object used for initialization.

Returns
-------
self : ObjectiveFunction
Initialized objective function object.
"""
return self.__init_from_dataset(dataset)

def __call__(self, y_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Evaluate the objective function given model predictions.

Parameters
----------
y_pred : numpy.ndarray
Predicted scores from the model.

Returns
-------
(grad, hess) : Tuple[np.ndarray, np.ndarray]
A tuple containing gradients and Hessians.
"""
if self._handle is None:
raise ValueError("Objective function seems uninitialized")

if self.num_data is None or self.num_class is None:
raise ValueError("ObjectiveFunction was not created properly")

data_shape = self.num_data * self.num_class
grad = np.zeros(dtype=np.float32, shape=data_shape)
hess = np.zeros(dtype=np.float32, shape=data_shape)
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved

_safe_call(
_LIB.LGBM_ObjectiveFunctionEval(
self._handle,
y_pred.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
)
)

return (grad, hess)

def __create(self):
self._handle = ctypes.c_void_p()
_safe_call(
_LIB.LGBM_ObjectiveFunctionCreate(
_c_str(self.name),
_c_str(_param_dict_to_str(self.params)),
ctypes.byref(self._handle),
)
)

def __init_from_dataset(self, dataset: Dataset) -> "ObjectiveFunction":
if dataset._handle is None:
raise ValueError("Cannot create ObjectiveFunction from uninitialised Dataset")

if self._handle is None:
raise ValueError("Dealocated ObjectiveFunction cannot be initialized")

tmp_num_data = ctypes.c_int(0)
_safe_call(
_LIB.LGBM_ObjectiveFunctionInit(
self._handle,
dataset._handle,
ctypes.byref(tmp_num_data),
)
)
self.num_data = tmp_num_data.value
return self

def __del__(self) -> None:
try:
self._free_handle()
except AttributeError:
pass

def _free_handle(self) -> "ObjectiveFunction":
if self._handle is not None:
_safe_call(_LIB.LGBM_ObjectiveFunctionFree(self._handle))
self._handle = None
return self
51 changes: 51 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ using LightGBM::kZeroThreshold;
using LightGBM::LGBM_APIHandleException;
using LightGBM::Log;
using LightGBM::Network;
using LightGBM::ObjectiveFunction;
using LightGBM::Random;
using LightGBM::ReduceScatterFunction;
using LightGBM::SingleRowPredictor;
Expand Down Expand Up @@ -2747,6 +2748,56 @@ int LGBM_BoosterGetLowerBoundValue(BoosterHandle handle,
API_END();
}

LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionCreate(const char *typ,
const char *parameter,
ObjectiveFunctionHandle *out) {
API_BEGIN();
auto param = Config::Str2Map(parameter);
Config config(param);
if (config.device_type != std::string("cpu")) {
Log::Fatal("Currently the ObjectiveFunction class is only exposed for CPU devices.");
} else {
*out = ObjectiveFunction::CreateObjectiveFunction(std::string(typ), config);
}
API_END();
}

LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionInit(ObjectiveFunctionHandle handle,
DatasetHandle dataset,
int *num_data) {
API_BEGIN();
ObjectiveFunction* ref_fobj = reinterpret_cast<ObjectiveFunction*>(handle);
Dataset* ref_dataset = reinterpret_cast<Dataset*>(dataset);
ref_fobj->Init(ref_dataset->metadata(), ref_dataset->num_data());
*num_data = ref_dataset->num_data();
API_END();
}

LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionEval(ObjectiveFunctionHandle handle,
const double* score,
float* grad,
float* hess) {
API_BEGIN();
#ifdef SCORE_T_USE_DOUBLE
(void) handle; // UNUSED VARIABLE
(void) grad; // UNUSED VARIABLE
(void) hess; // UNUSED VARIABLE
Log::Fatal("Don't support evaluating objective function when SCORE_T_USE_DOUBLE is enabled");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require a huge amount of work on the python side, so I leave it as that. There is already precedence for that in the file.

#else
ObjectiveFunction* ref_fobj = reinterpret_cast<ObjectiveFunction*>(handle);
ref_fobj->GetGradients(score, grad, hess);
#endif
API_END();
}

/*!
*/
LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionFree(ObjectiveFunctionHandle handle) {
API_BEGIN();
delete reinterpret_cast<ObjectiveFunction*>(handle);
API_END();
}

int LGBM_NetworkInit(const char* machines,
int local_listen_port,
int listen_time_out,
Expand Down
61 changes: 53 additions & 8 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@

from .utils import (
SERIALIZERS,
builtin_objective,
dummy_obj,
load_breast_cancer,
load_digits,
load_iris,
logistic_sigmoid,
make_synthetic_regression,
mse_obj,
multiclass_custom_objective,
pickle_and_unpickle_object,
sklearn_multiclass_custom_objective,
softmax,
)

Expand Down Expand Up @@ -2926,12 +2927,6 @@ def test_default_objective_and_metric():

@pytest.mark.parametrize("use_weight", [True, False])
def test_multiclass_custom_objective(use_weight):
def custom_obj(y_pred, ds):
y_true = ds.get_label()
weight = ds.get_weight()
grad, hess = sklearn_multiclass_custom_objective(y_true, y_pred, weight)
return grad, hess

centers = [[-4, -4], [4, 4], [-4, 4]]
X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
weight = np.full_like(y, 2)
Expand All @@ -2942,7 +2937,7 @@ def custom_obj(y_pred, ds):
builtin_obj_bst = lgb.train(params, ds, num_boost_round=10)
builtin_obj_preds = builtin_obj_bst.predict(X)

params["objective"] = custom_obj
params["objective"] = multiclass_custom_objective
custom_obj_bst = lgb.train(params, ds, num_boost_round=10)
custom_obj_preds = softmax(custom_obj_bst.predict(X))

Expand Down Expand Up @@ -4397,3 +4392,53 @@ def test_quantized_training():
quant_bst = lgb.train(bst_params, ds, num_boost_round=10)
quant_rmse = np.sqrt(np.mean((quant_bst.predict(X) - y) ** 2))
assert quant_rmse < rmse + 6.0


@pytest.mark.parametrize("use_weight", [False, True])
@pytest.mark.parametrize(
"test_data",
[
{
"custom_objective": mse_obj,
"objective_name": "regression",
"df": make_synthetic_regression(),
"num_class": 1,
},
{
"custom_objective": multiclass_custom_objective,
"objective_name": "multiclass",
"df": make_blobs(n_samples=100, centers=[[-4, -4], [4, 4], [-4, 4]], random_state=42),
"num_class": 3,
},
],
)
@pytest.mark.parametrize("num_boost_round", [5, 15])
@pytest.mark.skipif(getenv("TASK", "") == "cuda", reason="Skip due to ObjectiveFunction not exposed for cuda devices.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why couldn't this also be exposed for the CUDA implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It segfaults on the CI tests, and I cannot build the CUDA version on MacOS.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where exactly does it segfault? 🤔 it seems like this should work 😅

def test_objective_function_class(use_weight, test_data, num_boost_round):
X, y = test_data["df"]
rng = np.random.default_rng()
weight = rng.choice([1, 2], y.shape) if use_weight else None
lgb_train = lgb.Dataset(X, y, weight=weight, init_score=np.zeros((len(y), test_data["num_class"])))

params = {
"verbose": -1,
"objective": test_data["objective_name"],
"num_class": test_data["num_class"],
"device": "cpu",
}
builtin_loss = builtin_objective(test_data["objective_name"], copy.deepcopy(params))

params["objective"] = builtin_loss
booster_exposed = lgb.train(params, lgb_train, num_boost_round=num_boost_round)

params["objective"] = test_data["objective_name"]
booster = lgb.train(params, lgb_train, num_boost_round=num_boost_round)

params["objective"] = test_data["custom_objective"]
booster_custom = lgb.train(params, lgb_train, num_boost_round=num_boost_round)

np.testing.assert_allclose(booster_exposed.predict(X), booster.predict(X, raw_score=True))
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
np.testing.assert_allclose(booster_exposed.predict(X), booster_custom.predict(X))

y_pred = np.zeros_like(booster.predict(X, raw_score=True))
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
np.testing.assert_allclose(builtin_loss(y_pred, lgb_train), test_data["custom_objective"](y_pred, lgb_train))
25 changes: 25 additions & 0 deletions tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def mse_obj(y_pred, dtrain):
y_true = dtrain.get_label()
grad = y_pred - y_true
hess = np.ones(len(grad))
if dtrain.get_weight() is not None:
grad *= dtrain.get_weight()
hess *= dtrain.get_weight()
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
return grad, hess


Expand Down Expand Up @@ -158,6 +161,28 @@ def sklearn_multiclass_custom_objective(y_true, y_pred, weight=None):
return grad, hess


def multiclass_custom_objective(y_pred, ds):
y_true = ds.get_label()
weight = ds.get_weight()
grad, hess = sklearn_multiclass_custom_objective(y_true, y_pred, weight)
return grad, hess


def builtin_objective(name, params):
"""Mimics the builtin objective functions to mock training."""

def wrapper(y_pred, dtrain):
fobj = lgb.ObjectiveFunction(name, params)
fobj.init(dtrain)
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
(grad, hess) = fobj(y_pred)
if fobj.num_class != 1:
grad = grad.reshape((fobj.num_class, -1)).transpose()
hess = hess.reshape((fobj.num_class, -1)).transpose()
return (grad, hess)

return wrapper


def pickle_obj(obj, filepath, serializer):
if serializer == "pickle":
with open(filepath, "wb") as f:
Expand Down
Loading