Skip to content

Commit

Permalink
Fix #2763: ConnectionMultiplexer.Subscription is not Thread-safe (#2769)
Browse files Browse the repository at this point in the history
## Issue
#2763

## Solution
Simply added a lock around `_handlers` in `ConnectionMultiplexer.Subscription`, like I was suggesting in the issue.

## Unit Test
I added one that does exactly what the example code in #2763 was doing & testing for. I used the other tests as template/guide, let me know if something isn't up to spec.

---------

Co-authored-by: Nick Craver <[email protected]>
  • Loading branch information
Chuck-EP and NickCraver authored Aug 18, 2024
1 parent c0bb4eb commit fe40d17
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/ReleaseNotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Current package versions:

- Add support for hash field expiration (see [#2715](https://github.com/StackExchange/StackExchange.Redis/issues/2715)) ([#2716 by atakavci](https://github.com/StackExchange/StackExchange.Redis/pull/2716]))
- Add support for `HSCAN NOVALUES` (see [#2721](https://github.com/StackExchange/StackExchange.Redis/issues/2721)) ([#2722 by atakavci](https://github.com/StackExchange/StackExchange.Redis/pull/2722))
- Fix [#2763](https://github.com/StackExchange/StackExchange.Redis/issues/2763): Make ConnectionMultiplexer.Subscription thread-safe ([#2769 by Chuck-EP](https://github.com/StackExchange/StackExchange.Redis/pull/2769))

## 2.8.0

Expand Down
16 changes: 13 additions & 3 deletions src/StackExchange.Redis/RedisSubscriber.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ internal enum SubscriptionAction
internal sealed class Subscription
{
private Action<RedisChannel, RedisValue>? _handlers;
private readonly object _handlersLock = new object();
private ChannelMessageQueue? _queues;
private ServerEndPoint? CurrentServer;
public CommandFlags Flags { get; }
Expand Down Expand Up @@ -206,7 +207,10 @@ public void Add(Action<RedisChannel, RedisValue>? handler, ChannelMessageQueue?
{
if (handler != null)
{
_handlers += handler;
lock (_handlersLock)
{
_handlers += handler;
}
}
if (queue != null)
{
Expand All @@ -218,7 +222,10 @@ public bool Remove(Action<RedisChannel, RedisValue>? handler, ChannelMessageQueu
{
if (handler != null)
{
_handlers -= handler;
lock (_handlersLock)
{
_handlers -= handler;
}
}
if (queue != null)
{
Expand All @@ -236,7 +243,10 @@ public bool Remove(Action<RedisChannel, RedisValue>? handler, ChannelMessageQueu

internal void MarkCompleted()
{
_handlers = null;
lock (_handlersLock)
{
_handlers = null;
}
ChannelMessageQueue.MarkAllCompleted(ref _queues);
}

Expand Down
46 changes: 46 additions & 0 deletions tests/StackExchange.Redis.Tests/Issues/Issue2763Tests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;

namespace StackExchange.Redis.Tests.Issues
{
public class Issue2763Tests : TestBase
{
public Issue2763Tests(ITestOutputHelper output) : base(output) { }

[Fact]
public void Execute()
{
using var conn = Create();
var subscriber = conn.GetSubscriber();

static void Handler(RedisChannel c, RedisValue v) { }

const int COUNT = 1000;
RedisChannel channel = RedisChannel.Literal("CHANNEL:TEST");

List<Action> subscribes = new List<Action>(COUNT);
for (int i = 0; i < COUNT; i++)
subscribes.Add(() => subscriber.Subscribe(channel, Handler));
Parallel.ForEach(subscribes, action => action());

Assert.Equal(COUNT, CountSubscriptionsForChannel(subscriber, channel));

List<Action> unsubscribes = new List<Action>(COUNT);
for (int i = 0; i < COUNT; i++)
unsubscribes.Add(() => subscriber.Unsubscribe(channel, Handler));
Parallel.ForEach(unsubscribes, action => action());

Assert.Equal(0, CountSubscriptionsForChannel(subscriber, channel));
}

private static int CountSubscriptionsForChannel(ISubscriber subscriber, RedisChannel channel)
{
ConnectionMultiplexer connMultiplexer = (ConnectionMultiplexer)subscriber.Multiplexer;
connMultiplexer.GetSubscriberCounts(channel, out int handlers, out int _);
return handlers;
}
}
}

0 comments on commit fe40d17

Please sign in to comment.