Skip to content

Commit

Permalink
Get rid of copy memory when receive and send
Browse files Browse the repository at this point in the history
Signed-off-by: Sharpner6 <[email protected]>
  • Loading branch information
sharpener6 committed Oct 12, 2024
1 parent 7c2faf9 commit c030838
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<a href="https://pypi.org/project/scaler">
<img alt="PyPI - Version" src="https://img.shields.io/pypi/v/scaler?colorA=0f1632&colorB=255be3">
</a>
<img src="https://api.securityscorecards.dev/projects/github.com/citi/scaler/badge">
<img src="https://api.securityscorecards.dev/projects/github.com/Citi/scaler/badge">
</p>
</div>

Expand Down
2 changes: 1 addition & 1 deletion scaler/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.8.7"
__version__ = "1.8.8"
11 changes: 6 additions & 5 deletions scaler/io/async_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Awaitable, Callable, List, Optional, Dict

import zmq.asyncio
from zmq import Frame

from scaler.io.utility import deserialize, serialize
from scaler.protocol.python.mixins import Message
Expand Down Expand Up @@ -38,18 +39,18 @@ def register(self, callback: Callable[[bytes, Message], Awaitable[None]]):
self._callback = callback

async def routine(self):
frames = await self._socket.recv_multipart()
frames: List[Frame] = await self._socket.recv_multipart(copy=False)
if not self.__is_valid_message(frames):
return

source, payload = frames
message: Optional[Message] = deserialize(payload)
message: Optional[Message] = deserialize(payload.bytes)
if message is None:
logging.error(f"received unknown message from {source!r}: {payload!r}")
logging.error(f"received unknown message from {source.bytes!r}: {payload!r}")
return

self.__count_received(message.__class__.__name__)
await self._callback(source, message)
await self._callback(source.bytes, message)

async def send(self, to: bytes, message: Message):
self.__count_sent(message.__class__.__name__)
Expand All @@ -63,7 +64,7 @@ def __set_socket_options(self):
self._socket.setsockopt(zmq.SNDHWM, 0)
self._socket.setsockopt(zmq.RCVHWM, 0)

def __is_valid_message(self, frames: List[bytes]) -> bool:
def __is_valid_message(self, frames: List[Frame]) -> bool:
if len(frames) < 2:
logging.error(f"{self.__get_prefix()} received unexpected frames {frames}")
return False
Expand Down
6 changes: 3 additions & 3 deletions scaler/io/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ async def receive(self) -> Optional[Message]:
if self._socket.closed:
return None

payload = await self._socket.recv()
result: Optional[Message] = deserialize(payload)
payload = await self._socket.recv(copy=False)
result: Optional[Message] = deserialize(payload.bytes)
if result is None:
logging.error(f"received unknown message: {payload!r}")
logging.error(f"received unknown message: {payload.bytes!r}")
return None

return result
Expand Down
6 changes: 3 additions & 3 deletions scaler/io/sync_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def identity(self) -> bytes:

def send(self, message: Message):
with self._lock:
self._socket.send(serialize(message))
self._socket.send(serialize(message), copy=False)

def receive(self) -> Optional[Message]:
with self._lock:
payload = self._socket.recv()
payload = self._socket.recv(copy=False)

return self.__compose_message(payload)
return self.__compose_message(payload.bytes)

def __compose_message(self, payload: bytes) -> Optional[Message]:
result: Optional[Message] = deserialize(payload)
Expand Down
2 changes: 1 addition & 1 deletion scaler/io/sync_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __initialize(self):

def __routine_polling(self):
try:
self.__routine_receive(self._socket.recv())
self.__routine_receive(self._socket.recv(copy=False).bytes)
except zmq.Again:
raise TimeoutError(f"Cannot connect to {self._address.to_address()} in {self._timeout_seconds} seconds")

Expand Down

0 comments on commit c030838

Please sign in to comment.