diff --git a/eolearn/core/eoexecution.py b/eolearn/core/eoexecution.py index 6c7d83d8..23051692 100644 --- a/eolearn/core/eoexecution.py +++ b/eolearn/core/eoexecution.py @@ -204,13 +204,12 @@ def run(self, workers: int | None = 1, multiprocess: bool = True, **tqdm_kwargs: return full_execution_results - @classmethod def _run_execution( - cls, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams + self, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams ) -> list[WorkflowResults]: """Parallelizes the execution for each item of processing_args list.""" return parallelize( - cls._execute_workflow, + self._execute_workflow, processing_args, workers=run_params.workers, multiprocess=run_params.multiprocess, diff --git a/eolearn/core/extra/ray.py b/eolearn/core/extra/ray.py index ff2959ad..4ea10cca 100644 --- a/eolearn/core/extra/ray.py +++ b/eolearn/core/extra/ray.py @@ -10,15 +10,20 @@ """ from __future__ import annotations -from typing import Any, Callable, Collection, Generator, Iterable, List, TypeVar, cast +from logging import FileHandler, Filter +from typing import Any, Callable, Collection, Generator, Iterable, List, Sequence, TypeVar, cast + +from fs.base import FS + +from eolearn.core.eonode import EONode try: import ray except ImportError as exception: raise ImportError("This module requires an installation of Ray Python package") from exception -from ..eoexecution import EOExecutor, _ExecutionRunParams, _ProcessingData -from ..eoworkflow import WorkflowResults +from ..eoexecution import EOExecutor, _ExecutionRunParams, _HandlerFactoryType, _ProcessingData +from ..eoworkflow import EOWorkflow, WorkflowResults from ..utils.parallelize import _base_join_futures_iter # pylint: disable=invalid-name @@ -29,6 +34,33 @@ class RayExecutor(EOExecutor): """A special type of `EOExecutor` that works with Ray framework""" + def __init__( + self, + workflow: EOWorkflow, + execution_kwargs: Sequence[dict[EONode, dict[str, object]]], + *, + execution_names: list[str] | None = None, + save_logs: bool = False, + logs_folder: str = ".", + filesystem: FS | None = None, + logs_filter: Filter | None = None, + logs_handler_factory: _HandlerFactoryType = FileHandler, + raise_on_temporal_mismatch: bool = False, + ray_remote_kwargs: dict[str, Any] | None = None, + ): + super().__init__( + workflow, + execution_kwargs, + execution_names=execution_names, + save_logs=save_logs, + logs_folder=logs_folder, + filesystem=filesystem, + logs_filter=logs_filter, + logs_handler_factory=logs_handler_factory, + raise_on_temporal_mismatch=raise_on_temporal_mismatch, + ) + self.ray_remote_kwargs = ray_remote_kwargs + def run(self, **tqdm_kwargs: Any) -> list[WorkflowResults]: # type: ignore[override] """Runs the executor using a Ray cluster @@ -43,12 +75,13 @@ def run(self, **tqdm_kwargs: Any) -> list[WorkflowResults]: # type: ignore[over workers = ray.available_resources().get("CPU") return super().run(workers=workers, multiprocess=True, **tqdm_kwargs) - @classmethod def _run_execution( - cls, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams + self, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams ) -> list[WorkflowResults]: """Runs ray execution""" - futures = [_ray_workflow_executor.remote(workflow_args) for workflow_args in processing_args] + remote_kwargs = self.ray_remote_kwargs or {} + exec_func = _ray_workflow_executor.options(**remote_kwargs) # type: ignore[attr-defined] + futures = [exec_func.remote(workflow_args) for workflow_args in processing_args] return join_ray_futures(futures, **run_params.tqdm_kwargs) @@ -60,7 +93,10 @@ def _ray_workflow_executor(workflow_args: _ProcessingData) -> WorkflowResults: def parallelize_with_ray( - function: Callable[[InputType], OutputType], *params: Iterable[InputType], **tqdm_kwargs: Any + function: Callable[[InputType], OutputType], + *params: Iterable[InputType], + ray_remote_kwargs: dict[str, Any] | None = None, + **tqdm_kwargs: Any, ) -> list[OutputType]: """Parallelizes function execution with Ray. @@ -69,13 +105,15 @@ def parallelize_with_ray( :param function: A normal function that is not yet decorated by `ray.remote`. :param params: Iterables of parameters that will be used with given function. + :param ray_remote_kwargs: Keyword arguments passed to `ray.remote`. :param tqdm_kwargs: Keyword arguments that will be propagated to `tqdm` progress bar. :return: A list of results in the order that corresponds with the order of the given input `params`. """ + ray_remote_kwargs = ray_remote_kwargs or {} if not ray.is_initialized(): raise RuntimeError("Please initialize a Ray cluster before calling this method") - ray_function = ray.remote(function) + ray_function = ray.remote(function, **ray_remote_kwargs) futures = [ray_function.remote(*function_params) for function_params in zip(*params)] return join_ray_futures(futures, **tqdm_kwargs) diff --git a/tests/core/test_extra/test_ray.py b/tests/core/test_extra/test_ray.py index 679c28e2..02d721db 100644 --- a/tests/core/test_extra/test_ray.py +++ b/tests/core/test_extra/test_ray.py @@ -56,7 +56,7 @@ def filter(self, record): @pytest.fixture(name="_simple_cluster", scope="module") def _simple_cluster_fixture(): - ray.init(log_to_driver=False) + ray.init(log_to_driver=False, resources={"resourceA": 1}) yield ray.shutdown() @@ -103,6 +103,7 @@ def test_read_logs(filter_logs, execution_names, workflow, execution_kwargs): logs_folder=tmp_dir_name, logs_filter=CustomLogFilter() if filter_logs else None, execution_names=execution_names, + ray_remote_kwargs={"resources": {"resourceA": 0.5}}, ) executor.run()