Skip to content

Commit

Permalink
fix ray train session check
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu committed Jan 26, 2024
1 parent 005dc3e commit 2e5841a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
12 changes: 8 additions & 4 deletions lightgbm_ray/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ def _check_cpus_per_actor_at_least_2(cpus_per_actor: int, suppress_exception: bo
)


def _in_ray_tune_session() -> bool:
return (
RAY_TUNE_INSTALLED and ray.train.get_context().get_trial_resources() is not None
)


def _get_data_dict(data: RayDMatrix, param: Dict) -> Dict:
if not LEGACY_MATRIX and isinstance(data, RayDeviceQuantileDMatrix):
# If we only got a single data shard, create a list so we can
Expand Down Expand Up @@ -1058,8 +1064,7 @@ def train(
os.environ.setdefault("RAY_IGNORE_UNHANDLED_ERRORS", "1")

if _remote is None:
in_ray_tune_session = RAY_TUNE_INSTALLED and ray.train.get_context()
_remote = _is_client_connected() and not in_ray_tune_session
_remote = _is_client_connected() and not _in_ray_tune_session()

if not ray.is_initialized():
ray.init()
Expand Down Expand Up @@ -1572,8 +1577,7 @@ def predict(
os.environ.setdefault("RAY_IGNORE_UNHANDLED_ERRORS", "1")

if _remote is None:
in_ray_tune_session = RAY_TUNE_INSTALLED and ray.train.get_context()
_remote = _is_client_connected() and not in_ray_tune_session
_remote = _is_client_connected() and not _in_ray_tune_session()

if not ray.is_initialized():
ray.init()
Expand Down
4 changes: 3 additions & 1 deletion lightgbm_ray/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def __new__(cls: type, *args, **kwargs):


def _try_add_tune_callback(kwargs: Dict):
ray_train_context_initialized = ray.train.get_context()
ray_train_context_initialized = (
ray.train.get_context().get_trial_resources() is not None
)
if ray_train_context_initialized:
callbacks = kwargs.get("callbacks", []) or []
new_callbacks = []
Expand Down

0 comments on commit 2e5841a

Please sign in to comment.