Skip to content

Commit

Permalink
Update Ray core APIs (#35)
Browse files Browse the repository at this point in the history
* Update Ray core APIs

* Fix CI

Signed-off-by: Antoni Baum <[email protected]>

* Bump required xgboost-ray version

Signed-off-by: Antoni Baum <[email protected]>

* Fix

Signed-off-by: Antoni Baum <[email protected]>

Signed-off-by: Antoni Baum <[email protected]>
Co-authored-by: Antoni Baum <[email protected]>
  • Loading branch information
krfricke and Yard1 authored Aug 16, 2022
1 parent e3a35f7 commit 9797696
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
39 changes: 23 additions & 16 deletions lightgbm_ray/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

import ray
from ray.util.annotations import PublicAPI
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from xgboost_ray.main import (
_handle_queue, RayXGBoostActor, LEGACY_MATRIX, RayDeviceQuantileDMatrix,
Expand All @@ -60,7 +61,8 @@
_Checkpoint, _create_communication_processes, RayTaskError,
RayXGBoostActorAvailable, RayXGBoostTrainingError, _create_placement_group,
_shutdown, PlacementGroup, ActorHandle, combine_data, _trigger_data_load,
DEFAULT_PG, _autodetect_resources as _autodetect_resources_base)
DEFAULT_PG, _autodetect_resources as _autodetect_resources_base,
_ray_get_actor_cpus)
from xgboost_ray.session import put_queue
from xgboost_ray import RayDMatrix

Expand Down Expand Up @@ -329,9 +331,8 @@ def train(self, return_bst: bool, params: Dict[str, Any],
local_params = _choose_param_value(
main_param_name="num_threads",
params=params,
default_value=num_threads if num_threads > 0 else
sum(num
for _, num in ray.worker.get_resource_ids().get("CPU", [])))
default_value=num_threads
if num_threads > 0 else _ray_get_actor_cpus())

if "init_model" in kwargs:
if isinstance(kwargs["init_model"], bytes):
Expand Down Expand Up @@ -537,19 +538,23 @@ def _create_actor(
# Send DEFAULT_PG here, which changed in Ray > 1.4.0
# If we send `None`, this will ignore the parent placement group and
# lead to errors e.g. when used within Ray Tune
return _RemoteRayLightGBMActor.options(
actor_cls = _RemoteRayLightGBMActor.options(
num_cpus=num_cpus_per_actor,
num_gpus=num_gpus_per_actor,
resources=resources_per_actor,
placement_group_capture_child_tasks=True,
placement_group=placement_group or DEFAULT_PG).remote(
rank=rank,
num_actors=num_actors,
model_factory=model_factory,
queue=queue,
checkpoint_frequency=checkpoint_frequency,
distributed_callbacks=distributed_callbacks,
network_params={"local_listen_port": port} if port else None)
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group or DEFAULT_PG,
placement_group_capture_child_tasks=True,
))

return actor_cls.remote(
rank=rank,
num_actors=num_actors,
model_factory=model_factory,
queue=queue,
checkpoint_frequency=checkpoint_frequency,
distributed_callbacks=distributed_callbacks,
network_params={"local_listen_port": port} if port else None)


def _train(params: Dict,
Expand Down Expand Up @@ -734,7 +739,9 @@ def handle_actor_failure(actor_id):
# confilict, it can try and choose a new one. Most of the times
# it will complete in one iteration
machines = None
for n in range(5):
max_attempts = 5
i = 0
for i in range(max_attempts):
addresses = ray.get(
[actor.find_free_address.remote() for actor in live_actors])
if addresses:
Expand All @@ -750,7 +757,7 @@ def handle_actor_failure(actor_id):
else:
logger.debug("Couldn't obtain unique addresses, trying again.")
if machines:
logger.debug(f"Obtained unique addresses in {n} attempts.")
logger.debug(f"Obtained unique addresses in {i} attempts.")
else:
raise ValueError(
f"Couldn't obtain enough unique addresses for {len(live_actors)}."
Expand Down
20 changes: 17 additions & 3 deletions lightgbm_ray/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from lightgbm_ray.tune import TuneReportCallback,\
TuneReportCheckpointCallback, _try_add_tune_callback

try:
from ray.air import Checkpoint
except Exception:

class Checkpoint:
pass


class LightGBMRayTuneTest(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -145,7 +152,10 @@ def testEndToEndCheckpointing(self):
log_to_file=True,
local_dir=self.experiment_dir)

self.assertTrue(os.path.exists(analysis.best_checkpoint))
if isinstance(analysis.best_checkpoint, Checkpoint):
self.assertTrue(analysis.best_checkpoint)
else:
self.assertTrue(os.path.exists(analysis.best_checkpoint))

@unittest.skipIf(OrigTuneReportCallback is None,
"integration.lightgbmnot yet in ray.tune")
Expand All @@ -154,7 +164,8 @@ def testEndToEndCheckpointingOrigTune(self):
ray_params = RayParams(cpus_per_actor=2, num_actors=1)
analysis = tune.run(
self.train_func(
ray_params, callbacks=[OrigTuneReportCheckpointCallback()]),
ray_params,
callbacks=[OrigTuneReportCheckpointCallback(frequency=1)]),
config=self.params,
resources_per_trial=ray_params.get_tune_resources(),
num_samples=1,
Expand All @@ -163,7 +174,10 @@ def testEndToEndCheckpointingOrigTune(self):
log_to_file=True,
local_dir=self.experiment_dir)

self.assertTrue(os.path.exists(analysis.best_checkpoint))
if isinstance(analysis.best_checkpoint, Checkpoint):
self.assertTrue(analysis.best_checkpoint)
else:
self.assertTrue(os.path.exists(analysis.best_checkpoint))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
long_description="A distributed backend for LightGBM built on top of "
"distributed computing framework Ray.",
url="https://github.com/ray-project/lightgbm_ray",
install_requires=["lightgbm>=3.2.1", "xgboost_ray>=0.1.9"])
install_requires=["lightgbm>=3.2.1", "xgboost_ray>=0.1.10"])

0 comments on commit 9797696

Please sign in to comment.