diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3ff0dac9..98475e06 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -33,4 +33,4 @@ jobs: - name: Lint run: | . py3/bin/activate - black --check --diff . + black -S --check --diff . --exclude='fed/grpc|py3' diff --git a/.isort.cfg b/.isort.cfg index 5f06acbb..eb16c8b5 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -9,6 +9,5 @@ use_parentheses=True float_to_top=True filter_files=True -known_local_folder=ray -known_third_party=grpc +known_local_folder=fed sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER \ No newline at end of file diff --git a/benchmarks/many_tiny_tasks_benchmark.py b/benchmarks/many_tiny_tasks_benchmark.py index 5fbbf387..ae450041 100644 --- a/benchmarks/many_tiny_tasks_benchmark.py +++ b/benchmarks/many_tiny_tasks_benchmark.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ray -import time import sys +import time + +import ray + import fed @@ -53,8 +55,8 @@ def main(party): if i % 100 == 0: print(f"Running {i}th call") print(f"num calls: {num_calls}") - print("total time (ms) = ", (time.time() - start)*1000) - print("per task overhead (ms) =", (time.time() - start)*1000/num_calls) + print("total time (ms) = ", (time.time() - start) * 1000) + print("per task overhead (ms) =", (time.time() - start) * 1000 / num_calls) fed.shutdown() ray.shutdown() diff --git a/fed/__init__.py b/fed/__init__.py index 7636dd4a..f7870bc4 100644 --- a/fed/__init__.py +++ b/fed/__init__.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fed.api import (get, init, kill, remote, - shutdown) -from fed.proxy.barriers import recv, send -from fed.fed_object import FedObject +from fed.api import get, init, kill, remote, shutdown from fed.exceptions import FedRemoteError +from fed.fed_object import FedObject +from fed.proxy.barriers import recv, send __all__ = [ "get", @@ -27,5 +26,5 @@ "recv", "send", "FedObject", - "FedRemoteError" + "FedRemoteError", ] diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index 83d2a098..ab5eb154 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -13,10 +13,11 @@ # limitations under the License. import abc -import ray -import fed._private.constants as fed_constants +import ray import ray.experimental.internal_kv as ray_internal_kv + +import fed._private.constants as fed_constants from fed._private import constants @@ -41,15 +42,14 @@ def _compare_version_strings(version1, version2): def _ray_version_less_than_2_0_0(): - """ Whther the current ray version is less 2.0.0. - """ + """Whther the current ray version is less 2.0.0.""" return _compare_version_strings( - fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__) + fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__ + ) def init_ray(address: str = None, **kwargs): - """A compatible API to init Ray. - """ + """A compatible API to init Ray.""" if address == 'local' and _ray_version_less_than_2_0_0(): # Ignore the `local` when ray < 2.0.0 ray.init(**kwargs) @@ -58,8 +58,7 @@ def init_ray(address: str = None, **kwargs): def _get_gcs_address_from_ray_worker(): - """A compatible API to get the gcs address from the ray worker module. - """ + """A compatible API to get the gcs address from the ray worker module.""" try: return ray._private.worker._global_node.gcs_address except AttributeError: @@ -67,19 +66,19 @@ def _get_gcs_address_from_ray_worker(): def wrap_kv_key(job_name, key: str): - """Add an prefix to the key to avoid conflict with other jobs. - """ - assert isinstance(key, str), \ - f"The key of KV data must be `str` type, got {type(key)}." + """Add an prefix to the key to avoid conflict with other jobs.""" + assert isinstance( + key, str + ), f"The key of KV data must be `str` type, got {type(key)}." - return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format( - job_name, key) + return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(job_name, key) class AbstractInternalKv(abc.ABC): - """ An abstract class that represents for bridging Ray internal kv in + """An abstract class that represents for bridging Ray internal kv in both Ray client mode and non Ray client mode. """ + def __init__(self) -> None: pass @@ -105,8 +104,8 @@ def reset(self): class InternalKv(AbstractInternalKv): - """The internal kv class for non Ray client mode. - """ + """The internal kv class for non Ray client mode.""" + def __init__(self, job_name: str) -> None: super().__init__() self._job_name = job_name @@ -120,21 +119,18 @@ def initialize(self): from ray._raylet import GcsClient gcs_client = GcsClient( - address=_get_gcs_address_from_ray_worker(), - nums_reconnect_retry=10) + address=_get_gcs_address_from_ray_worker(), nums_reconnect_retry=10 + ) return ray_internal_kv._initialize_internal_kv(gcs_client) def put(self, k, v): - return ray_internal_kv._internal_kv_put( - wrap_kv_key(self._job_name, k), v) + return ray_internal_kv._internal_kv_put(wrap_kv_key(self._job_name, k), v) def get(self, k): - return ray_internal_kv._internal_kv_get( - wrap_kv_key(self._job_name, k)) + return ray_internal_kv._internal_kv_get(wrap_kv_key(self._job_name, k)) def delete(self, k): - return ray_internal_kv._internal_kv_del( - wrap_kv_key(self._job_name, k)) + return ray_internal_kv._internal_kv_del(wrap_kv_key(self._job_name, k)) def reset(self): return ray_internal_kv._internal_kv_reset() @@ -144,8 +140,8 @@ def _ping(self): class ClientModeInternalKv(AbstractInternalKv): - """The internal kv class for Ray client mode. - """ + """The internal kv class for Ray client mode.""" + def __init__(self) -> None: super().__init__() self._internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR") @@ -176,9 +172,13 @@ def _init_internal_kv(job_name): global kv if kv is None: from ray._private.client_mode_hook import is_client_mode_enabled + if is_client_mode_enabled: - kv_actor = ray.remote(InternalKv).options( - name="_INTERNAL_KV_ACTOR").remote(job_name) + kv_actor = ( + ray.remote(InternalKv) + .options(name="_INTERNAL_KV_ACTOR") + .remote(job_name) + ) response = kv_actor._ping.remote() ray.get(response) kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name) @@ -192,6 +192,7 @@ def _clear_internal_kv(): kv.delete(constants.KEY_OF_JOB_CONFIG) kv.reset() from ray._private.client_mode_hook import is_client_mode_enabled + if is_client_mode_enabled: _internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR") ray.kill(_internal_kv_actor) diff --git a/fed/_private/constants.py b/fed/_private/constants.py index f21f3d06..a77201b7 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -27,7 +27,7 @@ KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT" -RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa +RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S" diff --git a/fed/_private/fed_actor.py b/fed/_private/fed_actor.py index dd880579..6e84bfa2 100644 --- a/fed/_private/fed_actor.py +++ b/fed/_private/fed_actor.py @@ -16,6 +16,7 @@ import ray from ray.util.client.common import ClientActorHandle + from fed._private.fed_call_holder import FedCallHolder from fed.fed_object import FedObject @@ -90,19 +91,17 @@ def _execute_impl(self, cls_args, cls_kwargs): ) def _execute_remote_method( - self, - method_name, - options, - _ray_wrappered_method, - args, - kwargs, + self, + method_name, + options, + _ray_wrappered_method, + args, + kwargs, ): num_returns = 1 if options and 'num_returns' in options: num_returns = options['num_returns'] - logger.debug( - f"Actor method call: {method_name}, num_returns: {num_returns}" - ) + logger.debug(f"Actor method call: {method_name}, num_returns: {num_returns}") return _ray_wrappered_method.options( name='', diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index 9e349c46..80c2a350 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -14,15 +14,16 @@ import logging -# Set config in the very beginning to avoid being overwritten by other packages. -logging.basicConfig(level=logging.INFO) - +import fed.config as fed_config from fed._private.global_context import get_global_context -from fed.proxy.barriers import send from fed.fed_object import FedObject -from fed.utils import resolve_dependencies +from fed.proxy.barriers import send from fed.tree_util import tree_flatten -import fed.config as fed_config +from fed.utils import resolve_dependencies + +# Set config in the very beginning to avoid being overwritten by other packages. +logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index cd1337fc..8e5c5506 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -12,22 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fed.cleanup import CleanupManager -from typing import Callable import threading +from typing import Callable + +from fed.cleanup import CleanupManager class GlobalContext: - def __init__(self, job_name: str, - current_party: str, - failure_handler: Callable[[], None]) -> None: + def __init__( + self, job_name: str, current_party: str, failure_handler: Callable[[], None] + ) -> None: self._job_name = job_name self._seq_count = 0 self._failure_handler = failure_handler self._atomic_shutdown_flag_lock = threading.Lock() self._atomic_shutdown_flag = True self._cleanup_manager = CleanupManager( - current_party, self.acquire_shutdown_flag) + current_party, self.acquire_shutdown_flag + ) def next_seq_id(self) -> int: self._seq_count += 1 @@ -65,9 +67,9 @@ def acquire_shutdown_flag(self) -> bool: _global_context = None -def init_global_context(current_party: str, - job_name: str, - failure_handler: Callable[[], None] = None) -> None: +def init_global_context( + current_party: str, job_name: str, failure_handler: Callable[[], None] = None +) -> None: global _global_context if _global_context is None: _global_context = GlobalContext(job_name, current_party, failure_handler) diff --git a/fed/_private/message_queue.py b/fed/_private/message_queue.py index c4a1b6eb..9fdff7dc 100644 --- a/fed/_private/message_queue.py +++ b/fed/_private/message_queue.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import threading -from collections import deque import time -import logging - +from collections import deque logger = logging.getLogger(__name__) @@ -55,7 +54,8 @@ def _loop(): if self._thread is None or not self._thread.is_alive(): logger.debug( - f"Starting new thread[{self._thread_name}] for message polling.") + f"Starting new thread[{self._thread_name}] for message polling." + ) self._queue = deque() self._thread = threading.Thread(target=_loop, name=self._thread_name) self._thread.start() @@ -79,9 +79,11 @@ def stop(self): If False: forcelly kill the for-loop sub-thread. """ if threading.current_thread() == self._thread: - logger.error(f"Can't stop the message queue in the message " - f"polling thread[{self._thread_name}]. Ignore it as this" - f"could bring unknown time sequence problems.") + logger.error( + f"Can't stop the message queue in the message " + f"polling thread[{self._thread_name}]. Ignore it as this" + f"could bring unknown time sequence problems." + ) raise RuntimeError("Thread can't kill itself") # TODO(NKcqx): Force kill sub-thread by calling `._stop()` will diff --git a/fed/_private/serialization_utils.py b/fed/_private/serialization_utils.py index c1b2e73e..0c0f3c61 100644 --- a/fed/_private/serialization_utils.py +++ b/fed/_private/serialization_utils.py @@ -13,11 +13,11 @@ # limitations under the License. import io + import cloudpickle import fed.config as fed_config - _pickle_whitelist = None diff --git a/fed/cleanup.py b/fed/cleanup.py index d68b290b..91a0739a 100644 --- a/fed/cleanup.py +++ b/fed/cleanup.py @@ -16,11 +16,12 @@ import os import signal import threading -from fed._private.message_queue import MessageQueueManager -from fed.exceptions import FedRemoteError -from ray.exceptions import RayError import ray +from ray.exceptions import RayError + +from fed._private.message_queue import MessageQueueManager +from fed.exceptions import FedRemoteError logger = logging.getLogger(__name__) @@ -44,11 +45,13 @@ class CleanupManager: def __init__(self, current_party, acquire_shutdown_flag) -> None: self._sending_data_q = MessageQueueManager( lambda msg: self._process_data_sending_task_return(msg), - thread_name='DataSendingQueueThread') + thread_name='DataSendingQueueThread', + ) self._sending_error_q = MessageQueueManager( lambda msg: self._process_error_sending_task_return(msg), - thread_name="ErrorSendingQueueThread") + thread_name="ErrorSendingQueueThread", + ) self._monitor_thread = None @@ -80,12 +83,14 @@ def stop(self): self._sending_data_q.stop() self._sending_error_q.stop() - def push_to_sending(self, - obj_ref: ray.ObjectRef, - dest_party: str = None, - upstream_seq_id: int = -1, - downstream_seq_id: int = -1, - is_error: bool = False): + def push_to_sending( + self, + obj_ref: ray.ObjectRef, + dest_party: str = None, + upstream_seq_id: int = -1, + downstream_seq_id: int = -1, + is_error: bool = False, + ): """ Push the sending remote task's return value, i.e. `obj_ref` to the corresponding message queue. @@ -104,7 +109,7 @@ def push_to_sending(self, queue instead. """ msg_pack = (obj_ref, dest_party, upstream_seq_id, downstream_seq_id) - if (is_error): + if is_error: self._sending_error_q.append(msg_pack) else: self._sending_data_q.append(msg_pack) @@ -123,7 +128,7 @@ def _signal_exit(self): # will cause dead lock. In order to ensure executing `shutdown` exactly # once and avoid dead lock, the lock must be checked before sending # signals. - if (self._acquire_shutdown_flag()): + if self._acquire_shutdown_flag(): logger.debug("Signal SIGINT to exit.") os.kill(os.getpid(), signal.SIGINT) @@ -151,16 +156,24 @@ def _process_data_sending_task_return(self, message): try: res = ray.get(obj_ref) except Exception as e: - logger.warn(f'Failed to send {obj_ref} to {dest_party}, error: {e},' - f'upstream_seq_id: {upstream_seq_id}, ' - f'downstream_seq_id: {downstream_seq_id}.') - if (isinstance(e, RayError)): + logger.warn( + f'Failed to send {obj_ref} to {dest_party}, error: {e},' + f'upstream_seq_id: {upstream_seq_id}, ' + f'downstream_seq_id: {downstream_seq_id}.' + ) + if isinstance(e, RayError): logger.info(f"Sending error {e.cause} to {dest_party}.") from fed.proxy.barriers import send + # TODO(NKcqx): Cascade broadcast to all parties error_trace = e.cause if self._expose_error_trace else None - send(dest_party, FedRemoteError(self._current_party, error_trace), - upstream_seq_id, downstream_seq_id, True) + send( + dest_party, + FedRemoteError(self._current_party, error_trace), + upstream_seq_id, + downstream_seq_id, + True, + ) res = False @@ -183,10 +196,12 @@ def _process_error_sending_task_return(self, error_msg): res = False if not res: - logger.warning(f"Failed to send error {error_ref} to {dest_party}, " - f"upstream_seq_id: {upstream_seq_id} " - f"downstream_seq_id: {downstream_seq_id}. " - "In this case, other parties won't sense " - "this error and may cause unknown behaviour.") + logger.warning( + f"Failed to send error {error_ref} to {dest_party}, " + f"upstream_seq_id: {upstream_seq_id} " + f"downstream_seq_id: {downstream_seq_id}. " + "In this case, other parties won't sense " + "this error and may cause unknown behaviour." + ) # Return True so that remaining error objects can be sent return True diff --git a/fed/exceptions.py b/fed/exceptions.py index dad4abfc..bd21f9a6 100644 --- a/fed/exceptions.py +++ b/fed/exceptions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + class FedRemoteError(Exception): def __init__(self, src_party: str, cause: Exception) -> None: self._src_party = src_party diff --git a/fed/fed_object.py b/fed/fed_object.py index 6e62faa9..aaa7e1dd 100644 --- a/fed/fed_object.py +++ b/fed/fed_object.py @@ -17,6 +17,7 @@ class FedObjectSendingContext: """The class that's used for holding the all contexts about sending side.""" + def __init__(self) -> None: # This field holds the target(downstream) parties that this fed object # is sending or sent to. @@ -33,6 +34,7 @@ def was_sending_or_sent_to_party(self, target_party: str): class FedObjectReceivingContext: """The class that's used for holding the all contexts about receiving side.""" + pass diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index aa207ae4..ee614463 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -51,9 +51,9 @@ def set_receiver_proxy_actor_name(name: str): _RECEIVER_PROXY_ACTOR_NAME = name -def set_proxy_actor_name(job_name: str, - use_global_proxy: bool, - sender_recvr_proxy: bool = False): +def set_proxy_actor_name( + job_name: str, use_global_proxy: bool, sender_recvr_proxy: bool = False +): """ Generate the name of the proxy actor. @@ -136,7 +136,8 @@ def __init__( job_config = fed_config.get_job_config(job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance: SenderProxy = proxy_cls( - addresses, party, job_name, tls_config, cross_silo_comm_config) + addresses, party, job_name, tls_config, cross_silo_comm_config + ) async def is_ready(self): res = await self._proxy_instance.is_ready() @@ -207,7 +208,8 @@ def __init__( job_config = fed_config.get_job_config(job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance: ReceiverProxy = proxy_cls( - listening_address, party, job_name, tls_config, cross_silo_comm_config) + listening_address, party, job_name, tls_config, cross_silo_comm_config + ) async def start(self): await self._proxy_instance.start() @@ -222,9 +224,11 @@ async def get_data(self, src_party, upstream_seq_id, curr_seq_id): src_party, upstream_seq_id, curr_seq_id ) if isinstance(data, Exception): - logger.debug(f"Receiving exception: {type(data)}, {data} from {src_party}, " - f"upstream_seq_id: {upstream_seq_id}, " - f"curr_seq_id: {curr_seq_id}. Re-raise it.") + logger.debug( + f"Receiving exception: {type(data)}, {data} from {src_party}, " + f"upstream_seq_id: {upstream_seq_id}, " + f"curr_seq_id: {curr_seq_id}. Re-raise it." + ) raise data return data @@ -437,13 +441,14 @@ def _start_sender_receiver_proxy( global _SENDER_RECEIVER_PROXY_ACTOR _SENDER_RECEIVER_PROXY_ACTOR = SenderReceiverProxyActor.options( - **actor_options).remote( - addresses=addresses, - party=party, - job_name=job_name, - tls_config=tls_config, - logging_level=logging_level, - proxy_cls=proxy_cls, + **actor_options + ).remote( + addresses=addresses, + party=party, + job_name=job_name, + tls_config=tls_config, + logging_level=logging_level, + proxy_cls=proxy_cls, ) _SENDER_RECEIVER_PROXY_ACTOR.start.remote() server_state = ray.get( @@ -453,13 +458,7 @@ def _start_sender_receiver_proxy( logger.info("Succeeded to create receiver proxy actor.") -def send( - dest_party, - data, - upstream_seq_id, - downstream_seq_id, - is_error=False -): +def send(dest_party, data, upstream_seq_id, downstream_seq_id, is_error=False): """ Args: is_error: Whether the `data` is an error object or not. Default is False. @@ -473,7 +472,8 @@ def send( downstream_seq_id=downstream_seq_id, ) get_global_context().get_cleanup_manager().push_to_sending( - res, dest_party, upstream_seq_id, downstream_seq_id, is_error) + res, dest_party, upstream_seq_id, downstream_seq_id, is_error + ) return res diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py index b2eba265..c62285b8 100644 --- a/fed/proxy/base_proxy.py +++ b/fed/proxy/base_proxy.py @@ -51,7 +51,7 @@ def __init__( party: str, job_name: str, tls_config: Dict, - proxy_config: CrossSiloMessageConfig = None + proxy_config: CrossSiloMessageConfig = None, ) -> None: self._listen_addr = listen_addr self._party = party diff --git a/fed/proxy/grpc/grpc_options.py b/fed/proxy/grpc/grpc_options.py index 6e4b2d14..064824b3 100644 --- a/fed/proxy/grpc/grpc_options.py +++ b/fed/proxy/grpc/grpc_options.py @@ -14,7 +14,6 @@ import json - _GRPC_SERVICE = "GrpcService" _DEFAULT_GRPC_RETRY_POLICY = { @@ -34,17 +33,16 @@ 'grpc.so_reuseport': 0, 'grpc.max_send_message_length': _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH, 'grpc.max_receive_message_length': _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH, - 'grpc.service_config': - json.dumps( - { - 'methodConfig': [ - { - 'name': [{'service': _GRPC_SERVICE}], - 'retryPolicy': _DEFAULT_GRPC_RETRY_POLICY, - } - ] - } - ), + 'grpc.service_config': json.dumps( + { + 'methodConfig': [ + { + 'name': [{'service': _GRPC_SERVICE}], + 'retryPolicy': _DEFAULT_GRPC_RETRY_POLICY, + } + ] + } + ), } diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 1fdfa4e2..e47c62dd 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -14,27 +14,29 @@ import asyncio import copy -import cloudpickle -import grpc +import json import logging import threading -import json from typing import Dict -import fed.utils as fed_utils +import cloudpickle +import grpc -from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig import fed._private.compatible_utils as compatible_utils -from fed.proxy.grpc.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE +import fed.utils as fed_utils +from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig from fed.proxy.barriers import ( add_two_dim_dict, get_from_two_dim_dict, - pop_from_two_dim_dict, key_exists_in_two_dim_dict, + pop_from_two_dim_dict, ) -from fed.proxy.base_proxy import SenderProxy, ReceiverProxy +from fed.proxy.base_proxy import ReceiverProxy, SenderProxy +from fed.proxy.grpc.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE + if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): + fed_utils.get_package_version('protobuf'), '4.0.0' +): from fed.grpc.pb4 import fed_pb2 as fed_pb2 from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc else: @@ -67,43 +69,44 @@ def parse_grpc_options(proxy_config: CrossSiloMessageConfig): # However, `GrpcCrossSiloMessageConfig` provides a more flexible way # to configure grpc channel options, i.e. the `grpc_channel_options` # field, which may override the `messages_max_size_in_bytes` field. - if (isinstance(proxy_config, CrossSiloMessageConfig)): - if (proxy_config.messages_max_size_in_bytes is not None): - grpc_channel_options.update({ - 'grpc.max_send_message_length': - proxy_config.messages_max_size_in_bytes, - 'grpc.max_receive_message_length': - proxy_config.messages_max_size_in_bytes, - }) + if isinstance(proxy_config, CrossSiloMessageConfig): + if proxy_config.messages_max_size_in_bytes is not None: + grpc_channel_options.update( + { + 'grpc.max_send_message_length': proxy_config.messages_max_size_in_bytes, + 'grpc.max_receive_message_length': proxy_config.messages_max_size_in_bytes, + } + ) if isinstance(proxy_config, GrpcCrossSiloMessageConfig): if proxy_config.grpc_channel_options is not None: grpc_channel_options.update(proxy_config.grpc_channel_options) if proxy_config.grpc_retry_policy is not None: - grpc_channel_options.update({ - 'grpc.service_config': - json.dumps( - { - 'methodConfig': [ - { - 'name': [{'service': _GRPC_SERVICE}], - 'retryPolicy': proxy_config.grpc_retry_policy, - } - ] - } - ), - }) + grpc_channel_options.update( + { + 'grpc.service_config': json.dumps( + { + 'methodConfig': [ + { + 'name': [{'service': _GRPC_SERVICE}], + 'retryPolicy': proxy_config.grpc_retry_policy, + } + ] + } + ), + } + ) return grpc_channel_options class GrpcSenderProxy(SenderProxy): def __init__( - self, - cluster: Dict, - party: str, - job_name: str, - tls_config: Dict, - proxy_config: Dict = None + self, + cluster: Dict, + party: str, + job_name: str, + tls_config: Dict, + proxy_config: Dict = None, ) -> None: proxy_config = GrpcCrossSiloMessageConfig.from_dict(proxy_config) super().__init__(cluster, party, job_name, tls_config, proxy_config) @@ -113,29 +116,27 @@ def __init__( # Mapping the destination party name to the reused client stub. self._stubs = {} - async def send( - self, - dest_party, - data, - upstream_seq_id, - downstream_seq_id): + async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): dest_addr = self._addresses[dest_party] grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) tls_enabled = fed_utils.tls_enabled(self._tls_config) if dest_party not in self._stubs: if tls_enabled: ca_cert, private_key, cert_chain = fed_utils.load_cert_config( - self._tls_config) + self._tls_config + ) credentials = grpc.ssl_channel_credentials( certificate_chain=cert_chain, private_key=private_key, root_certificates=ca_cert, ) channel = grpc.aio.secure_channel( - dest_addr, credentials, options=grpc_channel_options) + dest_addr, credentials, options=grpc_channel_options + ) else: channel = grpc.aio.insecure_channel( - dest_addr, options=grpc_channel_options) + dest_addr, options=grpc_channel_options + ) stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub @@ -153,8 +154,7 @@ async def send( return response.result def get_grpc_config_by_party(self, dest_party): - """Overide global config by party specific config - """ + """Overide global config by party specific config""" grpc_metadata = self._grpc_metadata grpc_options = self._grpc_options @@ -162,14 +162,9 @@ def get_grpc_config_by_party(self, dest_party): if dest_party_msg_config is not None: if dest_party_msg_config.http_header is not None: dest_party_grpc_metadata = dict(dest_party_msg_config.http_header) - grpc_metadata = { - **grpc_metadata, - **dest_party_grpc_metadata - } + grpc_metadata = {**grpc_metadata, **dest_party_grpc_metadata} dest_party_grpc_options = parse_grpc_options(dest_party_msg_config) - grpc_options = { - **grpc_options, **dest_party_grpc_options - } + grpc_options = {**grpc_options, **dest_party_grpc_options} return grpc_metadata, fed_utils.dict2tuple(grpc_options) async def get_proxy_config(self, dest_party=None): @@ -188,8 +183,10 @@ def handle_response_error(self, response): if 400 <= response.code < 500: # Request error should also be identified as a sending failure, # though the request was physically sent. - logger.warning(f"Request was successfully sent but got error response, " - f"code: {response.code}, message: {response.result}.") + logger.warning( + f"Request was successfully sent but got error response, " + f"code: {response.code}, message: {response.result}." + ) raise RuntimeError(response.result) @@ -225,12 +222,12 @@ async def send_data_grpc( class GrpcReceiverProxy(ReceiverProxy): def __init__( - self, - listen_addr: str, - party: str, - job_name: str, - tls_config: Dict, - proxy_config: Dict + self, + listen_addr: str, + party: str, + job_name: str, + tls_config: Dict, + proxy_config: Dict, ) -> None: proxy_config = GrpcCrossSiloMessageConfig.from_dict(proxy_config) super().__init__(listen_addr, party, job_name, tls_config, proxy_config) @@ -260,9 +257,11 @@ async def start(self): fed_utils.dict2tuple(self._grpc_options), ) except RuntimeError as err: - msg = f'Grpc server failed to listen to port: {port}' \ - f' Try another port by setting `listen_addr` into `cluster` config' \ - f' when calling `fed.init`. Grpc error msg: {err}' + msg = ( + f'Grpc server failed to listen to port: {port}' + f' Try another port by setting `listen_addr` into `cluster` config' + f' when calling `fed.init`. Grpc error msg: {err}' + ) self._server_ready_future.set_result((False, msg)) async def is_ready(self): @@ -289,6 +288,7 @@ async def get_data(self, src_party, upstream_seq_id, curr_seq_id): # NOTE(qwang): This is used to avoid the conflict with pickle5 in Ray. import fed._private.serialization_utils as fed_ser_utils + fed_ser_utils._apply_loads_function_with_whitelist() return cloudpickle.loads(data) @@ -309,12 +309,15 @@ def __init__(self, all_events, all_data, party, lock, job_name): async def SendData(self, request, context): job_name = request.job_name if job_name != self._job_name: - logger.warning(f"Receive data from job {job_name}, ignore it. " - f"The reason may be that the ReceiverProxy is listening " - f"on the same address with that job.") + logger.warning( + f"Receive data from job {job_name}, ignore it. " + f"The reason may be that the ReceiverProxy is listening " + f"on the same address with that job." + ) return fed_pb2.SendDataResponse( code=417, - result=f"JobName mis-match, expected {self._job_name}, got {job_name}.") + result=f"JobName mis-match, expected {self._job_name}, got {job_name}.", + ) upstream_seq_id = request.upstream_seq_id downstream_seq_id = request.downstream_seq_id logger.debug( @@ -340,8 +343,15 @@ async def SendData(self, request, context): async def _run_grpc_server( - port, event, all_data, party, lock, job_name, - server_ready_future, tls_config=None, grpc_options=None + port, + event, + all_data, + party, + lock, + job_name, + server_ready_future, + tls_config=None, + grpc_options=None, ): logger.info(f"ReceiverProxy binding port {port}, options: {grpc_options}...") server = grpc.aio.server(options=grpc_options) diff --git a/fed/tests/client_mode_tests/test_basic_client_mode.py b/fed/tests/client_mode_tests/test_basic_client_mode.py index 98078029..35f3d851 100644 --- a/fed/tests/client_mode_tests/test_basic_client_mode.py +++ b/fed/tests/client_mode_tests/test_basic_client_mode.py @@ -16,9 +16,10 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils -from fed.tests.test_utils import ray_client_mode_setup # noqa +from fed.tests.test_utils import ray_client_mode_setup # noqa @fed.remote @@ -49,10 +50,13 @@ def mean(x, y): def run(party): import time + if party == 'alice': time.sleep(1.4) - address = 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' # noqa + address = ( + 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' + ) # noqa compatible_utils.init_ray(address=address) addresses = { @@ -83,7 +87,7 @@ def run(party): ray.shutdown() -def test_fed_get_in_2_parties(ray_client_mode_setup): # noqa +def test_fed_get_in_2_parties(ray_client_mode_setup): # noqa p_alice = multiprocessing.Process(target=run, args=('alice',)) p_bob = multiprocessing.Process(target=run, args=('bob',)) p_alice.start() diff --git a/fed/tests/multi-jobs/test_ignore_other_job_msg.py b/fed/tests/multi-jobs/test_ignore_other_job_msg.py index a75d8f78..d83acff3 100644 --- a/fed/tests/multi-jobs/test_ignore_other_job_msg.py +++ b/fed/tests/multi-jobs/test_ignore_other_job_msg.py @@ -13,32 +13,30 @@ # limitations under the License. import multiprocessing -import fed -import ray + import grpc import pytest -import fed.utils as fed_utils +import ray + +import fed import fed._private.compatible_utils as compatible_utils +import fed.utils as fed_utils from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, send_data_grpc + if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): + fed_utils.get_package_version('protobuf'), '4.0.0' +): from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc else: from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc class TestGrpcSenderProxy(GrpcSenderProxy): - async def send( - self, - dest_party, - data, - upstream_seq_id, - downstream_seq_id): + async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): dest_addr = self._addresses[dest_party] grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) if dest_party not in self._stubs: - channel = grpc.aio.insecure_channel( - dest_addr, options=grpc_channel_options) + channel = grpc.aio.insecure_channel(dest_addr, options=grpc_channel_options) stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub @@ -81,14 +79,17 @@ def agg_fn(obj1, obj2): def run(party, job_name): ray.init(address='local') - fed.init(addresses=addresses, - party=party, - job_name=job_name, - sender_proxy_cls=TestGrpcSenderProxy, - config={ - 'cross_silo_comm': { - 'exit_on_sending_failure': True, - }}) + fed.init( + addresses=addresses, + party=party, + job_name=job_name, + sender_proxy_cls=TestGrpcSenderProxy, + config={ + 'cross_silo_comm': { + 'exit_on_sending_failure': True, + } + }, + ) # 'bob' only needs to start the proxy actors if party == 'alice': ds1, ds2 = [123, 789] @@ -103,6 +104,7 @@ def run(party, job_name): fed.shutdown() ray.shutdown() import time + # Wait for SIGTERM as failure on sending. time.sleep(86400) diff --git a/fed/tests/multi-jobs/test_multi_proxy_actor.py b/fed/tests/multi-jobs/test_multi_proxy_actor.py index 5021ec0d..a540086c 100644 --- a/fed/tests/multi-jobs/test_multi_proxy_actor.py +++ b/fed/tests/multi-jobs/test_multi_proxy_actor.py @@ -13,32 +13,30 @@ # limitations under the License. import multiprocessing -import fed -import ray + import grpc import pytest -import fed.utils as fed_utils +import ray + +import fed import fed._private.compatible_utils as compatible_utils +import fed.utils as fed_utils from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, send_data_grpc + if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): + fed_utils.get_package_version('protobuf'), '4.0.0' +): from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc else: from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc class TestGrpcSenderProxy(GrpcSenderProxy): - async def send( - self, - dest_party, - data, - upstream_seq_id, - downstream_seq_id): + async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): dest_addr = self._addresses[dest_party] grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) if dest_party not in self._stubs: - channel = grpc.aio.insecure_channel( - dest_addr, options=grpc_channel_options) + channel = grpc.aio.insecure_channel(dest_addr, options=grpc_channel_options) stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub @@ -87,16 +85,19 @@ def agg_fn(obj1, obj2): def run(party, job_name): ray.init(address='local') - fed.init(addresses=addresses[job_name], - party=party, - job_name=job_name, - sender_proxy_cls=TestGrpcSenderProxy, - config={ - 'cross_silo_comm': { - 'exit_on_sending_failure': True, - # Create unique proxy for current job - 'use_global_proxy': False - }}) + fed.init( + addresses=addresses[job_name], + party=party, + job_name=job_name, + sender_proxy_cls=TestGrpcSenderProxy, + config={ + 'cross_silo_comm': { + 'exit_on_sending_failure': True, + # Create unique proxy for current job + 'use_global_proxy': False, + } + }, + ) sender_proxy_actor_name = f"SenderProxyActor_{job_name}" receiver_proxy_actor_name = f"ReceiverProxyActor_{job_name}" diff --git a/fed/tests/simple_example.py b/fed/tests/simple_example.py index 4ca1095f..88c9ce00 100644 --- a/fed/tests/simple_example.py +++ b/fed/tests/simple_example.py @@ -13,9 +13,11 @@ # limitations under the License. import multiprocessing -import fed + import ray +import fed + @fed.remote class MyActor: diff --git a/fed/tests/test_api.py b/fed/tests/test_api.py index 9a3b00af..c9eaec16 100644 --- a/fed/tests/test_api.py +++ b/fed/tests/test_api.py @@ -13,10 +13,12 @@ # limitations under the License. import multiprocessing + import pytest +import ray + import fed import fed._private.compatible_utils as compatible_utils -import ray import fed.config as fed_config diff --git a/fed/tests/test_async_startup_2_clusters.py b/fed/tests/test_async_startup_2_clusters.py index 9542f87e..af04776b 100644 --- a/fed/tests/test_async_startup_2_clusters.py +++ b/fed/tests/test_async_startup_2_clusters.py @@ -15,8 +15,8 @@ import multiprocessing import pytest - import ray + import fed import fed._private.compatible_utils as compatible_utils diff --git a/fed/tests/test_basic_pass_fed_objects.py b/fed/tests/test_basic_pass_fed_objects.py index ebca911f..223097af 100644 --- a/fed/tests/test_basic_pass_fed_objects.py +++ b/fed/tests/test_basic_pass_fed_objects.py @@ -16,6 +16,7 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils diff --git a/fed/tests/test_cross_silo_error.py b/fed/tests/test_cross_silo_error.py index 17547a09..f2f623eb 100644 --- a/fed/tests/test_cross_silo_error.py +++ b/fed/tests/test_cross_silo_error.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ray import multiprocessing +import sys +from unittest.mock import Mock import pytest +import ray + import fed import fed._private.compatible_utils as compatible_utils -import sys - -from unittest.mock import Mock from fed.exceptions import FedRemoteError @@ -59,10 +59,10 @@ def run(party): 'cross_silo_comm': { 'exit_on_sending_failure': True, 'timeout_ms': 20 * 1000, - 'expose_error_trace': True + 'expose_error_trace': True, }, }, - failure_handler=my_failure_handler + failure_handler=my_failure_handler, ) # Both party should catch the error @@ -107,10 +107,10 @@ def run2(party): 'cross_silo_comm': { 'exit_on_sending_failure': True, 'timeout_ms': 20 * 1000, - 'expose_error_trace': True + 'expose_error_trace': True, }, }, - failure_handler=my_failure_handler + failure_handler=my_failure_handler, ) # Both party should catch the error @@ -162,7 +162,7 @@ def run3(party): 'timeout_ms': 20 * 1000, }, }, - failure_handler=my_failure_handler + failure_handler=my_failure_handler, ) # Both party should catch the error diff --git a/fed/tests/test_enable_tls_across_parties.py b/fed/tests/test_enable_tls_across_parties.py index 5c3c1c38..59921a1b 100644 --- a/fed/tests/test_enable_tls_across_parties.py +++ b/fed/tests/test_enable_tls_across_parties.py @@ -16,8 +16,8 @@ import os import pytest - import ray + import fed import fed._private.compatible_utils as compatible_utils diff --git a/fed/tests/test_exit_on_failure_sending.py b/fed/tests/test_exit_on_failure_sending.py index ab1a8613..3da50ae7 100644 --- a/fed/tests/test_exit_on_failure_sending.py +++ b/fed/tests/test_exit_on_failure_sending.py @@ -13,17 +13,16 @@ # limitations under the License. import multiprocessing +import os +import signal +import sys import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils -import signal - -import os -import sys - def signal_handler(sig, frame): if sig == signal.SIGTERM.value: @@ -73,7 +72,7 @@ def run(party): 'timeout_ms': 20 * 1000, }, }, - failure_handler=lambda : os.kill(os.getpid(), signal.SIGTERM) + failure_handler=lambda: os.kill(os.getpid(), signal.SIGTERM), ) o = f.party("alice").remote() diff --git a/fed/tests/test_fed_get.py b/fed/tests/test_fed_get.py index 5752f778..8a7e72d1 100644 --- a/fed/tests/test_fed_get.py +++ b/fed/tests/test_fed_get.py @@ -16,6 +16,7 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -48,6 +49,7 @@ def mean(x, y): def run(party): import time + if party == 'alice': time.sleep(1.4) diff --git a/fed/tests/test_grpc_options_on_proxies.py b/fed/tests/test_grpc_options_on_proxies.py index cb14e922..22e31592 100644 --- a/fed/tests/test_grpc_options_on_proxies.py +++ b/fed/tests/test_grpc_options_on_proxies.py @@ -13,11 +13,12 @@ # limitations under the License. import multiprocessing + import pytest -import fed -import fed._private.compatible_utils as compatible_utils import ray +import fed +import fed._private.compatible_utils as compatible_utils from fed.proxy.barriers import receiver_proxy_actor_name, sender_proxy_actor_name @@ -131,7 +132,7 @@ def run3(party): "messages_max_size_in_bytes": 100, "grpc_channel_options": [ ('grpc.max_send_message_length', 200), - ], + ], }, }, ) diff --git a/fed/tests/test_internal_kv.py b/fed/tests/test_internal_kv.py index bb048239..65a6e4fa 100644 --- a/fed/tests/test_internal_kv.py +++ b/fed/tests/test_internal_kv.py @@ -1,10 +1,12 @@ import multiprocessing +import time + import pytest import ray +import ray.experimental.internal_kv as ray_internal_kv + import fed -import time import fed._private.compatible_utils as compatible_utils -import ray.experimental.internal_kv as ray_internal_kv def run(party): @@ -21,8 +23,10 @@ def run(party): # Test that a prefix key name is added under the hood. assert ray_internal_kv._internal_kv_get(b"test_key") is None - assert ray_internal_kv._internal_kv_get( - b"RAYFED#test_job_name#test_key") == b"test_val" + assert ( + ray_internal_kv._internal_kv_get(b"RAYFED#test_job_name#test_key") + == b"test_val" + ) time.sleep(5) fed.shutdown() diff --git a/fed/tests/test_listening_address.py b/fed/tests/test_listening_address.py index a2787226..5f9891ca 100644 --- a/fed/tests/test_listening_address.py +++ b/fed/tests/test_listening_address.py @@ -16,6 +16,7 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -36,9 +37,7 @@ def _run(party): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("127.0.0.1", occupied_port)) - addresses = { - 'alice': f'127.0.0.1:{occupied_port}' - } + addresses = {'alice': f'127.0.0.1:{occupied_port}'} # Starting grpc server on an used port will cause AssertionError with pytest.raises(AssertionError): diff --git a/fed/tests/test_options.py b/fed/tests/test_options.py index 13bbb6f3..6ab9107b 100644 --- a/fed/tests/test_options.py +++ b/fed/tests/test_options.py @@ -16,6 +16,7 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils diff --git a/fed/tests/test_pass_fed_objects_in_containers_in_actor.py b/fed/tests/test_pass_fed_objects_in_containers_in_actor.py index 08d3a2d6..4d9e4c82 100644 --- a/fed/tests/test_pass_fed_objects_in_containers_in_actor.py +++ b/fed/tests/test_pass_fed_objects_in_containers_in_actor.py @@ -15,8 +15,8 @@ import multiprocessing import pytest - import ray + import fed import fed._private.compatible_utils as compatible_utils diff --git a/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py b/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py index 7b78cb60..866370d2 100644 --- a/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py +++ b/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py @@ -15,8 +15,8 @@ import multiprocessing import pytest - import ray + import fed import fed._private.compatible_utils as compatible_utils diff --git a/fed/tests/test_ping_others.py b/fed/tests/test_ping_others.py index 4753ddec..a734104e 100644 --- a/fed/tests/test_ping_others.py +++ b/fed/tests/test_ping_others.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import multiprocessing +import time + +import pytest +import ray + import fed import fed._private.compatible_utils as compatible_utils -import ray -import time from fed.proxy.barriers import ping_others - addresses = { 'alice': '127.0.0.1:11012', 'bob': '127.0.0.1:11011', @@ -31,7 +32,7 @@ def test_ping_non_started_party(): def run(party): compatible_utils.init_ray(address='local') fed.init(addresses=addresses, party=party) - if (party == 'alice'): + if party == 'alice': with pytest.raises(RuntimeError): ping_others(addresses, party, 5) @@ -47,7 +48,7 @@ def test_ping_started_party(): def run(party): compatible_utils.init_ray(address='local') fed.init(addresses=addresses, party=party) - if (party == 'alice'): + if party == 'alice': ping_success = ping_others(addresses, party, 5) assert ping_success is True else: diff --git a/fed/tests/test_repeat_init.py b/fed/tests/test_repeat_init.py index 8926c786..1e68622d 100644 --- a/fed/tests/test_repeat_init.py +++ b/fed/tests/test_repeat_init.py @@ -16,9 +16,10 @@ import multiprocessing import pytest +import ray + import fed import fed._private.compatible_utils as compatible_utils -import ray @fed.remote diff --git a/fed/tests/test_reset_context.py b/fed/tests/test_reset_context.py index 95c6e53f..6350851e 100644 --- a/fed/tests/test_reset_context.py +++ b/fed/tests/test_reset_context.py @@ -1,8 +1,10 @@ import multiprocessing -import fed + +import pytest import ray + +import fed import fed._private.compatible_utils as compatible_utils -import pytest addresses = { 'alice': '127.0.0.1:11012', @@ -21,9 +23,7 @@ def get(self): def run(party): compatible_utils.init_ray(address='local') - fed.init( - addresses=addresses, - party=party) + fed.init(addresses=addresses, party=party) actor = A.party('alice').remote(10) alice_fed_obj = actor.get.remote() @@ -45,9 +45,7 @@ def run(party): compatible_utils.kv.put("key2", "val2") compatible_utils.init_ray(address='local') - fed.init( - addresses=addresses, - party=party) + fed.init(addresses=addresses, party=party) actor = A.party('alice').remote(10) alice_fed_obj = actor.get.remote() @@ -70,8 +68,8 @@ def run(party): def test_reset_context(): - p_alice = multiprocessing.Process(target=run, args=('alice', )) - p_bob = multiprocessing.Process(target=run, args=('bob', )) + p_alice = multiprocessing.Process(target=run, args=('alice',)) + p_bob = multiprocessing.Process(target=run, args=('bob',)) p_alice.start() import time @@ -85,4 +83,5 @@ def test_reset_context(): if __name__ == "__main__": import sys + sys.exit(pytest.main(["-sv", __file__])) diff --git a/fed/tests/test_retry_policy.py b/fed/tests/test_retry_policy.py index 574ce9fb..093d7631 100644 --- a/fed/tests/test_retry_policy.py +++ b/fed/tests/test_retry_policy.py @@ -15,6 +15,7 @@ import multiprocessing from unittest import TestCase + import pytest import ray diff --git a/fed/tests/test_transport_proxy.py b/fed/tests/test_transport_proxy.py index bb6f3f2d..de37e865 100644 --- a/fed/tests/test_transport_proxy.py +++ b/fed/tests/test_transport_proxy.py @@ -31,7 +31,8 @@ from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy, GrpcSenderProxy if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): + fed_utils.get_package_version('protobuf'), '4.0.0' +): from fed.grpc.pb4 import fed_pb2 as fed_pb2 from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc else: @@ -99,8 +100,9 @@ def test_n_to_1_transport(): class TestSendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock, - expected_metadata, expected_jobname): + def __init__( + self, all_events, all_data, party, lock, expected_metadata, expected_jobname + ): self.expected_metadata = expected_metadata or {} self._expected_jobname = expected_jobname or "" @@ -109,8 +111,9 @@ async def SendData(self, request, context): assert self._expected_jobname == job_name metadata = dict(context.invocation_metadata()) for k, v in self.expected_metadata.items(): - assert k in metadata, \ - f"The expected key {k} is not in the metadata keys: {metadata.keys()}." + assert ( + k in metadata + ), f"The expected key {k} is not in the metadata keys: {metadata.keys()}." assert v == metadata[k] event = asyncio.Event() event.set() @@ -129,9 +132,10 @@ async def _test_run_grpc_server( ): server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( - TestSendDataService(event, all_data, party, lock, - expected_metadata, expected_jobname), - server + TestSendDataService( + event, all_data, party, lock, expected_metadata, expected_jobname + ), + server, ) server.add_insecure_port(f'[::]:{port}') await server.start() @@ -154,13 +158,13 @@ def __init__( async def run_grpc_server(self): return await _test_run_grpc_server( - self._listen_addr[self._listen_addr.index(':') + 1:], + self._listen_addr[self._listen_addr.index(':') + 1 :], None, None, self._party, None, expected_metadata=self._expected_metadata, - expected_jobname=self._expected_jobname + expected_jobname=self._expected_jobname, ) async def is_ready(self): @@ -178,9 +182,12 @@ def _test_start_receiver_proxy( address = addresses[party] receiver_proxy_actor = TestReceiverProxyActor.options( name=receiver_proxy_actor_name(), max_concurrency=1000 - ).remote(listen_addr=address, party=party, - expected_metadata=expected_metadata, - expected_jobname=expected_jobname) + ).remote( + listen_addr=address, + party=party, + expected_metadata=expected_metadata, + expected_jobname=expected_jobname, + ) receiver_proxy_actor.run_grpc_server.remote() assert ray.get(receiver_proxy_actor.is_ready.remote()) @@ -214,7 +221,7 @@ def test_send_grpc_with_meta(): addresses, party_name, expected_metadata=metadata, - expected_jobname=test_job_name + expected_jobname=test_job_name, ) _start_sender_proxy( addresses, diff --git a/fed/tests/test_utils.py b/fed/tests/test_utils.py index f17f1a6f..d42bde8b 100644 --- a/fed/tests/test_utils.py +++ b/fed/tests/test_utils.py @@ -13,15 +13,16 @@ # limitations under the License. import time + import pytest import fed.utils as fed_utils def start_ray_cluster( - ray_port, - client_server_port, - dashboard_port, + ray_port, + client_server_port, + dashboard_port, ): command = [ 'ray', @@ -45,8 +46,9 @@ def start_ray_cluster( # container, you can increase /dev/shm size by passing '--shm-size=1.97gb' to # 'docker run' (or add it to the run_options list in a Ray cluster config). # Make sure to set this to more than 0% of available RAM. - assert 'Overwriting previous Ray address' in str(e) \ - or 'WARNING: The object store is using /tmp instead of /dev/shm' in str(e) + assert 'Overwriting previous Ray address' in str( + e + ) or 'WARNING: The object store is using /tmp instead of /dev/shm' in str(e) @pytest.fixture diff --git a/fed/tests/without_ray_tests/test_tree_utils.py b/fed/tests/without_ray_tests/test_tree_utils.py index 41729ad0..416b5d3a 100644 --- a/fed/tests/without_ray_tests/test_tree_utils.py +++ b/fed/tests/without_ray_tests/test_tree_utils.py @@ -1,4 +1,3 @@ - # Copyright 2023 The RayFed Team # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Tuple, Union + import pytest -from typing import Any, Union, List, Tuple, Dict import fed.tree_util as tree_utils @@ -28,7 +28,6 @@ def test_flatten_none(): def test_flatten_single_primivite_elements(): - def _assert_flatten_single_element(target: Any): li, tree_def = tree_utils.tree_flatten(target) assert isinstance(li, list) diff --git a/fed/tests/without_ray_tests/test_utils.py b/fed/tests/without_ray_tests/test_utils.py index b00e042e..98233250 100644 --- a/fed/tests/without_ray_tests/test_utils.py +++ b/fed/tests/without_ray_tests/test_utils.py @@ -1,4 +1,3 @@ - # Copyright 2023 The RayFed Team # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,18 +17,21 @@ import fed -@pytest.mark.parametrize("input_address, is_valid_address", [ - ("192.168.0.1:8080", True), - ("sa127032as:80", True), - ("https://www.example.com", True), - ("http://www.example.com", True), - ("local", True), - ("localhost", True), - (None, False), - ("invalid_string", False), - ("http", False), - ("example.com", False), -]) +@pytest.mark.parametrize( + "input_address, is_valid_address", + [ + ("192.168.0.1:8080", True), + ("sa127032as:80", True), + ("https://www.example.com", True), + ("http://www.example.com", True), + ("local", True), + ("localhost", True), + (None, False), + ("invalid_string", False), + ("http", False), + ("example.com", False), + ], +) def test_validate_address(input_address, is_valid_address): if is_valid_address: fed.utils.validate_address(input_address) @@ -43,4 +45,5 @@ def test_validate_address(input_address, is_valid_address): if __name__ == "__main__": import sys + sys.exit(pytest.main(["-sv", __file__])) diff --git a/fed/tree_util.py b/fed/tree_util.py index 14a99fc3..033960d2 100644 --- a/fed/tree_util.py +++ b/fed/tree_util.py @@ -14,10 +14,9 @@ # Most codes are copied from https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/_pytree.py # noqa -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, TypeVar -from collections import namedtuple, OrderedDict +from collections import OrderedDict, namedtuple from dataclasses import dataclass - +from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Type, TypeVar, cast T = TypeVar('T') S = TypeVar('S') @@ -63,9 +62,8 @@ class NodeDef(NamedTuple): def _register_pytree_node( - typ: Any, - flatten_fn: FlattenFunc, - unflatten_fn: UnflattenFunc) -> None: + typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc +) -> None: SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) @@ -161,8 +159,11 @@ def __repr__(self, indent: int = 0) -> str: children_specs_str += self.children_specs[0].__repr__(indent) children_specs_str += ',' if len(self.children_specs) > 1 else '' children_specs_str += ','.join( - ['\n' + ' ' * indent + child.__repr__(indent) - for child in self.children_specs[1:]]) + [ + '\n' + ' ' * indent + child.__repr__(indent) + for child in self.children_specs[1:] + ] + ) repr_suffix: str = f'{children_specs_str}])' return repr_prefix + repr_suffix @@ -188,8 +189,8 @@ def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: child_pytrees, context = flatten_fn(pytree) # Recursively flatten the children - result : List[Any] = [] - children_specs : List['TreeSpec'] = [] + result: List[Any] = [] + children_specs: List['TreeSpec'] = [] for child in child_pytrees: flat, child_spec = tree_flatten(child) result += flat @@ -205,12 +206,14 @@ def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: if not isinstance(spec, TreeSpec): raise ValueError( f'tree_unflatten(values, spec): Expected `spec` to be instance of ' - f'TreeSpec but got item of type {type(spec)}.') + f'TreeSpec but got item of type {type(spec)}.' + ) if len(values) != spec.num_leaves: raise ValueError( f'tree_unflatten(values, spec): `values` has length {len(values)} ' f'but the spec refers to a pytree that holds {spec.num_leaves} ' - f'items ({spec}).') + f'items ({spec}).' + ) if isinstance(spec, LeafSpec): return values[0] diff --git a/fed/utils.py b/fed/utils.py index b5450f26..59b6fe79 100644 --- a/fed/utils.py +++ b/fed/utils.py @@ -14,8 +14,8 @@ import logging import re -import sys import subprocess +import sys import ray @@ -239,17 +239,16 @@ def validate_addresses(addresses: dict): validate_address(address) -def start_command(command: str, timeout=60) : +def start_command(command: str, timeout=60): """ A util to start a shell command. """ process = subprocess.Popen( - command, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) output, error = process.communicate(timeout=timeout) if len(error) != 0: raise RuntimeError( - f'Failed to start command [{command}], the error is:\n {error.decode()}') + f'Failed to start command [{command}], the error is:\n {error.decode()}' + ) return output diff --git a/setup.py b/setup.py index 468ff6ad..1305b97d 100644 --- a/setup.py +++ b/setup.py @@ -66,8 +66,10 @@ def run(self): name=package_name, version=VERSION, license='Apache 2.0', - description='A multiple parties joint, distributed execution engine based on Ray,' - 'to help build your own federated learning frameworks in minutes.', + description=( + 'A multiple parties joint, distributed execution engine based on Ray,' + 'to help build your own federated learning frameworks in minutes.' + ), long_description=long_description, long_description_content_type='text/markdown', author='RayFed Team', diff --git a/tool/generate_tls_certs.py b/tool/generate_tls_certs.py index 7b9a7897..a0372007 100644 --- a/tool/generate_tls_certs.py +++ b/tool/generate_tls_certs.py @@ -13,9 +13,9 @@ # limitations under the License. import datetime +import errno import os import socket -import errno def try_make_directory_shared(directory_path):