diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index 04fe553..83d2a09 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -66,6 +66,16 @@ def _get_gcs_address_from_ray_worker(): return ray.worker._global_node.gcs_address +def wrap_kv_key(job_name, key: str): + """Add an prefix to the key to avoid conflict with other jobs. + """ + assert isinstance(key, str), \ + f"The key of KV data must be `str` type, got {type(key)}." + + return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format( + job_name, key) + + class AbstractInternalKv(abc.ABC): """ An abstract class that represents for bridging Ray internal kv in both Ray client mode and non Ray client mode. @@ -97,8 +107,9 @@ def reset(self): class InternalKv(AbstractInternalKv): """The internal kv class for non Ray client mode. """ - def __init__(self) -> None: + def __init__(self, job_name: str) -> None: super().__init__() + self._job_name = job_name def initialize(self): try: @@ -114,13 +125,16 @@ def initialize(self): return ray_internal_kv._initialize_internal_kv(gcs_client) def put(self, k, v): - return ray_internal_kv._internal_kv_put(k, v) + return ray_internal_kv._internal_kv_put( + wrap_kv_key(self._job_name, k), v) def get(self, k): - return ray_internal_kv._internal_kv_get(k) + return ray_internal_kv._internal_kv_get( + wrap_kv_key(self._job_name, k)) def delete(self, k): - return ray_internal_kv._internal_kv_del(k) + return ray_internal_kv._internal_kv_del( + wrap_kv_key(self._job_name, k)) def reset(self): return ray_internal_kv._internal_kv_reset() @@ -157,17 +171,17 @@ def reset(self): return ray.get(o) -def _init_internal_kv(): +def _init_internal_kv(job_name): """An internal API that initialize the internal kv object.""" global kv if kv is None: from ray._private.client_mode_hook import is_client_mode_enabled if is_client_mode_enabled: kv_actor = ray.remote(InternalKv).options( - name="_INTERNAL_KV_ACTOR").remote() + name="_INTERNAL_KV_ACTOR").remote(job_name) response = kv_actor._ping.remote() ray.get(response) - kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv() + kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name) kv.initialize() diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 72ebb58..2dfc904 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -13,11 +13,11 @@ # limitations under the License. -KEY_OF_CLUSTER_CONFIG = b"CLUSTER_CONFIG" +KEY_OF_CLUSTER_CONFIG = "CLUSTER_CONFIG" -KEY_OF_JOB_CONFIG = b"JOB_CONFIG" +KEY_OF_JOB_CONFIG = "JOB_CONFIG" -KEY_OF_GRPC_METADATA = b"GRPC_METADATA" +KEY_OF_GRPC_METADATA = "GRPC_METADATA" KEY_OF_CLUSTER_ADDRESSES = "CLUSTER_ADDRESSES" @@ -27,8 +27,12 @@ KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT" -RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa +RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S" RAY_VERSION_2_0_0_STR = "2.0.0" + +RAYFED_DEFAULT_JOB_NAME = "Anonymous" + +RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}" diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index 1fc339e..abb2be4 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -46,7 +46,10 @@ def __init__( submit_ray_task_func, options={}, ) -> None: - self._party = fed_config.get_cluster_config().current_party + # Note(NKcqx): FedCallHolder will only be created in driver process, where + # the GlobalContext must has been initialized. + job_name = get_global_context().job_name() + self._party = fed_config.get_cluster_config(job_name).current_party self._node_party = node_party self._options = options self._submit_ray_task_func = submit_ray_task_func diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index b0367b1..673bd00 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -16,7 +16,8 @@ class GlobalContext: - def __init__(self) -> None: + def __init__(self, job_name: str) -> None: + self._job_name = job_name self._seq_count = 0 self._cleanup_manager = CleanupManager() @@ -27,10 +28,19 @@ def next_seq_id(self) -> int: def get_cleanup_manager(self) -> CleanupManager: return self._cleanup_manager + def job_name(self) -> str: + return self._job_name + _global_context = None +def init_global_context(job_name: str) -> None: + global _global_context + if _global_context is None: + _global_context = GlobalContext(job_name) + + def get_global_context(): global _global_context if _global_context is None: diff --git a/fed/api.py b/fed/api.py index 19950db..4d0bab2 100644 --- a/fed/api.py +++ b/fed/api.py @@ -26,7 +26,11 @@ from fed._private import constants from fed._private.fed_actor import FedActorHandle from fed._private.fed_call_holder import FedCallHolder -from fed._private.global_context import get_global_context, clear_global_context +from fed._private.global_context import ( + init_global_context, + get_global_context, + clear_global_context +) from fed.proxy.barriers import ( ping_others, recv, @@ -54,6 +58,7 @@ def init( sender_proxy_cls: SenderProxy = None, receiver_proxy_cls: ReceiverProxy = None, receiver_sender_proxy_cls: SenderReceiverProxy = None, + job_name: str = constants.RAYFED_DEFAULT_JOB_NAME ): """ Initialize a RayFed client. @@ -91,7 +96,12 @@ def init( } logging_level: optional; the logging level, could be `debug`, `info`, `warning`, `error`, `critical`, not case sensititive. - + job_name: optional; the job name of the current job. Note that, the job name + must be identical in all parties, otherwise, messages will be ignored + because of the job name mismatch. If the job name is not provided, an + default fixed name will be assigned, therefore messages of all anonymous + jobs will be mixed together, which should only be used in the single job + scenario or test mode. Examples: >>> import fed >>> import ray @@ -109,7 +119,7 @@ def init( assert party in addresses, f"Party {party} is not in the addresses {addresses}." fed_utils.validate_addresses(addresses) - + init_global_context(job_name=job_name) tls_config = {} if tls_config is None else tls_config if tls_config: assert ( @@ -118,7 +128,7 @@ def init( cross_silo_comm_dict = config.get("cross_silo_comm", {}) # A Ray private accessing, should be replaced in public API. - compatible_utils._init_internal_kv() + compatible_utils._init_internal_kv(job_name) cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: addresses, @@ -141,7 +151,8 @@ def init( logging_level=logging_level, logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, - party_val=_get_party(), + party_val=_get_party(job_name), + job_name=job_name ) logger.info(f'Started rayfed with {cluster_config}') @@ -215,25 +226,25 @@ def shutdown(): logger.info('Shutdowned rayfed.') -def _get_addresses(): +def _get_addresses(job_name: str = None): """ Get the RayFed addresses configration. """ - return fed_config.get_cluster_config().cluster_addresses + return fed_config.get_cluster_config(job_name).cluster_addresses -def _get_party(): +def _get_party(job_name: str = None): """ A private util function to get the current party name. """ - return fed_config.get_cluster_config().current_party + return fed_config.get_cluster_config(job_name).current_party -def _get_tls(): +def _get_tls(job_name: str = None): """ Get the tls configurations on this party. """ - return fed_config.get_cluster_config().tls_config + return fed_config.get_cluster_config(job_name).tls_config class FedRemoteFunction: @@ -287,11 +298,12 @@ def options(self, **options): def remote(self, *cls_args, **cls_kwargs): fed_class_task_id = get_global_context().next_seq_id() + job_name = get_global_context().job_name() fed_actor_handle = FedActorHandle( fed_class_task_id, - _get_addresses(), + _get_addresses(job_name), self._cls, - _get_party(), + _get_party(job_name), self._party, self._options, ) @@ -340,8 +352,9 @@ def get( # A fake fed_task_id for a `fed.get()` operator. This is useful # to help contruct the whole DAG within `fed.get`. fake_fed_task_id = get_global_context().next_seq_id() - addresses = _get_addresses() - current_party = _get_party() + job_name = get_global_context().job_name() + addresses = _get_addresses(job_name) + current_party = _get_party(job_name) is_individual_id = isinstance(fed_objects, FedObject) if is_individual_id: fed_objects = [fed_objects] @@ -396,7 +409,8 @@ def get( def kill(actor: FedActorHandle, *, no_restart=True): - current_party = _get_party() + job_name = get_global_context().job_name() + current_party = _get_party(job_name) if actor._node_party == current_party: handler = actor._actor_handle ray.kill(handler, no_restart=no_restart) diff --git a/fed/config.py b/fed/config.py index bb92e11..c3edded 100644 --- a/fed/config.py +++ b/fed/config.py @@ -48,23 +48,25 @@ def cross_silo_comm_config_dict(self) -> Dict: _job_config = None -def get_cluster_config(): +def get_cluster_config(job_name: str = None): """This function is not thread safe to use.""" global _cluster_config if _cluster_config is None: - compatible_utils._init_internal_kv() - compatible_utils.kv.initialize() + assert job_name is not None, \ + "Initializing internal kv need to provide job_name." + compatible_utils._init_internal_kv(job_name) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG) _cluster_config = ClusterConfig(raw_dict) return _cluster_config -def get_job_config(): +def get_job_config(job_name: str = None): """This config still acts like cluster config for now""" global _job_config if _job_config is None: - compatible_utils._init_internal_kv() - compatible_utils.kv.initialize() + assert job_name is not None, \ + "Initializing internal kv need to provide job_name." + compatible_utils._init_internal_kv(job_name) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG) _job_config = JobConfig(raw_dict) return _job_config diff --git a/fed/grpc/fed.proto b/fed/grpc/fed.proto index f19a8ca..d0e2ee6 100644 --- a/fed/grpc/fed.proto +++ b/fed/grpc/fed.proto @@ -10,6 +10,7 @@ message SendDataRequest { bytes data = 1; string upstream_seq_id = 2; string downstream_seq_id = 3; + string job_name = 4; }; message SendDataResponse { diff --git a/fed/grpc/pb3/fed_pb2.py b/fed/grpc/pb3/fed_pb2.py index ffc235e..8f7d992 100644 --- a/fed/grpc/pb3/fed_pb2.py +++ b/fed/grpc/pb3/fed_pb2.py @@ -17,6 +17,7 @@ # source: fed.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database @@ -27,99 +28,12 @@ -DESCRIPTOR = _descriptor.FileDescriptor( - name='fed.proto', - package='', - syntax='proto3', - serialized_options=b'\200\001\001', - create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\tfed.proto\"S\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3' -) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') - -_SENDDATAREQUEST = _descriptor.Descriptor( - name='SendDataRequest', - full_name='SendDataRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='data', full_name='SendDataRequest.data', index=0, - number=1, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=b"", - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='upstream_seq_id', full_name='SendDataRequest.upstream_seq_id', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='downstream_seq_id', full_name='SendDataRequest.downstream_seq_id', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=13, - serialized_end=96, -) - - -_SENDDATARESPONSE = _descriptor.Descriptor( - name='SendDataResponse', - full_name='SendDataResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='result', full_name='SendDataResponse.result', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=98, - serialized_end=132, -) - -DESCRIPTOR.message_types_by_name['SendDataRequest'] = _SENDDATAREQUEST -DESCRIPTOR.message_types_by_name['SendDataResponse'] = _SENDDATARESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - +_SENDDATAREQUEST = DESCRIPTOR.message_types_by_name['SendDataRequest'] +_SENDDATARESPONSE = DESCRIPTOR.message_types_by_name['SendDataResponse'] SendDataRequest = _reflection.GeneratedProtocolMessageType('SendDataRequest', (_message.Message,), { 'DESCRIPTOR' : _SENDDATAREQUEST, '__module__' : 'fed_pb2' @@ -134,32 +48,15 @@ }) _sym_db.RegisterMessage(SendDataResponse) - -DESCRIPTOR._options = None - -_GRPCSERVICE = _descriptor.ServiceDescriptor( - name='GrpcService', - full_name='GrpcService', - file=DESCRIPTOR, - index=0, - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_start=134, - serialized_end=198, - methods=[ - _descriptor.MethodDescriptor( - name='SendData', - full_name='GrpcService.SendData', - index=0, - containing_service=None, - input_type=_SENDDATAREQUEST, - output_type=_SENDDATARESPONSE, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), -]) -_sym_db.RegisterServiceDescriptor(_GRPCSERVICE) - -DESCRIPTOR.services_by_name['GrpcService'] = _GRPCSERVICE - +_GRPCSERVICE = DESCRIPTOR.services_by_name['GrpcService'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\200\001\001' + _SENDDATAREQUEST._serialized_start=13 + _SENDDATAREQUEST._serialized_end=114 + _SENDDATARESPONSE._serialized_start=116 + _SENDDATARESPONSE._serialized_end=150 + _GRPCSERVICE._serialized_start=152 + _GRPCSERVICE._serialized_end=216 # @@protoc_insertion_point(module_scope) diff --git a/fed/grpc/pb4/fed_pb2.py b/fed/grpc/pb4/fed_pb2.py index 3ea5f51..bd69d26 100644 --- a/fed/grpc/pb4/fed_pb2.py +++ b/fed/grpc/pb4/fed_pb2.py @@ -14,7 +14,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: fed_4.proto +# source: fed.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -27,19 +27,19 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x66\x65\x64_4.proto\"S\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fed_4_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fed_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\200\001\001' - _globals['_SENDDATAREQUEST']._serialized_start=15 - _globals['_SENDDATAREQUEST']._serialized_end=98 - _globals['_SENDDATARESPONSE']._serialized_start=100 - _globals['_SENDDATARESPONSE']._serialized_end=134 - _globals['_GRPCSERVICE']._serialized_start=136 - _globals['_GRPCSERVICE']._serialized_end=200 + _globals['_SENDDATAREQUEST']._serialized_start=13 + _globals['_SENDDATAREQUEST']._serialized_end=114 + _globals['_SENDDATARESPONSE']._serialized_start=116 + _globals['_SENDDATARESPONSE']._serialized_end=150 + _globals['_GRPCSERVICE']._serialized_start=152 + _globals['_GRPCSERVICE']._serialized_end=216 # @@protoc_insertion_point(module_scope) diff --git a/fed/grpc/pb4/fed_pb2_grpc.py b/fed/grpc/pb4/fed_pb2_grpc.py index a8cfbff..a76c956 100644 --- a/fed/grpc/pb4/fed_pb2_grpc.py +++ b/fed/grpc/pb4/fed_pb2_grpc.py @@ -16,7 +16,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -import fed.grpc.pb4.fed_pb2 as fed__4__pb2 +import fed.grpc.pb4.fed_pb2 as fed__pb2 class GrpcServiceStub(object): @@ -30,8 +30,8 @@ def __init__(self, channel): """ self.SendData = channel.unary_unary( '/GrpcService/SendData', - request_serializer=fed__4__pb2.SendDataRequest.SerializeToString, - response_deserializer=fed__4__pb2.SendDataResponse.FromString, + request_serializer=fed__pb2.SendDataRequest.SerializeToString, + response_deserializer=fed__pb2.SendDataResponse.FromString, ) @@ -49,8 +49,8 @@ def add_GrpcServiceServicer_to_server(servicer, server): rpc_method_handlers = { 'SendData': grpc.unary_unary_rpc_method_handler( servicer.SendData, - request_deserializer=fed__4__pb2.SendDataRequest.FromString, - response_serializer=fed__4__pb2.SendDataResponse.SerializeToString, + request_deserializer=fed__pb2.SendDataRequest.FromString, + response_serializer=fed__pb2.SendDataResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -74,7 +74,7 @@ def SendData(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/GrpcService/SendData', - fed__4__pb2.SendDataRequest.SerializeToString, - fed__4__pb2.SendDataResponse.FromString, + fed__pb2.SendDataRequest.SerializeToString, + fed__pb2.SendDataResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index c478317..6551033 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -82,6 +82,7 @@ def __init__( self, addresses: Dict, party: str, + job_name: str, tls_config: Dict = None, logging_level: str = None, proxy_cls=None, @@ -91,17 +92,18 @@ def __init__( logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, party_val=party, + job_name=job_name, ) self._stats = {"send_op_count": 0} self._addresses = addresses self._party = party + self._job_name = job_name self._tls_config = tls_config - job_config = fed_config.get_job_config() + job_config = fed_config.get_job_config(job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance: SenderProxy = proxy_cls( - addresses, party, tls_config, cross_silo_comm_config - ) + addresses, party, job_name, tls_config, cross_silo_comm_config) async def is_ready(self): res = await self._proxy_instance.is_ready() @@ -152,6 +154,7 @@ def __init__( self, listening_address: str, party: str, + job_name: str, logging_level: str, tls_config=None, proxy_cls=None, @@ -161,16 +164,17 @@ def __init__( logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, party_val=party, + job_name=job_name, ) self._stats = {"receive_op_count": 0} self._listening_address = listening_address self._party = party + self._job_name = job_name self._tls_config = tls_config - job_config = fed_config.get_job_config() + job_config = fed_config.get_job_config(job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance: ReceiverProxy = proxy_cls( - listening_address, party, tls_config, cross_silo_comm_config - ) + listening_address, party, job_name, tls_config, cross_silo_comm_config) async def start(self): await self._proxy_instance.start() @@ -223,6 +227,7 @@ def _start_receiver_proxy( ).remote( listening_address=addresses[party], party=party, + job_name=get_global_context().job_name(), tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls, @@ -273,9 +278,11 @@ def _start_sender_proxy( name=_SENDER_PROXY_ACTOR_NAME, **actor_options ) + job_name = get_global_context().job_name() _SENDER_PROXY_ACTOR = _SENDER_PROXY_ACTOR.remote( addresses=addresses, party=party, + job_name=job_name, tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls, @@ -296,6 +303,7 @@ def __init__( self, addresses: Dict, party: str, + job_name: str, tls_config: Dict = None, logging_level: str = None, proxy_cls: SenderReceiverProxy = None, @@ -305,6 +313,7 @@ def __init__( logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, party_val=party, + job_name=job_name, ) self._stats = {'send_op_count': 0, 'receive_op_count': 0} @@ -389,6 +398,7 @@ def _start_sender_receiver_proxy( logger.debug(f"Starting ReceiverProxyActor with options: {actor_options}") + job_name = get_global_context().job_name() global _SENDER_RECEIVER_PROXY_ACTOR global _RECEIVER_PROXY_ACTOR_NAME _SENDER_RECEIVER_PROXY_ACTOR = SenderReceiverProxyActor.options( @@ -396,6 +406,7 @@ def _start_sender_receiver_proxy( ).remote( addresses=addresses, party=party, + job_name=job_name, tls_config=tls_config, logging_level=logging_level, proxy_cls=proxy_cls, diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py index 2c5f92c..b2eba26 100644 --- a/fed/proxy/base_proxy.py +++ b/fed/proxy/base_proxy.py @@ -23,6 +23,7 @@ def __init__( self, addresses: Dict, party: str, + job_name: str, tls_config: Dict, proxy_config: CrossSiloMessageConfig = None, ) -> None: @@ -30,6 +31,7 @@ def __init__( self._party = party self._tls_config = tls_config self._proxy_config = proxy_config + self._job_name = job_name @abc.abstractmethod async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): @@ -47,13 +49,15 @@ def __init__( self, listen_addr: str, party: str, + job_name: str, tls_config: Dict, - proxy_config: CrossSiloMessageConfig = None, + proxy_config: CrossSiloMessageConfig = None ) -> None: self._listen_addr = listen_addr self._party = party self._tls_config = tls_config self._proxy_config = proxy_config + self._job_name = job_name @abc.abstractmethod def start(self): diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 4feb1ca..743fe29 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -21,7 +21,6 @@ import json from typing import Dict - import fed.utils as fed_utils from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig @@ -89,11 +88,12 @@ def __init__( self, cluster: Dict, party: str, + job_name: str, tls_config: Dict, proxy_config: Dict = None ) -> None: proxy_config = GrpcCrossSiloMessageConfig.from_dict(proxy_config) - super().__init__(cluster, party, tls_config, proxy_config) + super().__init__(cluster, party, job_name, tls_config, proxy_config) self._grpc_metadata = proxy_config.http_header or {} self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) self._grpc_options.update(parse_grpc_options(self._proxy_config)) @@ -132,6 +132,7 @@ async def send( stub=self._stubs[dest_party], upstream_seq_id=upstream_seq_id, downstream_seq_id=downstream_seq_id, + job_name=self._job_name, timeout=timeout, metadata=grpc_metadata, ) @@ -173,6 +174,7 @@ async def send_data_grpc( upstream_seq_id, downstream_seq_id, timeout, + job_name, metadata=None, ): data = cloudpickle.dumps(data) @@ -180,6 +182,7 @@ async def send_data_grpc( data=data, upstream_seq_id=str(upstream_seq_id), downstream_seq_id=str(downstream_seq_id), + job_name=job_name, ) # Waiting for the reply from downstream. response = await stub.SendData( @@ -199,11 +202,12 @@ def __init__( self, listen_addr: str, party: str, + job_name: str, tls_config: Dict, proxy_config: Dict ) -> None: proxy_config = GrpcCrossSiloMessageConfig.from_dict(proxy_config) - super().__init__(listen_addr, party, tls_config, proxy_config) + super().__init__(listen_addr, party, job_name, tls_config, proxy_config) self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) self._grpc_options.update(parse_grpc_options(self._proxy_config)) @@ -224,6 +228,7 @@ async def start(self): self._all_data, self._party, self._lock, + self._job_name, self._server_ready_future, self._tls_config, fed_utils.dict2tuple(self._grpc_options), @@ -268,13 +273,21 @@ async def get_proxy_config(self): class SendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock): + def __init__(self, all_events, all_data, party, lock, job_name): self._events = all_events self._all_data = all_data self._party = party self._lock = lock + self._job_name = job_name async def SendData(self, request, context): + job_name = request.job_name + if job_name != self._job_name: + logger.warning(f"Receive data from job {job_name}, ignore it. " + f"The reason may be that the ReceiverProxy is listening " + f"on the same address with that job.") + return fed_pb2.SendDataResponse( + result=f"JobName mis-match, expected {self._job_name}, got {job_name}.") upstream_seq_id = request.upstream_seq_id downstream_seq_id = request.downstream_seq_id logger.debug( @@ -300,13 +313,13 @@ async def SendData(self, request, context): async def _run_grpc_server( - port, event, all_data, party, lock, + port, event, all_data, party, lock, job_name, server_ready_future, tls_config=None, grpc_options=None ): logger.info(f"ReceiveProxy binding port {port}, options: {grpc_options}...") server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( - SendDataService(event, all_data, party, lock), server + SendDataService(event, all_data, party, lock, job_name), server ) tls_enabled = fed_utils.tls_enabled(tls_config) diff --git a/fed/utils.py b/fed/utils.py index 481fcfd..0a82bc0 100644 --- a/fed/utils.py +++ b/fed/utils.py @@ -101,6 +101,7 @@ def setup_logger( date_format, log_dir=None, party_val=None, + job_name=None, ): class PartyRecordFilter(logging.Filter): def __init__(self, party_val=None) -> None: @@ -112,6 +113,16 @@ def filter(self, record) -> bool: record.party = self._party_val return True + class JobNameRecordFilter(logging.Filter): + def __init__(self, job_name=None) -> None: + self._job_name = job_name + super().__init__("JobNameRecordFilter") + + def filter(self, record) -> bool: + if not hasattr(record, "jobname"): + record.jobname = self._job_name + return True + logger = logging.getLogger() # Remove default handlers otherwise a msg will be printed twice. @@ -123,11 +134,13 @@ def filter(self, record) -> bool: logger.setLevel(logging_level) _formatter = logging.Formatter(fmt=logging_format, datefmt=date_format) - _filter = PartyRecordFilter(party_val=party_val) + _party_filter = PartyRecordFilter(party_val=party_val) + _job_name_fitler = JobNameRecordFilter(job_name=job_name) _customed_handler = logging.StreamHandler() _customed_handler.setFormatter(_formatter) - _customed_handler.addFilter(_filter) + _customed_handler.addFilter(_party_filter) + _customed_handler.addFilter(_job_name_fitler) logger.addHandler(_customed_handler) diff --git a/tests/multi-jobs/test_ignore_other_job_msg.py b/tests/multi-jobs/test_ignore_other_job_msg.py new file mode 100644 index 0000000..cbb45ac --- /dev/null +++ b/tests/multi-jobs/test_ignore_other_job_msg.py @@ -0,0 +1,121 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import fed +import ray +import grpc +import pytest +import fed.utils as fed_utils +import fed._private.compatible_utils as compatible_utils +from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, send_data_grpc +if compatible_utils._compare_version_strings( + fed_utils.get_package_version('protobuf'), '4.0.0'): + from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc +else: + from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc + + +class TestGrpcSenderProxy(GrpcSenderProxy): + async def send( + self, + dest_party, + data, + upstream_seq_id, + downstream_seq_id): + dest_addr = self._addresses[dest_party] + grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) + if dest_party not in self._stubs: + channel = grpc.aio.insecure_channel( + dest_addr, options=grpc_channel_options) + stub = fed_pb2_grpc.GrpcServiceStub(channel) + self._stubs[dest_party] = stub + + timeout = self._proxy_config.timeout_in_ms / 1000 + response: str = await send_data_grpc( + data=data, + stub=self._stubs[dest_party], + upstream_seq_id=upstream_seq_id, + downstream_seq_id=downstream_seq_id, + job_name=self._job_name, + timeout=timeout, + metadata=grpc_metadata, + ) + assert "JobName mis-match" in response + # So that process can exit + raise RuntimeError() + + +@fed.remote +class MyActor: + def __init__(self, party, data): + self.__data = data + self._party = party + + def f(self): + return f"f({self._party}, ip is {ray.util.get_node_ip_address()})" + + +@fed.remote +def agg_fn(obj1, obj2): + return f"agg-{obj1}-{obj2}" + + +addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', +} + + +def run(party, job_name): + ray.init(address='local') + fed.init(addresses=addresses, + party=party, + job_name=job_name, + sender_proxy_cls=TestGrpcSenderProxy, + config={ + 'cross_silo_comm': { + 'exit_on_sending_failure': True, + }}) + # 'bob' only needs to start the proxy actors + if party == 'alice': + ds1, ds2 = [123, 789] + actor_alice = MyActor.party("alice").remote(party, ds1) + actor_bob = MyActor.party("bob").remote(party, ds2) + + obj_alice_f = actor_alice.f.remote() + obj_bob_f = actor_bob.f.remote() + + obj = agg_fn.party("bob").remote(obj_alice_f, obj_bob_f) + fed.get(obj) + fed.shutdown() + ray.shutdown() + import time + # Wait for SIGTERM as failure on sending. + time.sleep(86400) + + +def test_ignore_other_job_msg(): + p_alice = multiprocessing.Process(target=run, args=('alice', 'job1')) + p_bob = multiprocessing.Process(target=run, args=('bob', 'job2')) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-sv", __file__])) diff --git a/tests/test_internal_kv.py b/tests/test_internal_kv.py index f2b5372..bb04823 100644 --- a/tests/test_internal_kv.py +++ b/tests/test_internal_kv.py @@ -4,6 +4,7 @@ import fed import time import fed._private.compatible_utils as compatible_utils +import ray.experimental.internal_kv as ray_internal_kv def run(party): @@ -13,10 +14,15 @@ def run(party): 'bob': '127.0.0.1:11011', } assert compatible_utils.kv is None - fed.init(addresses=addresses, party=party) + fed.init(addresses=addresses, party=party, job_name="test_job_name") assert compatible_utils.kv - assert not compatible_utils.kv.put(b"test_key", b"test_val") - assert compatible_utils.kv.get(b"test_key") == b"test_val" + assert not compatible_utils.kv.put("test_key", b"test_val") + assert compatible_utils.kv.get("test_key") == b"test_val" + + # Test that a prefix key name is added under the hood. + assert ray_internal_kv._internal_kv_get(b"test_key") is None + assert ray_internal_kv._internal_kv_get( + b"RAYFED#test_job_name#test_key") == b"test_val" time.sleep(5) fed.shutdown() diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index b6ee9b5..574ce9f 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -50,6 +50,7 @@ def run(): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } + test_job_name = 'test_retry_policy' fed.init( addresses=addresses, party='alice', @@ -60,7 +61,7 @@ def run(): }, ) - job_config = config.get_job_config() + job_config = config.get_job_config(test_job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict TestCase().assertDictEqual( cross_silo_comm_config['grpc_retry_policy'], retry_policy diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 7df2528..368223a 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -46,14 +46,15 @@ def test_n_to_1_transport(): N receivers to `get_data` from receiver proxy at that time. """ compatible_utils.init_ray(address='local') - + test_job_name = 'test_n_to_1_transport' + global_context.init_global_context(test_job_name) global_context.get_global_context().get_cleanup_manager().start() cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", } - compatible_utils._init_internal_kv() + compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) ) @@ -93,17 +94,23 @@ def test_n_to_1_transport(): global_context.get_global_context().get_cleanup_manager().graceful_stop() global_context.clear_global_context() + compatible_utils._clear_internal_kv() ray.shutdown() class TestSendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock, expected_metadata): + def __init__(self, all_events, all_data, party, lock, + expected_metadata, expected_jobname): self.expected_metadata = expected_metadata or {} + self._expected_jobname = expected_jobname or "" async def SendData(self, request, context): + job_name = request.job_name + assert self._expected_jobname == job_name metadata = dict(context.invocation_metadata()) for k, v in self.expected_metadata.items(): - assert k in metadata + assert k in metadata, \ + f"The expected key {k} is not in the metadata keys: {metadata.keys()}." assert v == metadata[k] event = asyncio.Event() event.set() @@ -118,10 +125,13 @@ async def _test_run_grpc_server( lock, grpc_options=None, expected_metadata=None, + expected_jobname=None, ): server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( - TestSendDataService(event, all_data, party, lock, expected_metadata), server + TestSendDataService(event, all_data, party, lock, + expected_metadata, expected_jobname), + server ) server.add_insecure_port(f'[::]:{port}') await server.start() @@ -135,19 +145,22 @@ def __init__( listen_addr: str, party: str, expected_metadata: dict, + expected_jobname: str, ): self._listen_addr = listen_addr self._party = party self._expected_metadata = expected_metadata + self._expected_jobname = expected_jobname async def run_grpc_server(self): return await _test_run_grpc_server( - self._listen_addr[self._listen_addr.index(':') + 1 :], + self._listen_addr[self._listen_addr.index(':') + 1:], None, None, self._party, None, expected_metadata=self._expected_metadata, + expected_jobname=self._expected_jobname ) async def is_ready(self): @@ -158,13 +171,16 @@ def _test_start_receiver_proxy( addresses: str, party: str, expected_metadata: dict, + expected_jobname: str, ): # Create RecevrProxyActor # Not that this is now a threaded actor. address = addresses[party] receiver_proxy_actor = TestReceiverProxyActor.options( name=receiver_proxy_actor_name(), max_concurrency=1000 - ).remote(listen_addr=address, party=party, expected_metadata=expected_metadata) + ).remote(listen_addr=address, party=party, + expected_metadata=expected_metadata, + expected_jobname=expected_jobname) receiver_proxy_actor.run_grpc_server.remote() assert ray.get(receiver_proxy_actor.is_ready.remote()) @@ -181,7 +197,9 @@ def test_send_grpc_with_meta(): job_config = { constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: config, } - compatible_utils._init_internal_kv() + test_job_name = 'test_send_grpc_with_meta' + global_context.init_global_context(test_job_name) + compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) ) @@ -195,13 +213,14 @@ def test_send_grpc_with_meta(): addresses, party_name, expected_metadata=metadata, + expected_jobname=test_job_name ) _start_sender_proxy( addresses, party_name, logging_level='info', proxy_cls=GrpcSenderProxy, - proxy_config={}, + proxy_config=config, ) sent_objs = [] sent_obj = send(party_name, "data", 0, 1) diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index 1ea525e..e8c1052 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -35,7 +35,7 @@ def test_n_to_1_transport(): N receivers to `get_data` from receiver proxy at that time. """ compatible_utils.init_ray(address='local') - + test_job_name = 'test_n_to_1_transport' cert_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "/tmp/rayfed/test-certs/" ) @@ -50,9 +50,9 @@ def test_n_to_1_transport(): constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: tls_config, } - + global_context.init_global_context(test_job_name) global_context.get_global_context().get_cleanup_manager().start() - compatible_utils._init_internal_kv() + compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) ) @@ -99,6 +99,7 @@ def test_n_to_1_transport(): global_context.get_global_context().get_cleanup_manager().graceful_stop() global_context.clear_global_context() + compatible_utils._clear_internal_kv() ray.shutdown()