From 652f2ab8ac12ba147bb9263754887ca65e29ee87 Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 25 Jul 2023 15:47:37 +0800 Subject: [PATCH 01/16] tmp save --- fed/_private/compatible_utils.py | 32 +++++++++++++++++++++++--------- fed/_private/global_context.py | 12 +++++++++++- fed/api.py | 17 +++++++++++++---- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index 04fe5535..5b6ca68f 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +import cloudpickle import ray import fed._private.constants as fed_constants @@ -97,8 +98,9 @@ def reset(self): class InternalKv(AbstractInternalKv): """The internal kv class for non Ray client mode. """ - def __init__(self) -> None: + def __init__(self, job_id) -> None: super().__init__() + self._job_id = job_id def initialize(self): try: @@ -114,15 +116,29 @@ def initialize(self): return ray_internal_kv._initialize_internal_kv(gcs_client) def put(self, k, v): - return ray_internal_kv._internal_kv_put(k, v) + content = ray_internal_kv._internal_kv_get(self._job_id) + if content is None: + content = {} + else: + content = cloudpickle.loads(content) + content[k] = v + return ray_internal_kv._internal_kv_put(self._job_id, cloudpickle.dumps(content)) def get(self, k): - return ray_internal_kv._internal_kv_get(k) + content = ray_internal_kv._internal_kv_get(self._job_id) + content = cloudpickle.loads(content) + return content.get(k, None) def delete(self, k): - return ray_internal_kv._internal_kv_del(k) + content = self.get(self._job_id) + if k in content: + del content[k] + self.put(self._job_id, content) + return 1 + return 0 def reset(self): + ray_internal_kv._internal_kv_del(self._job_id) return ray_internal_kv._internal_kv_reset() def _ping(self): @@ -157,25 +173,23 @@ def reset(self): return ray.get(o) -def _init_internal_kv(): +def _init_internal_kv(job_id): """An internal API that initialize the internal kv object.""" global kv if kv is None: from ray._private.client_mode_hook import is_client_mode_enabled if is_client_mode_enabled: kv_actor = ray.remote(InternalKv).options( - name="_INTERNAL_KV_ACTOR").remote() + name="_INTERNAL_KV_ACTOR").remote(job_id) response = kv_actor._ping.remote() ray.get(response) - kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv() + kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_id) kv.initialize() def _clear_internal_kv(): global kv if kv is not None: - kv.delete(constants.KEY_OF_CLUSTER_CONFIG) - kv.delete(constants.KEY_OF_JOB_CONFIG) kv.reset() from ray._private.client_mode_hook import is_client_mode_enabled if is_client_mode_enabled: diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index b0367b10..a465ed07 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -16,7 +16,8 @@ class GlobalContext: - def __init__(self) -> None: + def __init__(self, job_id: str) -> None: + self._job_id = job_id self._seq_count = 0 self._cleanup_manager = CleanupManager() @@ -27,10 +28,19 @@ def next_seq_id(self) -> int: def get_cleanup_manager(self) -> CleanupManager: return self._cleanup_manager + def job_id(self) -> str: + return self._job_id + _global_context = None +def init_global_context(job_id: str) -> None: + global _global_context + if _global_context is None: + _global_context = GlobalContext(job_id) + + def get_global_context(): global _global_context if _global_context is None: diff --git a/fed/api.py b/fed/api.py index fdf97fbe..4768423b 100644 --- a/fed/api.py +++ b/fed/api.py @@ -26,7 +26,11 @@ from fed._private import constants from fed._private.fed_actor import FedActorHandle from fed._private.fed_call_holder import FedCallHolder -from fed._private.global_context import get_global_context, clear_global_context +from fed._private.global_context import ( + init_global_context, + get_global_context, + clear_global_context +) from fed.proxy.barriers import ( ping_others, recv, @@ -50,6 +54,7 @@ def init( logging_level: str = 'info', sender_proxy_cls: SenderProxy = None, receiver_proxy_cls: ReceiverProxy = None, + job_id: str = 'anonymous' ): """ Initialize a RayFed client. @@ -87,7 +92,11 @@ def init( } logging_level: optional; the logging level, could be `debug`, `info`, `warning`, `error`, `critical`, not case sensititive. - + job_id: optional; the job id of the current job. Note that, the job id + must be identical in all parties, otherwise, messages will be ignored + because of the job id mismatch. If the job id is not provided, messages + of this job will not be distinguished from other jobs, which should + only be used in the single job scenario or simulation mode. Examples: >>> import fed >>> import ray @@ -105,7 +114,7 @@ def init( assert party in addresses, f"Party {party} is not in the addresses {addresses}." fed_utils.validate_addresses(addresses) - + init_global_context(job_id=job_id) tls_config = {} if tls_config is None else tls_config if tls_config: assert ( @@ -116,7 +125,7 @@ def init( cross_silo_message_config = GrpcCrossSiloMessageConfig.from_dict( cross_silo_message_dict) # A Ray private accessing, should be replaced in public API. - compatible_utils._init_internal_kv() + compatible_utils._init_internal_kv(job_id) cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: addresses, From 8654dbee271bb33efe470357620ff8bf094b4524 Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 26 Jul 2023 15:16:37 +0800 Subject: [PATCH 02/16] split kv data of jobs Signed-off-by: paer --- fed/_private/compatible_utils.py | 35 ++++++++++++++++---------------- fed/api.py | 27 +++++++++++++----------- fed/config.py | 12 +++++------ fed/proxy/barriers.py | 13 +++++++++--- tests/test_internal_kv.py | 9 ++++++-- 5 files changed, 55 insertions(+), 41 deletions(-) diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index 5b6ca68f..b5941a17 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -67,6 +67,14 @@ def _get_gcs_address_from_ray_worker(): return ray.worker._global_node.gcs_address +def wrap_kv_key(job_id, key): + """Add an prefix to the key to avoid conflict with other jobs. + """ + if (type(key) == bytes): + key = key.decode("utf-8") + return f"RAYFED#{job_id}#{key}".encode("utf-8") + + class AbstractInternalKv(abc.ABC): """ An abstract class that represents for bridging Ray internal kv in both Ray client mode and non Ray client mode. @@ -98,7 +106,7 @@ def reset(self): class InternalKv(AbstractInternalKv): """The internal kv class for non Ray client mode. """ - def __init__(self, job_id) -> None: + def __init__(self, job_id:str) -> None: super().__init__() self._job_id = job_id @@ -116,29 +124,18 @@ def initialize(self): return ray_internal_kv._initialize_internal_kv(gcs_client) def put(self, k, v): - content = ray_internal_kv._internal_kv_get(self._job_id) - if content is None: - content = {} - else: - content = cloudpickle.loads(content) - content[k] = v - return ray_internal_kv._internal_kv_put(self._job_id, cloudpickle.dumps(content)) + return ray_internal_kv._internal_kv_put( + wrap_kv_key(self._job_id, k), v) def get(self, k): - content = ray_internal_kv._internal_kv_get(self._job_id) - content = cloudpickle.loads(content) - return content.get(k, None) + return ray_internal_kv._internal_kv_get( + wrap_kv_key(self._job_id, k)) def delete(self, k): - content = self.get(self._job_id) - if k in content: - del content[k] - self.put(self._job_id, content) - return 1 - return 0 + return ray_internal_kv._internal_kv_del( + wrap_kv_key(self._job_id, k)) def reset(self): - ray_internal_kv._internal_kv_del(self._job_id) return ray_internal_kv._internal_kv_reset() def _ping(self): @@ -190,6 +187,8 @@ def _init_internal_kv(job_id): def _clear_internal_kv(): global kv if kv is not None: + kv.delete(constants.KEY_OF_CLUSTER_CONFIG) + kv.delete(constants.KEY_OF_JOB_CONFIG) kv.reset() from ray._private.client_mode_hook import is_client_mode_enabled if is_client_mode_enabled: diff --git a/fed/api.py b/fed/api.py index 4768423b..fe3535dc 100644 --- a/fed/api.py +++ b/fed/api.py @@ -148,7 +148,7 @@ def init( logging_level=logging_level, logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, - party_val=_get_party(), + party_val=_get_party(job_id), ) logger.info(f'Started rayfed with {cluster_config}') @@ -200,25 +200,25 @@ def shutdown(): logger.info('Shutdowned rayfed.') -def _get_addresses(): +def _get_addresses(job_id: str=None): """ Get the RayFed addresses configration. """ - return fed_config.get_cluster_config().cluster_addresses + return fed_config.get_cluster_config(job_id).cluster_addresses -def _get_party(): +def _get_party(job_id: str=None): """ A private util function to get the current party name. """ - return fed_config.get_cluster_config().current_party + return fed_config.get_cluster_config(job_id).current_party -def _get_tls(): +def _get_tls(job_id: str=None): """ Get the tls configurations on this party. """ - return fed_config.get_cluster_config().tls_config + return fed_config.get_cluster_config(job_id).tls_config class FedRemoteFunction: @@ -272,11 +272,12 @@ def options(self, **options): def remote(self, *cls_args, **cls_kwargs): fed_class_task_id = get_global_context().next_seq_id() + job_id = get_global_context().job_id() fed_actor_handle = FedActorHandle( fed_class_task_id, - _get_addresses(), + _get_addresses(job_id), self._cls, - _get_party(), + _get_party(job_id), self._party, self._options, ) @@ -325,8 +326,9 @@ def get( # A fake fed_task_id for a `fed.get()` operator. This is useful # to help contruct the whole DAG within `fed.get`. fake_fed_task_id = get_global_context().next_seq_id() - addresses = _get_addresses() - current_party = _get_party() + job_id = get_global_context().job_id() + addresses = _get_addresses(job_id) + current_party = _get_party(job_id) is_individual_id = isinstance(fed_objects, FedObject) if is_individual_id: fed_objects = [fed_objects] @@ -381,7 +383,8 @@ def get( def kill(actor: FedActorHandle, *, no_restart=True): - current_party = _get_party() + job_id = get_global_context().job_id() + current_party = _get_party(job_id) if actor._node_party == current_party: handler = actor._actor_handle ray.kill(handler, no_restart=no_restart) diff --git a/fed/config.py b/fed/config.py index 3ad9eb4d..cdabc160 100644 --- a/fed/config.py +++ b/fed/config.py @@ -51,23 +51,23 @@ def cross_silo_message_config(self): _job_config = None -def get_cluster_config(): +def get_cluster_config(job_id: str = None): """This function is not thread safe to use.""" global _cluster_config if _cluster_config is None: - compatible_utils._init_internal_kv() - compatible_utils.kv.initialize() + assert job_id is not None, "Initializing internal kv need to provide job_id." + compatible_utils._init_internal_kv(job_id) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG) _cluster_config = ClusterConfig(raw_dict) return _cluster_config -def get_job_config(): +def get_job_config(job_id: str = None): """This config still acts like cluster config for now""" global _job_config if _job_config is None: - compatible_utils._init_internal_kv() - compatible_utils.kv.initialize() + assert job_id is not None, "Initializing internal kv need to provide job_id." + compatible_utils._init_internal_kv(job_id) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG) _job_config = JobConfig(raw_dict) return _job_config diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 0599cef3..28fb1ed5 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -60,6 +60,7 @@ def __init__( self, addresses: Dict, party: str, + job_id: str, tls_config: Dict = None, logging_level: str = None, proxy_cls=None @@ -74,8 +75,9 @@ def __init__( self._stats = {"send_op_count": 0} self._addresses = addresses self._party = party + self._job_id = job_id self._tls_config = tls_config - job_config = fed_config.get_job_config() + job_config = fed_config.get_job_config(job_id) cross_silo_message_config = job_config.cross_silo_message_config self._proxy_instance: SenderProxy = proxy_cls( addresses, party, tls_config, cross_silo_message_config) @@ -128,6 +130,7 @@ def __init__( self, listening_address: str, party: str, + job_id: str, logging_level: str, tls_config=None, proxy_cls=None, @@ -141,8 +144,9 @@ def __init__( self._stats = {"receive_op_count": 0} self._listening_address = listening_address self._party = party + self._job_id = job_id self._tls_config = tls_config - job_config = fed_config.get_job_config() + job_config = fed_config.get_job_config(job_id) cross_silo_message_config = job_config.cross_silo_message_config self._proxy_instance: ReceiverProxy = proxy_cls( listening_address, party, tls_config, cross_silo_message_config) @@ -200,6 +204,7 @@ def _start_receiver_proxy( ).remote( listening_address=listening_address, party=party, + job_id=get_global_context().job_id(), tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls @@ -241,14 +246,16 @@ def _start_sender_proxy( _SENDER_PROXY_ACTOR = SenderProxyActor.options( name="SenderProxyActor", **actor_options) + job_id = get_global_context().job_id() _SENDER_PROXY_ACTOR = _SENDER_PROXY_ACTOR.remote( addresses=addresses, party=party, + job_id=job_id, tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls ) - timeout = get_job_config().cross_silo_message_config.timeout_in_ms / 1000 + timeout = get_job_config(job_id).cross_silo_message_config.timeout_in_ms / 1000 assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=timeout) logger.info("SenderProxyActor has successfully created.") diff --git a/tests/test_internal_kv.py b/tests/test_internal_kv.py index f2b5372b..326b6ab0 100644 --- a/tests/test_internal_kv.py +++ b/tests/test_internal_kv.py @@ -4,7 +4,7 @@ import fed import time import fed._private.compatible_utils as compatible_utils - +import ray.experimental.internal_kv as ray_internal_kv def run(party): compatible_utils.init_ray("local") @@ -13,11 +13,16 @@ def run(party): 'bob': '127.0.0.1:11011', } assert compatible_utils.kv is None - fed.init(addresses=addresses, party=party) + fed.init(addresses=addresses, party=party, job_id="test_job_id") assert compatible_utils.kv assert not compatible_utils.kv.put(b"test_key", b"test_val") assert compatible_utils.kv.get(b"test_key") == b"test_val" + # Test that a prefix key name is added under the hood. + assert ray_internal_kv._internal_kv_get(b"test_key") is None + assert ray_internal_kv._internal_kv_get( + b"RAYFED#test_job_id#test_key") == b"test_val" + time.sleep(5) fed.shutdown() From 98b348e2d86c161ebe66a8952905153f651abd8c Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 26 Jul 2023 17:42:02 +0800 Subject: [PATCH 03/16] send/recvmsg with job_id Signed-off-by: paer --- fed/grpc/fed.proto | 1 + fed/proxy/barriers.py | 4 ++-- fed/proxy/base_proxy.py | 4 ++++ fed/proxy/grpc/grpc_proxy.py | 33 ++++++++++++++++++++++++++------- 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/fed/grpc/fed.proto b/fed/grpc/fed.proto index f19a8ca2..2dcda338 100644 --- a/fed/grpc/fed.proto +++ b/fed/grpc/fed.proto @@ -10,6 +10,7 @@ message SendDataRequest { bytes data = 1; string upstream_seq_id = 2; string downstream_seq_id = 3; + string job_id = 4; }; message SendDataResponse { diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 28fb1ed5..0ca32b6d 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -80,7 +80,7 @@ def __init__( job_config = fed_config.get_job_config(job_id) cross_silo_message_config = job_config.cross_silo_message_config self._proxy_instance: SenderProxy = proxy_cls( - addresses, party, tls_config, cross_silo_message_config) + addresses, party, job_id, tls_config, cross_silo_message_config) async def is_ready(self): res = await self._proxy_instance.is_ready() @@ -149,7 +149,7 @@ def __init__( job_config = fed_config.get_job_config(job_id) cross_silo_message_config = job_config.cross_silo_message_config self._proxy_instance: ReceiverProxy = proxy_cls( - listening_address, party, tls_config, cross_silo_message_config) + listening_address, party, job_id, tls_config, cross_silo_message_config) async def start(self): await self._proxy_instance.start() diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py index ecd4a336..c98a1fc1 100644 --- a/fed/proxy/base_proxy.py +++ b/fed/proxy/base_proxy.py @@ -23,6 +23,7 @@ def __init__( self, addresses: Dict, party: str, + job_id: str, tls_config: Dict, proxy_config: CrossSiloMessageConfig = None ) -> None: @@ -30,6 +31,7 @@ def __init__( self._party = party self._tls_config = tls_config self._proxy_config = proxy_config + self._job_id = job_id @abc.abstractmethod async def send( @@ -53,6 +55,7 @@ def __init__( self, listen_addr: str, party: str, + job_id: str, tls_config: Dict, proxy_config: CrossSiloMessageConfig = None ) -> None: @@ -60,6 +63,7 @@ def __init__( self._party = party self._tls_config = tls_config self._proxy_config = proxy_config + self._job_id = job_id @abc.abstractmethod def start(self): diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 50ad372f..4911bda0 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -21,7 +21,6 @@ import json from typing import Dict - import fed.utils as fed_utils from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig @@ -89,10 +88,11 @@ def __init__( self, cluster: Dict, party: str, + job_id: str, tls_config: Dict, proxy_config: CrossSiloMessageConfig = None ) -> None: - super().__init__(cluster, party, tls_config, proxy_config) + super().__init__(cluster, party, job_id, tls_config, proxy_config) self._grpc_metadata = proxy_config.http_header or {} self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) self._grpc_options.update(parse_grpc_options(self._proxy_config)) @@ -131,6 +131,7 @@ async def send( stub=self._stubs[dest_party], upstream_seq_id=upstream_seq_id, downstream_seq_id=downstream_seq_id, + job_id=self._job_id, timeout=timeout, metadata=grpc_metadata, ) @@ -172,6 +173,7 @@ async def send_data_grpc( upstream_seq_id, downstream_seq_id, timeout, + job_id, metadata=None, ): data = cloudpickle.dumps(data) @@ -179,6 +181,7 @@ async def send_data_grpc( data=data, upstream_seq_id=str(upstream_seq_id), downstream_seq_id=str(downstream_seq_id), + job_id=job_id, ) # Waiting for the reply from downstream. response = await stub.SendData( @@ -198,10 +201,11 @@ def __init__( self, listen_addr: str, party: str, + job_id: str, tls_config: Dict, proxy_config: CrossSiloMessageConfig ) -> None: - super().__init__(listen_addr, party, tls_config, proxy_config) + super().__init__(listen_addr, party, job_id, tls_config, proxy_config) self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) self._grpc_options.update(parse_grpc_options(self._proxy_config)) @@ -216,12 +220,23 @@ def __init__( async def start(self): port = self._listen_addr[self._listen_addr.index(':') + 1 :] try: + print(f"[Debug] params list: " + f"port: {port}, " + f"self._events: {self._events}, " + f"self._all_data: {self._all_data}, " + f"self._party: {self._party}, " + f"self._lock: {self._lock}, " + f"self._job_id: {self._job_id}, " + f"self._server_ready_future: {self._server_ready_future}, " + f"self._tls_config: {self._tls_config}, " + f"self._grpc_options: {fed_utils.dict2tuple(self._grpc_options)}") await _run_grpc_server( port, self._events, self._all_data, self._party, self._lock, + self._job_id, self._server_ready_future, self._tls_config, fed_utils.dict2tuple(self._grpc_options), @@ -266,13 +281,17 @@ async def get_proxy_config(self): class SendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock): + def __init__(self, all_events, all_data, party, lock, job_id): self._events = all_events self._all_data = all_data self._party = party self._lock = lock + self._job_id = job_id - async def SendData(self, request, context): + async def SendData(self, request): + job_id = request.job_id + if job_id != self._job_id: + return fed_pb2.SendDataResponse(result="ERROR") upstream_seq_id = request.upstream_seq_id downstream_seq_id = request.downstream_seq_id logger.debug( @@ -298,13 +317,13 @@ async def SendData(self, request, context): async def _run_grpc_server( - port, event, all_data, party, lock, + port, event, all_data, party, lock, job_id, server_ready_future, tls_config=None, grpc_options=None ): print(f"ReceiveProxy binding port {port}, options: {grpc_options}...") server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( - SendDataService(event, all_data, party, lock), server + SendDataService(event, all_data, party, lock, job_id), server ) tls_enabled = fed_utils.tls_enabled(tls_config) From 30494d8bf85309efcf2a4b7fd06c036a6bf6d869 Mon Sep 17 00:00:00 2001 From: paer Date: Fri, 28 Jul 2023 17:32:28 +0800 Subject: [PATCH 04/16] retrieve cluster_config by id for normal task Signed-off-by: paer --- fed/_private/fed_call_holder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index 1fc339e9..b4996005 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -46,7 +46,8 @@ def __init__( submit_ray_task_func, options={}, ) -> None: - self._party = fed_config.get_cluster_config().current_party + job_id = get_global_context().job_id() + self._party = fed_config.get_cluster_config(job_id).current_party self._node_party = node_party self._options = options self._submit_ray_task_func = submit_ray_task_func From cb589758ceeaff59d370ffb77752ac60dcfd1e7f Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 9 Aug 2023 14:16:08 +0800 Subject: [PATCH 05/16] job_id in proto --- fed/grpc/pb3/__init__.py | 13 ++++ fed/grpc/pb3/fed_pb2.py | 147 ++++------------------------------- fed/grpc/pb3/fed_pb2_grpc.py | 14 ---- fed/grpc/pb4/__init__.py | 13 ++++ fed/proxy/barriers.py | 2 +- fed/proxy/grpc/grpc_proxy.py | 2 +- 6 files changed, 43 insertions(+), 148 deletions(-) create mode 100644 fed/grpc/pb3/__init__.py create mode 100644 fed/grpc/pb4/__init__.py diff --git a/fed/grpc/pb3/__init__.py b/fed/grpc/pb3/__init__.py new file mode 100644 index 00000000..1d7f5f9e --- /dev/null +++ b/fed/grpc/pb3/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/fed/grpc/pb3/fed_pb2.py b/fed/grpc/pb3/fed_pb2.py index ffc235ee..04181276 100644 --- a/fed/grpc/pb3/fed_pb2.py +++ b/fed/grpc/pb3/fed_pb2.py @@ -1,22 +1,9 @@ -# Copyright 2023 The RayFed Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: fed.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database @@ -27,99 +14,12 @@ -DESCRIPTOR = _descriptor.FileDescriptor( - name='fed.proto', - package='', - syntax='proto3', - serialized_options=b'\200\001\001', - create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\tfed.proto\"S\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3' -) - - - - -_SENDDATAREQUEST = _descriptor.Descriptor( - name='SendDataRequest', - full_name='SendDataRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='data', full_name='SendDataRequest.data', index=0, - number=1, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=b"", - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='upstream_seq_id', full_name='SendDataRequest.upstream_seq_id', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='downstream_seq_id', full_name='SendDataRequest.downstream_seq_id', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=13, - serialized_end=96, -) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"c\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x0e\n\x06job_id\x18\x04 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') -_SENDDATARESPONSE = _descriptor.Descriptor( - name='SendDataResponse', - full_name='SendDataResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='result', full_name='SendDataResponse.result', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=98, - serialized_end=132, -) - -DESCRIPTOR.message_types_by_name['SendDataRequest'] = _SENDDATAREQUEST -DESCRIPTOR.message_types_by_name['SendDataResponse'] = _SENDDATARESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) +_SENDDATAREQUEST = DESCRIPTOR.message_types_by_name['SendDataRequest'] +_SENDDATARESPONSE = DESCRIPTOR.message_types_by_name['SendDataResponse'] SendDataRequest = _reflection.GeneratedProtocolMessageType('SendDataRequest', (_message.Message,), { 'DESCRIPTOR' : _SENDDATAREQUEST, '__module__' : 'fed_pb2' @@ -134,32 +34,15 @@ }) _sym_db.RegisterMessage(SendDataResponse) - -DESCRIPTOR._options = None - -_GRPCSERVICE = _descriptor.ServiceDescriptor( - name='GrpcService', - full_name='GrpcService', - file=DESCRIPTOR, - index=0, - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_start=134, - serialized_end=198, - methods=[ - _descriptor.MethodDescriptor( - name='SendData', - full_name='GrpcService.SendData', - index=0, - containing_service=None, - input_type=_SENDDATAREQUEST, - output_type=_SENDDATARESPONSE, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), -]) -_sym_db.RegisterServiceDescriptor(_GRPCSERVICE) - -DESCRIPTOR.services_by_name['GrpcService'] = _GRPCSERVICE - +_GRPCSERVICE = DESCRIPTOR.services_by_name['GrpcService'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\200\001\001' + _SENDDATAREQUEST._serialized_start=13 + _SENDDATAREQUEST._serialized_end=112 + _SENDDATARESPONSE._serialized_start=114 + _SENDDATARESPONSE._serialized_end=148 + _GRPCSERVICE._serialized_start=150 + _GRPCSERVICE._serialized_end=214 # @@protoc_insertion_point(module_scope) diff --git a/fed/grpc/pb3/fed_pb2_grpc.py b/fed/grpc/pb3/fed_pb2_grpc.py index 55c674d4..830c8b88 100644 --- a/fed/grpc/pb3/fed_pb2_grpc.py +++ b/fed/grpc/pb3/fed_pb2_grpc.py @@ -1,17 +1,3 @@ -# Copyright 2023 The RayFed Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/fed/grpc/pb4/__init__.py b/fed/grpc/pb4/__init__.py new file mode 100644 index 00000000..1d7f5f9e --- /dev/null +++ b/fed/grpc/pb4/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 07d2657a..82c805f0 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -285,7 +285,7 @@ def _start_sender_proxy( logging_level=logging_level, proxy_cls=proxy_cls, ) - timeout = fed_config.get_job_config(job_id).cross_silo_message_config.timeout_in_ms / 1000 + # timeout = fed_config.get_job_config(job_id).cross_silo_comm_config_dict.timeout_in_ms / 1000 # assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=timeout) assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=ready_timeout_second) logger.info("SenderProxyActor has successfully created.") diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index eb058b7e..680f6df5 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -290,7 +290,7 @@ def __init__(self, all_events, all_data, party, lock, job_id): self._lock = lock self._job_id = job_id - async def SendData(self, request): + async def SendData(self, request, context): job_id = request.job_id if job_id != self._job_id: return fed_pb2.SendDataResponse(result="ERROR") From e04b7c39278ecf2c7c40bb9486b0cc1d1eb3f842 Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 9 Aug 2023 14:18:07 +0800 Subject: [PATCH 06/16] rename to job_name --- fed/_private/compatible_utils.py | 20 ++++++++--------- fed/_private/fed_call_holder.py | 4 ++-- fed/_private/global_context.py | 12 +++++----- fed/api.py | 38 ++++++++++++++++---------------- fed/config.py | 12 +++++----- fed/grpc/fed.proto | 2 +- fed/grpc/pb3/fed_pb2.py | 12 +++++----- fed/proxy/barriers.py | 24 ++++++++++---------- fed/proxy/base_proxy.py | 8 +++---- fed/proxy/grpc/grpc_proxy.py | 30 ++++++++++++------------- tests/test_internal_kv.py | 4 ++-- 11 files changed, 83 insertions(+), 83 deletions(-) diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index b5941a17..13de44dc 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -67,12 +67,12 @@ def _get_gcs_address_from_ray_worker(): return ray.worker._global_node.gcs_address -def wrap_kv_key(job_id, key): +def wrap_kv_key(job_name, key): """Add an prefix to the key to avoid conflict with other jobs. """ if (type(key) == bytes): key = key.decode("utf-8") - return f"RAYFED#{job_id}#{key}".encode("utf-8") + return f"RAYFED#{job_name}#{key}".encode("utf-8") class AbstractInternalKv(abc.ABC): @@ -106,9 +106,9 @@ def reset(self): class InternalKv(AbstractInternalKv): """The internal kv class for non Ray client mode. """ - def __init__(self, job_id:str) -> None: + def __init__(self, job_name:str) -> None: super().__init__() - self._job_id = job_id + self._job_name = job_name def initialize(self): try: @@ -125,15 +125,15 @@ def initialize(self): def put(self, k, v): return ray_internal_kv._internal_kv_put( - wrap_kv_key(self._job_id, k), v) + wrap_kv_key(self._job_name, k), v) def get(self, k): return ray_internal_kv._internal_kv_get( - wrap_kv_key(self._job_id, k)) + wrap_kv_key(self._job_name, k)) def delete(self, k): return ray_internal_kv._internal_kv_del( - wrap_kv_key(self._job_id, k)) + wrap_kv_key(self._job_name, k)) def reset(self): return ray_internal_kv._internal_kv_reset() @@ -170,17 +170,17 @@ def reset(self): return ray.get(o) -def _init_internal_kv(job_id): +def _init_internal_kv(job_name): """An internal API that initialize the internal kv object.""" global kv if kv is None: from ray._private.client_mode_hook import is_client_mode_enabled if is_client_mode_enabled: kv_actor = ray.remote(InternalKv).options( - name="_INTERNAL_KV_ACTOR").remote(job_id) + name="_INTERNAL_KV_ACTOR").remote(job_name) response = kv_actor._ping.remote() ray.get(response) - kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_id) + kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name) kv.initialize() diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index b4996005..376f6214 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -46,8 +46,8 @@ def __init__( submit_ray_task_func, options={}, ) -> None: - job_id = get_global_context().job_id() - self._party = fed_config.get_cluster_config(job_id).current_party + job_name = get_global_context().job_name() + self._party = fed_config.get_cluster_config(job_name).current_party self._node_party = node_party self._options = options self._submit_ray_task_func = submit_ray_task_func diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index a465ed07..673bd00d 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -16,8 +16,8 @@ class GlobalContext: - def __init__(self, job_id: str) -> None: - self._job_id = job_id + def __init__(self, job_name: str) -> None: + self._job_name = job_name self._seq_count = 0 self._cleanup_manager = CleanupManager() @@ -28,17 +28,17 @@ def next_seq_id(self) -> int: def get_cleanup_manager(self) -> CleanupManager: return self._cleanup_manager - def job_id(self) -> str: - return self._job_id + def job_name(self) -> str: + return self._job_name _global_context = None -def init_global_context(job_id: str) -> None: +def init_global_context(job_name: str) -> None: global _global_context if _global_context is None: - _global_context = GlobalContext(job_id) + _global_context = GlobalContext(job_name) def get_global_context(): diff --git a/fed/api.py b/fed/api.py index 2de6e6fa..051cce4a 100644 --- a/fed/api.py +++ b/fed/api.py @@ -58,7 +58,7 @@ def init( sender_proxy_cls: SenderProxy = None, receiver_proxy_cls: ReceiverProxy = None, receiver_sender_proxy_cls: SenderReceiverProxy = None, - job_id: str = 'anonymous' + job_name: str = 'anonymous' ): """ Initialize a RayFed client. @@ -96,7 +96,7 @@ def init( } logging_level: optional; the logging level, could be `debug`, `info`, `warning`, `error`, `critical`, not case sensititive. - job_id: optional; the job id of the current job. Note that, the job id + job_name: optional; the job id of the current job. Note that, the job id must be identical in all parties, otherwise, messages will be ignored because of the job id mismatch. If the job id is not provided, messages of this job will not be distinguished from other jobs, which should @@ -118,7 +118,7 @@ def init( assert party in addresses, f"Party {party} is not in the addresses {addresses}." fed_utils.validate_addresses(addresses) - init_global_context(job_id=job_id) + init_global_context(job_name=job_name) tls_config = {} if tls_config is None else tls_config if tls_config: assert ( @@ -127,7 +127,7 @@ def init( cross_silo_comm_dict = config.get("cross_silo_comm", {}) # A Ray private accessing, should be replaced in public API. - compatible_utils._init_internal_kv(job_id) + compatible_utils._init_internal_kv(job_name) cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: addresses, @@ -150,7 +150,7 @@ def init( logging_level=logging_level, logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, - party_val=_get_party(job_id), + party_val=_get_party(job_name), ) logger.info(f'Started rayfed with {cluster_config}') @@ -224,25 +224,25 @@ def shutdown(): logger.info('Shutdowned rayfed.') -def _get_addresses(job_id: str=None): +def _get_addresses(job_name: str=None): """ Get the RayFed addresses configration. """ - return fed_config.get_cluster_config(job_id).cluster_addresses + return fed_config.get_cluster_config(job_name).cluster_addresses -def _get_party(job_id: str=None): +def _get_party(job_name: str=None): """ A private util function to get the current party name. """ - return fed_config.get_cluster_config(job_id).current_party + return fed_config.get_cluster_config(job_name).current_party -def _get_tls(job_id: str=None): +def _get_tls(job_name: str=None): """ Get the tls configurations on this party. """ - return fed_config.get_cluster_config(job_id).tls_config + return fed_config.get_cluster_config(job_name).tls_config class FedRemoteFunction: @@ -296,12 +296,12 @@ def options(self, **options): def remote(self, *cls_args, **cls_kwargs): fed_class_task_id = get_global_context().next_seq_id() - job_id = get_global_context().job_id() + job_name = get_global_context().job_name() fed_actor_handle = FedActorHandle( fed_class_task_id, - _get_addresses(job_id), + _get_addresses(job_name), self._cls, - _get_party(job_id), + _get_party(job_name), self._party, self._options, ) @@ -350,9 +350,9 @@ def get( # A fake fed_task_id for a `fed.get()` operator. This is useful # to help contruct the whole DAG within `fed.get`. fake_fed_task_id = get_global_context().next_seq_id() - job_id = get_global_context().job_id() - addresses = _get_addresses(job_id) - current_party = _get_party(job_id) + job_name = get_global_context().job_name() + addresses = _get_addresses(job_name) + current_party = _get_party(job_name) is_individual_id = isinstance(fed_objects, FedObject) if is_individual_id: fed_objects = [fed_objects] @@ -407,8 +407,8 @@ def get( def kill(actor: FedActorHandle, *, no_restart=True): - job_id = get_global_context().job_id() - current_party = _get_party(job_id) + job_name = get_global_context().job_name() + current_party = _get_party(job_name) if actor._node_party == current_party: handler = actor._actor_handle ray.kill(handler, no_restart=no_restart) diff --git a/fed/config.py b/fed/config.py index 6c95aa01..509fb7b7 100644 --- a/fed/config.py +++ b/fed/config.py @@ -48,23 +48,23 @@ def cross_silo_comm_config_dict(self) -> Dict: _job_config = None -def get_cluster_config(job_id: str = None): +def get_cluster_config(job_name: str = None): """This function is not thread safe to use.""" global _cluster_config if _cluster_config is None: - assert job_id is not None, "Initializing internal kv need to provide job_id." - compatible_utils._init_internal_kv(job_id) + assert job_name is not None, "Initializing internal kv need to provide job_name." + compatible_utils._init_internal_kv(job_name) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG) _cluster_config = ClusterConfig(raw_dict) return _cluster_config -def get_job_config(job_id: str = None): +def get_job_config(job_name: str = None): """This config still acts like cluster config for now""" global _job_config if _job_config is None: - assert job_id is not None, "Initializing internal kv need to provide job_id." - compatible_utils._init_internal_kv(job_id) + assert job_name is not None, "Initializing internal kv need to provide job_name." + compatible_utils._init_internal_kv(job_name) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG) _job_config = JobConfig(raw_dict) return _job_config diff --git a/fed/grpc/fed.proto b/fed/grpc/fed.proto index 2dcda338..d0e2ee64 100644 --- a/fed/grpc/fed.proto +++ b/fed/grpc/fed.proto @@ -10,7 +10,7 @@ message SendDataRequest { bytes data = 1; string upstream_seq_id = 2; string downstream_seq_id = 3; - string job_id = 4; + string job_name = 4; }; message SendDataResponse { diff --git a/fed/grpc/pb3/fed_pb2.py b/fed/grpc/pb3/fed_pb2.py index 04181276..28669ff7 100644 --- a/fed/grpc/pb3/fed_pb2.py +++ b/fed/grpc/pb3/fed_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"c\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x0e\n\x06job_id\x18\x04 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') @@ -40,9 +40,9 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\200\001\001' _SENDDATAREQUEST._serialized_start=13 - _SENDDATAREQUEST._serialized_end=112 - _SENDDATARESPONSE._serialized_start=114 - _SENDDATARESPONSE._serialized_end=148 - _GRPCSERVICE._serialized_start=150 - _GRPCSERVICE._serialized_end=214 + _SENDDATAREQUEST._serialized_end=114 + _SENDDATARESPONSE._serialized_start=116 + _SENDDATARESPONSE._serialized_end=150 + _GRPCSERVICE._serialized_start=152 + _GRPCSERVICE._serialized_end=216 # @@protoc_insertion_point(module_scope) diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 82c805f0..05e43f28 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -82,7 +82,7 @@ def __init__( self, addresses: Dict, party: str, - job_id: str, + job_name: str, tls_config: Dict = None, logging_level: str = None, proxy_cls=None, @@ -97,12 +97,12 @@ def __init__( self._stats = {"send_op_count": 0} self._addresses = addresses self._party = party - self._job_id = job_id + self._job_name = job_name self._tls_config = tls_config - job_config = fed_config.get_job_config(job_id) + job_config = fed_config.get_job_config(job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance: SenderProxy = proxy_cls( - addresses, party, job_id, tls_config, cross_silo_comm_config) + addresses, party, job_name, tls_config, cross_silo_comm_config) async def is_ready(self): res = await self._proxy_instance.is_ready() @@ -153,7 +153,7 @@ def __init__( self, listening_address: str, party: str, - job_id: str, + job_name: str, logging_level: str, tls_config=None, proxy_cls=None, @@ -167,12 +167,12 @@ def __init__( self._stats = {"receive_op_count": 0} self._listening_address = listening_address self._party = party - self._job_id = job_id + self._job_name = job_name self._tls_config = tls_config - job_config = fed_config.get_job_config(job_id) + job_config = fed_config.get_job_config(job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance: ReceiverProxy = proxy_cls( - listening_address, party, job_id, tls_config, cross_silo_comm_config) + listening_address, party, job_name, tls_config, cross_silo_comm_config) async def start(self): await self._proxy_instance.start() @@ -225,7 +225,7 @@ def _start_receiver_proxy( ).remote( listening_address=addresses[party], party=party, - job_id=get_global_context().job_id(), + job_name=get_global_context().job_name(), tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls, @@ -276,16 +276,16 @@ def _start_sender_proxy( name=_SENDER_PROXY_ACTOR_NAME, **actor_options ) - job_id = get_global_context().job_id() + job_name = get_global_context().job_name() _SENDER_PROXY_ACTOR = _SENDER_PROXY_ACTOR.remote( addresses=addresses, party=party, - job_id=job_id, + job_name=job_name, tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls, ) - # timeout = fed_config.get_job_config(job_id).cross_silo_comm_config_dict.timeout_in_ms / 1000 + # timeout = fed_config.get_job_config(job_name).cross_silo_comm_config_dict.timeout_in_ms / 1000 # assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=timeout) assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=ready_timeout_second) logger.info("SenderProxyActor has successfully created.") diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py index cf6e6fb1..b2eba265 100644 --- a/fed/proxy/base_proxy.py +++ b/fed/proxy/base_proxy.py @@ -23,7 +23,7 @@ def __init__( self, addresses: Dict, party: str, - job_id: str, + job_name: str, tls_config: Dict, proxy_config: CrossSiloMessageConfig = None, ) -> None: @@ -31,7 +31,7 @@ def __init__( self._party = party self._tls_config = tls_config self._proxy_config = proxy_config - self._job_id = job_id + self._job_name = job_name @abc.abstractmethod async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): @@ -49,7 +49,7 @@ def __init__( self, listen_addr: str, party: str, - job_id: str, + job_name: str, tls_config: Dict, proxy_config: CrossSiloMessageConfig = None ) -> None: @@ -57,7 +57,7 @@ def __init__( self._party = party self._tls_config = tls_config self._proxy_config = proxy_config - self._job_id = job_id + self._job_name = job_name @abc.abstractmethod def start(self): diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 680f6df5..bb89b740 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -88,12 +88,12 @@ def __init__( self, cluster: Dict, party: str, - job_id: str, + job_name: str, tls_config: Dict, proxy_config: Dict = None ) -> None: proxy_config = GrpcCrossSiloMessageConfig.from_dict(proxy_config) - super().__init__(cluster, party, job_id, tls_config, proxy_config) + super().__init__(cluster, party, job_name, tls_config, proxy_config) self._grpc_metadata = proxy_config.http_header or {} self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) self._grpc_options.update(parse_grpc_options(self._proxy_config)) @@ -132,7 +132,7 @@ async def send( stub=self._stubs[dest_party], upstream_seq_id=upstream_seq_id, downstream_seq_id=downstream_seq_id, - job_id=self._job_id, + job_name=self._job_name, timeout=timeout, metadata=grpc_metadata, ) @@ -174,7 +174,7 @@ async def send_data_grpc( upstream_seq_id, downstream_seq_id, timeout, - job_id, + job_name, metadata=None, ): data = cloudpickle.dumps(data) @@ -182,7 +182,7 @@ async def send_data_grpc( data=data, upstream_seq_id=str(upstream_seq_id), downstream_seq_id=str(downstream_seq_id), - job_id=job_id, + job_name=job_name, ) # Waiting for the reply from downstream. response = await stub.SendData( @@ -202,12 +202,12 @@ def __init__( self, listen_addr: str, party: str, - job_id: str, + job_name: str, tls_config: Dict, proxy_config: Dict ) -> None: proxy_config = GrpcCrossSiloMessageConfig.from_dict(proxy_config) - super().__init__(listen_addr, party, job_id, tls_config, proxy_config) + super().__init__(listen_addr, party, job_name, tls_config, proxy_config) self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) self._grpc_options.update(parse_grpc_options(self._proxy_config)) @@ -228,7 +228,7 @@ async def start(self): f"self._all_data: {self._all_data}, " f"self._party: {self._party}, " f"self._lock: {self._lock}, " - f"self._job_id: {self._job_id}, " + f"self._job_name: {self._job_name}, " f"self._server_ready_future: {self._server_ready_future}, " f"self._tls_config: {self._tls_config}, " f"self._grpc_options: {fed_utils.dict2tuple(self._grpc_options)}") @@ -238,7 +238,7 @@ async def start(self): self._all_data, self._party, self._lock, - self._job_id, + self._job_name, self._server_ready_future, self._tls_config, fed_utils.dict2tuple(self._grpc_options), @@ -283,16 +283,16 @@ async def get_proxy_config(self): class SendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock, job_id): + def __init__(self, all_events, all_data, party, lock, job_name): self._events = all_events self._all_data = all_data self._party = party self._lock = lock - self._job_id = job_id + self._job_name = job_name async def SendData(self, request, context): - job_id = request.job_id - if job_id != self._job_id: + job_name = request.job_name + if job_name != self._job_name: return fed_pb2.SendDataResponse(result="ERROR") upstream_seq_id = request.upstream_seq_id downstream_seq_id = request.downstream_seq_id @@ -319,13 +319,13 @@ async def SendData(self, request, context): async def _run_grpc_server( - port, event, all_data, party, lock, job_id, + port, event, all_data, party, lock, job_name, server_ready_future, tls_config=None, grpc_options=None ): logger.info(f"ReceiveProxy binding port {port}, options: {grpc_options}...") server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( - SendDataService(event, all_data, party, lock, job_id), server + SendDataService(event, all_data, party, lock, job_name), server ) tls_enabled = fed_utils.tls_enabled(tls_config) diff --git a/tests/test_internal_kv.py b/tests/test_internal_kv.py index 326b6ab0..eeec5303 100644 --- a/tests/test_internal_kv.py +++ b/tests/test_internal_kv.py @@ -13,7 +13,7 @@ def run(party): 'bob': '127.0.0.1:11011', } assert compatible_utils.kv is None - fed.init(addresses=addresses, party=party, job_id="test_job_id") + fed.init(addresses=addresses, party=party, job_name="test_job_name") assert compatible_utils.kv assert not compatible_utils.kv.put(b"test_key", b"test_val") assert compatible_utils.kv.get(b"test_key") == b"test_val" @@ -21,7 +21,7 @@ def run(party): # Test that a prefix key name is added under the hood. assert ray_internal_kv._internal_kv_get(b"test_key") is None assert ray_internal_kv._internal_kv_get( - b"RAYFED#test_job_id#test_key") == b"test_val" + b"RAYFED#test_job_name#test_key") == b"test_val" time.sleep(5) fed.shutdown() From 4e6018611954e124032cda74c75c3fc2aa026016 Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 9 Aug 2023 14:46:27 +0800 Subject: [PATCH 07/16] add job_name to logger msg --- fed/_private/constants.py | 2 +- fed/api.py | 1 + fed/proxy/barriers.py | 6 ++++++ fed/utils.py | 17 +++++++++++++++-- 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 72ebb583..18c87fca 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -27,7 +27,7 @@ KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT" -RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa +RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S" diff --git a/fed/api.py b/fed/api.py index 051cce4a..33631d2f 100644 --- a/fed/api.py +++ b/fed/api.py @@ -151,6 +151,7 @@ def init( logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, party_val=_get_party(job_name), + job_name=job_name ) logger.info(f'Started rayfed with {cluster_config}') diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 05e43f28..f998cb51 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -92,6 +92,7 @@ def __init__( logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, party_val=party, + job_name=job_name, ) self._stats = {"send_op_count": 0} @@ -163,6 +164,7 @@ def __init__( logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, party_val=party, + job_name=job_name, ) self._stats = {"receive_op_count": 0} self._listening_address = listening_address @@ -303,6 +305,7 @@ def __init__( self, addresses: Dict, party: str, + job_name: str, tls_config: Dict = None, logging_level: str = None, proxy_cls: SenderReceiverProxy = None, @@ -312,6 +315,7 @@ def __init__( logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, party_val=party, + job_name=job_name, ) self._stats = {'send_op_count': 0, 'receive_op_count': 0} @@ -396,6 +400,7 @@ def _start_sender_receiver_proxy( logger.debug(f"Starting ReceiverProxyActor with options: {actor_options}") + job_name = get_global_context().job_name() global _SENDER_RECEIVER_PROXY_ACTOR global _RECEIVER_PROXY_ACTOR_NAME _SENDER_RECEIVER_PROXY_ACTOR = SenderReceiverProxyActor.options( @@ -403,6 +408,7 @@ def _start_sender_receiver_proxy( ).remote( addresses=addresses, party=party, + job_name=job_name, tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls, diff --git a/fed/utils.py b/fed/utils.py index 481fcfdc..0a82bc03 100644 --- a/fed/utils.py +++ b/fed/utils.py @@ -101,6 +101,7 @@ def setup_logger( date_format, log_dir=None, party_val=None, + job_name=None, ): class PartyRecordFilter(logging.Filter): def __init__(self, party_val=None) -> None: @@ -112,6 +113,16 @@ def filter(self, record) -> bool: record.party = self._party_val return True + class JobNameRecordFilter(logging.Filter): + def __init__(self, job_name=None) -> None: + self._job_name = job_name + super().__init__("JobNameRecordFilter") + + def filter(self, record) -> bool: + if not hasattr(record, "jobname"): + record.jobname = self._job_name + return True + logger = logging.getLogger() # Remove default handlers otherwise a msg will be printed twice. @@ -123,11 +134,13 @@ def filter(self, record) -> bool: logger.setLevel(logging_level) _formatter = logging.Formatter(fmt=logging_format, datefmt=date_format) - _filter = PartyRecordFilter(party_val=party_val) + _party_filter = PartyRecordFilter(party_val=party_val) + _job_name_fitler = JobNameRecordFilter(job_name=job_name) _customed_handler = logging.StreamHandler() _customed_handler.setFormatter(_formatter) - _customed_handler.addFilter(_filter) + _customed_handler.addFilter(_party_filter) + _customed_handler.addFilter(_job_name_fitler) logger.addHandler(_customed_handler) From d1b3e5c5ec39efcebc0339f12095bb247e751f29 Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 9 Aug 2023 18:12:00 +0800 Subject: [PATCH 08/16] fix UT Signed-off-by: paer --- fed/_private/compatible_utils.py | 3 ++- fed/_private/constants.py | 4 ++++ fed/_private/fed_call_holder.py | 2 ++ fed/api.py | 2 +- tests/test_transport_proxy.py | 29 +++++++++++++++++++++++------ 5 files changed, 32 insertions(+), 8 deletions(-) diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index 13de44dc..18a27754 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -72,7 +72,8 @@ def wrap_kv_key(job_name, key): """ if (type(key) == bytes): key = key.decode("utf-8") - return f"RAYFED#{job_name}#{key}".encode("utf-8") + return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format( + job_name, key).encode("utf-8") class AbstractInternalKv(abc.ABC): diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 18c87fca..d2529abe 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -32,3 +32,7 @@ RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S" RAY_VERSION_2_0_0_STR = "2.0.0" + +RAYFED_DEFAULT_JOB_NAME = "Anonymous" + +RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}" \ No newline at end of file diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index 376f6214..abb2be48 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -46,6 +46,8 @@ def __init__( submit_ray_task_func, options={}, ) -> None: + # Note(NKcqx): FedCallHolder will only be created in driver process, where + # the GlobalContext must has been initialized. job_name = get_global_context().job_name() self._party = fed_config.get_cluster_config(job_name).current_party self._node_party = node_party diff --git a/fed/api.py b/fed/api.py index 33631d2f..4f61c221 100644 --- a/fed/api.py +++ b/fed/api.py @@ -58,7 +58,7 @@ def init( sender_proxy_cls: SenderProxy = None, receiver_proxy_cls: ReceiverProxy = None, receiver_sender_proxy_cls: SenderReceiverProxy = None, - job_name: str = 'anonymous' + job_name: str = constants.RAYFED_DEFAULT_JOB_NAME ): """ Initialize a RayFed client. diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 7df25284..dbcea269 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -46,14 +46,15 @@ def test_n_to_1_transport(): N receivers to `get_data` from receiver proxy at that time. """ compatible_utils.init_ray(address='local') - + test_job_name = 'test_n_to_1_transport' + global_context.init_global_context(test_job_name) global_context.get_global_context().get_cleanup_manager().start() cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", } - compatible_utils._init_internal_kv() + compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) ) @@ -97,10 +98,14 @@ def test_n_to_1_transport(): class TestSendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock, expected_metadata): + def __init__(self, all_events, all_data, party, lock, + expected_metadata, expected_jobname): self.expected_metadata = expected_metadata or {} + self._expected_jobname = expected_jobname or "" async def SendData(self, request, context): + job_name = request.job_name + assert self._expected_jobname == job_name metadata = dict(context.invocation_metadata()) for k, v in self.expected_metadata.items(): assert k in metadata @@ -118,10 +123,13 @@ async def _test_run_grpc_server( lock, grpc_options=None, expected_metadata=None, + expected_jobname=None, ): server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( - TestSendDataService(event, all_data, party, lock, expected_metadata), server + TestSendDataService(event, all_data, party, lock, + expected_metadata, expected_jobname), + server ) server.add_insecure_port(f'[::]:{port}') await server.start() @@ -135,10 +143,12 @@ def __init__( listen_addr: str, party: str, expected_metadata: dict, + expected_jobname: str, ): self._listen_addr = listen_addr self._party = party self._expected_metadata = expected_metadata + self._expected_jobname = expected_jobname async def run_grpc_server(self): return await _test_run_grpc_server( @@ -148,6 +158,7 @@ async def run_grpc_server(self): self._party, None, expected_metadata=self._expected_metadata, + expected_jobname=self._expected_jobname ) async def is_ready(self): @@ -158,13 +169,16 @@ def _test_start_receiver_proxy( addresses: str, party: str, expected_metadata: dict, + expected_jobname: str, ): # Create RecevrProxyActor # Not that this is now a threaded actor. address = addresses[party] receiver_proxy_actor = TestReceiverProxyActor.options( name=receiver_proxy_actor_name(), max_concurrency=1000 - ).remote(listen_addr=address, party=party, expected_metadata=expected_metadata) + ).remote(listen_addr=address, party=party, + expected_metadata=expected_metadata, + expected_jobname=expected_jobname) receiver_proxy_actor.run_grpc_server.remote() assert ray.get(receiver_proxy_actor.is_ready.remote()) @@ -181,7 +195,9 @@ def test_send_grpc_with_meta(): job_config = { constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: config, } - compatible_utils._init_internal_kv() + test_job_name = 'test_send_grpc_with_meta' + global_context.init_global_context(test_job_name) + compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) ) @@ -195,6 +211,7 @@ def test_send_grpc_with_meta(): addresses, party_name, expected_metadata=metadata, + expected_jobname=test_job_name ) _start_sender_proxy( addresses, From 0c81850fee9f38888c2db55987be40537a391fbd Mon Sep 17 00:00:00 2001 From: paer Date: Thu, 10 Aug 2023 11:00:43 +0800 Subject: [PATCH 09/16] lint codes Signed-off-by: paer --- fed/_private/compatible_utils.py | 3 +-- fed/_private/constants.py | 2 +- fed/api.py | 6 +++--- fed/config.py | 6 ++++-- fed/proxy/barriers.py | 2 -- tests/test_internal_kv.py | 1 + tests/test_transport_proxy.py | 2 +- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index 18a27754..b0e70758 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -13,7 +13,6 @@ # limitations under the License. import abc -import cloudpickle import ray import fed._private.constants as fed_constants @@ -107,7 +106,7 @@ def reset(self): class InternalKv(AbstractInternalKv): """The internal kv class for non Ray client mode. """ - def __init__(self, job_name:str) -> None: + def __init__(self, job_name: str) -> None: super().__init__() self._job_name = job_name diff --git a/fed/_private/constants.py b/fed/_private/constants.py index d2529abe..5f9f5c47 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -35,4 +35,4 @@ RAYFED_DEFAULT_JOB_NAME = "Anonymous" -RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}" \ No newline at end of file +RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}" diff --git a/fed/api.py b/fed/api.py index 4f61c221..21eb59b4 100644 --- a/fed/api.py +++ b/fed/api.py @@ -225,21 +225,21 @@ def shutdown(): logger.info('Shutdowned rayfed.') -def _get_addresses(job_name: str=None): +def _get_addresses(job_name: str = None): """ Get the RayFed addresses configration. """ return fed_config.get_cluster_config(job_name).cluster_addresses -def _get_party(job_name: str=None): +def _get_party(job_name: str = None): """ A private util function to get the current party name. """ return fed_config.get_cluster_config(job_name).current_party -def _get_tls(job_name: str=None): +def _get_tls(job_name: str = None): """ Get the tls configurations on this party. """ diff --git a/fed/config.py b/fed/config.py index 509fb7b7..c3edded8 100644 --- a/fed/config.py +++ b/fed/config.py @@ -52,7 +52,8 @@ def get_cluster_config(job_name: str = None): """This function is not thread safe to use.""" global _cluster_config if _cluster_config is None: - assert job_name is not None, "Initializing internal kv need to provide job_name." + assert job_name is not None, \ + "Initializing internal kv need to provide job_name." compatible_utils._init_internal_kv(job_name) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG) _cluster_config = ClusterConfig(raw_dict) @@ -63,7 +64,8 @@ def get_job_config(job_name: str = None): """This config still acts like cluster config for now""" global _job_config if _job_config is None: - assert job_name is not None, "Initializing internal kv need to provide job_name." + assert job_name is not None, \ + "Initializing internal kv need to provide job_name." compatible_utils._init_internal_kv(job_name) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG) _job_config = JobConfig(raw_dict) diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index f998cb51..65510336 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -287,8 +287,6 @@ def _start_sender_proxy( logging_level=logging_level, proxy_cls=proxy_cls, ) - # timeout = fed_config.get_job_config(job_name).cross_silo_comm_config_dict.timeout_in_ms / 1000 - # assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=timeout) assert ray.get(_SENDER_PROXY_ACTOR.is_ready.remote(), timeout=ready_timeout_second) logger.info("SenderProxyActor has successfully created.") diff --git a/tests/test_internal_kv.py b/tests/test_internal_kv.py index eeec5303..6cc9318b 100644 --- a/tests/test_internal_kv.py +++ b/tests/test_internal_kv.py @@ -6,6 +6,7 @@ import fed._private.compatible_utils as compatible_utils import ray.experimental.internal_kv as ray_internal_kv + def run(party): compatible_utils.init_ray("local") addresses = { diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index dbcea269..ab26e221 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -108,7 +108,7 @@ async def SendData(self, request, context): assert self._expected_jobname == job_name metadata = dict(context.invocation_metadata()) for k, v in self.expected_metadata.items(): - assert k in metadata + assert k in metadata, f"{k} not in {metadata.keys()}" assert v == metadata[k] event = asyncio.Event() event.set() From bbd2f03f73746db8d6bcf79c3a3a63cb2a7ae23c Mon Sep 17 00:00:00 2001 From: paer Date: Thu, 10 Aug 2023 16:05:50 +0800 Subject: [PATCH 10/16] fix two job UT failure Signed-off-by: paer --- tests/test_retry_policy.py | 3 ++- tests/test_transport_proxy.py | 3 ++- tests/test_transport_proxy_tls.py | 5 +++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index b6ee9b55..574ce9fb 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -50,6 +50,7 @@ def run(): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } + test_job_name = 'test_retry_policy' fed.init( addresses=addresses, party='alice', @@ -60,7 +61,7 @@ def run(): }, ) - job_config = config.get_job_config() + job_config = config.get_job_config(test_job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict TestCase().assertDictEqual( cross_silo_comm_config['grpc_retry_policy'], retry_policy diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index ab26e221..4f008fcb 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -94,6 +94,7 @@ def test_n_to_1_transport(): global_context.get_global_context().get_cleanup_manager().graceful_stop() global_context.clear_global_context() + compatible_utils._clear_internal_kv() ray.shutdown() @@ -218,7 +219,7 @@ def test_send_grpc_with_meta(): party_name, logging_level='info', proxy_cls=GrpcSenderProxy, - proxy_config={}, + proxy_config=config, ) sent_objs = [] sent_obj = send(party_name, "data", 0, 1) diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index 1ea525eb..ec83c782 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -35,7 +35,7 @@ def test_n_to_1_transport(): N receivers to `get_data` from receiver proxy at that time. """ compatible_utils.init_ray(address='local') - + test_job_name = 'test_n_to_1_transport' cert_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "/tmp/rayfed/test-certs/" ) @@ -52,7 +52,7 @@ def test_n_to_1_transport(): } global_context.get_global_context().get_cleanup_manager().start() - compatible_utils._init_internal_kv() + compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) ) @@ -99,6 +99,7 @@ def test_n_to_1_transport(): global_context.get_global_context().get_cleanup_manager().graceful_stop() global_context.clear_global_context() + compatible_utils._clear_internal_kv() ray.shutdown() From c4c743400285a6ba3508c67417222ddfa5154b7e Mon Sep 17 00:00:00 2001 From: paer Date: Thu, 10 Aug 2023 16:44:10 +0800 Subject: [PATCH 11/16] rm debug code Signed-off-by: paer --- fed/grpc/pb3/fed_pb2.py | 14 ++++++++++++++ fed/grpc/pb3/fed_pb2_grpc.py | 14 ++++++++++++++ fed/proxy/grpc/grpc_proxy.py | 10 ---------- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/fed/grpc/pb3/fed_pb2.py b/fed/grpc/pb3/fed_pb2.py index 28669ff7..8f7d9924 100644 --- a/fed/grpc/pb3/fed_pb2.py +++ b/fed/grpc/pb3/fed_pb2.py @@ -1,3 +1,17 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: fed.proto diff --git a/fed/grpc/pb3/fed_pb2_grpc.py b/fed/grpc/pb3/fed_pb2_grpc.py index 830c8b88..55c674d4 100644 --- a/fed/grpc/pb3/fed_pb2_grpc.py +++ b/fed/grpc/pb3/fed_pb2_grpc.py @@ -1,3 +1,17 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index bb89b740..97a7412e 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -222,16 +222,6 @@ def __init__( async def start(self): port = self._listen_addr[self._listen_addr.index(':') + 1 :] try: - print(f"[Debug] params list: " - f"port: {port}, " - f"self._events: {self._events}, " - f"self._all_data: {self._all_data}, " - f"self._party: {self._party}, " - f"self._lock: {self._lock}, " - f"self._job_name: {self._job_name}, " - f"self._server_ready_future: {self._server_ready_future}, " - f"self._tls_config: {self._tls_config}, " - f"self._grpc_options: {fed_utils.dict2tuple(self._grpc_options)}") await _run_grpc_server( port, self._events, From 3d923e47ac31880bd3b817c73d155bd64a0d8604 Mon Sep 17 00:00:00 2001 From: paer Date: Thu, 10 Aug 2023 19:44:08 +0800 Subject: [PATCH 12/16] fix UT that missing job_name Signed-off-by: paer --- fed/_private/compatible_utils.py | 2 +- tests/test_transport_proxy_tls.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index b0e70758..5f8a055a 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -69,7 +69,7 @@ def _get_gcs_address_from_ray_worker(): def wrap_kv_key(job_name, key): """Add an prefix to the key to avoid conflict with other jobs. """ - if (type(key) == bytes): + if (isinstance(key, bytes)): key = key.decode("utf-8") return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format( job_name, key).encode("utf-8") diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index ec83c782..e8c10523 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -50,7 +50,7 @@ def test_n_to_1_transport(): constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: tls_config, } - + global_context.init_global_context(test_job_name) global_context.get_global_context().get_cleanup_manager().start() compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( From d643823e9935a3149cb82f5d3cf67e1f3b296204 Mon Sep 17 00:00:00 2001 From: paer Date: Fri, 11 Aug 2023 01:17:01 +0800 Subject: [PATCH 13/16] add UT to test ignore msg Signed-off-by: paer --- tests/multi-jobs/test_job_msg_ignore.py | 121 ++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 tests/multi-jobs/test_job_msg_ignore.py diff --git a/tests/multi-jobs/test_job_msg_ignore.py b/tests/multi-jobs/test_job_msg_ignore.py new file mode 100644 index 00000000..a2085e7e --- /dev/null +++ b/tests/multi-jobs/test_job_msg_ignore.py @@ -0,0 +1,121 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import fed +import ray +import grpc +import pytest +import fed.utils as fed_utils +import fed._private.compatible_utils as compatible_utils +from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, send_data_grpc +if compatible_utils._compare_version_strings( + fed_utils.get_package_version('protobuf'), '4.0.0'): + from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc +else: + from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc + + +class TestGrpcSenderProxy(GrpcSenderProxy): + async def send( + self, + dest_party, + data, + upstream_seq_id, + downstream_seq_id): + dest_addr = self._addresses[dest_party] + grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) + if dest_party not in self._stubs: + channel = grpc.aio.insecure_channel( + dest_addr, options=grpc_channel_options) + stub = fed_pb2_grpc.GrpcServiceStub(channel) + self._stubs[dest_party] = stub + + timeout = self._proxy_config.timeout_in_ms / 1000 + response = await send_data_grpc( + data=data, + stub=self._stubs[dest_party], + upstream_seq_id=upstream_seq_id, + downstream_seq_id=downstream_seq_id, + job_name=self._job_name, + timeout=timeout, + metadata=grpc_metadata, + ) + assert response == "ERROR" + # So that process can exit + raise RuntimeError() + + +@fed.remote +class MyActor: + def __init__(self, party, data): + self.__data = data + self._party = party + + def f(self): + return f"f({self._party}, ip is {ray.util.get_node_ip_address()})" + + +@fed.remote +def agg_fn(obj1, obj2): + return f"agg-{obj1}-{obj2}" + + +addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', +} + + +def run(party, job_name): + ray.init(address='local') + fed.init(addresses=addresses, + party=party, + job_name=job_name, + sender_proxy_cls=TestGrpcSenderProxy, + config={ + 'cross_silo_comm': { + 'exit_on_sending_failure': True, + }}) + # 'bob' only needs to start the proxy actors + if party == 'alice': + ds1, ds2 = [123, 789] + actor_alice = MyActor.party("alice").remote(party, ds1) + actor_bob = MyActor.party("bob").remote(party, ds2) + + obj_alice_f = actor_alice.f.remote() + obj_bob_f = actor_bob.f.remote() + + obj = agg_fn.party("bob").remote(obj_alice_f, obj_bob_f) + fed.get(obj) + fed.shutdown() + ray.shutdown() + import time + # Wait for SIGTERM as failure on sending. + time.sleep(86400) + + +def test_multi_job_ignore_msg(): + p_alice = multiprocessing.Process(target=run, args=('alice', 'job1')) + p_bob = multiprocessing.Process(target=run, args=('bob', 'job2')) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-sv", __file__])) From 2557c5367258af1b107d09c08715f6701654b270 Mon Sep 17 00:00:00 2001 From: NKcqx <892670992@qq.com> Date: Fri, 11 Aug 2023 14:24:42 +0800 Subject: [PATCH 14/16] update pb4 Signed-off-by: NKcqx <892670992@qq.com> --- fed/grpc/pb4/fed_pb2.py | 18 +++++++++--------- fed/grpc/pb4/fed_pb2_grpc.py | 14 +++++++------- fed/proxy/grpc/grpc_proxy.py | 2 +- tests/multi-jobs/test_job_msg_ignore.py | 4 ++-- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/fed/grpc/pb4/fed_pb2.py b/fed/grpc/pb4/fed_pb2.py index 3ea5f51a..bd69d268 100644 --- a/fed/grpc/pb4/fed_pb2.py +++ b/fed/grpc/pb4/fed_pb2.py @@ -14,7 +14,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: fed_4.proto +# source: fed.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -27,19 +27,19 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x66\x65\x64_4.proto\"S\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fed_4_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fed_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\200\001\001' - _globals['_SENDDATAREQUEST']._serialized_start=15 - _globals['_SENDDATAREQUEST']._serialized_end=98 - _globals['_SENDDATARESPONSE']._serialized_start=100 - _globals['_SENDDATARESPONSE']._serialized_end=134 - _globals['_GRPCSERVICE']._serialized_start=136 - _globals['_GRPCSERVICE']._serialized_end=200 + _globals['_SENDDATAREQUEST']._serialized_start=13 + _globals['_SENDDATAREQUEST']._serialized_end=114 + _globals['_SENDDATARESPONSE']._serialized_start=116 + _globals['_SENDDATARESPONSE']._serialized_end=150 + _globals['_GRPCSERVICE']._serialized_start=152 + _globals['_GRPCSERVICE']._serialized_end=216 # @@protoc_insertion_point(module_scope) diff --git a/fed/grpc/pb4/fed_pb2_grpc.py b/fed/grpc/pb4/fed_pb2_grpc.py index a8cfbff0..a76c956e 100644 --- a/fed/grpc/pb4/fed_pb2_grpc.py +++ b/fed/grpc/pb4/fed_pb2_grpc.py @@ -16,7 +16,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -import fed.grpc.pb4.fed_pb2 as fed__4__pb2 +import fed.grpc.pb4.fed_pb2 as fed__pb2 class GrpcServiceStub(object): @@ -30,8 +30,8 @@ def __init__(self, channel): """ self.SendData = channel.unary_unary( '/GrpcService/SendData', - request_serializer=fed__4__pb2.SendDataRequest.SerializeToString, - response_deserializer=fed__4__pb2.SendDataResponse.FromString, + request_serializer=fed__pb2.SendDataRequest.SerializeToString, + response_deserializer=fed__pb2.SendDataResponse.FromString, ) @@ -49,8 +49,8 @@ def add_GrpcServiceServicer_to_server(servicer, server): rpc_method_handlers = { 'SendData': grpc.unary_unary_rpc_method_handler( servicer.SendData, - request_deserializer=fed__4__pb2.SendDataRequest.FromString, - response_serializer=fed__4__pb2.SendDataResponse.SerializeToString, + request_deserializer=fed__pb2.SendDataRequest.FromString, + response_serializer=fed__pb2.SendDataResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -74,7 +74,7 @@ def SendData(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/GrpcService/SendData', - fed__4__pb2.SendDataRequest.SerializeToString, - fed__4__pb2.SendDataResponse.FromString, + fed__pb2.SendDataRequest.SerializeToString, + fed__pb2.SendDataResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 97a7412e..d83e178c 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -283,7 +283,7 @@ def __init__(self, all_events, all_data, party, lock, job_name): async def SendData(self, request, context): job_name = request.job_name if job_name != self._job_name: - return fed_pb2.SendDataResponse(result="ERROR") + return fed_pb2.SendDataResponse(result=f"JobName mis-match, expected {self._job_name}, got {job_name}.") upstream_seq_id = request.upstream_seq_id downstream_seq_id = request.downstream_seq_id logger.debug( diff --git a/tests/multi-jobs/test_job_msg_ignore.py b/tests/multi-jobs/test_job_msg_ignore.py index a2085e7e..6e2986b8 100644 --- a/tests/multi-jobs/test_job_msg_ignore.py +++ b/tests/multi-jobs/test_job_msg_ignore.py @@ -43,7 +43,7 @@ async def send( self._stubs[dest_party] = stub timeout = self._proxy_config.timeout_in_ms / 1000 - response = await send_data_grpc( + response: str = await send_data_grpc( data=data, stub=self._stubs[dest_party], upstream_seq_id=upstream_seq_id, @@ -52,7 +52,7 @@ async def send( timeout=timeout, metadata=grpc_metadata, ) - assert response == "ERROR" + assert "JobName mis-match" in response # So that process can exit raise RuntimeError() From 0b787cf18814f4ba819b5967dda944c13432873a Mon Sep 17 00:00:00 2001 From: paer Date: Fri, 11 Aug 2023 17:38:58 +0800 Subject: [PATCH 15/16] more clear comment Signed-off-by: paer --- fed/api.py | 9 +++++---- fed/proxy/grpc/grpc_proxy.py | 6 +++++- ...st_job_msg_ignore.py => test_ignore_other_job_msg.py} | 2 +- tests/test_transport_proxy.py | 5 +++-- 4 files changed, 14 insertions(+), 8 deletions(-) rename tests/multi-jobs/{test_job_msg_ignore.py => test_ignore_other_job_msg.py} (99%) diff --git a/fed/api.py b/fed/api.py index 21eb59b4..4d0bab2f 100644 --- a/fed/api.py +++ b/fed/api.py @@ -96,11 +96,12 @@ def init( } logging_level: optional; the logging level, could be `debug`, `info`, `warning`, `error`, `critical`, not case sensititive. - job_name: optional; the job id of the current job. Note that, the job id + job_name: optional; the job name of the current job. Note that, the job name must be identical in all parties, otherwise, messages will be ignored - because of the job id mismatch. If the job id is not provided, messages - of this job will not be distinguished from other jobs, which should - only be used in the single job scenario or simulation mode. + because of the job name mismatch. If the job name is not provided, an + default fixed name will be assigned, therefore messages of all anonymous + jobs will be mixed together, which should only be used in the single job + scenario or test mode. Examples: >>> import fed >>> import ray diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index d83e178c..743fe298 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -283,7 +283,11 @@ def __init__(self, all_events, all_data, party, lock, job_name): async def SendData(self, request, context): job_name = request.job_name if job_name != self._job_name: - return fed_pb2.SendDataResponse(result=f"JobName mis-match, expected {self._job_name}, got {job_name}.") + logger.warning(f"Receive data from job {job_name}, ignore it. " + f"The reason may be that the ReceiverProxy is listening " + f"on the same address with that job.") + return fed_pb2.SendDataResponse( + result=f"JobName mis-match, expected {self._job_name}, got {job_name}.") upstream_seq_id = request.upstream_seq_id downstream_seq_id = request.downstream_seq_id logger.debug( diff --git a/tests/multi-jobs/test_job_msg_ignore.py b/tests/multi-jobs/test_ignore_other_job_msg.py similarity index 99% rename from tests/multi-jobs/test_job_msg_ignore.py rename to tests/multi-jobs/test_ignore_other_job_msg.py index 6e2986b8..cbb45ac0 100644 --- a/tests/multi-jobs/test_job_msg_ignore.py +++ b/tests/multi-jobs/test_ignore_other_job_msg.py @@ -106,7 +106,7 @@ def run(party, job_name): time.sleep(86400) -def test_multi_job_ignore_msg(): +def test_ignore_other_job_msg(): p_alice = multiprocessing.Process(target=run, args=('alice', 'job1')) p_bob = multiprocessing.Process(target=run, args=('bob', 'job2')) p_alice.start() diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 4f008fcb..368223a4 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -109,7 +109,8 @@ async def SendData(self, request, context): assert self._expected_jobname == job_name metadata = dict(context.invocation_metadata()) for k, v in self.expected_metadata.items(): - assert k in metadata, f"{k} not in {metadata.keys()}" + assert k in metadata, \ + f"The expected key {k} is not in the metadata keys: {metadata.keys()}." assert v == metadata[k] event = asyncio.Event() event.set() @@ -153,7 +154,7 @@ def __init__( async def run_grpc_server(self): return await _test_run_grpc_server( - self._listen_addr[self._listen_addr.index(':') + 1 :], + self._listen_addr[self._listen_addr.index(':') + 1:], None, None, self._party, From 333f59211770a1b18c9bc47d635da65d0c33914c Mon Sep 17 00:00:00 2001 From: paer Date: Mon, 14 Aug 2023 17:53:25 +0800 Subject: [PATCH 16/16] restrict kv key type to str only Signed-off-by: paer --- fed/_private/compatible_utils.py | 9 +++++---- fed/_private/constants.py | 6 +++--- tests/test_internal_kv.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index 5f8a055a..83d2a098 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -66,13 +66,14 @@ def _get_gcs_address_from_ray_worker(): return ray.worker._global_node.gcs_address -def wrap_kv_key(job_name, key): +def wrap_kv_key(job_name, key: str): """Add an prefix to the key to avoid conflict with other jobs. """ - if (isinstance(key, bytes)): - key = key.decode("utf-8") + assert isinstance(key, str), \ + f"The key of KV data must be `str` type, got {type(key)}." + return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format( - job_name, key).encode("utf-8") + job_name, key) class AbstractInternalKv(abc.ABC): diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 5f9f5c47..2dfc904f 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -13,11 +13,11 @@ # limitations under the License. -KEY_OF_CLUSTER_CONFIG = b"CLUSTER_CONFIG" +KEY_OF_CLUSTER_CONFIG = "CLUSTER_CONFIG" -KEY_OF_JOB_CONFIG = b"JOB_CONFIG" +KEY_OF_JOB_CONFIG = "JOB_CONFIG" -KEY_OF_GRPC_METADATA = b"GRPC_METADATA" +KEY_OF_GRPC_METADATA = "GRPC_METADATA" KEY_OF_CLUSTER_ADDRESSES = "CLUSTER_ADDRESSES" diff --git a/tests/test_internal_kv.py b/tests/test_internal_kv.py index 6cc9318b..bb048239 100644 --- a/tests/test_internal_kv.py +++ b/tests/test_internal_kv.py @@ -16,8 +16,8 @@ def run(party): assert compatible_utils.kv is None fed.init(addresses=addresses, party=party, job_name="test_job_name") assert compatible_utils.kv - assert not compatible_utils.kv.put(b"test_key", b"test_val") - assert compatible_utils.kv.get(b"test_key") == b"test_val" + assert not compatible_utils.kv.put("test_key", b"test_val") + assert compatible_utils.kv.get("test_key") == b"test_val" # Test that a prefix key name is added under the hood. assert ray_internal_kv._internal_kv_get(b"test_key") is None