Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename functions, variables and classes. #157

Merged
merged 2 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

KEY_OF_TLS_CONFIG = "TLS_CONFIG"

KEY_OF_CROSS_SILO_MSG_CONFIG = "CROSS_SILO_MSG_CONFIG"
KEY_OF_CROSS_SILO_MESSAGE_CONFIG = "CROSS_SILO_MESSAGE_CONFIG"

RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa

Expand Down
2 changes: 1 addition & 1 deletion fed/_private/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _apply_loads_function_with_whitelist():
global _pickle_whitelist

_pickle_whitelist = fed_config.get_job_config() \
.cross_silo_msg_config.serializing_allowed_list
.cross_silo_message_config.serializing_allowed_list
if _pickle_whitelist is None:
return

Expand Down
61 changes: 31 additions & 30 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
ping_others,
recv,
send,
start_recv_proxy,
start_send_proxy,
_start_receiver_proxy,
_start_sender_proxy,
)
from fed.proxy.grpc.grpc_proxy import SendProxy, RecvProxy
from fed.config import CrossSiloMsgConfig
from fed.proxy.grpc.grpc_proxy import SenderProxy, ReceiverProxy
from fed.config import CrossSiloMessageConfig
from fed.fed_object import FedObject
from fed.utils import is_ray_object_refs, setup_logger

Expand All @@ -48,9 +48,9 @@ def init(
tls_config: Dict = None,
logging_level: str = 'info',
enable_waiting_for_other_parties_ready: bool = False,
send_proxy_cls: SendProxy = None,
recv_proxy_cls: RecvProxy = None,
global_cross_silo_msg_config: Optional[CrossSiloMsgConfig] = None,
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
global_cross_silo_message_config: Optional[CrossSiloMessageConfig] = None,
**kwargs,
):
"""
Expand All @@ -67,7 +67,7 @@ def init(
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10001',
'cross_silo_msg_config': CrossSiloMsgConfig
'cross_silo_message_config': CrossSiloMessageConfig
},
'bob': {
# The address for other parties.
Expand Down Expand Up @@ -111,9 +111,9 @@ def init(
`warning`, `error`, `critical`, not case sensititive.
enable_waiting_for_other_parties_ready: ping other parties until they
are all ready if True.
global_cross_silo_msg_config: Global cross-silo message related
global_cross_silo_message_config: Global cross-silo message related
configs that are applied to all connections. Supported configs
can refer to CrossSiloMsgConfig in config.py.
can refer to CrossSiloMessageConfig in config.py.

Examples:
>>> import fed
Expand All @@ -139,8 +139,8 @@ def init(
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

global_cross_silo_msg_config = \
global_cross_silo_msg_config or CrossSiloMsgConfig()
global_cross_silo_message_config = \
global_cross_silo_message_config or CrossSiloMessageConfig()
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

Expand All @@ -151,8 +151,8 @@ def init(
}

job_config = {
constants.KEY_OF_CROSS_SILO_MSG_CONFIG:
global_cross_silo_msg_config,
constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG:
global_cross_silo_message_config,
}
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
Expand All @@ -170,35 +170,36 @@ def init(

logger.info(f'Started rayfed with {cluster_config}')
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=global_cross_silo_msg_config.exit_on_sending_failure)
exit_when_failure_sending=global_cross_silo_message_config.exit_on_sending_failure) # noqa

if recv_proxy_cls is None:
if receiver_proxy_cls is None:
logger.debug(
"Not declaring recver proxy class, using `GrpcRecvProxy` as default.")
from fed.proxy.grpc.grpc_proxy import GrpcRecvProxy
recv_proxy_cls = GrpcRecvProxy
# Start recv proxy
start_recv_proxy(
"There is no receiver proxy class specified, it uses `GrpcRecvProxy` by "
"default.")
from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy
receiver_proxy_cls = GrpcReceiverProxy
_start_receiver_proxy(
cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=recv_proxy_cls,
proxy_config=global_cross_silo_msg_config
proxy_cls=receiver_proxy_cls,
proxy_config=global_cross_silo_message_config
)

if send_proxy_cls is None:
if sender_proxy_cls is None:
logger.debug(
"Not declaring send proxy class, using `GrpcSendProxy` as default.")
from fed.proxy.grpc.grpc_proxy import GrpcSendProxy
send_proxy_cls = GrpcSendProxy
start_send_proxy(
"There is no sender proxy class specified, it uses `GrpcRecvProxy` by "
"default.")
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy
sender_proxy_cls = GrpcSenderProxy
_start_sender_proxy(
cluster=cluster,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=send_proxy_cls,
proxy_config=global_cross_silo_msg_config
proxy_cls=sender_proxy_cls,
proxy_config=global_cross_silo_message_config
)

if enable_waiting_for_other_parties_ready:
Expand Down
28 changes: 15 additions & 13 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def __init__(self, raw_bytes: bytes) -> None:
self._data = cloudpickle.loads(raw_bytes)

@property
def cross_silo_msg_config(self):
def cross_silo_message_config(self):
return self._data.get(
fed_constants.KEY_OF_CROSS_SILO_MSG_CONFIG,
CrossSiloMsgConfig())
fed_constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG,
CrossSiloMessageConfig())


# A module level cache for the cluster configurations.
Expand Down Expand Up @@ -74,22 +74,24 @@ def get_job_config():


@dataclass
class CrossSiloMsgConfig:
class CrossSiloMessageConfig:
"""A class to store parameters used for Proxy Actor

Attributes:
proxy_max_restarts: The max restart times for the send proxy.
serializing_allowed_list: The package or class list allowed for
serializing(deserializating) cross silos. It's used for avoiding pickle
deserializing execution attack when crossing solis.
send_resource_label: Customized resource label, the SendProxyActor
send_resource_label: Customized resource label, the SenderProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the SendProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
recv_resource_label: Customized resource label, the RecverProxyActor
when setting to `{"my_label": 1}`, then the sender proxy actor will be
started only on Nodes with `{"resource": {"my_label": $NUM}}` where
$NUM >= 1.
recv_resource_label: Customized resource label, the ReceiverProxyActor
will be scheduled based on the declared resource label. For example,
when setting to `{"my_label": 1}`, then the RecverProxyActor will be started
only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1.
when setting to `{"my_label": 1}`, then the receiver proxy actor will be
started only on Nodes with `{"resource": {"my_label": $NUM}}` where
$NUM >= 1.
exit_on_sending_failure: whether exit when failure on
cross-silo sending. If True, a SIGTERM will be signaled to self
if failed to sending cross-silo data.
Expand Down Expand Up @@ -121,13 +123,13 @@ def from_json(cls, json_str):

@classmethod
def from_dict(cls, data: Dict):
"""Initialize CrossSiloMsgConfig from a dictionary.
"""Initialize CrossSiloMessageConfig from a dictionary.

Args:
data (Dict): Dictionary with keys as member variable names.

Returns:
CrossSiloMsgConfig: An instance of CrossSiloMsgConfig.
CrossSiloMessageConfig: An instance of CrossSiloMessageConfig.
"""
# Get the attributes of the class
attrs = {attr for attr, _ in cls.__annotations__.items()}
Expand All @@ -137,7 +139,7 @@ def from_dict(cls, data: Dict):


@dataclass
class GrpcCrossSiloMsgConfig(CrossSiloMsgConfig):
class GrpcCrossSiloMessageConfig(CrossSiloMessageConfig):
"""A class to store parameters used for GRPC communication

Attributes:
Expand Down
74 changes: 38 additions & 36 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import fed.config as fed_config
from fed.config import get_job_config
from fed.proxy.base_proxy import SendProxy, RecvProxy
from fed.proxy.base_proxy import SenderProxy, ReceiverProxy
from fed.utils import setup_logger
from fed._private import constants
from fed._private.global_context import get_global_context
Expand Down Expand Up @@ -55,7 +55,7 @@ def pop_from_two_dim_dict(the_dict, key_a, key_b):


@ray.remote
class SendProxyActor:
class SenderProxyActor:
def __init__(
self,
cluster: Dict,
Expand All @@ -75,9 +75,10 @@ def __init__(
self._cluster = cluster
self._party = party
self._tls_config = tls_config
cross_silo_msg_config = fed_config.get_job_config().cross_silo_msg_config
self._proxy_instance: SendProxy = proxy_cls(
cluster, party, tls_config, cross_silo_msg_config)
job_config = fed_config.get_job_config()
cross_silo_message_config = job_config.cross_silo_message_config
self._proxy_instance: SenderProxy = proxy_cls(
cluster, party, tls_config, cross_silo_message_config)

async def is_ready(self):
res = await self._proxy_instance.is_ready()
Expand Down Expand Up @@ -122,7 +123,7 @@ async def _get_proxy_config(self, dest_party=None):


@ray.remote
class RecverProxyActor:
class ReceiverProxyActor:
def __init__(
self,
listen_addr: str,
Expand All @@ -141,9 +142,10 @@ def __init__(
self._listen_addr = listen_addr
self._party = party
self._tls_config = tls_config
cross_silo_msg_config = fed_config.get_job_config().cross_silo_msg_config
self._proxy_instance: RecvProxy = proxy_cls(
listen_addr, party, tls_config, cross_silo_msg_config)
job_config = fed_config.get_job_config()
cross_silo_message_config = job_config.cross_silo_message_config
self._proxy_instance: ReceiverProxy = proxy_cls(
listen_addr, party, tls_config, cross_silo_message_config)

async def start(self):
await self._proxy_instance.start()
Expand All @@ -165,18 +167,18 @@ async def _get_proxy_config(self):
return await self._proxy_instance.get_proxy_config()


_DEFAULT_RECV_PROXY_OPTIONS = {
_DEFAULT_RECEIVER_PROXY_OPTIONS = {
"max_concurrency": 1000,
}


def start_recv_proxy(
def _start_receiver_proxy(
cluster: str,
party: str,
logging_level: str,
tls_config=None,
proxy_cls=None,
proxy_config: Optional[fed_config.CrossSiloMsgConfig] = None
proxy_config: Optional[fed_config.CrossSiloMessageConfig] = None
):

# Create RecevrProxyActor
Expand All @@ -187,46 +189,46 @@ def start_recv_proxy(
if not listen_addr:
listen_addr = party_addr['address']

actor_options = copy.deepcopy(_DEFAULT_RECV_PROXY_OPTIONS)
actor_options = copy.deepcopy(_DEFAULT_RECEIVER_PROXY_OPTIONS)
if proxy_config is not None and proxy_config.recv_resource_label is not None:
actor_options.update({"resources": proxy_config.recv_resource_label})

logger.debug(f"Starting RecvProxyActor with options: {actor_options}")
logger.debug(f"Starting ReceiverProxyActor with options: {actor_options}")

recver_proxy_actor = RecverProxyActor.options(
name=f"RecverProxyActor-{party}", **actor_options
receiver_proxy_actor = ReceiverProxyActor.options(
name=f"ReceiverProxyActor-{party}", **actor_options
).remote(
listen_addr=listen_addr,
party=party,
tls_config=tls_config,
logging_level=logging_level,
proxy_cls=proxy_cls
)
recver_proxy_actor.start.remote()
receiver_proxy_actor.start.remote()
timeout = proxy_config.timeout_in_ms / 1000 if proxy_config is not None else 60
server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout)
server_state = ray.get(receiver_proxy_actor.is_ready.remote(), timeout=timeout)
assert server_state[0], server_state[1]
logger.info("RecverProxy has successfully created.")
logger.info("Succeeded to create receiver proxy actor.")


_SEND_PROXY_ACTOR = None
_DEFAULT_SEND_PROXY_OPTIONS = {
_SENDER_PROXY_ACTOR = None
_DEFAULT_SENDER_PROXY_OPTIONS = {
"max_concurrency": 1000,
}


def start_send_proxy(
def _start_sender_proxy(
cluster: Dict,
party: str,
logging_level: str,
tls_config: Dict = None,
proxy_cls=None,
proxy_config: Optional[fed_config.CrossSiloMsgConfig] = None
proxy_config: Optional[fed_config.CrossSiloMessageConfig] = None
):
# Create SendProxyActor
global _SEND_PROXY_ACTOR
# Create SenderProxyActor
global _SENDER_PROXY_ACTOR

actor_options = copy.deepcopy(_DEFAULT_SEND_PROXY_OPTIONS)
actor_options = copy.deepcopy(_DEFAULT_SENDER_PROXY_OPTIONS)
if proxy_config and proxy_config.proxy_max_restarts:
actor_options.update({
"max_task_retries": proxy_config.proxy_max_restarts,
Expand All @@ -235,20 +237,20 @@ def start_send_proxy(
if proxy_config and proxy_config.send_resource_label:
actor_options.update({"resources": proxy_config.send_resource_label})

logger.debug(f"Starting SendProxyActor with options: {actor_options}")
_SEND_PROXY_ACTOR = SendProxyActor.options(
name="SendProxyActor", **actor_options)
logger.debug(f"Starting SenderProxyActor with options: {actor_options}")
_SENDER_PROXY_ACTOR = SenderProxyActor.options(
name="SenderProxyActor", **actor_options)

_SEND_PROXY_ACTOR = _SEND_PROXY_ACTOR.remote(
_SENDER_PROXY_ACTOR = _SENDER_PROXY_ACTOR.remote(
cluster=cluster,
party=party,
tls_config=tls_config,
logging_level=logging_level,
proxy_cls=proxy_cls
)
timeout = get_job_config().cross_silo_msg_config.timeout_in_ms / 1000
assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout)
logger.info("SendProxyActor has successfully created.")
timeout = get_job_config().cross_silo_message_config.timeout_in_ms / 1000
assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=timeout)
logger.info("SenderProxyActor has successfully created.")


def send(
Expand All @@ -257,8 +259,8 @@ def send(
upstream_seq_id,
downstream_seq_id,
):
send_proxy = ray.get_actor("SendProxyActor")
res = send_proxy.send.remote(
sender_proxy = ray.get_actor("SenderProxyActor")
res = sender_proxy.send.remote(
dest_party=dest_party,
data=data,
upstream_seq_id=upstream_seq_id,
Expand All @@ -270,7 +272,7 @@ def send(

def recv(party: str, src_party: str, upstream_seq_id, curr_seq_id):
assert party, 'Party can not be None.'
receiver_proxy = ray.get_actor(f"RecverProxyActor-{party}")
receiver_proxy = ray.get_actor(f"ReceiverProxyActor-{party}")
return receiver_proxy.get_data.remote(src_party, upstream_seq_id, curr_seq_id)


Expand Down
Loading