From 516b0497659761e09b354572f4bfaed260f77917 Mon Sep 17 00:00:00 2001 From: dbarbosapn Date: Wed, 4 Sep 2024 22:37:29 +0200 Subject: [PATCH 1/4] Initial proposal for fixing #2785 --- docs/Configuration.md | 1 + .../Configuration/DefaultOptionsProvider.cs | 11 +++++++ .../ConfigurationOptions.cs | 28 ++++++++++++++++-- .../PublicAPI/PublicAPI.Shipped.txt | 3 ++ src/StackExchange.Redis/ServerEndPoint.cs | 29 +++++++++++++------ .../StackExchange.Redis.Tests/ConfigTests.cs | 21 ++++++++++++++ tests/StackExchange.Redis.Tests/TestBase.cs | 4 +++ 7 files changed, 85 insertions(+), 12 deletions(-) diff --git a/docs/Configuration.md b/docs/Configuration.md index 96e4b5bae..bc696915b 100644 --- a/docs/Configuration.md +++ b/docs/Configuration.md @@ -100,6 +100,7 @@ The `ConfigurationOptions` object has a wide range of properties, all of which a | setlib={bool} | `SetClientLibrary` | `true` | Whether to attempt to use `CLIENT SETINFO` to set the library name/version on the connection | | protocol={string} | `Protocol` | `null` | Redis protocol to use; see section below | | highIntegrity={bool} | `HighIntegrity` | `false` | High integrity (incurs overhead) sequence checking on every command; see section below | +| waitForAuth={bool} | `WaitForAuth` | `false` | Wait before the result of the `AUTH` command is returned before trying to send any other commands to the server | Additional code-only options: - LoggerFactory (`ILoggerFactory`) - Default: `null` diff --git a/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs b/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs index 703adbcac..d9bd8f857 100644 --- a/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs +++ b/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs @@ -115,6 +115,17 @@ public static DefaultOptionsProvider GetProvider(EndPoint endpoint) /// public virtual bool HighIntegrity => false; + /// + /// A Boolean value that specifies whether the client should wait for the server to return + /// response for the initial AUTH command before trying any further commands. + /// + /// + /// This is especially useful when connecting to Envoy proxies with external authentication + /// providers. + /// The default and recommended value is false. + /// + public virtual bool WaitForAuth => false; + /// /// The number of times to repeat the initial connect cycle if no servers respond promptly. /// diff --git a/src/StackExchange.Redis/ConfigurationOptions.cs b/src/StackExchange.Redis/ConfigurationOptions.cs index e972962b2..3ca74fd20 100644 --- a/src/StackExchange.Redis/ConfigurationOptions.cs +++ b/src/StackExchange.Redis/ConfigurationOptions.cs @@ -111,7 +111,8 @@ internal const string Tunnel = "tunnel", SetClientLibrary = "setlib", Protocol = "protocol", - HighIntegrity = "highIntegrity"; + HighIntegrity = "highIntegrity", + WaitForAuth = "true"; private static readonly Dictionary normalizedOptions = new[] { @@ -143,6 +144,7 @@ internal const string CheckCertificateRevocation, Protocol, HighIntegrity, + WaitForAuth, }.ToDictionary(x => x, StringComparer.OrdinalIgnoreCase); public static string TryNormalize(string value) @@ -158,7 +160,7 @@ public static string TryNormalize(string value) private DefaultOptionsProvider? defaultOptions; private bool? allowAdmin, abortOnConnectFail, resolveDns, ssl, checkCertificateRevocation, heartbeatConsistencyChecks, - includeDetailInExceptions, includePerformanceCountersInExceptions, setClientLibrary, highIntegrity; + includeDetailInExceptions, includePerformanceCountersInExceptions, setClientLibrary, highIntegrity, waitForAuth; private string? tieBreaker, sslHost, configChannel, user, password; @@ -295,6 +297,21 @@ public bool HighIntegrity set => highIntegrity = value; } + /// + /// A Boolean value that specifies whether the client should wait for the server to return + /// response for the initial AUTH command before trying any further commands. + /// + /// + /// This is especially useful when connecting to Envoy proxies with external authentication + /// providers. + /// The default and recommended value is false. + /// + public bool WaitForAuth + { + get => waitForAuth ?? Defaults.WaitForAuth; + set => waitForAuth = value; + } + /// /// Create a certificate validation check that checks against the supplied issuer even when not known by the machine. /// @@ -786,6 +803,7 @@ public static ConfigurationOptions Parse(string configuration, bool ignoreUnknow heartbeatInterval = heartbeatInterval, heartbeatConsistencyChecks = heartbeatConsistencyChecks, highIntegrity = highIntegrity, + waitForAuth = waitForAuth, }; /// @@ -867,6 +885,7 @@ public string ToString(bool includePassword) Append(sb, OptionKeys.DefaultDatabase, DefaultDatabase); Append(sb, OptionKeys.SetClientLibrary, setClientLibrary); Append(sb, OptionKeys.HighIntegrity, highIntegrity); + Append(sb, OptionKeys.WaitForAuth, waitForAuth); Append(sb, OptionKeys.Protocol, FormatProtocol(Protocol)); if (Tunnel is { IsInbuilt: true } tunnel) { @@ -912,7 +931,7 @@ private void Clear() { ClientName = ServiceName = user = password = tieBreaker = sslHost = configChannel = null; keepAlive = syncTimeout = asyncTimeout = connectTimeout = connectRetry = configCheckSeconds = DefaultDatabase = null; - allowAdmin = abortOnConnectFail = resolveDns = ssl = setClientLibrary = highIntegrity = null; + allowAdmin = abortOnConnectFail = resolveDns = ssl = setClientLibrary = highIntegrity = waitForAuth = null; SslProtocols = null; defaultVersion = null; EndPoints.Clear(); @@ -1034,6 +1053,9 @@ private ConfigurationOptions DoParse(string configuration, bool ignoreUnknown) case OptionKeys.HighIntegrity: HighIntegrity = OptionKeys.ParseBoolean(key, value); break; + case OptionKeys.WaitForAuth: + WaitForAuth = OptionKeys.ParseBoolean(key, value); + break; case OptionKeys.Tunnel: if (value.IsNullOrWhiteSpace()) { diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt index a24333c8e..c4c0a0113 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt @@ -254,6 +254,8 @@ StackExchange.Redis.ConfigurationOptions.HeartbeatInterval.get -> System.TimeSpa StackExchange.Redis.ConfigurationOptions.HeartbeatInterval.set -> void StackExchange.Redis.ConfigurationOptions.HighIntegrity.get -> bool StackExchange.Redis.ConfigurationOptions.HighIntegrity.set -> void +StackExchange.Redis.ConfigurationOptions.WaitForAuth.get -> bool +StackExchange.Redis.ConfigurationOptions.WaitForAuth.set -> void StackExchange.Redis.ConfigurationOptions.HighPrioritySocketThreads.get -> bool StackExchange.Redis.ConfigurationOptions.HighPrioritySocketThreads.set -> void StackExchange.Redis.ConfigurationOptions.IncludeDetailInExceptions.get -> bool @@ -1846,6 +1848,7 @@ virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.GetSslHostFromE virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HeartbeatConsistencyChecks.get -> bool virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HeartbeatInterval.get -> System.TimeSpan virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HighIntegrity.get -> bool +virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.WaitForAuth.get -> bool virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IncludeDetailInExceptions.get -> bool virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IncludePerformanceCountersInExceptions.get -> bool virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IsMatch(System.Net.EndPoint! endpoint) -> bool diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index 8b099afd2..da0fb9b18 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -972,9 +972,8 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log) if (Multiplexer.RawConfig.TryResp3()) // note this includes an availability check on HELLO { log?.LogInformation($"{Format.ToString(this)}: Authenticating via HELLO"); - var hello = Message.CreateHello(3, user, password, clientName, CommandFlags.FireAndForget); - hello.SetInternalCall(); - await WriteDirectOrQueueFireAndForgetAsync(connection, hello, autoConfig ??= ResultProcessor.AutoConfigureProcessor.Create(log)).ForAwait(); + var hello = Message.CreateHello(3, user, password, clientName, CommandFlags.None); + await SendAuthMessageAsync(connection, hello, autoConfig ??= ResultProcessor.AutoConfigureProcessor.Create(log)).ForAwait(); // note that the server can reject RESP3 via either an -ERR response (HELLO not understood), or by simply saying "nope", // so we don't set the actual .Protocol until we process the result of the HELLO request @@ -990,16 +989,14 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log) if (!string.IsNullOrWhiteSpace(user) && Multiplexer.CommandMap.IsAvailable(RedisCommand.AUTH)) { log?.LogInformation($"{Format.ToString(this)}: Authenticating (user/password)"); - msg = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.AUTH, (RedisValue)user, (RedisValue)password); - msg.SetInternalCall(); - await WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); + msg = Message.Create(-1, CommandFlags.None, RedisCommand.AUTH, (RedisValue)user, (RedisValue)password); + await SendAuthMessageAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); } else if (!string.IsNullOrWhiteSpace(password) && Multiplexer.CommandMap.IsAvailable(RedisCommand.AUTH)) { log?.LogInformation($"{Format.ToString(this)}: Authenticating (password)"); - msg = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.AUTH, (RedisValue)password); - msg.SetInternalCall(); - await WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); + msg = Message.Create(-1, CommandFlags.None, RedisCommand.AUTH, (RedisValue)password); + await SendAuthMessageAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); } if (Multiplexer.CommandMap.IsAvailable(RedisCommand.CLIENT)) @@ -1073,6 +1070,20 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log) await connection.FlushAsync().ForAwait(); } + private async Task SendAuthMessageAsync(PhysicalConnection connection, Message msg, ResultProcessor demandOK) + { + if (Multiplexer.RawConfig.WaitForAuth) + { + await WriteDirectAsync(msg, ResultProcessor.DemandOK).ForAwait(); + } + else + { + msg.Flags = CommandFlags.FireAndForget; + msg.SetInternalCall(); + await WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); + } + } + private void SetConfig(ref T field, T value, [CallerMemberName] string? caller = null) { if (!EqualityComparer.Default.Equals(field, value)) diff --git a/tests/StackExchange.Redis.Tests/ConfigTests.cs b/tests/StackExchange.Redis.Tests/ConfigTests.cs index 2a4b2bc75..b4a53e7b8 100644 --- a/tests/StackExchange.Redis.Tests/ConfigTests.cs +++ b/tests/StackExchange.Redis.Tests/ConfigTests.cs @@ -85,6 +85,7 @@ public void ExpectedFields() "tieBreaker", "Tunnel", "user", + "waitForAuth", }, fields); } @@ -811,4 +812,24 @@ public void CheckHighIntegrity(bool? assigned, bool expected, string cs) var parsed = ConfigurationOptions.Parse(cs); Assert.Equal(expected, options.HighIntegrity); } + + [Theory] + [InlineData(null, false, "dummy")] + [InlineData(false, false, "dummy,waitForAuth=False")] + [InlineData(true, true, "dummy,waitForAuth=True")] + public void CheckWaitForAuth(bool? assigned, bool expected, string cs) + { + var options = ConfigurationOptions.Parse("dummy"); + if (assigned.HasValue) options.WaitForAuth = assigned.Value; + + Assert.Equal(expected, options.WaitForAuth); + Assert.Equal(cs, options.ToString()); + + var clone = options.Clone(); + Assert.Equal(expected, clone.WaitForAuth); + Assert.Equal(cs, clone.ToString()); + + var parsed = ConfigurationOptions.Parse(cs); + Assert.Equal(expected, options.WaitForAuth); + } } diff --git a/tests/StackExchange.Redis.Tests/TestBase.cs b/tests/StackExchange.Redis.Tests/TestBase.cs index 230438d07..a4232ad15 100644 --- a/tests/StackExchange.Redis.Tests/TestBase.cs +++ b/tests/StackExchange.Redis.Tests/TestBase.cs @@ -262,6 +262,7 @@ internal virtual IInternalConnectionMultiplexer Create( BacklogPolicy? backlogPolicy = null, Version? require = null, RedisProtocol? protocol = null, + bool? waitForAuth = null, [CallerMemberName] string caller = "") { if (Output == null) @@ -314,6 +315,7 @@ internal virtual IInternalConnectionMultiplexer Create( backlogPolicy, protocol, highIntegrity, + waitForAuth, caller); ThrowIfIncorrectProtocol(conn, protocol); @@ -409,6 +411,7 @@ public static ConnectionMultiplexer CreateDefault( BacklogPolicy? backlogPolicy = null, RedisProtocol? protocol = null, bool highIntegrity = false, + bool? waitForAuth = null, [CallerMemberName] string caller = "") { StringWriter? localLog = null; @@ -445,6 +448,7 @@ public static ConnectionMultiplexer CreateDefault( if (backlogPolicy is not null) config.BacklogPolicy = backlogPolicy; if (protocol is not null) config.Protocol = protocol; if (highIntegrity) config.HighIntegrity = highIntegrity; + if (waitForAuth is not null) config.WaitForAuth = waitForAuth.Value; var watch = Stopwatch.StartNew(); var task = ConnectionMultiplexer.ConnectAsync(config, log); if (!task.Wait(config.ConnectTimeout >= (int.MaxValue / 2) ? int.MaxValue : config.ConnectTimeout * 2)) From 2b6c84d2052ee10e42c90e815b448a5a67c79110 Mon Sep 17 00:00:00 2001 From: dbarbosapn Date: Wed, 4 Sep 2024 22:47:55 +0200 Subject: [PATCH 2/4] Fix toString --- src/StackExchange.Redis/ConfigurationOptions.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/StackExchange.Redis/ConfigurationOptions.cs b/src/StackExchange.Redis/ConfigurationOptions.cs index 3ca74fd20..3f1bdc5ec 100644 --- a/src/StackExchange.Redis/ConfigurationOptions.cs +++ b/src/StackExchange.Redis/ConfigurationOptions.cs @@ -112,7 +112,7 @@ internal const string SetClientLibrary = "setlib", Protocol = "protocol", HighIntegrity = "highIntegrity", - WaitForAuth = "true"; + WaitForAuth = "waitForAuth"; private static readonly Dictionary normalizedOptions = new[] { From a363abec5605a93085a15569603f111869f94ae0 Mon Sep 17 00:00:00 2001 From: dbarbosapn Date: Wed, 4 Sep 2024 23:06:34 +0200 Subject: [PATCH 3/4] Hello can be pipelined --- src/StackExchange.Redis/ServerEndPoint.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index da0fb9b18..d59e6b090 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -972,8 +972,9 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log) if (Multiplexer.RawConfig.TryResp3()) // note this includes an availability check on HELLO { log?.LogInformation($"{Format.ToString(this)}: Authenticating via HELLO"); - var hello = Message.CreateHello(3, user, password, clientName, CommandFlags.None); - await SendAuthMessageAsync(connection, hello, autoConfig ??= ResultProcessor.AutoConfigureProcessor.Create(log)).ForAwait(); + var hello = Message.CreateHello(3, user, password, clientName, CommandFlags.FireAndForget); + hello.SetInternalCall(); + await WriteDirectOrQueueFireAndForgetAsync(connection, hello, autoConfig ??= ResultProcessor.AutoConfigureProcessor.Create(log)).ForAwait(); // note that the server can reject RESP3 via either an -ERR response (HELLO not understood), or by simply saying "nope", // so we don't set the actual .Protocol until we process the result of the HELLO request From 1e79a19587a0b070675b038306d01135b8c56be5 Mon Sep 17 00:00:00 2001 From: dbarbosapn Date: Wed, 4 Sep 2024 23:21:38 +0200 Subject: [PATCH 4/4] Remove double await and return task directly --- src/StackExchange.Redis/ServerEndPoint.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index d59e6b090..2c874993c 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -1071,17 +1071,17 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log) await connection.FlushAsync().ForAwait(); } - private async Task SendAuthMessageAsync(PhysicalConnection connection, Message msg, ResultProcessor demandOK) + private ValueTask SendAuthMessageAsync(PhysicalConnection connection, Message msg, ResultProcessor demandOK) { if (Multiplexer.RawConfig.WaitForAuth) { - await WriteDirectAsync(msg, ResultProcessor.DemandOK).ForAwait(); + return new ValueTask(WriteDirectAsync(msg, ResultProcessor.DemandOK)); } else { msg.Flags = CommandFlags.FireAndForget; msg.SetInternalCall(); - await WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); + return WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK); } }