Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-Job][3/N] Distinguish cross_silo msg with the job name #172

Merged
merged 17 commits into from
Aug 15, 2023
28 changes: 21 additions & 7 deletions fed/_private/compatible_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def _get_gcs_address_from_ray_worker():
return ray.worker._global_node.gcs_address


def wrap_kv_key(job_name, key: str):
"""Add an prefix to the key to avoid conflict with other jobs.
"""
assert isinstance(key, str), \
f"The key of KV data must be `str` type, got {type(key)}."

return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(
job_name, key)


class AbstractInternalKv(abc.ABC):
""" An abstract class that represents for bridging Ray internal kv in
both Ray client mode and non Ray client mode.
Expand Down Expand Up @@ -97,8 +107,9 @@ def reset(self):
class InternalKv(AbstractInternalKv):
"""The internal kv class for non Ray client mode.
"""
def __init__(self) -> None:
def __init__(self, job_name: str) -> None:
super().__init__()
self._job_name = job_name

def initialize(self):
try:
Expand All @@ -114,13 +125,16 @@ def initialize(self):
return ray_internal_kv._initialize_internal_kv(gcs_client)

def put(self, k, v):
return ray_internal_kv._internal_kv_put(k, v)
return ray_internal_kv._internal_kv_put(
wrap_kv_key(self._job_name, k), v)

def get(self, k):
return ray_internal_kv._internal_kv_get(k)
return ray_internal_kv._internal_kv_get(
wrap_kv_key(self._job_name, k))

def delete(self, k):
return ray_internal_kv._internal_kv_del(k)
return ray_internal_kv._internal_kv_del(
wrap_kv_key(self._job_name, k))

def reset(self):
return ray_internal_kv._internal_kv_reset()
Expand Down Expand Up @@ -157,17 +171,17 @@ def reset(self):
return ray.get(o)


def _init_internal_kv():
def _init_internal_kv(job_name):
"""An internal API that initialize the internal kv object."""
global kv
if kv is None:
from ray._private.client_mode_hook import is_client_mode_enabled
if is_client_mode_enabled:
kv_actor = ray.remote(InternalKv).options(
name="_INTERNAL_KV_ACTOR").remote()
name="_INTERNAL_KV_ACTOR").remote(job_name)
response = kv_actor._ping.remote()
ray.get(response)
kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv()
kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name)
kv.initialize()


Expand Down
12 changes: 8 additions & 4 deletions fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.


KEY_OF_CLUSTER_CONFIG = b"CLUSTER_CONFIG"
KEY_OF_CLUSTER_CONFIG = "CLUSTER_CONFIG"

KEY_OF_JOB_CONFIG = b"JOB_CONFIG"
KEY_OF_JOB_CONFIG = "JOB_CONFIG"

KEY_OF_GRPC_METADATA = b"GRPC_METADATA"
KEY_OF_GRPC_METADATA = "GRPC_METADATA"

KEY_OF_CLUSTER_ADDRESSES = "CLUSTER_ADDRESSES"

Expand All @@ -27,8 +27,12 @@

KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT"

RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa
RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa

RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S"

RAY_VERSION_2_0_0_STR = "2.0.0"

RAYFED_DEFAULT_JOB_NAME = "Anonymous"

RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}"
5 changes: 4 additions & 1 deletion fed/_private/fed_call_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def __init__(
submit_ray_task_func,
options={},
) -> None:
self._party = fed_config.get_cluster_config().current_party
# Note(NKcqx): FedCallHolder will only be created in driver process, where
# the GlobalContext must has been initialized.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good note.

job_name = get_global_context().job_name()
self._party = fed_config.get_cluster_config(job_name).current_party
self._node_party = node_party
self._options = options
self._submit_ray_task_func = submit_ray_task_func
Expand Down
12 changes: 11 additions & 1 deletion fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@


class GlobalContext:
def __init__(self) -> None:
def __init__(self, job_name: str) -> None:
self._job_name = job_name
self._seq_count = 0
self._cleanup_manager = CleanupManager()

Expand All @@ -27,10 +28,19 @@ def next_seq_id(self) -> int:
def get_cleanup_manager(self) -> CleanupManager:
return self._cleanup_manager

def job_name(self) -> str:
return self._job_name


_global_context = None


def init_global_context(job_name: str) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(job_name)


def get_global_context():
global _global_context
if _global_context is None:
Expand Down
46 changes: 30 additions & 16 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from fed._private import constants
from fed._private.fed_actor import FedActorHandle
from fed._private.fed_call_holder import FedCallHolder
from fed._private.global_context import get_global_context, clear_global_context
from fed._private.global_context import (
init_global_context,
get_global_context,
clear_global_context
)
from fed.proxy.barriers import (
ping_others,
recv,
Expand Down Expand Up @@ -54,6 +58,7 @@ def init(
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
receiver_sender_proxy_cls: SenderReceiverProxy = None,
job_name: str = constants.RAYFED_DEFAULT_JOB_NAME
):
"""
Initialize a RayFed client.
Expand Down Expand Up @@ -91,7 +96,12 @@ def init(
}
logging_level: optional; the logging level, could be `debug`, `info`,
`warning`, `error`, `critical`, not case sensititive.

job_name: optional; the job name of the current job. Note that, the job name
must be identical in all parties, otherwise, messages will be ignored
because of the job name mismatch. If the job name is not provided, an
default fixed name will be assigned, therefore messages of all anonymous
jobs will be mixed together, which should only be used in the single job
scenario or test mode.
Examples:
>>> import fed
>>> import ray
Expand All @@ -109,7 +119,7 @@ def init(
assert party in addresses, f"Party {party} is not in the addresses {addresses}."

fed_utils.validate_addresses(addresses)

init_global_context(job_name=job_name)
tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
Expand All @@ -118,7 +128,7 @@ def init(

cross_silo_comm_dict = config.get("cross_silo_comm", {})
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()
compatible_utils._init_internal_kv(job_name)

cluster_config = {
constants.KEY_OF_CLUSTER_ADDRESSES: addresses,
Expand All @@ -141,7 +151,8 @@ def init(
logging_level=logging_level,
logging_format=constants.RAYFED_LOG_FMT,
date_format=constants.RAYFED_DATE_FMT,
party_val=_get_party(),
party_val=_get_party(job_name),
job_name=job_name
)

logger.info(f'Started rayfed with {cluster_config}')
Expand Down Expand Up @@ -215,25 +226,25 @@ def shutdown():
logger.info('Shutdowned rayfed.')


def _get_addresses():
def _get_addresses(job_name: str = None):
"""
Get the RayFed addresses configration.
"""
return fed_config.get_cluster_config().cluster_addresses
return fed_config.get_cluster_config(job_name).cluster_addresses


def _get_party():
def _get_party(job_name: str = None):
"""
A private util function to get the current party name.
"""
return fed_config.get_cluster_config().current_party
return fed_config.get_cluster_config(job_name).current_party


def _get_tls():
def _get_tls(job_name: str = None):
"""
Get the tls configurations on this party.
"""
return fed_config.get_cluster_config().tls_config
return fed_config.get_cluster_config(job_name).tls_config


class FedRemoteFunction:
Expand Down Expand Up @@ -287,11 +298,12 @@ def options(self, **options):

def remote(self, *cls_args, **cls_kwargs):
fed_class_task_id = get_global_context().next_seq_id()
job_name = get_global_context().job_name()
fed_actor_handle = FedActorHandle(
fed_class_task_id,
_get_addresses(),
_get_addresses(job_name),
self._cls,
_get_party(),
_get_party(job_name),
self._party,
self._options,
)
Expand Down Expand Up @@ -340,8 +352,9 @@ def get(
# A fake fed_task_id for a `fed.get()` operator. This is useful
# to help contruct the whole DAG within `fed.get`.
fake_fed_task_id = get_global_context().next_seq_id()
addresses = _get_addresses()
current_party = _get_party()
job_name = get_global_context().job_name()
addresses = _get_addresses(job_name)
current_party = _get_party(job_name)
is_individual_id = isinstance(fed_objects, FedObject)
if is_individual_id:
fed_objects = [fed_objects]
Expand Down Expand Up @@ -396,7 +409,8 @@ def get(


def kill(actor: FedActorHandle, *, no_restart=True):
current_party = _get_party()
job_name = get_global_context().job_name()
current_party = _get_party(job_name)
if actor._node_party == current_party:
handler = actor._actor_handle
ray.kill(handler, no_restart=no_restart)
14 changes: 8 additions & 6 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,25 @@ def cross_silo_comm_config_dict(self) -> Dict:
_job_config = None


def get_cluster_config():
def get_cluster_config(job_name: str = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, so this method name should be changed to get_job_config or get_config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both "cluster_config" and "job_config" are part of the JOB, see #156. So I think it's fine to retrieve a "cluster config" by "job name" though it's definitely hard to understand 🤣

I think in the near future, we can merge these two configs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Then let us leave a TODO comment on that target?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

"""This function is not thread safe to use."""
global _cluster_config
if _cluster_config is None:
compatible_utils._init_internal_kv()
compatible_utils.kv.initialize()
assert job_name is not None, \
"Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG)
_cluster_config = ClusterConfig(raw_dict)
return _cluster_config


def get_job_config():
def get_job_config(job_name: str = None):
"""This config still acts like cluster config for now"""
global _job_config
if _job_config is None:
compatible_utils._init_internal_kv()
compatible_utils.kv.initialize()
assert job_name is not None, \
"Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG)
_job_config = JobConfig(raw_dict)
return _job_config
Expand Down
1 change: 1 addition & 0 deletions fed/grpc/fed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ message SendDataRequest {
bytes data = 1;
string upstream_seq_id = 2;
string downstream_seq_id = 3;
string job_name = 4;
};

message SendDataResponse {
Expand Down
13 changes: 13 additions & 0 deletions fed/grpc/pb3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The RayFed Team
NKcqx marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading