Skip to content

Commit

Permalink
Format codes.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouaihui committed Dec 13, 2023
1 parent 09019eb commit 8a7cda0
Show file tree
Hide file tree
Showing 47 changed files with 422 additions and 356 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ jobs:
- name: Lint
run: |
. py3/bin/activate
black --check --diff .
black -S --check --diff . --exclude='fed/grpc|py3'
3 changes: 1 addition & 2 deletions .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ use_parentheses=True
float_to_top=True
filter_files=True

known_local_folder=ray
known_third_party=grpc
known_local_folder=fed
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
10 changes: 6 additions & 4 deletions benchmarks/many_tiny_tasks_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ray
import time
import sys
import time

import ray

import fed


Expand Down Expand Up @@ -53,8 +55,8 @@ def main(party):
if i % 100 == 0:
print(f"Running {i}th call")
print(f"num calls: {num_calls}")
print("total time (ms) = ", (time.time() - start)*1000)
print("per task overhead (ms) =", (time.time() - start)*1000/num_calls)
print("total time (ms) = ", (time.time() - start) * 1000)
print("per task overhead (ms) =", (time.time() - start) * 1000 / num_calls)

fed.shutdown()
ray.shutdown()
Expand Down
9 changes: 4 additions & 5 deletions fed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from fed.api import (get, init, kill, remote,
shutdown)
from fed.proxy.barriers import recv, send
from fed.fed_object import FedObject
from fed.api import get, init, kill, remote, shutdown
from fed.exceptions import FedRemoteError
from fed.fed_object import FedObject
from fed.proxy.barriers import recv, send

__all__ = [
"get",
Expand All @@ -27,5 +26,5 @@
"recv",
"send",
"FedObject",
"FedRemoteError"
"FedRemoteError",
]
61 changes: 31 additions & 30 deletions fed/_private/compatible_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import abc
import ray
import fed._private.constants as fed_constants

import ray
import ray.experimental.internal_kv as ray_internal_kv

import fed._private.constants as fed_constants
from fed._private import constants


Expand All @@ -41,15 +42,14 @@ def _compare_version_strings(version1, version2):


def _ray_version_less_than_2_0_0():
""" Whther the current ray version is less 2.0.0.
"""
"""Whther the current ray version is less 2.0.0."""
return _compare_version_strings(
fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__)
fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__
)


def init_ray(address: str = None, **kwargs):
"""A compatible API to init Ray.
"""
"""A compatible API to init Ray."""
if address == 'local' and _ray_version_less_than_2_0_0():
# Ignore the `local` when ray < 2.0.0
ray.init(**kwargs)
Expand All @@ -58,28 +58,27 @@ def init_ray(address: str = None, **kwargs):


def _get_gcs_address_from_ray_worker():
"""A compatible API to get the gcs address from the ray worker module.
"""
"""A compatible API to get the gcs address from the ray worker module."""
try:
return ray._private.worker._global_node.gcs_address
except AttributeError:
return ray.worker._global_node.gcs_address


def wrap_kv_key(job_name, key: str):
"""Add an prefix to the key to avoid conflict with other jobs.
"""
assert isinstance(key, str), \
f"The key of KV data must be `str` type, got {type(key)}."
"""Add an prefix to the key to avoid conflict with other jobs."""
assert isinstance(
key, str
), f"The key of KV data must be `str` type, got {type(key)}."

return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(
job_name, key)
return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(job_name, key)


class AbstractInternalKv(abc.ABC):
""" An abstract class that represents for bridging Ray internal kv in
"""An abstract class that represents for bridging Ray internal kv in
both Ray client mode and non Ray client mode.
"""

def __init__(self) -> None:
pass

Expand All @@ -105,8 +104,8 @@ def reset(self):


class InternalKv(AbstractInternalKv):
"""The internal kv class for non Ray client mode.
"""
"""The internal kv class for non Ray client mode."""

def __init__(self, job_name: str) -> None:
super().__init__()
self._job_name = job_name
Expand All @@ -120,21 +119,18 @@ def initialize(self):
from ray._raylet import GcsClient

gcs_client = GcsClient(
address=_get_gcs_address_from_ray_worker(),
nums_reconnect_retry=10)
address=_get_gcs_address_from_ray_worker(), nums_reconnect_retry=10
)
return ray_internal_kv._initialize_internal_kv(gcs_client)

def put(self, k, v):
return ray_internal_kv._internal_kv_put(
wrap_kv_key(self._job_name, k), v)
return ray_internal_kv._internal_kv_put(wrap_kv_key(self._job_name, k), v)

def get(self, k):
return ray_internal_kv._internal_kv_get(
wrap_kv_key(self._job_name, k))
return ray_internal_kv._internal_kv_get(wrap_kv_key(self._job_name, k))

def delete(self, k):
return ray_internal_kv._internal_kv_del(
wrap_kv_key(self._job_name, k))
return ray_internal_kv._internal_kv_del(wrap_kv_key(self._job_name, k))

def reset(self):
return ray_internal_kv._internal_kv_reset()
Expand All @@ -144,8 +140,8 @@ def _ping(self):


class ClientModeInternalKv(AbstractInternalKv):
"""The internal kv class for Ray client mode.
"""
"""The internal kv class for Ray client mode."""

def __init__(self) -> None:
super().__init__()
self._internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR")
Expand Down Expand Up @@ -176,9 +172,13 @@ def _init_internal_kv(job_name):
global kv
if kv is None:
from ray._private.client_mode_hook import is_client_mode_enabled

if is_client_mode_enabled:
kv_actor = ray.remote(InternalKv).options(
name="_INTERNAL_KV_ACTOR").remote(job_name)
kv_actor = (
ray.remote(InternalKv)
.options(name="_INTERNAL_KV_ACTOR")
.remote(job_name)
)
response = kv_actor._ping.remote()
ray.get(response)
kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name)
Expand All @@ -192,6 +192,7 @@ def _clear_internal_kv():
kv.delete(constants.KEY_OF_JOB_CONFIG)
kv.reset()
from ray._private.client_mode_hook import is_client_mode_enabled

if is_client_mode_enabled:
_internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR")
ray.kill(_internal_kv_actor)
Expand Down
2 changes: 1 addition & 1 deletion fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT"

RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa
RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa

RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S"

Expand Down
17 changes: 8 additions & 9 deletions fed/_private/fed_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ray
from ray.util.client.common import ClientActorHandle

from fed._private.fed_call_holder import FedCallHolder
from fed.fed_object import FedObject

Expand Down Expand Up @@ -90,19 +91,17 @@ def _execute_impl(self, cls_args, cls_kwargs):
)

def _execute_remote_method(
self,
method_name,
options,
_ray_wrappered_method,
args,
kwargs,
self,
method_name,
options,
_ray_wrappered_method,
args,
kwargs,
):
num_returns = 1
if options and 'num_returns' in options:
num_returns = options['num_returns']
logger.debug(
f"Actor method call: {method_name}, num_returns: {num_returns}"
)
logger.debug(f"Actor method call: {method_name}, num_returns: {num_returns}")

return _ray_wrappered_method.options(
name='',
Expand Down
13 changes: 7 additions & 6 deletions fed/_private/fed_call_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

import logging

# Set config in the very beginning to avoid being overwritten by other packages.
logging.basicConfig(level=logging.INFO)

import fed.config as fed_config
from fed._private.global_context import get_global_context
from fed.proxy.barriers import send
from fed.fed_object import FedObject
from fed.utils import resolve_dependencies
from fed.proxy.barriers import send
from fed.tree_util import tree_flatten
import fed.config as fed_config
from fed.utils import resolve_dependencies

# Set config in the very beginning to avoid being overwritten by other packages.
logging.basicConfig(level=logging.INFO)


logger = logging.getLogger(__name__)

Expand Down
20 changes: 11 additions & 9 deletions fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from fed.cleanup import CleanupManager
from typing import Callable
import threading
from typing import Callable

from fed.cleanup import CleanupManager


class GlobalContext:
def __init__(self, job_name: str,
current_party: str,
failure_handler: Callable[[], None]) -> None:
def __init__(
self, job_name: str, current_party: str, failure_handler: Callable[[], None]
) -> None:
self._job_name = job_name
self._seq_count = 0
self._failure_handler = failure_handler
self._atomic_shutdown_flag_lock = threading.Lock()
self._atomic_shutdown_flag = True
self._cleanup_manager = CleanupManager(
current_party, self.acquire_shutdown_flag)
current_party, self.acquire_shutdown_flag
)

def next_seq_id(self) -> int:
self._seq_count += 1
Expand Down Expand Up @@ -65,9 +67,9 @@ def acquire_shutdown_flag(self) -> bool:
_global_context = None


def init_global_context(current_party: str,
job_name: str,
failure_handler: Callable[[], None] = None) -> None:
def init_global_context(
current_party: str, job_name: str, failure_handler: Callable[[], None] = None
) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(job_name, current_party, failure_handler)
Expand Down
16 changes: 9 additions & 7 deletions fed/_private/message_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading
from collections import deque
import time
import logging

from collections import deque

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +54,8 @@ def _loop():

if self._thread is None or not self._thread.is_alive():
logger.debug(
f"Starting new thread[{self._thread_name}] for message polling.")
f"Starting new thread[{self._thread_name}] for message polling."
)
self._queue = deque()
self._thread = threading.Thread(target=_loop, name=self._thread_name)
self._thread.start()
Expand All @@ -79,9 +79,11 @@ def stop(self):
If False: forcelly kill the for-loop sub-thread.
"""
if threading.current_thread() == self._thread:
logger.error(f"Can't stop the message queue in the message "
f"polling thread[{self._thread_name}]. Ignore it as this"
f"could bring unknown time sequence problems.")
logger.error(
f"Can't stop the message queue in the message "
f"polling thread[{self._thread_name}]. Ignore it as this"
f"could bring unknown time sequence problems."
)
raise RuntimeError("Thread can't kill itself")

# TODO(NKcqx): Force kill sub-thread by calling `._stop()` will
Expand Down
2 changes: 1 addition & 1 deletion fed/_private/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

import io

import cloudpickle

import fed.config as fed_config


_pickle_whitelist = None


Expand Down
Loading

0 comments on commit 8a7cda0

Please sign in to comment.