Skip to content

Commit

Permalink
Simplify checkpointing callback for Ray Train/Tune integration (#53)
Browse files Browse the repository at this point in the history
This PR removes some deprecated APIs and unifies the Tune integration to be centered around a single `TuneReportCheckpointCallback`.

---------

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Feb 4, 2024
1 parent 47fe3fc commit 4c4d341
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 172 deletions.
41 changes: 32 additions & 9 deletions lightgbm_ray/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,26 @@
concat_dataframes,
force_on_current_node,
get_current_placement_group,
is_session_enabled,
pickle,
)
from xgboost_ray.session import put_queue

from lightgbm_ray.tune import _try_add_tune_callback, _TuneLGBMRank0Mixin
from lightgbm_ray.util import find_free_port, is_port_free, lgbm_network_free

RAY_TUNE_INSTALLED = True

try:
import ray.train
import ray.tune
except (ImportError, ModuleNotFoundError):
RAY_TUNE_INSTALLED = False

if RAY_TUNE_INSTALLED:
from lightgbm_ray.tune import _try_add_tune_callback, _TuneLGBMRank0Mixin
else:
_try_add_tune_callback = _TuneLGBMRank0Mixin = None


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -125,6 +137,12 @@ def _check_cpus_per_actor_at_least_2(cpus_per_actor: int, suppress_exception: bo
)


def _in_ray_tune_session() -> bool:
return (
RAY_TUNE_INSTALLED and ray.train.get_context().get_trial_resources() is not None
)


def _get_data_dict(data: RayDMatrix, param: Dict) -> Dict:
if not LEGACY_MATRIX and isinstance(data, RayDeviceQuantileDMatrix):
# If we only got a single data shard, create a list so we can
Expand Down Expand Up @@ -411,9 +429,10 @@ def train(
callbacks.append(self._save_checkpoint_callback(is_rank_0=return_bst))
callbacks.append(self._stop_callback(is_rank_0=return_bst))
callbacks.append(record_evaluation(evals_result))
for callback in callbacks:
if isinstance(callback, _TuneLGBMRank0Mixin):
callback.is_rank_0 = return_bst
if RAY_TUNE_INSTALLED:
for callback in callbacks:
if isinstance(callback, _TuneLGBMRank0Mixin):
callback.is_rank_0 = return_bst
kwargs["callbacks"] = callbacks

if LIGHTGBM_VERSION < Version("3.3.0"):
Expand Down Expand Up @@ -1045,7 +1064,7 @@ def train(
os.environ.setdefault("RAY_IGNORE_UNHANDLED_ERRORS", "1")

if _remote is None:
_remote = _is_client_connected() and not is_session_enabled()
_remote = _is_client_connected() and not _in_ray_tune_session()

if not ray.is_initialized():
ray.init()
Expand Down Expand Up @@ -1153,7 +1172,11 @@ def _wrapped(*args, **kwargs):
"`dtrain = RayDMatrix(data=data, label=label)`.".format(type(dtrain))
)

added_tune_callback = _try_add_tune_callback(kwargs)
if RAY_TUNE_INSTALLED:
added_tune_callback = _try_add_tune_callback(kwargs)
else:
added_tune_callback = False

# LightGBM currently does not support elastic training.
if ray_params.elastic_training:
raise ValueError(
Expand Down Expand Up @@ -1270,7 +1293,7 @@ def _wrapped(*args, **kwargs):
evals.append((valid_data, f"valid_{i}"))

if evals:
for (deval, _name) in evals:
for deval, _name in evals:
if not isinstance(deval, RayDMatrix):
raise ValueError(
"Evaluation data must be a `RayDMatrix`, got " f"{type(deval)}."
Expand Down Expand Up @@ -1554,7 +1577,7 @@ def predict(
os.environ.setdefault("RAY_IGNORE_UNHANDLED_ERRORS", "1")

if _remote is None:
_remote = _is_client_connected() and not is_session_enabled()
_remote = _is_client_connected() and not _in_ray_tune_session()

if not ray.is_initialized():
ray.init()
Expand Down
10 changes: 9 additions & 1 deletion lightgbm_ray/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def start_client_server_5_cpus():
yield client


@pytest.fixture
def start_client_server_5_cpus_modin(monkeypatch):
monkeypatch.setenv("__MODIN_AUTOIMPORT_PANDAS__", "1")
ray.init(num_cpus=5, runtime_env={"env_vars": {"__MODIN_AUTOIMPORT_PANDAS__": "1"}})
with ray_start_client_server() as client:
yield client


def test_simple_train(start_client_server_4_cpus):
assert ray.util.client.ray.is_connected()
from lightgbm_ray.examples.simple import main
Expand All @@ -41,7 +49,7 @@ def test_simple_dask(start_client_server_5_cpus):
main(cpus_per_actor=2, num_actors=2)


def test_simple_modin(start_client_server_5_cpus):
def test_simple_modin(start_client_server_5_cpus_modin):
assert ray.util.client.ray.is_connected()
from lightgbm_ray.examples.simple_modin import main

Expand Down
73 changes: 13 additions & 60 deletions lightgbm_ray/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,17 @@
import shutil
import tempfile
import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import numpy as np
import ray
from ray import tune

try:
from ray.tune.integration.lightgbm import (
TuneReportCallback as OrigTuneReportCallback,
)
from ray.tune.integration.lightgbm import (
TuneReportCheckpointCallback as OrigTuneReportCheckpointCallback,
)
except ImportError:
OrigTuneReportCallback = OrigTuneReportCheckpointCallback = None

from lightgbm_ray import RayDMatrix, RayParams, RayShardingMode, train
from lightgbm_ray.tune import (
TuneReportCallback,
TuneReportCheckpointCallback,
_try_add_tune_callback,
from ray.tune.integration.lightgbm import (
TuneReportCheckpointCallback as OrigTuneReportCheckpointCallback,
)

try:
from ray.air import Checkpoint
except Exception:

class Checkpoint:
pass
from lightgbm_ray import RayDMatrix, RayParams, RayShardingMode, train
from lightgbm_ray.tune import TuneReportCheckpointCallback, _try_add_tune_callback


class LightGBMRayTuneTest(unittest.TestCase):
Expand Down Expand Up @@ -59,8 +41,8 @@ def setUp(self):
"num_boost_round": tune.choice([1, 3]),
}

def train_func(ray_params, callbacks=None, **kwargs):
def _inner_train(config, checkpoint_dir):
def train_func(ray_params, callbacks=None):
def _inner_train(config):
train_set = RayDMatrix(x, y, sharding=RayShardingMode.BATCH)
train(
config["lgbm"],
Expand All @@ -69,7 +51,6 @@ def _inner_train(config, checkpoint_dir):
num_boost_round=config["num_boost_round"],
evals=[(train_set, "train")],
callbacks=callbacks,
**kwargs
)

return _inner_train
Expand Down Expand Up @@ -116,41 +97,22 @@ def testNumItersClient(self):
self.assertTrue(ray.util.client.ray.is_connected())
self.testNumIters(init=False)

@unittest.skipIf(
OrigTuneReportCallback is None, "integration.lightgbmnot yet in ray.tune"
)
def testReplaceTuneCheckpoints(self):
"""Test if ray.tune.integration.lightgbm callbacks are replaced"""
ray.init(num_cpus=4)
# Report callback
in_cp = [OrigTuneReportCallback(metrics="met")]
in_dict = {"callbacks": in_cp}

with patch("lightgbm_ray.tune.is_session_enabled") as mocked:
mocked.return_value = True
_try_add_tune_callback(in_dict)

replaced = in_dict["callbacks"][0]
self.assertTrue(isinstance(replaced, TuneReportCallback))
self.assertSequenceEqual(replaced._metrics, ["met"])

# Report and checkpointing callback
in_cp = [OrigTuneReportCheckpointCallback(metrics="met", filename="test")]
in_cp = [OrigTuneReportCheckpointCallback(metrics="met")]
in_dict = {"callbacks": in_cp}

with patch("lightgbm_ray.tune.is_session_enabled") as mocked:
mocked.return_value = True
with patch("ray.train.get_context") as mocked:
mocked.return_value = MagicMock(return_value=True)
_try_add_tune_callback(in_dict)

replaced = in_dict["callbacks"][0]
self.assertTrue(isinstance(replaced, TuneReportCheckpointCallback))

if getattr(replaced, "_report", None):
self.assertSequenceEqual(replaced._report._metrics, ["met"])
self.assertEqual(replaced._checkpoint._filename, "test")
else:
self.assertSequenceEqual(replaced._metrics, ["met"])
self.assertEqual(replaced._filename, "test")
self.assertSequenceEqual(replaced._metrics, ["met"])

def testEndToEndCheckpointing(self):
ray.init(num_cpus=4)
Expand All @@ -168,14 +130,8 @@ def testEndToEndCheckpointing(self):
local_dir=self.experiment_dir,
)

if isinstance(analysis.best_checkpoint, Checkpoint):
self.assertTrue(analysis.best_checkpoint)
else:
self.assertTrue(os.path.exists(analysis.best_checkpoint))
self.assertTrue(os.path.exists(analysis.best_checkpoint.path))

@unittest.skipIf(
OrigTuneReportCallback is None, "integration.lightgbmnot yet in ray.tune"
)
def testEndToEndCheckpointingOrigTune(self):
ray.init(num_cpus=4)
ray_params = RayParams(cpus_per_actor=2, num_actors=1)
Expand All @@ -192,10 +148,7 @@ def testEndToEndCheckpointingOrigTune(self):
local_dir=self.experiment_dir,
)

if isinstance(analysis.best_checkpoint, Checkpoint):
self.assertTrue(analysis.best_checkpoint)
else:
self.assertTrue(os.path.exists(analysis.best_checkpoint))
self.assertTrue(os.path.exists(analysis.best_checkpoint.path))


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 4c4d341

Please sign in to comment.