diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 5af4c88..f21f3d0 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -36,3 +36,9 @@ RAYFED_DEFAULT_JOB_NAME = "Anonymous_job" RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}" + +RAYFED_DEFAULT_SENDER_PROXY_ACTOR_NAME = "SenderProxyActor" + +RAYFED_DEFAULT_RECEIVER_PROXY_ACTOR_NAME = "ReceiverProxyActor" + +RAYFED_DEFAULT_SENDER_RECEIVER_PROXY_ACTOR_NAME = "SenderReceiverProxyActor" diff --git a/fed/api.py b/fed/api.py index 6b020eb..692bec6 100644 --- a/fed/api.py +++ b/fed/api.py @@ -40,8 +40,7 @@ _start_receiver_proxy, _start_sender_proxy, _start_sender_receiver_proxy, - set_receiver_proxy_actor_name, - set_sender_proxy_actor_name, + set_proxy_actor_name, ) from fed.proxy.base_proxy import SenderProxy, ReceiverProxy, SenderReceiverProxy from fed.config import CrossSiloMessageConfig @@ -116,6 +115,7 @@ def init( "timeout_in_ms": 1000, "exit_on_sending_failure": True, "expose_error_trace": True, + "use_global_proxy": True, }, "barrier_on_initializing": True, } @@ -170,7 +170,6 @@ def init( 'cert' in tls_config and 'key' in tls_config ), 'Cert or key are not in tls_config.' - cross_silo_comm_dict = config.get("cross_silo_comm", {}) # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv(job_name) @@ -180,6 +179,7 @@ def init( constants.KEY_OF_TLS_CONFIG: tls_config, } + cross_silo_comm_dict = config.get("cross_silo_comm", {}) job_config = { constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: cross_silo_comm_dict, } @@ -206,10 +206,10 @@ def init( exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure, expose_error_trace=cross_silo_comm_config.expose_error_trace ) + if receiver_sender_proxy_cls is not None: - proxy_actor_name = 'sender_recevier_actor' - set_sender_proxy_actor_name(proxy_actor_name) - set_receiver_proxy_actor_name(proxy_actor_name) + set_proxy_actor_name( + job_name, cross_silo_comm_dict.get("use_global_proxy", True), True) _start_sender_receiver_proxy( addresses=addresses, party=party, @@ -230,6 +230,8 @@ def init( from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy receiver_proxy_cls = GrpcReceiverProxy + set_proxy_actor_name( + job_name, cross_silo_comm_dict.get("use_global_proxy", True)) _start_receiver_proxy( addresses=addresses, party=party, @@ -242,12 +244,13 @@ def init( if sender_proxy_cls is None: logger.debug( - "No sender proxy class specified, use `GrpcRecvProxy` by " + "No sender proxy class specified, use `GrpcSenderProxy` by " "default." ) from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy sender_proxy_cls = GrpcSenderProxy + _start_sender_proxy( addresses=addresses, party=party, diff --git a/fed/cleanup.py b/fed/cleanup.py index 811f21a..d68b290 100644 --- a/fed/cleanup.py +++ b/fed/cleanup.py @@ -34,11 +34,11 @@ class CleanupManager: The main logic path is: A. If `fed.shutdown()` is invoked in the main thread and every thing works well, the `stop()` will be invoked as well and the checking thread will be - notifiled to exit gracefully. + notified to exit gracefully. - B. If the main thread are broken before sending the notification flag to the - sending thread, the monitor thread will detect that and it joins until the main - thread exited, then notifys the checking thread. + B. If the main thread are broken before sending the stop flag to the sending + thread, the monitor thread will detect that and then notifys the checking + thread. """ def __init__(self, current_party, acquire_shutdown_flag) -> None: diff --git a/fed/config.py b/fed/config.py index 386984c..c447efb 100644 --- a/fed/config.py +++ b/fed/config.py @@ -102,6 +102,8 @@ class CrossSiloMessageConfig: This won't override basic tcp headers, such as `user-agent`, but concat them together. max_concurrency: the max_concurrency of the sender/receiver proxy actor. + use_global_proxy: Whether using the global proxy actor or create new proxy + actor for current job. """ proxy_max_restarts: int = None @@ -114,6 +116,7 @@ class CrossSiloMessageConfig: http_header: Optional[Dict[str, str]] = None max_concurrency: Optional[int] = None expose_error_trace: Optional[bool] = False + use_global_proxy: Optional[bool] = True def __json__(self): return json.dumps(self.__dict__) diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index d6101c6..aef0812 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -51,6 +51,39 @@ def set_receiver_proxy_actor_name(name: str): _RECEIVER_PROXY_ACTOR_NAME = name +def set_proxy_actor_name(job_name: str, + use_global_proxy: bool, + sender_recvr_proxy: bool = False): + """ + Generate the name of the proxy actor. + + Args: + job_name: The name of the job, used for actor name's postfix + use_global_proxy: Whether + to use a single proxy actor or not. If True, the name of the proxy + actor will be the default global name, otherwise the name will be + added with a postfix. + sender_recvr_proxy: Whether to use the sender-receiver proxy actor or + not. If True, since there's only one proxy actor, make two actor name + the same. + """ + sender_actor_name = ( + constants.RAYFED_DEFAULT_SENDER_PROXY_ACTOR_NAME + if not sender_recvr_proxy + else constants.RAYFED_DEFAULT_SENDER_RECEIVER_PROXY_ACTOR_NAME + ) + receiver_actor_name = ( + constants.RAYFED_DEFAULT_RECEIVER_PROXY_ACTOR_NAME + if not sender_recvr_proxy + else constants.RAYFED_DEFAULT_SENDER_RECEIVER_PROXY_ACTOR_NAME + ) + if not use_global_proxy: + sender_actor_name = f"{sender_actor_name}_{job_name}" + receiver_actor_name = f"{receiver_actor_name}_{job_name}" + set_sender_proxy_actor_name(sender_actor_name) + set_receiver_proxy_actor_name(receiver_actor_name) + + def key_exists_in_two_dim_dict(the_dict, key_a, key_b) -> bool: key_a, key_b = str(key_a), str(key_b) if key_a not in the_dict: @@ -217,22 +250,20 @@ def _start_receiver_proxy( ready_timeout_second: int = 60, ): actor_options = copy.deepcopy(_DEFAULT_RECEIVER_PROXY_OPTIONS) - if proxy_config: - proxy_config = fed_config.CrossSiloMessageConfig.from_dict(proxy_config) - if proxy_config.recv_resource_label is not None: - actor_options.update({"resources": proxy_config.recv_resource_label}) - if proxy_config.max_concurrency: - actor_options.update({"max_concurrency": proxy_config.max_concurrency}) + proxy_config = fed_config.CrossSiloMessageConfig.from_dict(proxy_config) + if proxy_config.recv_resource_label is not None: + actor_options.update({"resources": proxy_config.recv_resource_label}) + if proxy_config.max_concurrency: + actor_options.update({"max_concurrency": proxy_config.max_concurrency}) + actor_options.update({"name": receiver_proxy_actor_name()}) logger.debug(f"Starting ReceiverProxyActor with options: {actor_options}") + job_name = get_global_context().get_job_name() - global _RECEIVER_PROXY_ACTOR_NAME - receiver_proxy_actor = ReceiverProxyActor.options( - name=_RECEIVER_PROXY_ACTOR_NAME, **actor_options - ).remote( + receiver_proxy_actor = ReceiverProxyActor.options(**actor_options).remote( listening_address=addresses[party], party=party, - job_name=get_global_context().get_job_name(), + job_name=job_name, tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls, @@ -260,30 +291,28 @@ def _start_sender_proxy( proxy_config: Dict = None, ready_timeout_second: int = 60, ): - if proxy_config: - proxy_config = fed_config.GrpcCrossSiloMessageConfig.from_dict(proxy_config) actor_options = copy.deepcopy(_DEFAULT_SENDER_PROXY_OPTIONS) - if proxy_config: - if proxy_config.proxy_max_restarts: - actor_options.update( - { - "max_task_retries": proxy_config.proxy_max_restarts, - "max_restarts": 1, - } - ) - if proxy_config.send_resource_label: - actor_options.update({"resources": proxy_config.send_resource_label}) - if proxy_config.max_concurrency: - actor_options.update({"max_concurrency": proxy_config.max_concurrency}) + proxy_config = fed_config.GrpcCrossSiloMessageConfig.from_dict(proxy_config) + if proxy_config.proxy_max_restarts: + actor_options.update( + { + "max_task_retries": proxy_config.proxy_max_restarts, + "max_restarts": 1, + } + ) + if proxy_config.send_resource_label: + actor_options.update({"resources": proxy_config.send_resource_label}) + if proxy_config.max_concurrency: + actor_options.update({"max_concurrency": proxy_config.max_concurrency}) + + job_name = get_global_context().get_job_name() + actor_options.update({"name": sender_proxy_actor_name()}) logger.debug(f"Starting SenderProxyActor with options: {actor_options}") global _SENDER_PROXY_ACTOR - global _SENDER_PROXY_ACTOR_NAME - _SENDER_PROXY_ACTOR = SenderProxyActor.options( - name=_SENDER_PROXY_ACTOR_NAME, **actor_options - ) - job_name = get_global_context().get_job_name() + _SENDER_PROXY_ACTOR = SenderProxyActor.options(**actor_options) + _SENDER_PROXY_ACTOR = _SENDER_PROXY_ACTOR.remote( addresses=addresses, party=party, @@ -389,32 +418,32 @@ def _start_sender_receiver_proxy( ): global _DEFAULT_SENDER_RECEIVER_PROXY_OPTIONS actor_options = copy.deepcopy(_DEFAULT_SENDER_RECEIVER_PROXY_OPTIONS) - if proxy_config: - proxy_config = fed_config.CrossSiloMessageConfig.from_dict(proxy_config) - if proxy_config.proxy_max_restarts: - actor_options.update( - { - "max_task_retries": proxy_config.proxy_max_restarts, - "max_restarts": 1, - } - ) - if proxy_config.max_concurrency: - actor_options.update({"max_concurrency": proxy_config.max_concurrency}) + proxy_config = fed_config.CrossSiloMessageConfig.from_dict(proxy_config) + if proxy_config.proxy_max_restarts: + actor_options.update( + { + "max_task_retries": proxy_config.proxy_max_restarts, + "max_restarts": 1, + } + ) + if proxy_config.max_concurrency: + actor_options.update({"max_concurrency": proxy_config.max_concurrency}) + # NOTE(NKcqx): sender & receiver have the same name + actor_options.update({"name": receiver_proxy_actor_name()}) logger.debug(f"Starting ReceiverProxyActor with options: {actor_options}") job_name = get_global_context().get_job_name() global _SENDER_RECEIVER_PROXY_ACTOR - global _RECEIVER_PROXY_ACTOR_NAME + _SENDER_RECEIVER_PROXY_ACTOR = SenderReceiverProxyActor.options( - name=_RECEIVER_PROXY_ACTOR_NAME, **actor_options - ).remote( - addresses=addresses, - party=party, - job_name=job_name, - tls_config=tls_config, - logging_level=logging_level, - proxy_cls=proxy_cls, + **actor_options).remote( + addresses=addresses, + party=party, + job_name=job_name, + tls_config=tls_config, + logging_level=logging_level, + proxy_cls=proxy_cls, ) _SENDER_RECEIVER_PROXY_ACTOR.start.remote() server_state = ray.get( @@ -436,8 +465,7 @@ def send( is_error: Whether the `data` is an error object or not. Default is False. If True, the data will be sent to the error message queue. """ - global _SENDER_PROXY_ACTOR_NAME - sender_proxy = ray.get_actor(_SENDER_PROXY_ACTOR_NAME) + sender_proxy = ray.get_actor(sender_proxy_actor_name()) res = sender_proxy.send.remote( dest_party=dest_party, data=data, @@ -451,8 +479,7 @@ def send( def recv(party: str, src_party: str, upstream_seq_id, curr_seq_id): assert party, 'Party can not be None.' - global _RECEIVER_PROXY_ACTOR_NAME - receiver_proxy = ray.get_actor(_RECEIVER_PROXY_ACTOR_NAME) + receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) return receiver_proxy.get_data.remote(src_party, upstream_seq_id, curr_seq_id) diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 85b7256..1fdfa4e 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -343,7 +343,7 @@ async def _run_grpc_server( port, event, all_data, party, lock, job_name, server_ready_future, tls_config=None, grpc_options=None ): - logger.info(f"ReceiveProxy binding port {port}, options: {grpc_options}...") + logger.info(f"ReceiverProxy binding port {port}, options: {grpc_options}...") server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( SendDataService(event, all_data, party, lock, job_name), server diff --git a/fed/tests/multi-jobs/test_multi_proxy_actor.py b/fed/tests/multi-jobs/test_multi_proxy_actor.py new file mode 100644 index 0000000..5021ec0 --- /dev/null +++ b/fed/tests/multi-jobs/test_multi_proxy_actor.py @@ -0,0 +1,122 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import fed +import ray +import grpc +import pytest +import fed.utils as fed_utils +import fed._private.compatible_utils as compatible_utils +from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, send_data_grpc +if compatible_utils._compare_version_strings( + fed_utils.get_package_version('protobuf'), '4.0.0'): + from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc +else: + from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc + + +class TestGrpcSenderProxy(GrpcSenderProxy): + async def send( + self, + dest_party, + data, + upstream_seq_id, + downstream_seq_id): + dest_addr = self._addresses[dest_party] + grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) + if dest_party not in self._stubs: + channel = grpc.aio.insecure_channel( + dest_addr, options=grpc_channel_options) + stub = fed_pb2_grpc.GrpcServiceStub(channel) + self._stubs[dest_party] = stub + + timeout = self._proxy_config.timeout_in_ms / 1000 + response: str = await send_data_grpc( + data=data, + stub=self._stubs[dest_party], + upstream_seq_id=upstream_seq_id, + downstream_seq_id=downstream_seq_id, + job_name=self._job_name, + timeout=timeout, + metadata=grpc_metadata, + ) + assert response.code == 417 + assert "JobName mis-match" in response.result + # So that process can exit + raise RuntimeError(response.result) + + +@fed.remote +class MyActor: + def __init__(self, party, data): + self.__data = data + self._party = party + + def f(self): + return f"f({self._party}, ip is {ray.util.get_node_ip_address()})" + + +@fed.remote +def agg_fn(obj1, obj2): + return f"agg-{obj1}-{obj2}" + + +addresses = { + 'job1': { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', + }, + 'job2': { + 'alice': '127.0.0.1:12012', + 'bob': '127.0.0.1:12011', + }, +} + + +def run(party, job_name): + ray.init(address='local') + fed.init(addresses=addresses[job_name], + party=party, + job_name=job_name, + sender_proxy_cls=TestGrpcSenderProxy, + config={ + 'cross_silo_comm': { + 'exit_on_sending_failure': True, + # Create unique proxy for current job + 'use_global_proxy': False + }}) + + sender_proxy_actor_name = f"SenderProxyActor_{job_name}" + receiver_proxy_actor_name = f"ReceiverProxyActor_{job_name}" + assert ray.get_actor(sender_proxy_actor_name) + assert ray.get_actor(receiver_proxy_actor_name) + + fed.shutdown() + ray.shutdown() + + +def test_multi_proxy_actor(): + p_alice_job1 = multiprocessing.Process(target=run, args=('alice', 'job1')) + p_alice_job2 = multiprocessing.Process(target=run, args=('alice', 'job2')) + p_alice_job1.start() + p_alice_job2.start() + p_alice_job1.join() + p_alice_job2.join() + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-sv", __file__])) diff --git a/fed/tests/test_ping_others.py b/fed/tests/test_ping_others.py index 0782bb7..4753dde 100644 --- a/fed/tests/test_ping_others.py +++ b/fed/tests/test_ping_others.py @@ -17,6 +17,7 @@ import fed import fed._private.compatible_utils as compatible_utils import ray +import time from fed.proxy.barriers import ping_others @@ -49,7 +50,10 @@ def run(party): if (party == 'alice'): ping_success = ping_others(addresses, party, 5) assert ping_success is True - + else: + # Wait for alice to ping, otherwise, bob may + # exit before alice when started first. + time.sleep(10) fed.shutdown() ray.shutdown()