Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing messages_max_size_in_bytes config #182

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,33 @@ def init(
'carol': '127.0.0.1:10003',
}
party: optional; self party.
config: optional; a dict describes general job configurations. Currently the
supported configurations are [`cross_silo_comm`, 'barrier_on_initializing'].
* `cross_silo_comm`: optional; a dict describes the cross-silo common
configs, the supported configs can be referred to
`fed.config.CrossSiloMessageConfig` and
`fed.config.GrpcCrossSiloMessageConfig`. Note that, the
`cross_silo_comm.messages_max_size_in_bytes` will be overrided
if `cross_silo_comm.grpc_channel_options` is provided and contains
`grpc.max_send_message_length` or `grpc.max_receive_message_length`.
* `barrier_on_initializing`: optional; a bool value indicates whether to
wait for all parties to be ready before starting the job. If set
to True, the job will be started after all parties are ready,
otherwise, the job will be started immediately after the current
party is ready.

Example:

.. code:: python
{
"cross_silo_comm": {
"messages_max_size_in_bytes": 500*1024,
"timeout_in_ms": 1000,
"exit_on_sending_failure": True,
"expose_error_trace": True,
},
"barrier_on_initializing": True,
}
tls_config: optional; a dict describes the tls config. E.g.
For alice,

Expand Down
3 changes: 1 addition & 2 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ class CrossSiloMessageConfig:
cross-silo sending. If True, a SIGTERM will be signaled to self
if failed to sending cross-silo data.
messages_max_size_in_bytes: The maximum length in bytes of
cross-silo messages.
If None, the default value of 500 MB is specified.
cross-silo messages. If None, the default value of 500 MB is specified.
timeout_in_ms: The timeout in mili-seconds of a cross-silo RPC call.
It's 60000 by default.
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request.
Expand Down
17 changes: 15 additions & 2 deletions fed/proxy/grpc/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,21 @@ def parse_grpc_options(proxy_config: CrossSiloMessageConfig):
dict: A dictionary containing the gRPC channel options.
"""
grpc_channel_options = {}
if proxy_config is not None and isinstance(
proxy_config, GrpcCrossSiloMessageConfig):
if proxy_config is not None:
# NOTE(NKcqx): `messages_max_size_in_bytes` is a common cross-silo
# config that should be extracted and filled into proper grpc's
# channel options.
# However, `GrpcCrossSiloMessageConfig` provides a more flexible way
# to configure grpc channel options, i.e. the `grpc_channel_options`
# field, which may override the `messages_max_size_in_bytes` field.
if (isinstance(proxy_config, CrossSiloMessageConfig)):
if (proxy_config.messages_max_size_in_bytes is not None):
grpc_channel_options.update({
'grpc.max_send_message_length':
proxy_config.messages_max_size_in_bytes,
'grpc.max_receive_message_length':
proxy_config.messages_max_size_in_bytes,
})
if isinstance(proxy_config, GrpcCrossSiloMessageConfig):
if proxy_config.grpc_channel_options is not None:
grpc_channel_options.update(proxy_config.grpc_channel_options)
Expand Down
97 changes: 96 additions & 1 deletion fed/tests/test_grpc_options_on_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _assert_on_proxy(proxy_actor):
ray.shutdown()


def test_grpc_max_size():
def test_grpc_max_size_by_channel_options():
p_alice = multiprocessing.Process(target=run, args=('alice',))
p_bob = multiprocessing.Process(target=run, args=('bob',))
p_alice.start()
Expand All @@ -71,6 +71,101 @@ def test_grpc_max_size():
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


def run2(party):
compatible_utils.init_ray(address='local')
addresses = {
'alice': '127.0.0.1:11019',
'bob': '127.0.0.1:11018',
}
fed.init(
addresses=addresses,
party=party,
config={
"cross_silo_comm": {
"messages_max_size_in_bytes": 100,
},
},
)

def _assert_on_proxy(proxy_actor):
config = ray.get(proxy_actor._get_proxy_config.remote())
options = config['grpc_options']
assert ("grpc.max_send_message_length", 100) in options
assert ("grpc.max_receive_message_length", 100) in options
assert ('grpc.so_reuseport', 0) in options

sender_proxy = ray.get_actor(sender_proxy_actor_name())
receiver_proxy = ray.get_actor(receiver_proxy_actor_name())
_assert_on_proxy(sender_proxy)
_assert_on_proxy(receiver_proxy)

a = dummpy.party('alice').remote()
b = dummpy.party('bob').remote()
fed.get([a, b])

fed.shutdown()
ray.shutdown()


def test_grpc_max_size_by_common_config():
p_alice = multiprocessing.Process(target=run2, args=('alice',))
p_bob = multiprocessing.Process(target=run2, args=('bob',))
p_alice.start()
p_bob.start()
p_alice.join()
p_bob.join()
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


def run3(party):
compatible_utils.init_ray(address='local')
addresses = {
'alice': '127.0.0.1:11019',
'bob': '127.0.0.1:11018',
}
fed.init(
addresses=addresses,
party=party,
config={
"cross_silo_comm": {
"messages_max_size_in_bytes": 100,
"grpc_channel_options": [
('grpc.max_send_message_length', 200),
],
},
},
)

def _assert_on_proxy(proxy_actor):
config = ray.get(proxy_actor._get_proxy_config.remote())
options = config['grpc_options']
assert ("grpc.max_send_message_length", 200) in options
assert ("grpc.max_receive_message_length", 100) in options
assert ('grpc.so_reuseport', 0) in options

sender_proxy = ray.get_actor(sender_proxy_actor_name())
receiver_proxy = ray.get_actor(receiver_proxy_actor_name())
_assert_on_proxy(sender_proxy)
_assert_on_proxy(receiver_proxy)

a = dummpy.party('alice').remote()
b = dummpy.party('bob').remote()
fed.get([a, b])

fed.shutdown()
ray.shutdown()


def test_grpc_max_size_by_both_config():
p_alice = multiprocessing.Process(target=run3, args=('alice',))
p_bob = multiprocessing.Process(target=run3, args=('bob',))
p_alice.start()
p_bob.start()
p_alice.join()
p_bob.join()
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


if __name__ == "__main__":
import sys

Expand Down
Loading