Skip to content

Commit

Permalink
Make proxy's RPC framework replaceable (#140)
Browse files Browse the repository at this point in the history
* union rpc related config into one

Signed-off-by: NKcqx <[email protected]>

* union config in KV

Signed-off-by: NKcqx <[email protected]>

* update docstr

Signed-off-by: NKcqx <[email protected]>

* Pluggable cross_silo rpc impl
---------

Signed-off-by: NKcqx <[email protected]>
Signed-off-by: paer <[email protected]>
Co-authored-by: paer <[email protected]>
  • Loading branch information
NKcqx and paer authored Jul 18, 2023
1 parent 316d493 commit a84a2da
Show file tree
Hide file tree
Showing 18 changed files with 810 additions and 524 deletions.
6 changes: 1 addition & 5 deletions fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@

KEY_OF_TLS_CONFIG = "TLS_CONFIG"

KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST = "CROSS_SILO_SERIALIZING_ALLOWED_LIST" # noqa

KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES = "CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES" # noqa

KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS = "CROSS_SILO_TIMEOUT_IN_SECONDS"
KEY_OF_CROSS_SILO_MSG_CONFIG = "CROSS_SILO_MSG_CONFIG"

RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa

Expand Down
3 changes: 2 additions & 1 deletion fed/_private/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def find_class(self, module, name):
def _apply_loads_function_with_whitelist():
global _pickle_whitelist

_pickle_whitelist = fed_config.get_cluster_config().serializing_allowed_list
_pickle_whitelist = fed_config.get_job_config() \
.cross_silo_msg_config.serializing_allowed_list
if _pickle_whitelist is None:
return

Expand Down
104 changes: 32 additions & 72 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import functools
import inspect
import logging
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Optional

import cloudpickle
import ray
Expand All @@ -34,6 +34,8 @@
start_recv_proxy,
start_send_proxy,
)
from fed.proxy.grpc.grpc_proxy import SendProxy, RecvProxy
from fed.config import CrossSiloMsgConfig
from fed.fed_object import FedObject
from fed.utils import is_ray_object_refs, setup_logger

Expand All @@ -45,16 +47,10 @@ def init(
party: str = None,
tls_config: Dict = None,
logging_level: str = 'info',
cross_silo_grpc_retry_policy: Dict = None,
cross_silo_send_max_retries: int = None,
cross_silo_serializing_allowed_list: Dict = None,
cross_silo_send_resource_label: Dict = None,
cross_silo_recv_resource_label: Dict = None,
exit_on_failure_cross_silo_sending: bool = False,
cross_silo_messages_max_size_in_bytes: int = None,
cross_silo_timeout_in_seconds: int = 60,
enable_waiting_for_other_parties_ready: bool = False,
grpc_metadata: Dict = None,
send_proxy_cls: SendProxy = None,
recv_proxy_cls: RecvProxy = None,
global_cross_silo_msg_config: Optional[CrossSiloMsgConfig] = None,
**kwargs,
):
"""
Expand All @@ -71,20 +67,15 @@ def init(
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10001',
# (Optional) The party specific metadata sent with the grpc request
'grpc_metadata': (('token', 'alice-token'),),
'grpc_options': [
('grpc.default_authority', 'alice'),
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
'cross_silo_msg_config': CrossSiloMsgConfig
},
'bob': {
# The address for other parties.
'address': '127.0.0.1:10002',
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10002',
# (Optional) The party specific metadata sent with the grpc request
# (Optional) The party specific metadata sent with grpc requests
'grpc_metadata': (('token', 'bob-token'),),
},
'carol': {
Expand All @@ -93,7 +84,7 @@ def init(
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10003',
# (Optional) The party specific metadata sent with the grpc request
# (Optional) The party specific metadata sent with grpc requests
'grpc_metadata': (('token', 'carol-token'),),
},
}
Expand All @@ -116,48 +107,13 @@ def init(
"cert": "bob's server cert",
"key": "bob's server cert key",
}
logging_level: optional; the logging level, could be `debug`, `info`,
`warning`, `error`, `critical`, not case sensititive.
cross_silo_grpc_retry_policy: a dict descibes the retry policy for
cross silo rpc call. If None, the following default retry policy
will be used. More details please refer to
`retry-policy <https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy>`_. # noqa
.. code:: python
{
"maxAttempts": 4,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": [
"UNAVAILABLE"
]
}
cross_silo_send_max_retries: the max retries for sending data cross silo.
cross_silo_serializing_allowed_list: The package or class list allowed for
serializing(deserializating) cross silos. It's used for avoiding pickle
deserializing execution attack when crossing solis.
cross_silo_send_resource_label: Customized resource label, the SendProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the SendProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
cross_silo_recv_resource_label: Customized resource label, the RecverProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the RecverProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
exit_on_failure_cross_silo_sending: whether exit when failure on
cross-silo sending. If True, a SIGTERM will be signaled to self
if failed to sending cross-silo data.
cross_silo_messages_max_size_in_bytes: The maximum length in bytes of
cross-silo messages.
If None, the default value of 500 MB is specified.
cross_silo_timeout_in_seconds: The timeout in seconds of a cross-silo RPC call.
It's 60 by default.
enable_waiting_for_other_parties_ready: ping other parties until they
are all ready if True.
grpc_metadata: optional; The metadata sent with the grpc request. This won't override
basic tcp headers, such as `user-agent`, but aggregate them together.
global_cross_silo_msg_config: Global cross-silo message related
configs that are applied to all connections. Supported configs
can refer to CrossSiloMsgConfig in config.py.
Examples:
>>> import fed
Expand All @@ -182,22 +138,21 @@ def init(
assert (
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

global_cross_silo_msg_config = \
global_cross_silo_msg_config or CrossSiloMsgConfig()
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

cluster_config = {
constants.KEY_OF_CLUSTER_ADDRESSES: cluster,
constants.KEY_OF_CURRENT_PARTY_NAME: party,
constants.KEY_OF_TLS_CONFIG: tls_config,
constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST:
cross_silo_serializing_allowed_list,
constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES:
cross_silo_messages_max_size_in_bytes,
constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: cross_silo_timeout_in_seconds,
}

job_config = {
constants.KEY_OF_GRPC_METADATA : grpc_metadata,
constants.KEY_OF_CROSS_SILO_MSG_CONFIG:
global_cross_silo_msg_config,
}
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
Expand All @@ -215,30 +170,35 @@ def init(

logger.info(f'Started rayfed with {cluster_config}')
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=exit_on_failure_cross_silo_sending)
exit_when_failure_sending=global_cross_silo_msg_config.exit_on_sending_failure)

recv_actor_config = fed_config.ProxyActorConfig(
resource_label=cross_silo_recv_resource_label)
if recv_proxy_cls is None:
logger.debug(
"Not declaring recver proxy class, using `GrpcRecvProxy` as default.")
from fed.proxy.grpc.grpc_proxy import GrpcRecvProxy
recv_proxy_cls = GrpcRecvProxy
# Start recv proxy
start_recv_proxy(
cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
retry_policy=cross_silo_grpc_retry_policy,
actor_config=recv_actor_config
proxy_cls=recv_proxy_cls,
proxy_config=global_cross_silo_msg_config
)

send_actor_config = fed_config.ProxyActorConfig(
resource_label=cross_silo_send_resource_label)
if send_proxy_cls is None:
logger.debug(
"Not declaring send proxy class, using `GrpcSendProxy` as default.")
from fed.proxy.grpc.grpc_proxy import GrpcSendProxy
send_proxy_cls = GrpcSendProxy
start_send_proxy(
cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
retry_policy=cross_silo_grpc_retry_policy,
max_retries=cross_silo_send_max_retries,
actor_config=send_actor_config
proxy_cls=send_proxy_cls,
proxy_config=global_cross_silo_msg_config
)

if enable_waiting_for_other_parties_ready:
Expand Down
117 changes: 95 additions & 22 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import fed._private.compatible_utils as compatible_utils
import fed._private.constants as fed_constants
import cloudpickle
from typing import Dict, Optional
import json

from typing import Dict, List, Optional
from dataclasses import dataclass


class ClusterConfig:
Expand All @@ -27,18 +30,6 @@ def current_party(self):
def tls_config(self):
return self._data[fed_constants.KEY_OF_TLS_CONFIG]

@property
def serializing_allowed_list(self):
return self._data[fed_constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST]

@property
def cross_silo_timeout(self):
return self._data[fed_constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS]

@property
def cross_silo_messages_max_size(self):
return self._data[fed_constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES]


class JobConfig:
def __init__(self, raw_bytes: bytes) -> None:
Expand All @@ -48,8 +39,10 @@ def __init__(self, raw_bytes: bytes) -> None:
self._data = cloudpickle.loads(raw_bytes)

@property
def grpc_metadata(self):
return self._data.get(fed_constants.KEY_OF_GRPC_METADATA, {})
def cross_silo_msg_config(self):
return self._data.get(
fed_constants.KEY_OF_CROSS_SILO_MSG_CONFIG,
CrossSiloMsgConfig())


# A module level cache for the cluster configurations.
Expand Down Expand Up @@ -80,14 +73,94 @@ def get_job_config():
return _job_config


class ProxyActorConfig:
@dataclass
class CrossSiloMsgConfig:
"""A class to store parameters used for Proxy Actor
Attributes:
resource_label: The customized resources for the actor. This will be
filled into the "resource" field of Ray ActorClass.options.
proxy_max_restarts: The max restart times for the send proxy.
serializing_allowed_list: The package or class list allowed for
serializing(deserializating) cross silos. It's used for avoiding pickle
deserializing execution attack when crossing solis.
send_resource_label: Customized resource label, the SendProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the SendProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
recv_resource_label: Customized resource label, the RecverProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the RecverProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
exit_on_sending_failure: whether exit when failure on
cross-silo sending. If True, a SIGTERM will be signaled to self
if failed to sending cross-silo data.
messages_max_size_in_bytes: The maximum length in bytes of
cross-silo messages.
If None, the default value of 500 MB is specified.
timeout_in_ms: The timeout in mili-seconds of a cross-silo RPC call.
It's 60000 by default.
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request.
This won't override basic tcp headers, such as `user-agent`, but concat
them together.
"""
proxy_max_restarts: int = None
timeout_in_ms: int = 60000
messages_max_size_in_bytes: int = None
exit_on_sending_failure: Optional[bool] = False
serializing_allowed_list: Optional[Dict[str, str]] = None
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
http_header: Optional[Dict[str, str]] = None

def __json__(self):
return json.dumps(self.__dict__)

@classmethod
def from_json(cls, json_str):
data = json.loads(json_str)
return cls(**data)

@classmethod
def from_dict(cls, data: Dict):
"""Initialize CrossSiloMsgConfig from a dictionary.
Args:
data (Dict): Dictionary with keys as member variable names.
Returns:
CrossSiloMsgConfig: An instance of CrossSiloMsgConfig.
"""
# Get the attributes of the class
attrs = {attr for attr, _ in cls.__annotations__.items()}
# Filter the dictionary to only include keys that are attributes of the class
filtered_data = {key: value for key, value in data.items() if key in attrs}
return cls(**filtered_data)


@dataclass
class GrpcCrossSiloMsgConfig(CrossSiloMsgConfig):
"""A class to store parameters used for GRPC communication
Attributes:
grpc_retry_policy: a dict descibes the retry policy for
cross silo rpc call. If None, the following default retry policy
will be used. More details please refer to
`retry-policy <https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy>`_. # noqa
.. code:: python
{
"maxAttempts": 4,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": [
"UNAVAILABLE"
]
}
grpc_channel_options: A list of tuples to store GRPC channel options,
e.g. [
('grpc.enable_retries', 1),
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
"""
def __init__(
self,
resource_label: Optional[Dict[str, str]] = None) -> None:
self.resource_label = resource_label
grpc_channel_options: List = None
grpc_retry_policy: Dict[str, str] = None
Loading

0 comments on commit a84a2da

Please sign in to comment.