From fe40d17167ec25b16b1d6d6f91b9bea53276d854 Mon Sep 17 00:00:00 2001 From: Chuck Date: Sun, 18 Aug 2024 19:38:24 -0400 Subject: [PATCH] Fix #2763: ConnectionMultiplexer.Subscription is not Thread-safe (#2769) ## Issue #2763 ## Solution Simply added a lock around `_handlers` in `ConnectionMultiplexer.Subscription`, like I was suggesting in the issue. ## Unit Test I added one that does exactly what the example code in #2763 was doing & testing for. I used the other tests as template/guide, let me know if something isn't up to spec. --------- Co-authored-by: Nick Craver --- docs/ReleaseNotes.md | 1 + src/StackExchange.Redis/RedisSubscriber.cs | 16 +++++-- .../Issues/Issue2763Tests.cs | 46 +++++++++++++++++++ 3 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 tests/StackExchange.Redis.Tests/Issues/Issue2763Tests.cs diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md index 02570d3e1..26d277442 100644 --- a/docs/ReleaseNotes.md +++ b/docs/ReleaseNotes.md @@ -10,6 +10,7 @@ Current package versions: - Add support for hash field expiration (see [#2715](https://github.com/StackExchange/StackExchange.Redis/issues/2715)) ([#2716 by atakavci](https://github.com/StackExchange/StackExchange.Redis/pull/2716])) - Add support for `HSCAN NOVALUES` (see [#2721](https://github.com/StackExchange/StackExchange.Redis/issues/2721)) ([#2722 by atakavci](https://github.com/StackExchange/StackExchange.Redis/pull/2722)) +- Fix [#2763](https://github.com/StackExchange/StackExchange.Redis/issues/2763): Make ConnectionMultiplexer.Subscription thread-safe ([#2769 by Chuck-EP](https://github.com/StackExchange/StackExchange.Redis/pull/2769)) ## 2.8.0 diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs index 92c96ad6c..ee28f4c56 100644 --- a/src/StackExchange.Redis/RedisSubscriber.cs +++ b/src/StackExchange.Redis/RedisSubscriber.cs @@ -159,6 +159,7 @@ internal enum SubscriptionAction internal sealed class Subscription { private Action? _handlers; + private readonly object _handlersLock = new object(); private ChannelMessageQueue? _queues; private ServerEndPoint? CurrentServer; public CommandFlags Flags { get; } @@ -206,7 +207,10 @@ public void Add(Action? handler, ChannelMessageQueue? { if (handler != null) { - _handlers += handler; + lock (_handlersLock) + { + _handlers += handler; + } } if (queue != null) { @@ -218,7 +222,10 @@ public bool Remove(Action? handler, ChannelMessageQueu { if (handler != null) { - _handlers -= handler; + lock (_handlersLock) + { + _handlers -= handler; + } } if (queue != null) { @@ -236,7 +243,10 @@ public bool Remove(Action? handler, ChannelMessageQueu internal void MarkCompleted() { - _handlers = null; + lock (_handlersLock) + { + _handlers = null; + } ChannelMessageQueue.MarkAllCompleted(ref _queues); } diff --git a/tests/StackExchange.Redis.Tests/Issues/Issue2763Tests.cs b/tests/StackExchange.Redis.Tests/Issues/Issue2763Tests.cs new file mode 100644 index 000000000..4da997e7d --- /dev/null +++ b/tests/StackExchange.Redis.Tests/Issues/Issue2763Tests.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace StackExchange.Redis.Tests.Issues +{ + public class Issue2763Tests : TestBase + { + public Issue2763Tests(ITestOutputHelper output) : base(output) { } + + [Fact] + public void Execute() + { + using var conn = Create(); + var subscriber = conn.GetSubscriber(); + + static void Handler(RedisChannel c, RedisValue v) { } + + const int COUNT = 1000; + RedisChannel channel = RedisChannel.Literal("CHANNEL:TEST"); + + List subscribes = new List(COUNT); + for (int i = 0; i < COUNT; i++) + subscribes.Add(() => subscriber.Subscribe(channel, Handler)); + Parallel.ForEach(subscribes, action => action()); + + Assert.Equal(COUNT, CountSubscriptionsForChannel(subscriber, channel)); + + List unsubscribes = new List(COUNT); + for (int i = 0; i < COUNT; i++) + unsubscribes.Add(() => subscriber.Unsubscribe(channel, Handler)); + Parallel.ForEach(unsubscribes, action => action()); + + Assert.Equal(0, CountSubscriptionsForChannel(subscriber, channel)); + } + + private static int CountSubscriptionsForChannel(ISubscriber subscriber, RedisChannel channel) + { + ConnectionMultiplexer connMultiplexer = (ConnectionMultiplexer)subscriber.Multiplexer; + connMultiplexer.GetSubscriberCounts(channel, out int handlers, out int _); + return handlers; + } + } +}