Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-Job][2/N] Make proxy's RPC framework replaceable #140

Merged
merged 25 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

KEY_OF_TLS_CONFIG = "TLS_CONFIG"

KEY_OF_CROSS_SILO_COMM_CONFIG = "CROSS_SILO_COMM_CONFIG"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
KEY_OF_CROSS_SILO_COMM_CONFIG = "CROSS_SILO_COMM_CONFIG"
KEY_OF_CROSS_SILO_COMMON_CONFIG = "CROSS_SILO_COMMON_CONFIG"


KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST = "CROSS_SILO_SERIALIZING_ALLOWED_LIST" # noqa

KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES = "CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES" # noqa
Expand Down
3 changes: 2 additions & 1 deletion fed/_private/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def find_class(self, module, name):
def _apply_loads_function_with_whitelist():
global _pickle_whitelist

_pickle_whitelist = fed_config.get_cluster_config().serializing_allowed_list
_pickle_whitelist = fed_config.get_job_config() \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, we should rename the function to get_config(). But we could defer it later.

.cross_silo_comm_config.serializing_allowed_list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be job config?
@zhouaihui @fengsp CC

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we separate ray.init from fed.init, there's no way to reach the cluster-level information, since each fed.init starts and only starts a job session.

Unless there's a global actor (or service job) that can break the job isolation and filter each job's tasks' invalid param type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't have the cluster config, why we not just use config? The question that we should answer before it getting finalized is whether we need the cluster config in the future at high level.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss in #156

if _pickle_whitelist is None:
return

Expand Down
79 changes: 15 additions & 64 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import functools
import inspect
import logging
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Optional

import cloudpickle
import ray
Expand All @@ -29,6 +29,7 @@
from fed._private.global_context import get_global_context, clear_global_context
from fed.barriers import ping_others, recv, send, start_recv_proxy, start_send_proxy
from fed.cleanup import set_exit_on_failure_sending, wait_sending
from fed.config import CrossSiloCommConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the more native name is CommonCrossSiloConfig

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Comm" is short for "Communication" 😅

Copy link
Collaborator

@jovany-wang jovany-wang Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅 So that we should even more not use it as it's confusing though in this conversation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to CrossSiloMsgConfig and GrpcCrossSiloMsgConfig, any other name suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proposing CrossSiloMessageConfig

from fed.fed_object import FedObject
from fed.utils import is_ray_object_refs, setup_logger

Expand All @@ -40,16 +41,8 @@ def init(
party: str = None,
tls_config: Dict = None,
logging_level: str = 'info',
cross_silo_grpc_retry_policy: Dict = None,
cross_silo_send_max_retries: int = None,
cross_silo_serializing_allowed_list: Dict = None,
cross_silo_send_resource_label: Dict = None,
cross_silo_recv_resource_label: Dict = None,
exit_on_failure_cross_silo_sending: bool = False,
cross_silo_messages_max_size_in_bytes: int = None,
cross_silo_timeout_in_seconds: int = 60,
enable_waiting_for_other_parties_ready: bool = False,
grpc_metadata: Dict = None,
cross_silo_comm_config: Optional[CrossSiloCommConfig] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -111,48 +104,12 @@ def init(
"cert": "bob's server cert",
"key": "bob's server cert key",
}

logging_level: optional; the logging level, could be `debug`, `info`,
`warning`, `error`, `critical`, not case sensititive.
cross_silo_grpc_retry_policy: a dict descibes the retry policy for
cross silo rpc call. If None, the following default retry policy
will be used. More details please refer to
`retry-policy <https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy>`_. # noqa

.. code:: python
{
"maxAttempts": 4,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": [
"UNAVAILABLE"
]
}
cross_silo_send_max_retries: the max retries for sending data cross silo.
cross_silo_serializing_allowed_list: The package or class list allowed for
serializing(deserializating) cross silos. It's used for avoiding pickle
deserializing execution attack when crossing solis.
cross_silo_send_resource_label: Customized resource label, the SendProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the SendProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
cross_silo_recv_resource_label: Customized resource label, the RecverProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the RecverProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
exit_on_failure_cross_silo_sending: whether exit when failure on
cross-silo sending. If True, a SIGTERM will be signaled to self
if failed to sending cross-silo data.
cross_silo_messages_max_size_in_bytes: The maximum length in bytes of
cross-silo messages.
If None, the default value of 500 MB is specified.
cross_silo_timeout_in_seconds: The timeout in seconds of a cross-silo RPC call.
It's 60 by default.
enable_waiting_for_other_parties_ready: ping other parties until they
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not match the code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite understand.
If you're saying the functionality, I think it doesn't belong to the scope of this PR.

are all ready if True.
grpc_metadata: optional; The metadata sent with the grpc request. This won't override
basic tcp headers, such as `user-agent`, but aggregate them together.
cross_silo_comm_config: Cross-silo communication related config, supported
configs can refer to CrossSiloCommConfig in config.py

Examples:
>>> import fed
Expand All @@ -177,22 +134,20 @@ def init(
assert (
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

cross_silo_comm_config = cross_silo_comm_config or CrossSiloCommConfig()
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

cluster_config = {
constants.KEY_OF_CLUSTER_ADDRESSES: cluster,
constants.KEY_OF_CURRENT_PARTY_NAME: party,
constants.KEY_OF_TLS_CONFIG: tls_config,
constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST:
cross_silo_serializing_allowed_list,
constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES:
cross_silo_messages_max_size_in_bytes,
constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: cross_silo_timeout_in_seconds,
}

job_config = {
constants.KEY_OF_GRPC_METADATA : grpc_metadata,
constants.KEY_OF_CROSS_SILO_COMM_CONFIG:
cross_silo_comm_config,
}
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
Expand All @@ -209,29 +164,25 @@ def init(
)

logger.info(f'Started rayfed with {cluster_config}')
set_exit_on_failure_sending(exit_on_failure_cross_silo_sending)
recv_actor_config = fed_config.ProxyActorConfig(
resource_label=cross_silo_recv_resource_label)
set_exit_on_failure_sending(cross_silo_comm_config.exit_on_sending_failure)
# Start recv proxy
start_recv_proxy(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not related to this PR changes, but start_recv_proxy function should be a private one.

cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
retry_policy=cross_silo_grpc_retry_policy,
actor_config=recv_actor_config
retry_policy=cross_silo_comm_config.grpc_retry_policy,
actor_config=cross_silo_comm_config
)

send_actor_config = fed_config.ProxyActorConfig(
resource_label=cross_silo_send_resource_label)
start_send_proxy(
cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
retry_policy=cross_silo_grpc_retry_policy,
max_retries=cross_silo_send_max_retries,
actor_config=send_actor_config
retry_policy=cross_silo_comm_config.grpc_retry_policy,
max_retries=cross_silo_comm_config.proxier_fo_max_retries,
actor_config=cross_silo_comm_config
NKcqx marked this conversation as resolved.
Show resolved Hide resolved
)

if enable_waiting_for_other_parties_ready:
Expand Down
36 changes: 19 additions & 17 deletions fed/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from fed._private import constants
from fed._private.grpc_options import get_grpc_options, set_max_message_length
from fed.cleanup import push_to_sending
from fed.config import get_cluster_config
from fed.config import get_job_config
from fed.grpc import fed_pb2, fed_pb2_grpc
from fed.utils import setup_logger

Expand Down Expand Up @@ -136,7 +136,7 @@ async def send_data_grpc(
grpc_options = get_grpc_options(retry_policy=retry_policy) if \
grpc_options is None else fed_utils.dict2tuple(grpc_options)
tls_enabled = fed_utils.tls_enabled(tls_config)
cluster_config = fed_config.get_cluster_config()
timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds
metadata = fed_utils.dict2tuple(metadata)
if tls_enabled:
ca_cert, private_key, cert_chain = fed_utils.load_cert_config(tls_config)
Expand All @@ -160,7 +160,7 @@ async def send_data_grpc(
)
# wait for downstream's reply
response = await stub.SendData(
request, metadata=metadata, timeout=cluster_config.cross_silo_timeout)
request, metadata=metadata, timeout=timeout)
logger.debug(
f'Received data response from seq_id {downstream_seq_id}, '
f'result: {response.result}.'
Expand All @@ -177,7 +177,7 @@ async def send_data_grpc(
)
# wait for downstream's reply
response = await stub.SendData(
request, metadata=metadata, timeout=cluster_config.cross_silo_timeout)
request, metadata=metadata, timeout=timeout)
logger.debug(
f'Received data response from seq_id {downstream_seq_id} '
f'result: {response.result}.'
Expand All @@ -201,14 +201,15 @@ def __init__(
date_format=constants.RAYFED_DATE_FMT,
party_val=party,
)

self._stats = {"send_op_count": 0}
self._cluster = cluster
self._party = party
self._tls_config = tls_config
self.retry_policy = retry_policy
self._grpc_metadata = fed_config.get_job_config().grpc_metadata
cluster_config = fed_config.get_cluster_config()
set_max_message_length(cluster_config.cross_silo_messages_max_size)
cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config
self._grpc_metadata = cross_silo_comm_config.http_header
set_max_message_length(cross_silo_comm_config.messages_max_size_in_bytes)

async def is_ready(self):
return True
Expand All @@ -220,6 +221,7 @@ async def send(
upstream_seq_id,
downstream_seq_id,
):

self._stats["send_op_count"] += 1
assert (
dest_party in self._cluster
Expand Down Expand Up @@ -302,8 +304,8 @@ def __init__(
self._party = party
self._tls_config = tls_config
self.retry_policy = retry_policy
config = fed_config.get_cluster_config()
set_max_message_length(config.cross_silo_messages_max_size)
cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config
set_max_message_length(cross_silo_comm_config.messages_max_size_in_bytes)
# Workaround the threading coordinations

# Flag to see whether grpc server starts
Expand Down Expand Up @@ -378,7 +380,7 @@ def start_recv_proxy(
logging_level: str,
tls_config=None,
retry_policy=None,
actor_config: Optional[fed_config.ProxyActorConfig] = None
actor_config: Optional[fed_config.CrossSiloCommConfig] = None
):

# Create RecevrProxyActor
Expand All @@ -390,8 +392,8 @@ def start_recv_proxy(
listen_addr = party_addr['address']

actor_options = copy.deepcopy(_DEFAULT_RECV_PROXY_OPTIONS)
if actor_config is not None and actor_config.resource_label is not None:
actor_options.update({"resources": actor_config.resource_label})
if actor_config is not None and actor_config.recv_resource_label is not None:
actor_options.update({"resources": actor_config.recv_resource_label})

logger.debug(f"Starting RecvProxyActor with options: {actor_options}")

Expand All @@ -405,7 +407,7 @@ def start_recv_proxy(
retry_policy=retry_policy,
)
recver_proxy_actor.run_grpc_server.remote()
timeout = get_cluster_config().cross_silo_timeout
timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds
server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout)
assert server_state[0], server_state[1]
logger.info("RecverProxy has successfully created.")
Expand All @@ -424,7 +426,7 @@ def start_send_proxy(
tls_config: Dict = None,
retry_policy=None,
max_retries=None,
actor_config: Optional[fed_config.ProxyActorConfig] = None
actor_config: Optional[fed_config.CrossSiloCommConfig] = None
):
# Create SendProxyActor
global _SEND_PROXY_ACTOR
Expand All @@ -435,8 +437,8 @@ def start_send_proxy(
"max_task_retries": max_retries,
"max_restarts": 1,
})
if actor_config is not None and actor_config.resource_label is not None:
actor_options.update({"resources": actor_config.resource_label})
if actor_config is not None and actor_config.send_resource_label is not None:
actor_options.update({"resources": actor_config.send_resource_label})

logger.debug(f"Starting SendProxyActor with options: {actor_options}")
_SEND_PROXY_ACTOR = SendProxyActor.options(
Expand All @@ -449,7 +451,7 @@ def start_send_proxy(
logging_level=logging_level,
retry_policy=retry_policy,
)
timeout = get_cluster_config().cross_silo_timeout
timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds
assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout)
logger.info("SendProxyActor has successfully created.")

Expand Down
74 changes: 67 additions & 7 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import fed._private.constants as fed_constants
import cloudpickle
from typing import Dict, Optional
import json


class ClusterConfig:
Expand Down Expand Up @@ -48,8 +49,8 @@ def __init__(self, raw_bytes: bytes) -> None:
self._data = cloudpickle.loads(raw_bytes)

@property
def grpc_metadata(self):
return self._data.get(fed_constants.KEY_OF_GRPC_METADATA, {})
def cross_silo_comm_config(self):
NKcqx marked this conversation as resolved.
Show resolved Hide resolved
return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, {})
NKcqx marked this conversation as resolved.
Show resolved Hide resolved


# A module level cache for the cluster configurations.
Expand Down Expand Up @@ -80,14 +81,73 @@ def get_job_config():
return _job_config


class ProxyActorConfig:
NKcqx marked this conversation as resolved.
Show resolved Hide resolved
class CrossSiloCommConfig:
"""A class to store parameters used for Proxy Actor

Attributes:
resource_label: The customized resources for the actor. This will be
filled into the "resource" field of Ray ActorClass.options.
grpc_retry_policy: a dict descibes the retry policy for
cross silo rpc call. If None, the following default retry policy
will be used. More details please refer to
`retry-policy <https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy>`_. # noqa

.. code:: python
{
"maxAttempts": 4,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": [
"UNAVAILABLE"
]
}
proxier_fo_max_retries: The max restart times for the send proxy.
NKcqx marked this conversation as resolved.
Show resolved Hide resolved
serializing_allowed_list: The package or class list allowed for
serializing(deserializating) cross silos. It's used for avoiding pickle
deserializing execution attack when crossing solis.
send_resource_label: Customized resource label, the SendProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the SendProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
recv_resource_label: Customized resource label, the RecverProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the RecverProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
exit_on_sending_failure: whether exit when failure on
cross-silo sending. If True, a SIGTERM will be signaled to self
if failed to sending cross-silo data.
messages_max_size_in_bytes: The maximum length in bytes of
NKcqx marked this conversation as resolved.
Show resolved Hide resolved
cross-silo messages.
If None, the default value of 500 MB is specified.
timeout_in_seconds: The timeout in seconds of a cross-silo RPC call.
NKcqx marked this conversation as resolved.
Show resolved Hide resolved
It's 60 by default.
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. This won't override
basic tcp headers, such as `user-agent`, but concat them together.
"""
def __init__(
self,
resource_label: Optional[Dict[str, str]] = None) -> None:
self.resource_label = resource_label
grpc_retry_policy: Dict = None,
proxier_fo_max_retries: int = None,
NKcqx marked this conversation as resolved.
Show resolved Hide resolved
timeout_in_seconds: int = 60,
messages_max_size_in_bytes: int = None,
exit_on_sending_failure: Optional[bool] = False,
serializing_allowed_list: Optional[Dict[str, str]] = None,
send_resource_label: Optional[Dict[str, str]] = None,
recv_resource_label: Optional[Dict[str, str]] = None,
http_header: Optional[Dict[str, str]] = None) -> None:
self.grpc_retry_policy = grpc_retry_policy
self.proxier_fo_max_retries = proxier_fo_max_retries
self.timeout_in_seconds = timeout_in_seconds
self.messages_max_size_in_bytes = messages_max_size_in_bytes
self.exit_on_sending_failure = exit_on_sending_failure
self.serializing_allowed_list = serializing_allowed_list
self.send_resource_label = send_resource_label
self.recv_resource_label = recv_resource_label
self.http_header = http_header

def __json__(self):
return json.dumps(self.__dict__)

@classmethod
def from_json(cls, json_str):
data = json.loads(json_str)
return cls(**data)
Loading