diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 09af728..fc24708 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -25,7 +25,7 @@ KEY_OF_TLS_CONFIG = "TLS_CONFIG" -KEY_OF_CROSS_SILO_MSG_CONFIG = "CROSS_SILO_MSG_CONFIG" +KEY_OF_CROSS_SILO_MESSAGE_CONFIG = "CROSS_SILO_MESSAGE_CONFIG" RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa diff --git a/fed/_private/serialization_utils.py b/fed/_private/serialization_utils.py index 07182cf..646a3c3 100644 --- a/fed/_private/serialization_utils.py +++ b/fed/_private/serialization_utils.py @@ -64,7 +64,7 @@ def _apply_loads_function_with_whitelist(): global _pickle_whitelist _pickle_whitelist = fed_config.get_job_config() \ - .cross_silo_msg_config.serializing_allowed_list + .cross_silo_message_config.serializing_allowed_list if _pickle_whitelist is None: return diff --git a/fed/api.py b/fed/api.py index 22dff33..209cb49 100644 --- a/fed/api.py +++ b/fed/api.py @@ -31,11 +31,11 @@ ping_others, recv, send, - start_recv_proxy, - start_send_proxy, + _start_receiver_proxy, + _start_sender_proxy, ) -from fed.proxy.grpc.grpc_proxy import SendProxy, RecvProxy -from fed.config import CrossSiloMsgConfig +from fed.proxy.grpc.grpc_proxy import SenderProxy, ReceiverProxy +from fed.config import CrossSiloMessageConfig from fed.fed_object import FedObject from fed.utils import is_ray_object_refs, setup_logger @@ -48,9 +48,9 @@ def init( tls_config: Dict = None, logging_level: str = 'info', enable_waiting_for_other_parties_ready: bool = False, - send_proxy_cls: SendProxy = None, - recv_proxy_cls: RecvProxy = None, - global_cross_silo_msg_config: Optional[CrossSiloMsgConfig] = None, + sender_proxy_cls: SenderProxy = None, + receiver_proxy_cls: ReceiverProxy = None, + global_cross_silo_message_config: Optional[CrossSiloMessageConfig] = None, **kwargs, ): """ @@ -67,7 +67,7 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10001', - 'cross_silo_msg_config': CrossSiloMsgConfig + 'cross_silo_message_config': CrossSiloMessageConfig }, 'bob': { # The address for other parties. @@ -111,9 +111,9 @@ def init( `warning`, `error`, `critical`, not case sensititive. enable_waiting_for_other_parties_ready: ping other parties until they are all ready if True. - global_cross_silo_msg_config: Global cross-silo message related + global_cross_silo_message_config: Global cross-silo message related configs that are applied to all connections. Supported configs - can refer to CrossSiloMsgConfig in config.py. + can refer to CrossSiloMessageConfig in config.py. Examples: >>> import fed @@ -139,8 +139,8 @@ def init( 'cert' in tls_config and 'key' in tls_config ), 'Cert or key are not in tls_config.' - global_cross_silo_msg_config = \ - global_cross_silo_msg_config or CrossSiloMsgConfig() + global_cross_silo_message_config = \ + global_cross_silo_message_config or CrossSiloMessageConfig() # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv() @@ -151,8 +151,8 @@ def init( } job_config = { - constants.KEY_OF_CROSS_SILO_MSG_CONFIG: - global_cross_silo_msg_config, + constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG: + global_cross_silo_message_config, } compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)) @@ -170,35 +170,36 @@ def init( logger.info(f'Started rayfed with {cluster_config}') get_global_context().get_cleanup_manager().start( - exit_when_failure_sending=global_cross_silo_msg_config.exit_on_sending_failure) + exit_when_failure_sending=global_cross_silo_message_config.exit_on_sending_failure) # noqa - if recv_proxy_cls is None: + if receiver_proxy_cls is None: logger.debug( - "Not declaring recver proxy class, using `GrpcRecvProxy` as default.") - from fed.proxy.grpc.grpc_proxy import GrpcRecvProxy - recv_proxy_cls = GrpcRecvProxy - # Start recv proxy - start_recv_proxy( + "There is no receiver proxy class specified, it uses `GrpcRecvProxy` by " + "default.") + from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy + receiver_proxy_cls = GrpcReceiverProxy + _start_receiver_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - proxy_cls=recv_proxy_cls, - proxy_config=global_cross_silo_msg_config + proxy_cls=receiver_proxy_cls, + proxy_config=global_cross_silo_message_config ) - if send_proxy_cls is None: + if sender_proxy_cls is None: logger.debug( - "Not declaring send proxy class, using `GrpcSendProxy` as default.") - from fed.proxy.grpc.grpc_proxy import GrpcSendProxy - send_proxy_cls = GrpcSendProxy - start_send_proxy( + "There is no sender proxy class specified, it uses `GrpcRecvProxy` by " + "default.") + from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy + sender_proxy_cls = GrpcSenderProxy + _start_sender_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - proxy_cls=send_proxy_cls, - proxy_config=global_cross_silo_msg_config + proxy_cls=sender_proxy_cls, + proxy_config=global_cross_silo_message_config ) if enable_waiting_for_other_parties_ready: diff --git a/fed/config.py b/fed/config.py index 84af784..15130bb 100644 --- a/fed/config.py +++ b/fed/config.py @@ -39,10 +39,10 @@ def __init__(self, raw_bytes: bytes) -> None: self._data = cloudpickle.loads(raw_bytes) @property - def cross_silo_msg_config(self): + def cross_silo_message_config(self): return self._data.get( - fed_constants.KEY_OF_CROSS_SILO_MSG_CONFIG, - CrossSiloMsgConfig()) + fed_constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG, + CrossSiloMessageConfig()) # A module level cache for the cluster configurations. @@ -74,7 +74,7 @@ def get_job_config(): @dataclass -class CrossSiloMsgConfig: +class CrossSiloMessageConfig: """A class to store parameters used for Proxy Actor Attributes: @@ -82,14 +82,16 @@ class CrossSiloMsgConfig: serializing_allowed_list: The package or class list allowed for serializing(deserializating) cross silos. It's used for avoiding pickle deserializing execution attack when crossing solis. - send_resource_label: Customized resource label, the SendProxyActor + send_resource_label: Customized resource label, the SenderProxyActor will be scheduled based on the declared resource label. For example, - when setting to `{"my_label": 1}`, then the SendProxyActor will be started - only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. - recv_resource_label: Customized resource label, the RecverProxyActor + when setting to `{"my_label": 1}`, then the sender proxy actor will be + started only on Nodes with `{"resource": {"my_label": $NUM}}` where + $NUM >= 1. + recv_resource_label: Customized resource label, the ReceiverProxyActor will be scheduled based on the declared resource label. For example, - when setting to `{"my_label": 1}`, then the RecverProxyActor will be started - only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. + when setting to `{"my_label": 1}`, then the receiver proxy actor will be + started only on Nodes with `{"resource": {"my_label": $NUM}}` where + $NUM >= 1. exit_on_sending_failure: whether exit when failure on cross-silo sending. If True, a SIGTERM will be signaled to self if failed to sending cross-silo data. @@ -121,13 +123,13 @@ def from_json(cls, json_str): @classmethod def from_dict(cls, data: Dict): - """Initialize CrossSiloMsgConfig from a dictionary. + """Initialize CrossSiloMessageConfig from a dictionary. Args: data (Dict): Dictionary with keys as member variable names. Returns: - CrossSiloMsgConfig: An instance of CrossSiloMsgConfig. + CrossSiloMessageConfig: An instance of CrossSiloMessageConfig. """ # Get the attributes of the class attrs = {attr for attr, _ in cls.__annotations__.items()} @@ -137,7 +139,7 @@ def from_dict(cls, data: Dict): @dataclass -class GrpcCrossSiloMsgConfig(CrossSiloMsgConfig): +class GrpcCrossSiloMessageConfig(CrossSiloMessageConfig): """A class to store parameters used for GRPC communication Attributes: diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index a1a6573..647e8d2 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -21,7 +21,7 @@ import fed.config as fed_config from fed.config import get_job_config -from fed.proxy.base_proxy import SendProxy, RecvProxy +from fed.proxy.base_proxy import SenderProxy, ReceiverProxy from fed.utils import setup_logger from fed._private import constants from fed._private.global_context import get_global_context @@ -55,7 +55,7 @@ def pop_from_two_dim_dict(the_dict, key_a, key_b): @ray.remote -class SendProxyActor: +class SenderProxyActor: def __init__( self, cluster: Dict, @@ -75,9 +75,10 @@ def __init__( self._cluster = cluster self._party = party self._tls_config = tls_config - cross_silo_msg_config = fed_config.get_job_config().cross_silo_msg_config - self._proxy_instance: SendProxy = proxy_cls( - cluster, party, tls_config, cross_silo_msg_config) + job_config = fed_config.get_job_config() + cross_silo_message_config = job_config.cross_silo_message_config + self._proxy_instance: SenderProxy = proxy_cls( + cluster, party, tls_config, cross_silo_message_config) async def is_ready(self): res = await self._proxy_instance.is_ready() @@ -122,7 +123,7 @@ async def _get_proxy_config(self, dest_party=None): @ray.remote -class RecverProxyActor: +class ReceiverProxyActor: def __init__( self, listen_addr: str, @@ -141,9 +142,10 @@ def __init__( self._listen_addr = listen_addr self._party = party self._tls_config = tls_config - cross_silo_msg_config = fed_config.get_job_config().cross_silo_msg_config - self._proxy_instance: RecvProxy = proxy_cls( - listen_addr, party, tls_config, cross_silo_msg_config) + job_config = fed_config.get_job_config() + cross_silo_message_config = job_config.cross_silo_message_config + self._proxy_instance: ReceiverProxy = proxy_cls( + listen_addr, party, tls_config, cross_silo_message_config) async def start(self): await self._proxy_instance.start() @@ -165,18 +167,18 @@ async def _get_proxy_config(self): return await self._proxy_instance.get_proxy_config() -_DEFAULT_RECV_PROXY_OPTIONS = { +_DEFAULT_RECEIVER_PROXY_OPTIONS = { "max_concurrency": 1000, } -def start_recv_proxy( +def _start_receiver_proxy( cluster: str, party: str, logging_level: str, tls_config=None, proxy_cls=None, - proxy_config: Optional[fed_config.CrossSiloMsgConfig] = None + proxy_config: Optional[fed_config.CrossSiloMessageConfig] = None ): # Create RecevrProxyActor @@ -187,14 +189,14 @@ def start_recv_proxy( if not listen_addr: listen_addr = party_addr['address'] - actor_options = copy.deepcopy(_DEFAULT_RECV_PROXY_OPTIONS) + actor_options = copy.deepcopy(_DEFAULT_RECEIVER_PROXY_OPTIONS) if proxy_config is not None and proxy_config.recv_resource_label is not None: actor_options.update({"resources": proxy_config.recv_resource_label}) - logger.debug(f"Starting RecvProxyActor with options: {actor_options}") + logger.debug(f"Starting ReceiverProxyActor with options: {actor_options}") - recver_proxy_actor = RecverProxyActor.options( - name=f"RecverProxyActor-{party}", **actor_options + receiver_proxy_actor = ReceiverProxyActor.options( + name=f"ReceiverProxyActor-{party}", **actor_options ).remote( listen_addr=listen_addr, party=party, @@ -202,31 +204,31 @@ def start_recv_proxy( logging_level=logging_level, proxy_cls=proxy_cls ) - recver_proxy_actor.start.remote() + receiver_proxy_actor.start.remote() timeout = proxy_config.timeout_in_ms / 1000 if proxy_config is not None else 60 - server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout) + server_state = ray.get(receiver_proxy_actor.is_ready.remote(), timeout=timeout) assert server_state[0], server_state[1] - logger.info("RecverProxy has successfully created.") + logger.info("Succeeded to create receiver proxy actor.") -_SEND_PROXY_ACTOR = None -_DEFAULT_SEND_PROXY_OPTIONS = { +_SENDER_PROXY_ACTOR = None +_DEFAULT_SENDER_PROXY_OPTIONS = { "max_concurrency": 1000, } -def start_send_proxy( +def _start_sender_proxy( cluster: Dict, party: str, logging_level: str, tls_config: Dict = None, proxy_cls=None, - proxy_config: Optional[fed_config.CrossSiloMsgConfig] = None + proxy_config: Optional[fed_config.CrossSiloMessageConfig] = None ): - # Create SendProxyActor - global _SEND_PROXY_ACTOR + # Create SenderProxyActor + global _SENDER_PROXY_ACTOR - actor_options = copy.deepcopy(_DEFAULT_SEND_PROXY_OPTIONS) + actor_options = copy.deepcopy(_DEFAULT_SENDER_PROXY_OPTIONS) if proxy_config and proxy_config.proxy_max_restarts: actor_options.update({ "max_task_retries": proxy_config.proxy_max_restarts, @@ -235,20 +237,20 @@ def start_send_proxy( if proxy_config and proxy_config.send_resource_label: actor_options.update({"resources": proxy_config.send_resource_label}) - logger.debug(f"Starting SendProxyActor with options: {actor_options}") - _SEND_PROXY_ACTOR = SendProxyActor.options( - name="SendProxyActor", **actor_options) + logger.debug(f"Starting SenderProxyActor with options: {actor_options}") + _SENDER_PROXY_ACTOR = SenderProxyActor.options( + name="SenderProxyActor", **actor_options) - _SEND_PROXY_ACTOR = _SEND_PROXY_ACTOR.remote( + _SENDER_PROXY_ACTOR = _SENDER_PROXY_ACTOR.remote( cluster=cluster, party=party, tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls ) - timeout = get_job_config().cross_silo_msg_config.timeout_in_ms / 1000 - assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout) - logger.info("SendProxyActor has successfully created.") + timeout = get_job_config().cross_silo_message_config.timeout_in_ms / 1000 + assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=timeout) + logger.info("SenderProxyActor has successfully created.") def send( @@ -257,8 +259,8 @@ def send( upstream_seq_id, downstream_seq_id, ): - send_proxy = ray.get_actor("SendProxyActor") - res = send_proxy.send.remote( + sender_proxy = ray.get_actor("SenderProxyActor") + res = sender_proxy.send.remote( dest_party=dest_party, data=data, upstream_seq_id=upstream_seq_id, @@ -270,7 +272,7 @@ def send( def recv(party: str, src_party: str, upstream_seq_id, curr_seq_id): assert party, 'Party can not be None.' - receiver_proxy = ray.get_actor(f"RecverProxyActor-{party}") + receiver_proxy = ray.get_actor(f"ReceiverProxyActor-{party}") return receiver_proxy.get_data.remote(src_party, upstream_seq_id, curr_seq_id) diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py index 51c0a2f..ededca8 100644 --- a/fed/proxy/base_proxy.py +++ b/fed/proxy/base_proxy.py @@ -15,16 +15,16 @@ import abc from typing import Dict -from fed.config import CrossSiloMsgConfig +from fed.config import CrossSiloMessageConfig -class SendProxy(abc.ABC): +class SenderProxy(abc.ABC): def __init__( self, cluster: Dict, party: str, tls_config: Dict, - proxy_config: CrossSiloMsgConfig = None + proxy_config: CrossSiloMessageConfig = None ) -> None: self._cluster = cluster self._party = party @@ -48,13 +48,13 @@ async def get_proxy_config(self, dest_party=None): return self._proxy_config -class RecvProxy(abc.ABC): +class ReceiverProxy(abc.ABC): def __init__( self, listen_addr: str, party: str, tls_config: Dict, - proxy_config: CrossSiloMsgConfig = None + proxy_config: CrossSiloMessageConfig = None ) -> None: self._listen_addr = listen_addr self._party = party diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 2502c8f..3831714 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -24,7 +24,7 @@ import fed.utils as fed_utils -from fed.config import CrossSiloMsgConfig, GrpcCrossSiloMsgConfig +from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig import fed._private.compatible_utils as compatible_utils from fed.proxy.grpc.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE from fed.proxy.barriers import ( @@ -33,7 +33,7 @@ pop_from_two_dim_dict, key_exists_in_two_dim_dict, ) -from fed.proxy.base_proxy import SendProxy, RecvProxy +from fed.proxy.base_proxy import SenderProxy, ReceiverProxy if compatible_utils._compare_version_strings( fed_utils.get_package_version('protobuf'), '4.0.0'): from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) -def parse_grpc_options(proxy_config: CrossSiloMsgConfig): +def parse_grpc_options(proxy_config: CrossSiloMessageConfig): """ Extract certain fields in `CrossSiloGrpcCommConfig` into the "grpc_channel_options". Note that the resulting dict's key @@ -54,7 +54,7 @@ def parse_grpc_options(proxy_config: CrossSiloMsgConfig): option name. Args: - proxy_config (CrossSiloMsgConfig): The proxy configuration + proxy_config (CrossSiloMessageConfig): The proxy configuration from which to extract the gRPC options. Returns: @@ -62,8 +62,8 @@ def parse_grpc_options(proxy_config: CrossSiloMsgConfig): """ grpc_channel_options = {} if proxy_config is not None and isinstance( - proxy_config, GrpcCrossSiloMsgConfig): - if isinstance(proxy_config, GrpcCrossSiloMsgConfig): + proxy_config, GrpcCrossSiloMessageConfig): + if isinstance(proxy_config, GrpcCrossSiloMessageConfig): if proxy_config.grpc_channel_options is not None: grpc_channel_options.update(proxy_config.grpc_channel_options) if proxy_config.grpc_retry_policy is not None: @@ -84,13 +84,13 @@ def parse_grpc_options(proxy_config: CrossSiloMsgConfig): return grpc_channel_options -class GrpcSendProxy(SendProxy): +class GrpcSenderProxy(SenderProxy): def __init__( self, cluster: Dict, party: str, tls_config: Dict, - proxy_config: CrossSiloMsgConfig = None + proxy_config: CrossSiloMessageConfig = None ) -> None: super().__init__(cluster, party, tls_config, proxy_config) self._grpc_metadata = proxy_config.http_header or {} @@ -143,7 +143,7 @@ def get_grpc_config_by_party(self, dest_party): grpc_options = self._grpc_options dest_party_msg_config = self._cluster[dest_party].get( - 'cross_silo_msg_config', None) + 'cross_silo_message_config', None) if dest_party_msg_config is not None: if dest_party_msg_config.http_header is not None: dest_party_grpc_metadata = dict(dest_party_msg_config.http_header) @@ -194,13 +194,13 @@ async def send_data_grpc( return response.result -class GrpcRecvProxy(RecvProxy): +class GrpcReceiverProxy(ReceiverProxy): def __init__( self, listen_addr: str, party: str, tls_config: Dict, - proxy_config: CrossSiloMsgConfig + proxy_config: CrossSiloMessageConfig ) -> None: super().__init__(listen_addr, party, tls_config, proxy_config) self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) @@ -319,7 +319,6 @@ async def _run_grpc_server( server.add_secure_port(f'[::]:{port}', server_credentials) else: server.add_insecure_port(f'[::]:{port}') - # server.add_insecure_port(f'[::]:{port}') msg = f"Succeeded to add port {port}." await server.start() diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/tests/serializations_tests/test_unpickle_with_whitelist.py index ec4f846..98f39b5 100644 --- a/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -19,7 +19,7 @@ import multiprocessing import numpy -from fed.config import CrossSiloMsgConfig +from fed.config import CrossSiloMessageConfig @fed.remote @@ -53,7 +53,7 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_msg_config=CrossSiloMsgConfig( + global_cross_silo_message_config=CrossSiloMessageConfig( serializing_allowed_list=allowed_list )) diff --git a/tests/test_cache_fed_objects.py b/tests/test_cache_fed_objects.py index 535e5f5..22b0233 100644 --- a/tests/test_cache_fed_objects.py +++ b/tests/test_cache_fed_objects.py @@ -48,11 +48,11 @@ def run(party): assert c == "hello2" if party == "bob": - proxy_actor = ray.get_actor(f"RecverProxyActor-{party}") + proxy_actor = ray.get_actor(f"ReceiverProxyActor-{party}") stats = ray.get(proxy_actor._get_stats.remote()) assert stats["receive_op_count"] == 1 if party == "alice": - proxy_actor = ray.get_actor("SendProxyActor") + proxy_actor = ray.get_actor("SenderProxyActor") stats = ray.get(proxy_actor._get_stats.remote()) assert stats["send_op_count"] == 1 fed.shutdown() diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 33b3f98..594fc89 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -19,7 +19,7 @@ import fed import fed._private.compatible_utils as compatible_utils -from fed.config import GrpcCrossSiloMsgConfig +from fed.config import GrpcCrossSiloMessageConfig import signal @@ -63,7 +63,7 @@ def run(party, is_inner_party): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } - cross_silo_msg_config = GrpcCrossSiloMsgConfig( + cross_silo_message_config = GrpcCrossSiloMessageConfig( grpc_retry_policy=retry_policy, exit_on_sending_failure=True ) @@ -71,7 +71,7 @@ def run(party, is_inner_party): cluster=cluster, party=party, logging_level='debug', - global_cross_silo_msg_config=cross_silo_msg_config + global_cross_silo_message_config=cross_silo_message_config ) o = f.party("alice").remote() diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index 34735b8..a43007d 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -18,7 +18,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import GrpcCrossSiloMsgConfig +from fed.config import GrpcCrossSiloMessageConfig @fed.remote @@ -35,7 +35,7 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( + global_cross_silo_message_config=GrpcCrossSiloMessageConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] @@ -48,10 +48,10 @@ def _assert_on_proxy(proxy_actor): assert ("grpc.max_send_message_length", 100) in options assert ('grpc.so_reuseport', 0) in options - send_proxy = ray.get_actor("SendProxyActor") - recver_proxy = ray.get_actor(f"RecverProxyActor-{party}") - _assert_on_proxy(send_proxy) - _assert_on_proxy(recver_proxy) + sender_proxy = ray.get_actor("SenderProxyActor") + receiver_proxy = ray.get_actor(f"ReceiverProxyActor-{party}") + _assert_on_proxy(sender_proxy) + _assert_on_proxy(receiver_proxy) a = dummpy.party('alice').remote() b = dummpy.party('bob').remote() diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index 95b1465..018294f 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -18,7 +18,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import GrpcCrossSiloMsgConfig +from fed.config import GrpcCrossSiloMessageConfig @fed.remote @@ -31,7 +31,7 @@ def run(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( + 'cross_silo_message_config': GrpcCrossSiloMessageConfig( grpc_channel_options=[ ('grpc.default_authority', 'alice'), ('grpc.max_send_message_length', 200) @@ -42,14 +42,14 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( + global_cross_silo_message_config=GrpcCrossSiloMessageConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] ) ) - def _assert_on_send_proxy(proxy_actor): + def _assert_on_sender_proxy(proxy_actor): alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) # print(f"【NKcqx】alice config: {alice_config}") assert 'grpc_options' in alice_config @@ -63,8 +63,8 @@ def _assert_on_send_proxy(proxy_actor): assert ('grpc.max_send_message_length', 100) in bob_options assert not any(o[0] == 'grpc.default_authority' for o in bob_options) - send_proxy = ray.get_actor("SendProxyActor") - _assert_on_send_proxy(send_proxy) + sender_proxy = ray.get_actor("SenderProxyActor") + _assert_on_sender_proxy(sender_proxy) a = dummpy.party('alice').remote() b = dummpy.party('bob').remote() @@ -89,7 +89,7 @@ def party_grpc_options(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( + 'cross_silo_message_config': GrpcCrossSiloMessageConfig( grpc_channel_options=[ ('grpc.default_authority', 'alice'), ('grpc.max_send_message_length', 51 * 1024 * 1024) @@ -97,7 +97,7 @@ def party_grpc_options(party): }, 'bob': { 'address': '127.0.0.1:11011', - 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( + 'cross_silo_message_config': GrpcCrossSiloMessageConfig( grpc_channel_options=[ ('grpc.default_authority', 'bob'), ('grpc.max_send_message_length', 50 * 1024 * 1024) @@ -107,14 +107,14 @@ def party_grpc_options(party): fed.init( cluster=cluster, party=party, - global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( + global_cross_silo_message_config=GrpcCrossSiloMessageConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] ) ) - def _assert_on_send_proxy(proxy_actor): + def _assert_on_sender_proxy(proxy_actor): alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) assert 'grpc_options' in alice_config alice_options = alice_config['grpc_options'] @@ -127,8 +127,8 @@ def _assert_on_send_proxy(proxy_actor): assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_options assert ('grpc.default_authority', 'bob') in bob_options - send_proxy = ray.get_actor("SendProxyActor") - _assert_on_send_proxy(send_proxy) + sender_proxy = ray.get_actor("SenderProxyActor") + _assert_on_sender_proxy(sender_proxy) a = dummpy.party('alice').remote() b = dummpy.party('bob').remote() diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index af6d9e8..9450dac 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -20,7 +20,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import GrpcCrossSiloMsgConfig +from fed.config import GrpcCrossSiloMessageConfig @fed.remote @@ -53,7 +53,7 @@ def run(party, is_inner_party): fed.init( cluster=cluster, party=party, - global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( + global_cross_silo_message_config=GrpcCrossSiloMessageConfig( grpc_retry_policy=retry_policy ) ) diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index 00215e3..2a136cb 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -20,7 +20,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloMsgConfig +from fed.config import CrossSiloMessageConfig def run(party): @@ -29,21 +29,21 @@ def run(party): 'alice': {'address': '127.0.0.1:11010'}, 'bob': {'address': '127.0.0.1:11011'}, } - send_proxy_resources = { + sender_proxy_resources = { "127.0.0.1": 1 } - recv_proxy_resources = { + receiver_proxy_resources = { "127.0.0.1": 1 } fed.init( cluster=cluster, party=party, - cross_silo_send_resource_label=send_proxy_resources, - cross_silo_recv_resource_label=recv_proxy_resources, + cross_silo_send_resource_label=sender_proxy_resources, + cross_silo_recv_resource_label=receiver_proxy_resources, ) - assert ray.get_actor("SendProxyActor") is not None - assert ray.get_actor(f"RecverProxyActor-{party}") is not None + assert ray.get_actor("SenderProxyActor") is not None + assert ray.get_actor(f"ReceiverProxyActor-{party}") is not None fed.shutdown() ray.shutdown() @@ -55,19 +55,19 @@ def run_failure(party): 'alice': {'address': '127.0.0.1:11010'}, 'bob': {'address': '127.0.0.1:11011'}, } - send_proxy_resources = { + sender_proxy_resources = { "127.0.0.2": 1 # Insufficient resource } - recv_proxy_resources = { + receiver_proxy_resources = { "127.0.0.2": 1 # Insufficient resource } with pytest.raises(ray.exceptions.GetTimeoutError): fed.init( cluster=cluster, party=party, - global_cross_silo_msg_config=CrossSiloMsgConfig( - send_resource_label=send_proxy_resources, - recv_resource_label=recv_proxy_resources, + global_cross_silo_message_config=CrossSiloMessageConfig( + send_resource_label=sender_proxy_resources, + recv_resource_label=receiver_proxy_resources, timeout_in_ms=10*1000, ) ) diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 54aa3c2..7e9c3fe 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -20,15 +20,15 @@ import fed.utils as fed_utils import fed._private.compatible_utils as compatible_utils -from fed.config import CrossSiloMsgConfig, GrpcCrossSiloMsgConfig +from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig from fed._private import constants from fed._private import global_context from fed.proxy.barriers import ( send, - start_recv_proxy, - start_send_proxy + _start_receiver_proxy, + _start_sender_proxy ) -from fed.proxy.grpc.grpc_proxy import GrpcSendProxy, GrpcRecvProxy +from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, GrpcReceiverProxy if compatible_utils._compare_version_strings( fed_utils.get_package_version('protobuf'), '4.0.0'): from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 @@ -40,8 +40,8 @@ def test_n_to_1_transport(): """This case is used to test that we have N send_op barriers, - sending data to the target recver proxy, and there also have - N receivers to `get_data` from Recver proxy at that time. + sending data to the target receiver proxy, and there also have + N receivers to `get_data` from receiver proxy at that time. """ compatible_utils.init_ray(address='local') @@ -59,29 +59,29 @@ def test_n_to_1_transport(): SERVER_ADDRESS = "127.0.0.1:12344" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} - config = GrpcCrossSiloMsgConfig() - start_recv_proxy( + config = GrpcCrossSiloMessageConfig() + _start_receiver_proxy( cluster_config, party, logging_level='info', - proxy_cls=GrpcRecvProxy, + proxy_cls=GrpcReceiverProxy, proxy_config=config ) - start_send_proxy( + _start_sender_proxy( cluster_config, party, logging_level='info', - proxy_cls=GrpcSendProxy, + proxy_cls=GrpcSenderProxy, proxy_config=config ) sent_objs = [] get_objs = [] - recver_proxy_actor = ray.get_actor(f"RecverProxyActor-{party}") + receiver_proxy_actor = ray.get_actor(f"ReceiverProxyActor-{party}") for i in range(NUM_DATA): sent_obj = send(party, f"data-{i}", i, i + 1) sent_objs.append(sent_obj) - get_obj = recver_proxy_actor.get_data.remote(party, i, i + 1) + get_obj = receiver_proxy_actor.get_data.remote(party, i, i + 1) get_objs.append(get_obj) for result in ray.get(sent_objs): assert result @@ -122,7 +122,7 @@ async def _test_run_grpc_server( @ray.remote -class TestRecverProxyActor: +class TestReceiverProxyActor: def __init__( self, listen_addr: str, @@ -147,7 +147,7 @@ async def is_ready(self): return True -def _test_start_recv_proxy( +def _test_start_receiver_proxy( cluster: str, party: str, logging_level: str, @@ -160,15 +160,15 @@ def _test_start_recv_proxy( if not listen_addr: listen_addr = party_addr['address'] - recver_proxy_actor = TestRecverProxyActor.options( - name=f"RecverProxyActor-{party}", max_concurrency=1000 + receiver_proxy_actor = TestReceiverProxyActor.options( + name=f"ReceiverProxyActor-{party}", max_concurrency=1000 ).remote( listen_addr=listen_addr, party=party, expected_metadata=expected_metadata ) - recver_proxy_actor.run_grpc_server.remote() - assert ray.get(recver_proxy_actor.is_ready.remote()) + receiver_proxy_actor.run_grpc_server.remote() + assert ray.get(receiver_proxy_actor.is_ready.remote()) def test_send_grpc_with_meta(): @@ -179,12 +179,12 @@ def test_send_grpc_with_meta(): constants.KEY_OF_TLS_CONFIG: "", } metadata = {"key": "value"} - send_proxy_config = CrossSiloMsgConfig( + sender_proxy_config = CrossSiloMessageConfig( http_header=metadata ) job_config = { - constants.KEY_OF_CROSS_SILO_MSG_CONFIG: - send_proxy_config, + constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG: + sender_proxy_config, } compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -196,16 +196,16 @@ def test_send_grpc_with_meta(): SERVER_ADDRESS = "127.0.0.1:12344" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} - _test_start_recv_proxy( + _test_start_receiver_proxy( cluster_config, party, logging_level='info', expected_metadata=metadata, ) - start_send_proxy( + _start_sender_proxy( cluster_config, party, logging_level='info', - proxy_cls=GrpcSendProxy, - proxy_config=GrpcCrossSiloMsgConfig()) + proxy_cls=GrpcSenderProxy, + proxy_config=GrpcCrossSiloMessageConfig()) sent_objs = [] sent_obj = send(party, "data", 0, 1) sent_objs.append(sent_obj) @@ -224,11 +224,11 @@ def test_send_grpc_with_party_specific_meta(): constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", } - send_proxy_config = CrossSiloMsgConfig( + sender_proxy_config = CrossSiloMessageConfig( http_header={"key": "value"}) job_config = { - constants.KEY_OF_CROSS_SILO_MSG_CONFIG: - send_proxy_config, + constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG: + sender_proxy_config, } compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -242,20 +242,20 @@ def test_send_grpc_with_party_specific_meta(): cluster_parties_config = { 'test_party': { 'address': SERVER_ADDRESS, - 'cross_silo_msg_config': CrossSiloMsgConfig( + 'cross_silo_message_config': CrossSiloMessageConfig( http_header={"token": "test-party-token"}) } } - _test_start_recv_proxy( + _test_start_receiver_proxy( cluster_parties_config, party, logging_level='info', expected_metadata={"key": "value", "token": "test-party-token"}, ) - start_send_proxy( + _start_sender_proxy( cluster_parties_config, party, logging_level='info', - proxy_cls=GrpcSendProxy, - proxy_config=send_proxy_config) + proxy_cls=GrpcSenderProxy, + proxy_config=sender_proxy_config) sent_objs = [] sent_obj = send(party, "data", 0, 1) sent_objs.append(sent_obj) diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index 71f10ef..4105673 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -21,15 +21,15 @@ import fed._private.compatible_utils as compatible_utils from fed._private import constants from fed._private import global_context -from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy -from fed.proxy.grpc.grpc_proxy import GrpcSendProxy, GrpcRecvProxy -from fed.config import GrpcCrossSiloMsgConfig +from fed.proxy.barriers import send, _start_receiver_proxy, _start_sender_proxy +from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, GrpcReceiverProxy +from fed.config import GrpcCrossSiloMessageConfig def test_n_to_1_transport(): """This case is used to test that we have N send_op barriers, - sending data to the target recver proxy, and there also have - N receivers to `get_data` from Recver proxy at that time. + sending data to the target receiver proxy, and there also have + N receivers to `get_data` from receiver proxy at that time. """ compatible_utils.init_ray(address='local') @@ -57,27 +57,27 @@ def test_n_to_1_transport(): SERVER_ADDRESS = "127.0.0.1:65422" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} - config = GrpcCrossSiloMsgConfig() - start_recv_proxy( + config = GrpcCrossSiloMessageConfig() + _start_receiver_proxy( cluster_config, party, logging_level='info', tls_config=tls_config, - proxy_cls=GrpcRecvProxy, + proxy_cls=GrpcReceiverProxy, proxy_config=config ) - start_send_proxy( + _start_sender_proxy( cluster_config, party, logging_level='info', tls_config=tls_config, - proxy_cls=GrpcSendProxy, + proxy_cls=GrpcSenderProxy, proxy_config=config ) sent_objs = [] get_objs = [] - recver_proxy_actor = ray.get_actor(f"RecverProxyActor-{party}") + receiver_proxy_actor = ray.get_actor(f"ReceiverProxyActor-{party}") for i in range(NUM_DATA): sent_obj = send( party, @@ -86,7 +86,7 @@ def test_n_to_1_transport(): i + 1, ) sent_objs.append(sent_obj) - get_obj = recver_proxy_actor.get_data.remote(party, i, i + 1) + get_obj = receiver_proxy_actor.get_data.remote(party, i, i + 1) get_objs.append(get_obj) for result in ray.get(sent_objs): assert result