Skip to content

Commit

Permalink
[Multi-Job][3/N] Distinguish cross_silo msg with the job name (#172)
Browse files Browse the repository at this point in the history
* split kv data of jobs

* send/recvmsg with job_id

---------

Signed-off-by: paer <[email protected]>
Signed-off-by: NKcqx <[email protected]>
Co-authored-by: paer <[email protected]>
  • Loading branch information
NKcqx and paer authored Aug 15, 2023
1 parent e548b01 commit f0ccfb3
Show file tree
Hide file tree
Showing 19 changed files with 334 additions and 200 deletions.
28 changes: 21 additions & 7 deletions fed/_private/compatible_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def _get_gcs_address_from_ray_worker():
return ray.worker._global_node.gcs_address


def wrap_kv_key(job_name, key: str):
"""Add an prefix to the key to avoid conflict with other jobs.
"""
assert isinstance(key, str), \
f"The key of KV data must be `str` type, got {type(key)}."

return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(
job_name, key)


class AbstractInternalKv(abc.ABC):
""" An abstract class that represents for bridging Ray internal kv in
both Ray client mode and non Ray client mode.
Expand Down Expand Up @@ -97,8 +107,9 @@ def reset(self):
class InternalKv(AbstractInternalKv):
"""The internal kv class for non Ray client mode.
"""
def __init__(self) -> None:
def __init__(self, job_name: str) -> None:
super().__init__()
self._job_name = job_name

def initialize(self):
try:
Expand All @@ -114,13 +125,16 @@ def initialize(self):
return ray_internal_kv._initialize_internal_kv(gcs_client)

def put(self, k, v):
return ray_internal_kv._internal_kv_put(k, v)
return ray_internal_kv._internal_kv_put(
wrap_kv_key(self._job_name, k), v)

def get(self, k):
return ray_internal_kv._internal_kv_get(k)
return ray_internal_kv._internal_kv_get(
wrap_kv_key(self._job_name, k))

def delete(self, k):
return ray_internal_kv._internal_kv_del(k)
return ray_internal_kv._internal_kv_del(
wrap_kv_key(self._job_name, k))

def reset(self):
return ray_internal_kv._internal_kv_reset()
Expand Down Expand Up @@ -157,17 +171,17 @@ def reset(self):
return ray.get(o)


def _init_internal_kv():
def _init_internal_kv(job_name):
"""An internal API that initialize the internal kv object."""
global kv
if kv is None:
from ray._private.client_mode_hook import is_client_mode_enabled
if is_client_mode_enabled:
kv_actor = ray.remote(InternalKv).options(
name="_INTERNAL_KV_ACTOR").remote()
name="_INTERNAL_KV_ACTOR").remote(job_name)
response = kv_actor._ping.remote()
ray.get(response)
kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv()
kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name)
kv.initialize()


Expand Down
12 changes: 8 additions & 4 deletions fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.


KEY_OF_CLUSTER_CONFIG = b"CLUSTER_CONFIG"
KEY_OF_CLUSTER_CONFIG = "CLUSTER_CONFIG"

KEY_OF_JOB_CONFIG = b"JOB_CONFIG"
KEY_OF_JOB_CONFIG = "JOB_CONFIG"

KEY_OF_GRPC_METADATA = b"GRPC_METADATA"
KEY_OF_GRPC_METADATA = "GRPC_METADATA"

KEY_OF_CLUSTER_ADDRESSES = "CLUSTER_ADDRESSES"

Expand All @@ -27,8 +27,12 @@

KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT"

RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa
RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa

RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S"

RAY_VERSION_2_0_0_STR = "2.0.0"

RAYFED_DEFAULT_JOB_NAME = "Anonymous"

RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}"
5 changes: 4 additions & 1 deletion fed/_private/fed_call_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def __init__(
submit_ray_task_func,
options={},
) -> None:
self._party = fed_config.get_cluster_config().current_party
# Note(NKcqx): FedCallHolder will only be created in driver process, where
# the GlobalContext must has been initialized.
job_name = get_global_context().job_name()
self._party = fed_config.get_cluster_config(job_name).current_party
self._node_party = node_party
self._options = options
self._submit_ray_task_func = submit_ray_task_func
Expand Down
12 changes: 11 additions & 1 deletion fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@


class GlobalContext:
def __init__(self) -> None:
def __init__(self, job_name: str) -> None:
self._job_name = job_name
self._seq_count = 0
self._cleanup_manager = CleanupManager()

Expand All @@ -27,10 +28,19 @@ def next_seq_id(self) -> int:
def get_cleanup_manager(self) -> CleanupManager:
return self._cleanup_manager

def job_name(self) -> str:
return self._job_name


_global_context = None


def init_global_context(job_name: str) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(job_name)


def get_global_context():
global _global_context
if _global_context is None:
Expand Down
46 changes: 30 additions & 16 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from fed._private import constants
from fed._private.fed_actor import FedActorHandle
from fed._private.fed_call_holder import FedCallHolder
from fed._private.global_context import get_global_context, clear_global_context
from fed._private.global_context import (
init_global_context,
get_global_context,
clear_global_context
)
from fed.proxy.barriers import (
ping_others,
recv,
Expand Down Expand Up @@ -54,6 +58,7 @@ def init(
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
receiver_sender_proxy_cls: SenderReceiverProxy = None,
job_name: str = constants.RAYFED_DEFAULT_JOB_NAME
):
"""
Initialize a RayFed client.
Expand Down Expand Up @@ -91,7 +96,12 @@ def init(
}
logging_level: optional; the logging level, could be `debug`, `info`,
`warning`, `error`, `critical`, not case sensititive.
job_name: optional; the job name of the current job. Note that, the job name
must be identical in all parties, otherwise, messages will be ignored
because of the job name mismatch. If the job name is not provided, an
default fixed name will be assigned, therefore messages of all anonymous
jobs will be mixed together, which should only be used in the single job
scenario or test mode.
Examples:
>>> import fed
>>> import ray
Expand All @@ -109,7 +119,7 @@ def init(
assert party in addresses, f"Party {party} is not in the addresses {addresses}."

fed_utils.validate_addresses(addresses)

init_global_context(job_name=job_name)
tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
Expand All @@ -118,7 +128,7 @@ def init(

cross_silo_comm_dict = config.get("cross_silo_comm", {})
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()
compatible_utils._init_internal_kv(job_name)

cluster_config = {
constants.KEY_OF_CLUSTER_ADDRESSES: addresses,
Expand All @@ -141,7 +151,8 @@ def init(
logging_level=logging_level,
logging_format=constants.RAYFED_LOG_FMT,
date_format=constants.RAYFED_DATE_FMT,
party_val=_get_party(),
party_val=_get_party(job_name),
job_name=job_name
)

logger.info(f'Started rayfed with {cluster_config}')
Expand Down Expand Up @@ -215,25 +226,25 @@ def shutdown():
logger.info('Shutdowned rayfed.')


def _get_addresses():
def _get_addresses(job_name: str = None):
"""
Get the RayFed addresses configration.
"""
return fed_config.get_cluster_config().cluster_addresses
return fed_config.get_cluster_config(job_name).cluster_addresses


def _get_party():
def _get_party(job_name: str = None):
"""
A private util function to get the current party name.
"""
return fed_config.get_cluster_config().current_party
return fed_config.get_cluster_config(job_name).current_party


def _get_tls():
def _get_tls(job_name: str = None):
"""
Get the tls configurations on this party.
"""
return fed_config.get_cluster_config().tls_config
return fed_config.get_cluster_config(job_name).tls_config


class FedRemoteFunction:
Expand Down Expand Up @@ -287,11 +298,12 @@ def options(self, **options):

def remote(self, *cls_args, **cls_kwargs):
fed_class_task_id = get_global_context().next_seq_id()
job_name = get_global_context().job_name()
fed_actor_handle = FedActorHandle(
fed_class_task_id,
_get_addresses(),
_get_addresses(job_name),
self._cls,
_get_party(),
_get_party(job_name),
self._party,
self._options,
)
Expand Down Expand Up @@ -340,8 +352,9 @@ def get(
# A fake fed_task_id for a `fed.get()` operator. This is useful
# to help contruct the whole DAG within `fed.get`.
fake_fed_task_id = get_global_context().next_seq_id()
addresses = _get_addresses()
current_party = _get_party()
job_name = get_global_context().job_name()
addresses = _get_addresses(job_name)
current_party = _get_party(job_name)
is_individual_id = isinstance(fed_objects, FedObject)
if is_individual_id:
fed_objects = [fed_objects]
Expand Down Expand Up @@ -396,7 +409,8 @@ def get(


def kill(actor: FedActorHandle, *, no_restart=True):
current_party = _get_party()
job_name = get_global_context().job_name()
current_party = _get_party(job_name)
if actor._node_party == current_party:
handler = actor._actor_handle
ray.kill(handler, no_restart=no_restart)
14 changes: 8 additions & 6 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,25 @@ def cross_silo_comm_config_dict(self) -> Dict:
_job_config = None


def get_cluster_config():
def get_cluster_config(job_name: str = None):
"""This function is not thread safe to use."""
global _cluster_config
if _cluster_config is None:
compatible_utils._init_internal_kv()
compatible_utils.kv.initialize()
assert job_name is not None, \
"Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG)
_cluster_config = ClusterConfig(raw_dict)
return _cluster_config


def get_job_config():
def get_job_config(job_name: str = None):
"""This config still acts like cluster config for now"""
global _job_config
if _job_config is None:
compatible_utils._init_internal_kv()
compatible_utils.kv.initialize()
assert job_name is not None, \
"Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG)
_job_config = JobConfig(raw_dict)
return _job_config
Expand Down
1 change: 1 addition & 0 deletions fed/grpc/fed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ message SendDataRequest {
bytes data = 1;
string upstream_seq_id = 2;
string downstream_seq_id = 3;
string job_name = 4;
};

message SendDataResponse {
Expand Down
Loading

0 comments on commit f0ccfb3

Please sign in to comment.