diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 862fbaeb..09af728f 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -25,11 +25,7 @@ KEY_OF_TLS_CONFIG = "TLS_CONFIG" -KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST = "CROSS_SILO_SERIALIZING_ALLOWED_LIST" # noqa - -KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES = "CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES" # noqa - -KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS = "CROSS_SILO_TIMEOUT_IN_SECONDS" +KEY_OF_CROSS_SILO_MSG_CONFIG = "CROSS_SILO_MSG_CONFIG" RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa diff --git a/fed/_private/serialization_utils.py b/fed/_private/serialization_utils.py index fafd173e..07182cf6 100644 --- a/fed/_private/serialization_utils.py +++ b/fed/_private/serialization_utils.py @@ -63,7 +63,8 @@ def find_class(self, module, name): def _apply_loads_function_with_whitelist(): global _pickle_whitelist - _pickle_whitelist = fed_config.get_cluster_config().serializing_allowed_list + _pickle_whitelist = fed_config.get_job_config() \ + .cross_silo_msg_config.serializing_allowed_list if _pickle_whitelist is None: return diff --git a/fed/api.py b/fed/api.py index 460fbde6..22dff338 100644 --- a/fed/api.py +++ b/fed/api.py @@ -15,7 +15,7 @@ import functools import inspect import logging -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Optional import cloudpickle import ray @@ -34,6 +34,8 @@ start_recv_proxy, start_send_proxy, ) +from fed.proxy.grpc.grpc_proxy import SendProxy, RecvProxy +from fed.config import CrossSiloMsgConfig from fed.fed_object import FedObject from fed.utils import is_ray_object_refs, setup_logger @@ -45,16 +47,10 @@ def init( party: str = None, tls_config: Dict = None, logging_level: str = 'info', - cross_silo_grpc_retry_policy: Dict = None, - cross_silo_send_max_retries: int = None, - cross_silo_serializing_allowed_list: Dict = None, - cross_silo_send_resource_label: Dict = None, - cross_silo_recv_resource_label: Dict = None, - exit_on_failure_cross_silo_sending: bool = False, - cross_silo_messages_max_size_in_bytes: int = None, - cross_silo_timeout_in_seconds: int = 60, enable_waiting_for_other_parties_ready: bool = False, - grpc_metadata: Dict = None, + send_proxy_cls: SendProxy = None, + recv_proxy_cls: RecvProxy = None, + global_cross_silo_msg_config: Optional[CrossSiloMsgConfig] = None, **kwargs, ): """ @@ -71,12 +67,7 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10001', - # (Optional) The party specific metadata sent with the grpc request - 'grpc_metadata': (('token', 'alice-token'),), - 'grpc_options': [ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 50 * 1024 * 1024) - ] + 'cross_silo_msg_config': CrossSiloMsgConfig }, 'bob': { # The address for other parties. @@ -84,7 +75,7 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10002', - # (Optional) The party specific metadata sent with the grpc request + # (Optional) The party specific metadata sent with grpc requests 'grpc_metadata': (('token', 'bob-token'),), }, 'carol': { @@ -93,7 +84,7 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10003', - # (Optional) The party specific metadata sent with the grpc request + # (Optional) The party specific metadata sent with grpc requests 'grpc_metadata': (('token', 'carol-token'),), }, } @@ -116,48 +107,13 @@ def init( "cert": "bob's server cert", "key": "bob's server cert key", } - logging_level: optional; the logging level, could be `debug`, `info`, `warning`, `error`, `critical`, not case sensititive. - cross_silo_grpc_retry_policy: a dict descibes the retry policy for - cross silo rpc call. If None, the following default retry policy - will be used. More details please refer to - `retry-policy `_. # noqa - - .. code:: python - { - "maxAttempts": 4, - "initialBackoff": "0.1s", - "maxBackoff": "1s", - "backoffMultiplier": 2, - "retryableStatusCodes": [ - "UNAVAILABLE" - ] - } - cross_silo_send_max_retries: the max retries for sending data cross silo. - cross_silo_serializing_allowed_list: The package or class list allowed for - serializing(deserializating) cross silos. It's used for avoiding pickle - deserializing execution attack when crossing solis. - cross_silo_send_resource_label: Customized resource label, the SendProxyActor - will be scheduled based on the declared resource label. For example, - when setting to `{"my_label": 1}`, then the SendProxyActor will be started - only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. - cross_silo_recv_resource_label: Customized resource label, the RecverProxyActor - will be scheduled based on the declared resource label. For example, - when setting to `{"my_label": 1}`, then the RecverProxyActor will be started - only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. - exit_on_failure_cross_silo_sending: whether exit when failure on - cross-silo sending. If True, a SIGTERM will be signaled to self - if failed to sending cross-silo data. - cross_silo_messages_max_size_in_bytes: The maximum length in bytes of - cross-silo messages. - If None, the default value of 500 MB is specified. - cross_silo_timeout_in_seconds: The timeout in seconds of a cross-silo RPC call. - It's 60 by default. enable_waiting_for_other_parties_ready: ping other parties until they are all ready if True. - grpc_metadata: optional; The metadata sent with the grpc request. This won't override - basic tcp headers, such as `user-agent`, but aggregate them together. + global_cross_silo_msg_config: Global cross-silo message related + configs that are applied to all connections. Supported configs + can refer to CrossSiloMsgConfig in config.py. Examples: >>> import fed @@ -182,6 +138,9 @@ def init( assert ( 'cert' in tls_config and 'key' in tls_config ), 'Cert or key are not in tls_config.' + + global_cross_silo_msg_config = \ + global_cross_silo_msg_config or CrossSiloMsgConfig() # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv() @@ -189,15 +148,11 @@ def init( constants.KEY_OF_CLUSTER_ADDRESSES: cluster, constants.KEY_OF_CURRENT_PARTY_NAME: party, constants.KEY_OF_TLS_CONFIG: tls_config, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: - cross_silo_serializing_allowed_list, - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: - cross_silo_messages_max_size_in_bytes, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: cross_silo_timeout_in_seconds, } job_config = { - constants.KEY_OF_GRPC_METADATA : grpc_metadata, + constants.KEY_OF_CROSS_SILO_MSG_CONFIG: + global_cross_silo_msg_config, } compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)) @@ -215,30 +170,35 @@ def init( logger.info(f'Started rayfed with {cluster_config}') get_global_context().get_cleanup_manager().start( - exit_when_failure_sending=exit_on_failure_cross_silo_sending) + exit_when_failure_sending=global_cross_silo_msg_config.exit_on_sending_failure) - recv_actor_config = fed_config.ProxyActorConfig( - resource_label=cross_silo_recv_resource_label) + if recv_proxy_cls is None: + logger.debug( + "Not declaring recver proxy class, using `GrpcRecvProxy` as default.") + from fed.proxy.grpc.grpc_proxy import GrpcRecvProxy + recv_proxy_cls = GrpcRecvProxy # Start recv proxy start_recv_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - retry_policy=cross_silo_grpc_retry_policy, - actor_config=recv_actor_config + proxy_cls=recv_proxy_cls, + proxy_config=global_cross_silo_msg_config ) - send_actor_config = fed_config.ProxyActorConfig( - resource_label=cross_silo_send_resource_label) + if send_proxy_cls is None: + logger.debug( + "Not declaring send proxy class, using `GrpcSendProxy` as default.") + from fed.proxy.grpc.grpc_proxy import GrpcSendProxy + send_proxy_cls = GrpcSendProxy start_send_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - retry_policy=cross_silo_grpc_retry_policy, - max_retries=cross_silo_send_max_retries, - actor_config=send_actor_config + proxy_cls=send_proxy_cls, + proxy_config=global_cross_silo_msg_config ) if enable_waiting_for_other_parties_ready: diff --git a/fed/config.py b/fed/config.py index f3946c16..84af784a 100644 --- a/fed/config.py +++ b/fed/config.py @@ -7,7 +7,10 @@ import fed._private.compatible_utils as compatible_utils import fed._private.constants as fed_constants import cloudpickle -from typing import Dict, Optional +import json + +from typing import Dict, List, Optional +from dataclasses import dataclass class ClusterConfig: @@ -27,18 +30,6 @@ def current_party(self): def tls_config(self): return self._data[fed_constants.KEY_OF_TLS_CONFIG] - @property - def serializing_allowed_list(self): - return self._data[fed_constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST] - - @property - def cross_silo_timeout(self): - return self._data[fed_constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS] - - @property - def cross_silo_messages_max_size(self): - return self._data[fed_constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES] - class JobConfig: def __init__(self, raw_bytes: bytes) -> None: @@ -48,8 +39,10 @@ def __init__(self, raw_bytes: bytes) -> None: self._data = cloudpickle.loads(raw_bytes) @property - def grpc_metadata(self): - return self._data.get(fed_constants.KEY_OF_GRPC_METADATA, {}) + def cross_silo_msg_config(self): + return self._data.get( + fed_constants.KEY_OF_CROSS_SILO_MSG_CONFIG, + CrossSiloMsgConfig()) # A module level cache for the cluster configurations. @@ -80,14 +73,94 @@ def get_job_config(): return _job_config -class ProxyActorConfig: +@dataclass +class CrossSiloMsgConfig: """A class to store parameters used for Proxy Actor Attributes: - resource_label: The customized resources for the actor. This will be - filled into the "resource" field of Ray ActorClass.options. + proxy_max_restarts: The max restart times for the send proxy. + serializing_allowed_list: The package or class list allowed for + serializing(deserializating) cross silos. It's used for avoiding pickle + deserializing execution attack when crossing solis. + send_resource_label: Customized resource label, the SendProxyActor + will be scheduled based on the declared resource label. For example, + when setting to `{"my_label": 1}`, then the SendProxyActor will be started + only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. + recv_resource_label: Customized resource label, the RecverProxyActor + will be scheduled based on the declared resource label. For example, + when setting to `{"my_label": 1}`, then the RecverProxyActor will be started + only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. + exit_on_sending_failure: whether exit when failure on + cross-silo sending. If True, a SIGTERM will be signaled to self + if failed to sending cross-silo data. + messages_max_size_in_bytes: The maximum length in bytes of + cross-silo messages. + If None, the default value of 500 MB is specified. + timeout_in_ms: The timeout in mili-seconds of a cross-silo RPC call. + It's 60000 by default. + http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. + This won't override basic tcp headers, such as `user-agent`, but concat + them together. + """ + proxy_max_restarts: int = None + timeout_in_ms: int = 60000 + messages_max_size_in_bytes: int = None + exit_on_sending_failure: Optional[bool] = False + serializing_allowed_list: Optional[Dict[str, str]] = None + send_resource_label: Optional[Dict[str, str]] = None + recv_resource_label: Optional[Dict[str, str]] = None + http_header: Optional[Dict[str, str]] = None + + def __json__(self): + return json.dumps(self.__dict__) + + @classmethod + def from_json(cls, json_str): + data = json.loads(json_str) + return cls(**data) + + @classmethod + def from_dict(cls, data: Dict): + """Initialize CrossSiloMsgConfig from a dictionary. + + Args: + data (Dict): Dictionary with keys as member variable names. + + Returns: + CrossSiloMsgConfig: An instance of CrossSiloMsgConfig. + """ + # Get the attributes of the class + attrs = {attr for attr, _ in cls.__annotations__.items()} + # Filter the dictionary to only include keys that are attributes of the class + filtered_data = {key: value for key, value in data.items() if key in attrs} + return cls(**filtered_data) + + +@dataclass +class GrpcCrossSiloMsgConfig(CrossSiloMsgConfig): + """A class to store parameters used for GRPC communication + + Attributes: + grpc_retry_policy: a dict descibes the retry policy for + cross silo rpc call. If None, the following default retry policy + will be used. More details please refer to + `retry-policy `_. # noqa + + .. code:: python + { + "maxAttempts": 4, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": [ + "UNAVAILABLE" + ] + } + grpc_channel_options: A list of tuples to store GRPC channel options, + e.g. [ + ('grpc.enable_retries', 1), + ('grpc.max_send_message_length', 50 * 1024 * 1024) + ] """ - def __init__( - self, - resource_label: Optional[Dict[str, str]] = None) -> None: - self.resource_label = resource_label + grpc_channel_options: List = None + grpc_retry_policy: Dict[str, str] = None diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 66028064..a1a65736 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -12,31 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import logging -import threading import time import copy from typing import Dict, Optional -import cloudpickle -import grpc import ray import fed.config as fed_config -import fed.utils as fed_utils -from fed._private import constants -from fed._private.grpc_options import get_grpc_options, set_max_message_length -import fed._private.compatible_utils as compatible_utils -from fed.config import get_cluster_config -if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): - from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 - from fed.grpc import fed_pb2_grpc_in_protobuf4 as fed_pb2_grpc -else: - from fed.grpc import fed_pb2_in_protobuf3 as fed_pb2 - from fed.grpc import fed_pb2_grpc_in_protobuf3 as fed_pb2_grpc +from fed.config import get_job_config +from fed.proxy.base_proxy import SendProxy, RecvProxy from fed.utils import setup_logger +from fed._private import constants from fed._private.global_context import get_global_context logger = logging.getLogger(__name__) @@ -67,96 +54,6 @@ def pop_from_two_dim_dict(the_dict, key_a, key_b): return the_dict[key_a].pop(key_b) -class SendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock): - self._events = all_events - self._all_data = all_data - self._party = party - self._lock = lock - - async def SendData(self, request, context): - upstream_seq_id = request.upstream_seq_id - downstream_seq_id = request.downstream_seq_id - logger.debug( - f'Received a grpc data request from {upstream_seq_id} to ' - f'{downstream_seq_id}.' - ) - - with self._lock: - add_two_dim_dict( - self._all_data, upstream_seq_id, downstream_seq_id, request.data - ) - if not key_exists_in_two_dim_dict( - self._events, upstream_seq_id, downstream_seq_id - ): - event = asyncio.Event() - add_two_dim_dict( - self._events, upstream_seq_id, downstream_seq_id, event - ) - event = get_from_two_dim_dict(self._events, upstream_seq_id, downstream_seq_id) - event.set() - logger.debug(f"Event set for {upstream_seq_id}") - return fed_pb2.SendDataResponse(result="OK") - - -async def _run_grpc_server( - port, event, all_data, party, lock, - server_ready_future, tls_config=None, grpc_options=None -): - server = grpc.aio.server(options=grpc_options) - fed_pb2_grpc.add_GrpcServiceServicer_to_server( - SendDataService(event, all_data, party, lock), server - ) - - tls_enabled = fed_utils.tls_enabled(tls_config) - if tls_enabled: - ca_cert, private_key, cert_chain = fed_utils.load_cert_config(tls_config) - server_credentials = grpc.ssl_server_credentials( - [(private_key, cert_chain)], - root_certificates=ca_cert, - require_client_auth=ca_cert is not None, - ) - server.add_secure_port(f'[::]:{port}', server_credentials) - else: - server.add_insecure_port(f'[::]:{port}') - - msg = f"Succeeded to add port {port}." - await server.start() - logger.info( - f'Successfully start Grpc service with{"out" if not tls_enabled else ""} ' - 'credentials.' - ) - server_ready_future.set_result((True, msg)) - await server.wait_for_termination() - - -async def send_data_grpc( - data, - stub, - upstream_seq_id, - downstream_seq_id, - metadata=None, -): - cluster_config = fed_config.get_cluster_config() - data = cloudpickle.dumps(data) - request = fed_pb2.SendDataRequest( - data=data, - upstream_seq_id=str(upstream_seq_id), - downstream_seq_id=str(downstream_seq_id), - ) - # Waiting for the reply from downstream. - response = await stub.SendData( - request, - metadata=fed_utils.dict2tuple(metadata), - timeout=cluster_config.cross_silo_timeout, - ) - logger.debug( - f'Received data response from seq_id {downstream_seq_id}, ' - f'result: {response.result}.' - ) - return response.result - - @ray.remote class SendProxyActor: def __init__( @@ -165,7 +62,7 @@ def __init__( party: str, tls_config: Dict = None, logging_level: str = None, - retry_policy: Dict = None, + proxy_cls=None ): setup_logger( logging_level=logging_level, @@ -173,19 +70,18 @@ def __init__( date_format=constants.RAYFED_DATE_FMT, party_val=party, ) + self._stats = {"send_op_count": 0} self._cluster = cluster self._party = party self._tls_config = tls_config - self.retry_policy = retry_policy - self._grpc_metadata = fed_config.get_job_config().grpc_metadata - # Mapping the destination party name to the reused client stub. - self._stubs = {} - cluster_config = fed_config.get_cluster_config() - set_max_message_length(cluster_config.cross_silo_messages_max_size) + cross_silo_msg_config = fed_config.get_job_config().cross_silo_msg_config + self._proxy_instance: SendProxy = proxy_cls( + cluster, party, tls_config, cross_silo_msg_config) async def is_ready(self): - return True + res = await self._proxy_instance.is_ready() + return res async def send( self, @@ -206,72 +102,24 @@ async def send( f'Sending {send_log_msg} with{"out" if not self._tls_config else ""}' ' credentials.' ) - dest_addr = self._cluster[dest_party]['address'] - dest_party_grpc_config = self.setup_grpc_config(dest_party) try: - tls_enabled = fed_utils.tls_enabled(self._tls_config) - grpc_options = dest_party_grpc_config['grpc_options'] - grpc_options = get_grpc_options(retry_policy=self.retry_policy) if \ - grpc_options is None else fed_utils.dict2tuple(grpc_options) - - if dest_party not in self._stubs: - if tls_enabled: - ca_cert, private_key, cert_chain = fed_utils.load_cert_config( - self._tls_config) - credentials = grpc.ssl_channel_credentials( - certificate_chain=cert_chain, - private_key=private_key, - root_certificates=ca_cert, - ) - channel = grpc.aio.secure_channel( - dest_addr, credentials, options=grpc_options) - else: - channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options) - stub = fed_pb2_grpc.GrpcServiceStub(channel) - self._stubs[dest_party] = stub - - response = await send_data_grpc( - data=data, - stub=self._stubs[dest_party], - upstream_seq_id=upstream_seq_id, - downstream_seq_id=downstream_seq_id, - metadata=dest_party_grpc_config['grpc_metadata'], - ) + response = await self._proxy_instance.send( + dest_party, data, upstream_seq_id, downstream_seq_id) except Exception as e: logger.error(f'Failed to {send_log_msg}, error: {e}') return False logger.debug(f"Succeeded to send {send_log_msg}. Response is {response}") return True # True indicates it's sent successfully. - def setup_grpc_config(self, dest_party): - dest_party_grpc_config = {} - global_grpc_metadata = ( - dict(self._grpc_metadata) if self._grpc_metadata is not None else {} - ) - dest_party_grpc_metadata = dict( - self._cluster[dest_party].get('grpc_metadata', {}) - ) - # merge grpc metadata - dest_party_grpc_config['grpc_metadata'] = { - **global_grpc_metadata, **dest_party_grpc_metadata} - - global_grpc_options = dict(get_grpc_options(self.retry_policy)) - dest_party_grpc_options = dict( - self._cluster[dest_party].get('grpc_options', {}) - ) - dest_party_grpc_config['grpc_options'] = { - **global_grpc_options, **dest_party_grpc_options} - return dest_party_grpc_config - async def _get_stats(self): return self._stats - async def _get_grpc_options(self): - return get_grpc_options() - async def _get_cluster_info(self): return self._cluster + async def _get_proxy_config(self, dest_party=None): + return await self._proxy_instance.get_proxy_config(dest_party) + @ray.remote class RecverProxyActor: @@ -281,7 +129,7 @@ def __init__( party: str, logging_level: str, tls_config=None, - retry_policy: Dict = None, + proxy_cls=None, ): setup_logger( logging_level=logging_level, @@ -293,70 +141,28 @@ def __init__( self._listen_addr = listen_addr self._party = party self._tls_config = tls_config - self.retry_policy = retry_policy - config = fed_config.get_cluster_config() - set_max_message_length(config.cross_silo_messages_max_size) - # Workaround the threading coordinations - - # Flag to see whether grpc server starts - self._server_ready_future = asyncio.Future() + cross_silo_msg_config = fed_config.get_job_config().cross_silo_msg_config + self._proxy_instance: RecvProxy = proxy_cls( + listen_addr, party, tls_config, cross_silo_msg_config) - # All events for grpc waitting usage. - self._events = {} # map from (upstream_seq_id, downstream_seq_id) to event - self._all_data = {} # map from (upstream_seq_id, downstream_seq_id) to data - self._lock = threading.Lock() - - async def run_grpc_server(self): - try: - port = self._listen_addr[self._listen_addr.index(':') + 1 :] - await _run_grpc_server( - port, - self._events, - self._all_data, - self._party, - self._lock, - self._server_ready_future, - self._tls_config, - get_grpc_options(self.retry_policy), - ) - except RuntimeError as err: - msg = f'Grpc server failed to listen to port: {port}' \ - f' Try another port by setting `listen_addr` into `cluster` config' \ - f' when calling `fed.init`. Grpc error msg: {err}' - self._server_ready_future.set_result((False, msg)) + async def start(self): + await self._proxy_instance.start() async def is_ready(self): - await self._server_ready_future - return self._server_ready_future.result() + res = await self._proxy_instance.is_ready() + return res - async def get_data(self, src_aprty, upstream_seq_id, curr_seq_id): + async def get_data(self, src_party, upstream_seq_id, curr_seq_id): self._stats["receive_op_count"] += 1 - data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_aprty}" - logger.debug(f"Getting {data_log_msg}") - with self._lock: - if not key_exists_in_two_dim_dict( - self._events, upstream_seq_id, curr_seq_id - ): - add_two_dim_dict( - self._events, upstream_seq_id, curr_seq_id, asyncio.Event() - ) - curr_event = get_from_two_dim_dict(self._events, upstream_seq_id, curr_seq_id) - await curr_event.wait() - logging.debug(f"Waited {data_log_msg}.") - with self._lock: - data = pop_from_two_dim_dict(self._all_data, upstream_seq_id, curr_seq_id) - pop_from_two_dim_dict(self._events, upstream_seq_id, curr_seq_id) - - # NOTE(qwang): This is used to avoid the conflict with pickle5 in Ray. - import fed._private.serialization_utils as fed_ser_utils - fed_ser_utils._apply_loads_function_with_whitelist() - return cloudpickle.loads(data) + data = await self._proxy_instance.get_data( + src_party, upstream_seq_id, curr_seq_id) + return data async def _get_stats(self): return self._stats - async def _get_grpc_options(self): - return get_grpc_options() + async def _get_proxy_config(self): + return await self._proxy_instance.get_proxy_config() _DEFAULT_RECV_PROXY_OPTIONS = { @@ -369,8 +175,8 @@ def start_recv_proxy( party: str, logging_level: str, tls_config=None, - retry_policy=None, - actor_config: Optional[fed_config.ProxyActorConfig] = None + proxy_cls=None, + proxy_config: Optional[fed_config.CrossSiloMsgConfig] = None ): # Create RecevrProxyActor @@ -382,8 +188,8 @@ def start_recv_proxy( listen_addr = party_addr['address'] actor_options = copy.deepcopy(_DEFAULT_RECV_PROXY_OPTIONS) - if actor_config is not None and actor_config.resource_label is not None: - actor_options.update({"resources": actor_config.resource_label}) + if proxy_config is not None and proxy_config.recv_resource_label is not None: + actor_options.update({"resources": proxy_config.recv_resource_label}) logger.debug(f"Starting RecvProxyActor with options: {actor_options}") @@ -394,10 +200,10 @@ def start_recv_proxy( party=party, tls_config=tls_config, logging_level=logging_level, - retry_policy=retry_policy, + proxy_cls=proxy_cls ) - recver_proxy_actor.run_grpc_server.remote() - timeout = get_cluster_config().cross_silo_timeout + recver_proxy_actor.start.remote() + timeout = proxy_config.timeout_in_ms / 1000 if proxy_config is not None else 60 server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout) assert server_state[0], server_state[1] logger.info("RecverProxy has successfully created.") @@ -414,21 +220,20 @@ def start_send_proxy( party: str, logging_level: str, tls_config: Dict = None, - retry_policy=None, - max_retries=None, - actor_config: Optional[fed_config.ProxyActorConfig] = None + proxy_cls=None, + proxy_config: Optional[fed_config.CrossSiloMsgConfig] = None ): # Create SendProxyActor global _SEND_PROXY_ACTOR actor_options = copy.deepcopy(_DEFAULT_SEND_PROXY_OPTIONS) - if max_retries is not None: + if proxy_config and proxy_config.proxy_max_restarts: actor_options.update({ - "max_task_retries": max_retries, + "max_task_retries": proxy_config.proxy_max_restarts, "max_restarts": 1, }) - if actor_config is not None and actor_config.resource_label is not None: - actor_options.update({"resources": actor_config.resource_label}) + if proxy_config and proxy_config.send_resource_label: + actor_options.update({"resources": proxy_config.send_resource_label}) logger.debug(f"Starting SendProxyActor with options: {actor_options}") _SEND_PROXY_ACTOR = SendProxyActor.options( @@ -439,9 +244,9 @@ def start_send_proxy( party=party, tls_config=tls_config, logging_level=logging_level, - retry_policy=retry_policy, + proxy_cls=proxy_cls ) - timeout = get_cluster_config().cross_silo_timeout + timeout = get_job_config().cross_silo_msg_config.timeout_in_ms / 1000 assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout) logger.info("SendProxyActor has successfully created.") diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py new file mode 100644 index 00000000..51c0a2fa --- /dev/null +++ b/fed/proxy/base_proxy.py @@ -0,0 +1,80 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Dict + +from fed.config import CrossSiloMsgConfig + + +class SendProxy(abc.ABC): + def __init__( + self, + cluster: Dict, + party: str, + tls_config: Dict, + proxy_config: CrossSiloMsgConfig = None + ) -> None: + self._cluster = cluster + self._party = party + self._tls_config = tls_config + self._proxy_config = proxy_config + + @abc.abstractmethod + async def send( + self, + dest_party, + data, + upstream_seq_id, + downstream_seq_id + ): + pass + + async def is_ready(self): + return True + + async def get_proxy_config(self, dest_party=None): + return self._proxy_config + + +class RecvProxy(abc.ABC): + def __init__( + self, + listen_addr: str, + party: str, + tls_config: Dict, + proxy_config: CrossSiloMsgConfig = None + ) -> None: + self._listen_addr = listen_addr + self._party = party + self._tls_config = tls_config + self._proxy_config = proxy_config + + @abc.abstractmethod + def start(self): + pass + + @abc.abstractmethod + async def get_data( + self, + src_party, + upstream_seq_id, + curr_seq_id): + pass + + async def is_ready(self): + return True + + async def get_proxy_config(self): + return self._proxy_config diff --git a/fed/_private/grpc_options.py b/fed/proxy/grpc/grpc_options.py similarity index 61% rename from fed/_private/grpc_options.py rename to fed/proxy/grpc/grpc_options.py index 9b7d30fd..6e4b2d14 100644 --- a/fed/_private/grpc_options.py +++ b/fed/proxy/grpc/grpc_options.py @@ -14,7 +14,10 @@ import json -_GRPC_RETRY_POLICY = { + +_GRPC_SERVICE = "GrpcService" + +_DEFAULT_GRPC_RETRY_POLICY = { "maxAttempts": 5, "initialBackoff": "5s", "maxBackoff": "30s", @@ -22,49 +25,38 @@ "retryableStatusCodes": ["UNAVAILABLE"], } -_GRPC_SERVICE = "GrpcService" _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH = 500 * 1024 * 1024 _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH = 500 * 1024 * 1024 -_GRPC_MAX_SEND_MESSAGE_LENGTH = _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH -_GRPC_MAX_RECEIVE_MESSAGE_LENGTH = _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH - - -def set_max_message_length(max_size_in_bytes): - """Set the maximum length in bytes of gRPC messages. - - NOTE: The default maximum length is 500MB(500 * 1024 * 1024) - """ - global _GRPC_MAX_SEND_MESSAGE_LENGTH - global _GRPC_MAX_RECEIVE_MESSAGE_LENGTH - if not max_size_in_bytes: - return - if max_size_in_bytes < 0: - raise ValueError("Negative max size is not allowed") - _GRPC_MAX_SEND_MESSAGE_LENGTH = max_size_in_bytes - _GRPC_MAX_RECEIVE_MESSAGE_LENGTH = max_size_in_bytes - - -def get_grpc_max_send_message_length(): - global _GRPC_MAX_SEND_MESSAGE_LENGTH - return _GRPC_MAX_SEND_MESSAGE_LENGTH - - -def get_grpc_max_recieve_message_length(): - global _GRPC_MAX_SEND_MESSAGE_LENGTH - return _GRPC_MAX_SEND_MESSAGE_LENGTH +_DEFAULT_GRPC_CHANNEL_OPTIONS = { + 'grpc.enable_retries': 1, + 'grpc.so_reuseport': 0, + 'grpc.max_send_message_length': _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH, + 'grpc.max_receive_message_length': _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH, + 'grpc.service_config': + json.dumps( + { + 'methodConfig': [ + { + 'name': [{'service': _GRPC_SERVICE}], + 'retryPolicy': _DEFAULT_GRPC_RETRY_POLICY, + } + ] + } + ), +} def get_grpc_options( retry_policy=None, max_send_message_length=None, max_receive_message_length=None ): if not retry_policy: - retry_policy = _GRPC_RETRY_POLICY + retry_policy = _DEFAULT_GRPC_RETRY_POLICY if not max_send_message_length: - max_send_message_length = get_grpc_max_send_message_length() + max_send_message_length = _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH if not max_receive_message_length: - max_receive_message_length = get_grpc_max_recieve_message_length() + max_receive_message_length = _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH return [ ( diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py new file mode 100644 index 00000000..2502c8f4 --- /dev/null +++ b/fed/proxy/grpc/grpc_proxy.py @@ -0,0 +1,331 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import copy +import cloudpickle +import grpc +import logging +import threading +import json +from typing import Dict + + +import fed.utils as fed_utils + +from fed.config import CrossSiloMsgConfig, GrpcCrossSiloMsgConfig +import fed._private.compatible_utils as compatible_utils +from fed.proxy.grpc.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE +from fed.proxy.barriers import ( + add_two_dim_dict, + get_from_two_dim_dict, + pop_from_two_dim_dict, + key_exists_in_two_dim_dict, +) +from fed.proxy.base_proxy import SendProxy, RecvProxy +if compatible_utils._compare_version_strings( + fed_utils.get_package_version('protobuf'), '4.0.0'): + from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 + from fed.grpc import fed_pb2_grpc_in_protobuf4 as fed_pb2_grpc +else: + from fed.grpc import fed_pb2_in_protobuf3 as fed_pb2 + from fed.grpc import fed_pb2_grpc_in_protobuf3 as fed_pb2_grpc + + +logger = logging.getLogger(__name__) + + +def parse_grpc_options(proxy_config: CrossSiloMsgConfig): + """ + Extract certain fields in `CrossSiloGrpcCommConfig` into the + "grpc_channel_options". Note that the resulting dict's key + may not be identical to the config name, but a grpc-supported + option name. + + Args: + proxy_config (CrossSiloMsgConfig): The proxy configuration + from which to extract the gRPC options. + + Returns: + dict: A dictionary containing the gRPC channel options. + """ + grpc_channel_options = {} + if proxy_config is not None and isinstance( + proxy_config, GrpcCrossSiloMsgConfig): + if isinstance(proxy_config, GrpcCrossSiloMsgConfig): + if proxy_config.grpc_channel_options is not None: + grpc_channel_options.update(proxy_config.grpc_channel_options) + if proxy_config.grpc_retry_policy is not None: + grpc_channel_options.update({ + 'grpc.service_config': + json.dumps( + { + 'methodConfig': [ + { + 'name': [{'service': _GRPC_SERVICE}], + 'retryPolicy': proxy_config.grpc_retry_policy, + } + ] + } + ), + }) + + return grpc_channel_options + + +class GrpcSendProxy(SendProxy): + def __init__( + self, + cluster: Dict, + party: str, + tls_config: Dict, + proxy_config: CrossSiloMsgConfig = None + ) -> None: + super().__init__(cluster, party, tls_config, proxy_config) + self._grpc_metadata = proxy_config.http_header or {} + self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) + self._grpc_options.update(parse_grpc_options(self._proxy_config)) + # Mapping the destination party name to the reused client stub. + self._stubs = {} + + async def send( + self, + dest_party, + data, + upstream_seq_id, + downstream_seq_id): + dest_addr = self._cluster[dest_party]['address'] + grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) + tls_enabled = fed_utils.tls_enabled(self._tls_config) + if dest_party not in self._stubs: + if tls_enabled: + ca_cert, private_key, cert_chain = fed_utils.load_cert_config( + self._tls_config) + credentials = grpc.ssl_channel_credentials( + certificate_chain=cert_chain, + private_key=private_key, + root_certificates=ca_cert, + ) + channel = grpc.aio.secure_channel( + dest_addr, credentials, options=grpc_channel_options) + else: + channel = grpc.aio.insecure_channel( + dest_addr, options=grpc_channel_options) + stub = fed_pb2_grpc.GrpcServiceStub(channel) + self._stubs[dest_party] = stub + + timeout = self._proxy_config.timeout_in_ms / 1000 + response = await send_data_grpc( + data=data, + stub=self._stubs[dest_party], + upstream_seq_id=upstream_seq_id, + downstream_seq_id=downstream_seq_id, + timeout=timeout, + metadata=grpc_metadata, + ) + return response + + def get_grpc_config_by_party(self, dest_party): + """Overide global config by party specific config + """ + grpc_metadata = self._grpc_metadata + grpc_options = self._grpc_options + + dest_party_msg_config = self._cluster[dest_party].get( + 'cross_silo_msg_config', None) + if dest_party_msg_config is not None: + if dest_party_msg_config.http_header is not None: + dest_party_grpc_metadata = dict(dest_party_msg_config.http_header) + grpc_metadata = { + **grpc_metadata, + **dest_party_grpc_metadata + } + dest_party_grpc_options = parse_grpc_options(dest_party_msg_config) + grpc_options = { + **grpc_options, **dest_party_grpc_options + } + return grpc_metadata, fed_utils.dict2tuple(grpc_options) + + async def get_proxy_config(self, dest_party=None): + if dest_party is None: + grpc_options = fed_utils.dict2tuple(self._grpc_options) + else: + _, grpc_options = self.get_grpc_config_by_party(dest_party) + proxy_config = self._proxy_config.__dict__ + proxy_config.update({'grpc_options': grpc_options}) + return proxy_config + + +async def send_data_grpc( + data, + stub, + upstream_seq_id, + downstream_seq_id, + timeout, + metadata=None, +): + data = cloudpickle.dumps(data) + request = fed_pb2.SendDataRequest( + data=data, + upstream_seq_id=str(upstream_seq_id), + downstream_seq_id=str(downstream_seq_id), + ) + # Waiting for the reply from downstream. + response = await stub.SendData( + request, + metadata=fed_utils.dict2tuple(metadata), + timeout=timeout, + ) + logger.debug( + f'Received data response from seq_id {downstream_seq_id}, ' + f'result: {response.result}.' + ) + return response.result + + +class GrpcRecvProxy(RecvProxy): + def __init__( + self, + listen_addr: str, + party: str, + tls_config: Dict, + proxy_config: CrossSiloMsgConfig + ) -> None: + super().__init__(listen_addr, party, tls_config, proxy_config) + self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) + self._grpc_options.update(parse_grpc_options(self._proxy_config)) + + # Flag to see whether grpc server starts + self._server_ready_future = asyncio.Future() + + # All events for grpc waitting usage. + self._events = {} # map from (upstream_seq_id, downstream_seq_id) to event + self._all_data = {} # map from (upstream_seq_id, downstream_seq_id) to data + self._lock = threading.Lock() + + async def start(self): + port = self._listen_addr[self._listen_addr.index(':') + 1 :] + try: + await _run_grpc_server( + port, + self._events, + self._all_data, + self._party, + self._lock, + self._server_ready_future, + self._tls_config, + fed_utils.dict2tuple(self._grpc_options), + ) + except RuntimeError as err: + msg = f'Grpc server failed to listen to port: {port}' \ + f' Try another port by setting `listen_addr` into `cluster` config' \ + f' when calling `fed.init`. Grpc error msg: {err}' + self._server_ready_future.set_result((False, msg)) + + async def is_ready(self): + await self._server_ready_future + res = self._server_ready_future.result() + return res + + async def get_data(self, src_party, upstream_seq_id, curr_seq_id): + data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_party}" + logger.debug(f"Getting {data_log_msg}") + with self._lock: + if not key_exists_in_two_dim_dict( + self._events, upstream_seq_id, curr_seq_id + ): + add_two_dim_dict( + self._events, upstream_seq_id, curr_seq_id, asyncio.Event() + ) + curr_event = get_from_two_dim_dict(self._events, upstream_seq_id, curr_seq_id) + await curr_event.wait() + logger.debug(f"Waited {data_log_msg}.") + with self._lock: + data = pop_from_two_dim_dict(self._all_data, upstream_seq_id, curr_seq_id) + pop_from_two_dim_dict(self._events, upstream_seq_id, curr_seq_id) + + # NOTE(qwang): This is used to avoid the conflict with pickle5 in Ray. + import fed._private.serialization_utils as fed_ser_utils + fed_ser_utils._apply_loads_function_with_whitelist() + return cloudpickle.loads(data) + + async def get_proxy_config(self): + proxy_config = self._proxy_config.__dict__ + proxy_config.update({'grpc_options': fed_utils.dict2tuple(self._grpc_options)}) + return proxy_config + + +class SendDataService(fed_pb2_grpc.GrpcServiceServicer): + def __init__(self, all_events, all_data, party, lock): + self._events = all_events + self._all_data = all_data + self._party = party + self._lock = lock + + async def SendData(self, request, context): + upstream_seq_id = request.upstream_seq_id + downstream_seq_id = request.downstream_seq_id + logger.debug( + f'Received a grpc data request from {upstream_seq_id} to ' + f'{downstream_seq_id}.' + ) + + with self._lock: + add_two_dim_dict( + self._all_data, upstream_seq_id, downstream_seq_id, request.data + ) + if not key_exists_in_two_dim_dict( + self._events, upstream_seq_id, downstream_seq_id + ): + event = asyncio.Event() + add_two_dim_dict( + self._events, upstream_seq_id, downstream_seq_id, event + ) + event = get_from_two_dim_dict(self._events, upstream_seq_id, downstream_seq_id) + event.set() + logger.debug(f"Event set for {upstream_seq_id}") + return fed_pb2.SendDataResponse(result="OK") + + +async def _run_grpc_server( + port, event, all_data, party, lock, + server_ready_future, tls_config=None, grpc_options=None +): + print(f"ReceiveProxy binding port {port}, options: {grpc_options}...") + server = grpc.aio.server(options=grpc_options) + fed_pb2_grpc.add_GrpcServiceServicer_to_server( + SendDataService(event, all_data, party, lock), server + ) + + tls_enabled = fed_utils.tls_enabled(tls_config) + if tls_enabled: + ca_cert, private_key, cert_chain = fed_utils.load_cert_config(tls_config) + server_credentials = grpc.ssl_server_credentials( + [(private_key, cert_chain)], + root_certificates=ca_cert, + require_client_auth=ca_cert is not None, + ) + server.add_secure_port(f'[::]:{port}', server_credentials) + else: + server.add_insecure_port(f'[::]:{port}') + # server.add_insecure_port(f'[::]:{port}') + + msg = f"Succeeded to add port {port}." + await server.start() + logger.info( + f'Successfully start Grpc service with{"out" if not tls_enabled else ""} ' + 'credentials.' + ) + server_ready_future.set_result((True, msg)) + await server.wait_for_termination() diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/tests/serializations_tests/test_unpickle_with_whitelist.py index c4d0a458..ec4f8463 100644 --- a/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -19,6 +19,8 @@ import multiprocessing import numpy +from fed.config import CrossSiloMsgConfig + @fed.remote def generate_wrong_type(): @@ -51,7 +53,9 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_serializing_allowed_list=allowed_list) + global_cross_silo_msg_config=CrossSiloMsgConfig( + serializing_allowed_list=allowed_list + )) # Test passing an allowed type. o1 = generate_allowed_type.party("alice").remote() diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 1479a7d6..33b3f98c 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -19,6 +19,8 @@ import fed import fed._private.compatible_utils as compatible_utils +from fed.config import GrpcCrossSiloMsgConfig + import signal import os @@ -61,12 +63,15 @@ def run(party, is_inner_party): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } + cross_silo_msg_config = GrpcCrossSiloMsgConfig( + grpc_retry_policy=retry_policy, + exit_on_sending_failure=True + ) fed.init( cluster=cluster, party=party, logging_level='debug', - cross_silo_grpc_retry_policy=retry_policy, - exit_on_failure_cross_silo_sending=True, + global_cross_silo_msg_config=cross_silo_msg_config ) o = f.party("alice").remote() diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index 993167ce..34735b8d 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -18,6 +18,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import GrpcCrossSiloMsgConfig + @fed.remote def dummpy(): @@ -33,13 +35,17 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_messages_max_size_in_bytes=100, + global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( + grpc_channel_options=[( + 'grpc.max_send_message_length', 100 + )] + ) ) def _assert_on_proxy(proxy_actor): - options = ray.get(proxy_actor._get_grpc_options.remote()) - assert options[0][0] == "grpc.max_send_message_length" - assert options[0][1] == 100 + config = ray.get(proxy_actor._get_proxy_config.remote()) + options = config['grpc_options'] + assert ("grpc.max_send_message_length", 100) in options assert ('grpc.so_reuseport', 0) in options send_proxy = ray.get_actor("SendProxyActor") diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index b81dbc4b..95b1465c 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -18,6 +18,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import GrpcCrossSiloMsgConfig + @fed.remote def dummpy(): @@ -29,38 +31,37 @@ def run(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'grpc_options': [ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 200) - ] + 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'alice'), + ('grpc.max_send_message_length', 200) + ]) }, 'bob': {'address': '127.0.0.1:11011'}, } fed.init( cluster=cluster, party=party, - cross_silo_messages_max_size_in_bytes=100, + global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( + grpc_channel_options=[( + 'grpc.max_send_message_length', 100 + )] + ) ) def _assert_on_send_proxy(proxy_actor): - alice_config = ray.get(proxy_actor.setup_grpc_config.remote('alice')) + alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) # print(f"【NKcqx】alice config: {alice_config}") assert 'grpc_options' in alice_config alice_options = alice_config['grpc_options'] - assert 'grpc.max_send_message_length' in alice_options - # This should be overwritten by cluster config - assert alice_options['grpc.max_send_message_length'] == 200 - assert 'grpc.default_authority' in alice_options - assert alice_options['grpc.default_authority'] == 'alice' - - bob_config = ray.get(proxy_actor.setup_grpc_config.remote('bob')) - # print(f"【NKcqx】bob config: {bob_config}") + assert ('grpc.max_send_message_length', 200) in alice_options + assert ('grpc.default_authority', 'alice') in alice_options + + bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob')) assert 'grpc_options' in bob_config bob_options = bob_config['grpc_options'] - assert "grpc.max_send_message_length" in bob_options - # Not setting bob's grpc_options, should be the same with global - assert bob_options["grpc.max_send_message_length"] == 100 - assert 'grpc.default_authority' not in bob_options + assert ('grpc.max_send_message_length', 100) in bob_options + assert not any(o[0] == 'grpc.default_authority' for o in bob_options) send_proxy = ray.get_actor("SendProxyActor") _assert_on_send_proxy(send_proxy) @@ -83,6 +84,72 @@ def test_grpc_options(): assert p_alice.exitcode == 0 and p_bob.exitcode == 0 +def party_grpc_options(party): + compatible_utils.init_ray(address='local') + cluster = { + 'alice': { + 'address': '127.0.0.1:11010', + 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'alice'), + ('grpc.max_send_message_length', 51 * 1024 * 1024) + ]) + }, + 'bob': { + 'address': '127.0.0.1:11011', + 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'bob'), + ('grpc.max_send_message_length', 50 * 1024 * 1024) + ]) + }, + } + fed.init( + cluster=cluster, + party=party, + global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( + grpc_channel_options=[( + 'grpc.max_send_message_length', 100 + )] + ) + ) + + def _assert_on_send_proxy(proxy_actor): + alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) + assert 'grpc_options' in alice_config + alice_options = alice_config['grpc_options'] + assert ('grpc.max_send_message_length', 51 * 1024 * 1024) in alice_options + assert ('grpc.default_authority', 'alice') in alice_options + + bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob')) + assert 'grpc_options' in bob_config + bob_options = bob_config['grpc_options'] + assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_options + assert ('grpc.default_authority', 'bob') in bob_options + + send_proxy = ray.get_actor("SendProxyActor") + _assert_on_send_proxy(send_proxy) + + a = dummpy.party('alice').remote() + b = dummpy.party('bob').remote() + fed.get([a, b]) + + fed.shutdown() + ray.shutdown() + + +def test_party_specific_grpc_options(): + p_alice = multiprocessing.Process( + target=party_grpc_options, args=('alice',)) + p_bob = multiprocessing.Process( + target=party_grpc_options, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + if __name__ == "__main__": import sys diff --git a/tests/test_listen_addr.py b/tests/test_listen_addr.py index 72753e2a..05960e2f 100644 --- a/tests/test_listen_addr.py +++ b/tests/test_listen_addr.py @@ -72,29 +72,36 @@ def run(party): compatible_utils.init_ray(address='local') occupied_port = 11020 - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # NOTE(NKcqx): Firstly try to bind IPv6 because the grpc server will do so. + # Otherwise this UT will false because socket bind $occupied_port + # on IPv4 address while grpc server listendn Ipv6 address. + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) # Pre-occuping the port - s.bind(("localhost", occupied_port)) - - cluster = { - 'alice': { - 'address': '127.0.0.1:11012', - 'listen_addr': f'0.0.0.0:{occupied_port}'}, - 'bob': { - 'address': '127.0.0.1:11011', - 'listen_addr': '0.0.0.0:11011'}, - } - - # Starting grpc server on an used port will cause AssertionError - with pytest.raises(AssertionError): - fed.init(cluster=cluster, party=party) - - import time - - time.sleep(5) - s.close() - fed.shutdown() - ray.shutdown() + s.bind(("::", occupied_port)) + except OSError: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", occupied_port)) + + cluster = { + 'alice': { + 'address': '127.0.0.1:11012', + 'listen_addr': f'0.0.0.0:{occupied_port}'}, + 'bob': { + 'address': '127.0.0.1:11011', + 'listen_addr': '0.0.0.0:11011'}, + } + + # Starting grpc server on an used port will cause AssertionError + with pytest.raises(AssertionError): + fed.init(cluster=cluster, party=party) + + import time + + time.sleep(5) + s.close() + fed.shutdown() + ray.shutdown() p_alice = multiprocessing.Process(target=run, args=('alice',)) p_alice.start() @@ -103,6 +110,7 @@ def run(party): if __name__ == "__main__": - import sys + # import sys - sys.exit(pytest.main(["-sv", __file__])) + # sys.exit(pytest.main(["-sv", __file__])) + test_listen_used_addr() diff --git a/tests/test_party_specific_grpc_options.py b/tests/test_party_specific_grpc_options.py deleted file mode 100644 index 12936b72..00000000 --- a/tests/test_party_specific_grpc_options.py +++ /dev/null @@ -1,75 +0,0 @@ -import multiprocessing -import pytest -import fed -import fed._private.compatible_utils as compatible_utils -import ray - - -@fed.remote -def dummpy(): - return 2 - - -def party_grpc_options(party): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': { - 'address': '127.0.0.1:11010', - 'grpc_channel_option': [ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 51 * 1024 * 1024) - ]}, - 'bob': { - 'address': '127.0.0.1:11011', - 'grpc_channel_option': [ - ('grpc.default_authority', 'bob'), - ('grpc.max_send_message_length', 50 * 1024 * 1024) - ]}, - } - fed.init( - cluster=cluster, - party=party, - cross_silo_messages_max_size_in_bytes=100 - ) - - def _assert_on_proxy(proxy_actor): - cluster_info = ray.get(proxy_actor._get_cluster_info.remote()) - assert cluster_info['alice'] is not None - assert cluster_info['alice']['grpc_channel_option'] is not None - alice_channel_options = cluster_info['alice']['grpc_channel_option'] - assert ('grpc.default_authority', 'alice') in alice_channel_options - assert ('grpc.max_send_message_length', 51 * 1024 * 1024) in alice_channel_options # noqa - - assert cluster_info['bob'] is not None - assert cluster_info['bob']['grpc_channel_option'] is not None - bob_channel_options = cluster_info['bob']['grpc_channel_option'] - assert ('grpc.default_authority', 'bob') in bob_channel_options - assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_channel_options # noqa - - send_proxy = ray.get_actor("SendProxyActor") - _assert_on_proxy(send_proxy) - - a = dummpy.party('alice').remote() - b = dummpy.party('bob').remote() - fed.get([a, b]) - - fed.shutdown() - ray.shutdown() - - -def test_party_specific_grpc_options(): - p_alice = multiprocessing.Process( - target=party_grpc_options, args=('alice',)) - p_bob = multiprocessing.Process( - target=party_grpc_options, args=('bob',)) - p_alice.start() - p_bob.start() - p_alice.join() - p_bob.join() - assert p_alice.exitcode == 0 and p_bob.exitcode == 0 - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-sv", __file__])) diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 7c33be6a..af6d9e83 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -20,6 +20,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import GrpcCrossSiloMsgConfig + @fed.remote def f(): @@ -51,7 +53,9 @@ def run(party, is_inner_party): fed.init( cluster=cluster, party=party, - cross_silo_grpc_retry_policy=retry_policy, + global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( + grpc_retry_policy=retry_policy + ) ) o = f.party("alice").remote() diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index 5b8cf8ba..00215e38 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -20,6 +20,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import CrossSiloMsgConfig + def run(party): compatible_utils.init_ray(address='local', resources={"127.0.0.1": 2}) @@ -63,9 +65,11 @@ def run_failure(party): fed.init( cluster=cluster, party=party, - cross_silo_send_resource_label=send_proxy_resources, - cross_silo_recv_resource_label=recv_proxy_resources, - cross_silo_timeout_in_seconds=10, # Quick fail in test + global_cross_silo_msg_config=CrossSiloMsgConfig( + send_resource_label=send_proxy_resources, + recv_resource_label=recv_proxy_resources, + timeout_in_ms=10*1000, + ) ) fed.shutdown() diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 6559d2a8..54aa3c24 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -20,8 +20,15 @@ import fed.utils as fed_utils import fed._private.compatible_utils as compatible_utils +from fed.config import CrossSiloMsgConfig, GrpcCrossSiloMsgConfig from fed._private import constants from fed._private import global_context +from fed.proxy.barriers import ( + send, + start_recv_proxy, + start_send_proxy +) +from fed.proxy.grpc.grpc_proxy import GrpcSendProxy, GrpcRecvProxy if compatible_utils._compare_version_strings( fed_utils.get_package_version('protobuf'), '4.0.0'): from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 @@ -29,7 +36,6 @@ else: from fed.grpc import fed_pb2_in_protobuf3 as fed_pb2 from fed.grpc import fed_pb2_grpc_in_protobuf3 as fed_pb2_grpc -from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy def test_n_to_1_transport(): @@ -44,9 +50,6 @@ def test_n_to_1_transport(): constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -56,12 +59,21 @@ def test_n_to_1_transport(): SERVER_ADDRESS = "127.0.0.1:12344" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} + config = GrpcCrossSiloMsgConfig() start_recv_proxy( cluster_config, party, logging_level='info', + proxy_cls=GrpcRecvProxy, + proxy_config=config + ) + start_send_proxy( + cluster_config, + party, + logging_level='info', + proxy_cls=GrpcSendProxy, + proxy_config=config ) - start_send_proxy(cluster_config, party, logging_level='info') sent_objs = [] get_objs = [] @@ -153,7 +165,7 @@ def _test_start_recv_proxy( ).remote( listen_addr=listen_addr, party=party, - expected_metadata=expected_metadata, + expected_metadata=expected_metadata ) recver_proxy_actor.run_grpc_server.remote() assert ray.get(recver_proxy_actor.is_ready.remote()) @@ -165,14 +177,14 @@ def test_send_grpc_with_meta(): constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } + metadata = {"key": "value"} + send_proxy_config = CrossSiloMsgConfig( + http_header=metadata + ) job_config = { - constants.KEY_OF_GRPC_METADATA: { - "key": "value" - } + constants.KEY_OF_CROSS_SILO_MSG_CONFIG: + send_proxy_config, } compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -186,9 +198,14 @@ def test_send_grpc_with_meta(): cluster_config = {'test_party': {'address': SERVER_ADDRESS}} _test_start_recv_proxy( cluster_config, party, logging_level='info', - expected_metadata={"key": "value"}, + expected_metadata=metadata, ) - start_send_proxy(cluster_config, party, logging_level='info') + start_send_proxy( + cluster_config, + party, + logging_level='info', + proxy_cls=GrpcSendProxy, + proxy_config=GrpcCrossSiloMsgConfig()) sent_objs = [] sent_obj = send(party, "data", 0, 1) sent_objs.append(sent_obj) @@ -206,14 +223,12 @@ def test_send_grpc_with_party_specific_meta(): constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } + send_proxy_config = CrossSiloMsgConfig( + http_header={"key": "value"}) job_config = { - constants.KEY_OF_GRPC_METADATA: { - "key": "value" - } + constants.KEY_OF_CROSS_SILO_MSG_CONFIG: + send_proxy_config, } compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -227,14 +242,20 @@ def test_send_grpc_with_party_specific_meta(): cluster_parties_config = { 'test_party': { 'address': SERVER_ADDRESS, - 'grpc_metadata': (('token', 'test-party-token'),) + 'cross_silo_msg_config': CrossSiloMsgConfig( + http_header={"token": "test-party-token"}) } } _test_start_recv_proxy( cluster_parties_config, party, logging_level='info', expected_metadata={"key": "value", "token": "test-party-token"}, ) - start_send_proxy(cluster_parties_config, party, logging_level='info') + start_send_proxy( + cluster_parties_config, + party, + logging_level='info', + proxy_cls=GrpcSendProxy, + proxy_config=send_proxy_config) sent_objs = [] sent_obj = send(party, "data", 0, 1) sent_objs.append(sent_obj) diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index 65d24f9b..71f10ef7 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -22,6 +22,8 @@ from fed._private import constants from fed._private import global_context from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy +from fed.proxy.grpc.grpc_proxy import GrpcSendProxy, GrpcRecvProxy +from fed.config import GrpcCrossSiloMsgConfig def test_n_to_1_transport(): @@ -44,9 +46,6 @@ def test_n_to_1_transport(): constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: tls_config, - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } global_context.get_global_context().get_cleanup_manager().start() @@ -58,17 +57,22 @@ def test_n_to_1_transport(): SERVER_ADDRESS = "127.0.0.1:65422" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} + config = GrpcCrossSiloMsgConfig() start_recv_proxy( cluster_config, party, logging_level='info', tls_config=tls_config, + proxy_cls=GrpcRecvProxy, + proxy_config=config ) start_send_proxy( cluster_config, party, logging_level='info', tls_config=tls_config, + proxy_cls=GrpcSendProxy, + proxy_config=config ) sent_objs = []