Skip to content

Commit

Permalink
Updated EventListener instantiation inside of class
Browse files Browse the repository at this point in the history
  • Loading branch information
vladvildanov committed Dec 20, 2024
1 parent fcfdcb8 commit 063f0d5
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 24 deletions.
14 changes: 10 additions & 4 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
event_dispatcher: Optional[EventDispatcher] = None,
):
"""
Initialize a new Redis client.
Expand All @@ -250,6 +250,10 @@ def __init__(
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
kwargs: Dict[str, Any]
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
# auto_close_connection_pool only has an effect if connection_pool is
# None. It is assumed that if connection_pool is not None, the user
# wants to manage the connection pool themselves.
Expand Down Expand Up @@ -343,7 +347,6 @@ def __init__(
)

self.connection_pool = connection_pool
self._event_dispatcher = event_dispatcher
self.single_connection_client = single_connection_client
self.connection: Optional[Connection] = None

Expand Down Expand Up @@ -786,8 +789,12 @@ def __init__(
ignore_subscribe_messages: bool = False,
encoder=None,
push_handler_func: Optional[Callable] = None,
event_dispatcher: Optional["EventDispatcher"] = EventDispatcher(),
event_dispatcher: Optional["EventDispatcher"] = None,
):
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.connection_pool = connection_pool
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
Expand All @@ -814,7 +821,6 @@ def __init__(
self.pending_unsubscribe_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
self._event_dispatcher = event_dispatcher
self._lock = asyncio.Lock()

async def __aenter__(self):
Expand Down
16 changes: 12 additions & 4 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(
ssl_ciphers: Optional[str] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
event_dispatcher: Optional[EventDispatcher] = None,
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -370,12 +370,17 @@ def __init__(
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))

if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher

self.nodes_manager = NodesManager(
startup_nodes,
require_full_coverage,
kwargs,
address_remap=address_remap,
event_dispatcher=event_dispatcher,
event_dispatcher=self._event_dispatcher,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
Expand Down Expand Up @@ -1140,7 +1145,7 @@ def __init__(
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
event_dispatcher: Optional[EventDispatcher] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
Expand All @@ -1152,7 +1157,10 @@ def __init__(
self.slots_cache: Dict[int, List["ClusterNode"]] = {}
self.read_load_balancer = LoadBalancer()
self._moved_exception: MovedError = None
self._event_dispatcher = event_dispatcher
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher

def get_node(
self,
Expand Down
7 changes: 5 additions & 2 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
event_dispatcher: Optional[EventDispatcher] = None,
):
if (username or password) and credential_provider is not None:
raise DataError(
Expand All @@ -159,6 +159,10 @@ def __init__(
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.db = db
self.client_name = client_name
self.lib_name = lib_name
Expand Down Expand Up @@ -198,7 +202,6 @@ def __init__(
self.set_parser(parser_class)
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
self._buffer_cutoff = 6000
self._event_dispatcher = event_dispatcher
self._re_auth_token: Optional[TokenInterface] = None

try:
Expand Down
18 changes: 12 additions & 6 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
protocol: Optional[int] = 2,
cache: Optional[CacheInterface] = None,
cache_config: Optional[CacheConfig] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
event_dispatcher: Optional[EventDispatcher] = None,
) -> None:
"""
Initialize a new Redis client.
Expand All @@ -235,6 +235,10 @@ def __init__(
if `True`, connection pool is not used. In that case `Redis`
instance use is not thread safe.
"""
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
if not connection_pool:
if charset is not None:
warnings.warn(
Expand Down Expand Up @@ -321,22 +325,21 @@ def __init__(
}
)
connection_pool = ConnectionPool(**kwargs)
event_dispatcher.dispatch(
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.SYNC, credential_provider
)
)
self.auto_close_connection_pool = True
else:
self.auto_close_connection_pool = False
event_dispatcher.dispatch(
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.SYNC, credential_provider
)
)

self.connection_pool = connection_pool
self._event_dispatcher = event_dispatcher

if (cache_config or cache) and self.connection_pool.get_protocol() not in [
3,
Expand Down Expand Up @@ -724,7 +727,7 @@ def __init__(
ignore_subscribe_messages: bool = False,
encoder: Optional["Encoder"] = None,
push_handler_func: Union[None, Callable[[str], None]] = None,
event_dispatcher: Optional["EventDispatcher"] = EventDispatcher(),
event_dispatcher: Optional["EventDispatcher"] = None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
Expand All @@ -735,7 +738,10 @@ def __init__(
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
self.push_handler_func = push_handler_func
self.event_dispatcher = event_dispatcher
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self._lock = threading.Lock()
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
Expand Down
23 changes: 17 additions & 6 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def __init__(
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
cache: Optional[CacheInterface] = None,
cache_config: Optional[CacheConfig] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
event_dispatcher: Optional[EventDispatcher] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -645,6 +645,10 @@ def __init__(
self.read_from_replicas = read_from_replicas
self.reinitialize_counter = 0
self.reinitialize_steps = reinitialize_steps
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.nodes_manager = NodesManager(
startup_nodes=startup_nodes,
from_url=from_url,
Expand All @@ -653,7 +657,7 @@ def __init__(
address_remap=address_remap,
cache=cache,
cache_config=cache_config,
event_dispatcher=event_dispatcher,
event_dispatcher=self._event_dispatcher,
**kwargs,
)

Expand Down Expand Up @@ -1340,7 +1344,7 @@ def __init__(
cache: Optional[CacheInterface] = None,
cache_config: Optional[CacheConfig] = None,
cache_factory: Optional[CacheFactoryInterface] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
event_dispatcher: Optional[EventDispatcher] = None,
**kwargs,
):
self.nodes_cache = {}
Expand All @@ -1362,7 +1366,10 @@ def __init__(
if lock is None:
lock = threading.Lock()
self._lock = lock
self._event_dispatcher = event_dispatcher
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self._credential_provider = self.connection_kwargs.get(
"credential_provider", None
)
Expand Down Expand Up @@ -1719,7 +1726,7 @@ def __init__(
host=None,
port=None,
push_handler_func=None,
event_dispatcher: Optional["EventDispatcher"] = EventDispatcher(),
event_dispatcher: Optional["EventDispatcher"] = None,
**kwargs,
):
"""
Expand All @@ -1745,11 +1752,15 @@ def __init__(
self.cluster = redis_cluster
self.node_pubsub_mapping = {}
self._pubsubs_generator = self._pubsubs_generator()
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
super().__init__(
connection_pool=connection_pool,
encoder=redis_cluster.encoder,
push_handler_func=push_handler_func,
event_dispatcher=event_dispatcher,
event_dispatcher=self._event_dispatcher,
**kwargs,
)

Expand Down
7 changes: 5 additions & 2 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
command_packer: Optional[Callable[[], None]] = None,
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
event_dispatcher: Optional[EventDispatcher] = None,
):
"""
Initialize a new Connection.
Expand All @@ -259,6 +259,10 @@ def __init__(
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.pid = os.getpid()
self.db = db
self.client_name = client_name
Expand Down Expand Up @@ -298,7 +302,6 @@ def __init__(
self.set_parser(parser_class)
self._connect_callbacks = []
self._buffer_cutoff = 6000
self._event_dispatcher = event_dispatcher
self._re_auth_token: Optional[TokenInterface] = None
try:
p = int(protocol)
Expand Down

0 comments on commit 063f0d5

Please sign in to comment.