diff --git a/docs/Configuration.md b/docs/Configuration.md index 96e4b5bae..bc696915b 100644 --- a/docs/Configuration.md +++ b/docs/Configuration.md @@ -100,6 +100,7 @@ The `ConfigurationOptions` object has a wide range of properties, all of which a | setlib={bool} | `SetClientLibrary` | `true` | Whether to attempt to use `CLIENT SETINFO` to set the library name/version on the connection | | protocol={string} | `Protocol` | `null` | Redis protocol to use; see section below | | highIntegrity={bool} | `HighIntegrity` | `false` | High integrity (incurs overhead) sequence checking on every command; see section below | +| waitForAuth={bool} | `WaitForAuth` | `false` | Wait before the result of the `AUTH` command is returned before trying to send any other commands to the server | Additional code-only options: - LoggerFactory (`ILoggerFactory`) - Default: `null` diff --git a/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs b/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs index 703adbcac..d9bd8f857 100644 --- a/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs +++ b/src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs @@ -115,6 +115,17 @@ public static DefaultOptionsProvider GetProvider(EndPoint endpoint) /// public virtual bool HighIntegrity => false; + /// + /// A Boolean value that specifies whether the client should wait for the server to return + /// response for the initial AUTH command before trying any further commands. + /// + /// + /// This is especially useful when connecting to Envoy proxies with external authentication + /// providers. + /// The default and recommended value is false. + /// + public virtual bool WaitForAuth => false; + /// /// The number of times to repeat the initial connect cycle if no servers respond promptly. /// diff --git a/src/StackExchange.Redis/ConfigurationOptions.cs b/src/StackExchange.Redis/ConfigurationOptions.cs index e972962b2..3f1bdc5ec 100644 --- a/src/StackExchange.Redis/ConfigurationOptions.cs +++ b/src/StackExchange.Redis/ConfigurationOptions.cs @@ -111,7 +111,8 @@ internal const string Tunnel = "tunnel", SetClientLibrary = "setlib", Protocol = "protocol", - HighIntegrity = "highIntegrity"; + HighIntegrity = "highIntegrity", + WaitForAuth = "waitForAuth"; private static readonly Dictionary normalizedOptions = new[] { @@ -143,6 +144,7 @@ internal const string CheckCertificateRevocation, Protocol, HighIntegrity, + WaitForAuth, }.ToDictionary(x => x, StringComparer.OrdinalIgnoreCase); public static string TryNormalize(string value) @@ -158,7 +160,7 @@ public static string TryNormalize(string value) private DefaultOptionsProvider? defaultOptions; private bool? allowAdmin, abortOnConnectFail, resolveDns, ssl, checkCertificateRevocation, heartbeatConsistencyChecks, - includeDetailInExceptions, includePerformanceCountersInExceptions, setClientLibrary, highIntegrity; + includeDetailInExceptions, includePerformanceCountersInExceptions, setClientLibrary, highIntegrity, waitForAuth; private string? tieBreaker, sslHost, configChannel, user, password; @@ -295,6 +297,21 @@ public bool HighIntegrity set => highIntegrity = value; } + /// + /// A Boolean value that specifies whether the client should wait for the server to return + /// response for the initial AUTH command before trying any further commands. + /// + /// + /// This is especially useful when connecting to Envoy proxies with external authentication + /// providers. + /// The default and recommended value is false. + /// + public bool WaitForAuth + { + get => waitForAuth ?? Defaults.WaitForAuth; + set => waitForAuth = value; + } + /// /// Create a certificate validation check that checks against the supplied issuer even when not known by the machine. /// @@ -786,6 +803,7 @@ public static ConfigurationOptions Parse(string configuration, bool ignoreUnknow heartbeatInterval = heartbeatInterval, heartbeatConsistencyChecks = heartbeatConsistencyChecks, highIntegrity = highIntegrity, + waitForAuth = waitForAuth, }; /// @@ -867,6 +885,7 @@ public string ToString(bool includePassword) Append(sb, OptionKeys.DefaultDatabase, DefaultDatabase); Append(sb, OptionKeys.SetClientLibrary, setClientLibrary); Append(sb, OptionKeys.HighIntegrity, highIntegrity); + Append(sb, OptionKeys.WaitForAuth, waitForAuth); Append(sb, OptionKeys.Protocol, FormatProtocol(Protocol)); if (Tunnel is { IsInbuilt: true } tunnel) { @@ -912,7 +931,7 @@ private void Clear() { ClientName = ServiceName = user = password = tieBreaker = sslHost = configChannel = null; keepAlive = syncTimeout = asyncTimeout = connectTimeout = connectRetry = configCheckSeconds = DefaultDatabase = null; - allowAdmin = abortOnConnectFail = resolveDns = ssl = setClientLibrary = highIntegrity = null; + allowAdmin = abortOnConnectFail = resolveDns = ssl = setClientLibrary = highIntegrity = waitForAuth = null; SslProtocols = null; defaultVersion = null; EndPoints.Clear(); @@ -1034,6 +1053,9 @@ private ConfigurationOptions DoParse(string configuration, bool ignoreUnknown) case OptionKeys.HighIntegrity: HighIntegrity = OptionKeys.ParseBoolean(key, value); break; + case OptionKeys.WaitForAuth: + WaitForAuth = OptionKeys.ParseBoolean(key, value); + break; case OptionKeys.Tunnel: if (value.IsNullOrWhiteSpace()) { diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt index a24333c8e..c4c0a0113 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt @@ -254,6 +254,8 @@ StackExchange.Redis.ConfigurationOptions.HeartbeatInterval.get -> System.TimeSpa StackExchange.Redis.ConfigurationOptions.HeartbeatInterval.set -> void StackExchange.Redis.ConfigurationOptions.HighIntegrity.get -> bool StackExchange.Redis.ConfigurationOptions.HighIntegrity.set -> void +StackExchange.Redis.ConfigurationOptions.WaitForAuth.get -> bool +StackExchange.Redis.ConfigurationOptions.WaitForAuth.set -> void StackExchange.Redis.ConfigurationOptions.HighPrioritySocketThreads.get -> bool StackExchange.Redis.ConfigurationOptions.HighPrioritySocketThreads.set -> void StackExchange.Redis.ConfigurationOptions.IncludeDetailInExceptions.get -> bool @@ -1846,6 +1848,7 @@ virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.GetSslHostFromE virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HeartbeatConsistencyChecks.get -> bool virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HeartbeatInterval.get -> System.TimeSpan virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HighIntegrity.get -> bool +virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.WaitForAuth.get -> bool virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IncludeDetailInExceptions.get -> bool virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IncludePerformanceCountersInExceptions.get -> bool virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IsMatch(System.Net.EndPoint! endpoint) -> bool diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index 8b099afd2..2c874993c 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -990,16 +990,14 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log) if (!string.IsNullOrWhiteSpace(user) && Multiplexer.CommandMap.IsAvailable(RedisCommand.AUTH)) { log?.LogInformation($"{Format.ToString(this)}: Authenticating (user/password)"); - msg = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.AUTH, (RedisValue)user, (RedisValue)password); - msg.SetInternalCall(); - await WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); + msg = Message.Create(-1, CommandFlags.None, RedisCommand.AUTH, (RedisValue)user, (RedisValue)password); + await SendAuthMessageAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); } else if (!string.IsNullOrWhiteSpace(password) && Multiplexer.CommandMap.IsAvailable(RedisCommand.AUTH)) { log?.LogInformation($"{Format.ToString(this)}: Authenticating (password)"); - msg = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.AUTH, (RedisValue)password); - msg.SetInternalCall(); - await WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); + msg = Message.Create(-1, CommandFlags.None, RedisCommand.AUTH, (RedisValue)password); + await SendAuthMessageAsync(connection, msg, ResultProcessor.DemandOK).ForAwait(); } if (Multiplexer.CommandMap.IsAvailable(RedisCommand.CLIENT)) @@ -1073,6 +1071,20 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log) await connection.FlushAsync().ForAwait(); } + private ValueTask SendAuthMessageAsync(PhysicalConnection connection, Message msg, ResultProcessor demandOK) + { + if (Multiplexer.RawConfig.WaitForAuth) + { + return new ValueTask(WriteDirectAsync(msg, ResultProcessor.DemandOK)); + } + else + { + msg.Flags = CommandFlags.FireAndForget; + msg.SetInternalCall(); + return WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK); + } + } + private void SetConfig(ref T field, T value, [CallerMemberName] string? caller = null) { if (!EqualityComparer.Default.Equals(field, value)) diff --git a/tests/StackExchange.Redis.Tests/ConfigTests.cs b/tests/StackExchange.Redis.Tests/ConfigTests.cs index 2a4b2bc75..b4a53e7b8 100644 --- a/tests/StackExchange.Redis.Tests/ConfigTests.cs +++ b/tests/StackExchange.Redis.Tests/ConfigTests.cs @@ -85,6 +85,7 @@ public void ExpectedFields() "tieBreaker", "Tunnel", "user", + "waitForAuth", }, fields); } @@ -811,4 +812,24 @@ public void CheckHighIntegrity(bool? assigned, bool expected, string cs) var parsed = ConfigurationOptions.Parse(cs); Assert.Equal(expected, options.HighIntegrity); } + + [Theory] + [InlineData(null, false, "dummy")] + [InlineData(false, false, "dummy,waitForAuth=False")] + [InlineData(true, true, "dummy,waitForAuth=True")] + public void CheckWaitForAuth(bool? assigned, bool expected, string cs) + { + var options = ConfigurationOptions.Parse("dummy"); + if (assigned.HasValue) options.WaitForAuth = assigned.Value; + + Assert.Equal(expected, options.WaitForAuth); + Assert.Equal(cs, options.ToString()); + + var clone = options.Clone(); + Assert.Equal(expected, clone.WaitForAuth); + Assert.Equal(cs, clone.ToString()); + + var parsed = ConfigurationOptions.Parse(cs); + Assert.Equal(expected, options.WaitForAuth); + } } diff --git a/tests/StackExchange.Redis.Tests/TestBase.cs b/tests/StackExchange.Redis.Tests/TestBase.cs index 230438d07..a4232ad15 100644 --- a/tests/StackExchange.Redis.Tests/TestBase.cs +++ b/tests/StackExchange.Redis.Tests/TestBase.cs @@ -262,6 +262,7 @@ internal virtual IInternalConnectionMultiplexer Create( BacklogPolicy? backlogPolicy = null, Version? require = null, RedisProtocol? protocol = null, + bool? waitForAuth = null, [CallerMemberName] string caller = "") { if (Output == null) @@ -314,6 +315,7 @@ internal virtual IInternalConnectionMultiplexer Create( backlogPolicy, protocol, highIntegrity, + waitForAuth, caller); ThrowIfIncorrectProtocol(conn, protocol); @@ -409,6 +411,7 @@ public static ConnectionMultiplexer CreateDefault( BacklogPolicy? backlogPolicy = null, RedisProtocol? protocol = null, bool highIntegrity = false, + bool? waitForAuth = null, [CallerMemberName] string caller = "") { StringWriter? localLog = null; @@ -445,6 +448,7 @@ public static ConnectionMultiplexer CreateDefault( if (backlogPolicy is not null) config.BacklogPolicy = backlogPolicy; if (protocol is not null) config.Protocol = protocol; if (highIntegrity) config.HighIntegrity = highIntegrity; + if (waitForAuth is not null) config.WaitForAuth = waitForAuth.Value; var watch = Stopwatch.StartNew(); var task = ConnectionMultiplexer.ConnectAsync(config, log); if (!task.Wait(config.ConnectTimeout >= (int.MaxValue / 2) ? int.MaxValue : config.ConnectTimeout * 2))