diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Common/CosmosJsonDotNetSerializer.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Common/CosmosJsonDotNetSerializer.cs index 9ed8056360..0496752de9 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Common/CosmosJsonDotNetSerializer.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Common/CosmosJsonDotNetSerializer.cs @@ -7,6 +7,8 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom using System; using System.IO; using System.Text; + using System.Threading; + using System.Threading.Tasks; using Newtonsoft.Json; /// @@ -40,6 +42,18 @@ internal CosmosJsonDotNetSerializer(JsonSerializerSettings jsonSerializerSetting /// The object representing the deserialized stream public T FromStream(Stream stream) { + return this.FromStream(stream, false); + } + + /// + /// Convert a Stream to the passed in type. + /// + /// The type of object that should be deserialized + /// An open stream that is readable that contains JSON + /// True if input stream shouldn't be disposed + /// The object representing the deserialized stream + public T FromStream(Stream stream, bool leaveOpen) + { #if NET8_0_OR_GREATER ArgumentNullException.ThrowIfNull(stream); #else @@ -54,7 +68,7 @@ public T FromStream(Stream stream) return (T)(object)stream; } - using (StreamReader sr = new (stream)) + using (StreamReader sr = new (stream, Encoding.UTF8, true, 1024, leaveOpen)) using (JsonTextReader jsonTextReader = new (sr)) { jsonTextReader.ArrayPool = JsonArrayPool.Instance; @@ -72,7 +86,15 @@ public T FromStream(Stream stream) public MemoryStream ToStream(T input) { MemoryStream streamPayload = new (); - using (StreamWriter streamWriter = new (streamPayload, encoding: CosmosJsonDotNetSerializer.DefaultEncoding, bufferSize: 1024, leaveOpen: true)) +#pragma warning disable VSTHRD002 // Avoid problematic synchronous waits + this.ToStreamAsync(input, streamPayload, CancellationToken.None).GetAwaiter().GetResult(); +#pragma warning restore VSTHRD002 // Avoid problematic synchronous waits + return streamPayload; + } + + public async Task ToStreamAsync(T input, Stream output, CancellationToken cancellationToken) + { + using (StreamWriter streamWriter = new (output, encoding: CosmosJsonDotNetSerializer.DefaultEncoding, bufferSize: 1024, leaveOpen: true)) using (JsonTextWriter writer = new (streamWriter)) { writer.ArrayPool = JsonArrayPool.Instance; @@ -80,11 +102,14 @@ public MemoryStream ToStream(T input) JsonSerializer jsonSerializer = this.GetSerializer(); jsonSerializer.Serialize(writer, input); writer.Flush(); - streamWriter.Flush(); +#if NET8_0_OR_GREATER + await streamWriter.FlushAsync(cancellationToken); +#else + await streamWriter.FlushAsync(); +#endif } - streamPayload.Position = 0; - return streamPayload; + output.Position = 0; } /// diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/DataEncryptionKeyProperties.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/DataEncryptionKeyProperties.cs index ceca9e3db5..0f0c11dd17 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/DataEncryptionKeyProperties.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/DataEncryptionKeyProperties.cs @@ -7,6 +7,7 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom using System; using System.Collections.Generic; using System.Linq; + using System.Text.Json.Serialization; using Newtonsoft.Json; /// @@ -73,31 +74,36 @@ internal DataEncryptionKeyProperties(DataEncryptionKeyProperties source) /// /// [JsonProperty(PropertyName = "id")] + [JsonPropertyName("id")] public string Id { get; internal set; } /// /// Gets the Encryption algorithm that will be used along with this data encryption key to encrypt/decrypt data. /// [JsonProperty(PropertyName = "encryptionAlgorithm", NullValueHandling = NullValueHandling.Ignore)] + [JsonPropertyName("encryptionAlgorithm")] public string EncryptionAlgorithm { get; internal set; } /// /// Gets wrapped form of the data encryption key. /// [JsonProperty(PropertyName = "wrappedDataEncryptionKey", NullValueHandling = NullValueHandling.Ignore)] + [JsonPropertyName("wrappedDataEncryptionKey")] public byte[] WrappedDataEncryptionKey { get; internal set; } /// /// Gets metadata for the wrapping provider that can be used to unwrap the wrapped data encryption key. /// [JsonProperty(PropertyName = "keyWrapMetadata", NullValueHandling = NullValueHandling.Ignore)] + [JsonPropertyName("keyWrapMetadata")] public EncryptionKeyWrapMetadata EncryptionKeyWrapMetadata { get; internal set; } /// /// Gets the creation time of the resource from the Azure Cosmos DB service. /// - [JsonConverter(typeof(UnixDateTimeConverter))] + [Newtonsoft.Json.JsonConverter(typeof(UnixDateTimeConverter))] [JsonProperty(PropertyName = "createTime", NullValueHandling = NullValueHandling.Ignore)] + [JsonPropertyName("createTime")] public DateTime? CreatedTime { get; internal set; } /// @@ -110,14 +116,16 @@ internal DataEncryptionKeyProperties(DataEncryptionKeyProperties source) /// ETags are used for concurrency checking when updating resources. /// [JsonProperty(PropertyName = "_etag", NullValueHandling = NullValueHandling.Ignore)] + [JsonPropertyName("_etag")] public string ETag { get; internal set; } /// /// Gets the last modified time stamp associated with the resource from the Azure Cosmos DB service. /// /// The last modified time stamp associated with the resource. - [JsonConverter(typeof(UnixDateTimeConverter))] + [Newtonsoft.Json.JsonConverter(typeof(UnixDateTimeConverter))] [JsonProperty(PropertyName = "_ts", NullValueHandling = NullValueHandling.Ignore)] + [JsonPropertyName("_ts")] public DateTime? LastModified { get; internal set; } /// @@ -129,6 +137,7 @@ internal DataEncryptionKeyProperties(DataEncryptionKeyProperties source) /// E.g. a self-link for a document could be dbs/db_resourceid/colls/coll_resourceid/documents/doc_resourceid /// [JsonProperty(PropertyName = "_self", NullValueHandling = NullValueHandling.Ignore)] + [JsonPropertyName("_self")] public virtual string SelfLink { get; internal set; } /// @@ -143,6 +152,7 @@ internal DataEncryptionKeyProperties(DataEncryptionKeyProperties source) /// These resource ids are used when building up SelfLinks, a static addressable Uri for each resource within a database account. /// [JsonProperty(PropertyName = "_rid", NullValueHandling = NullValueHandling.Ignore)] + [JsonPropertyName("_rid")] internal string ResourceId { get; set; } /// diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/DecryptableItem.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/DecryptableItem.cs index f9cf7c86ef..f0b1b2660e 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/DecryptableItem.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/DecryptableItem.cs @@ -4,6 +4,8 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom { + using System; + using System.Threading; using System.Threading.Tasks; /// @@ -71,7 +73,7 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom /// ]]> /// /// - public abstract class DecryptableItem + public abstract class DecryptableItem : IDisposable { /// /// Decrypts and deserializes the content. @@ -79,5 +81,18 @@ public abstract class DecryptableItem /// The type of item to be returned. /// The requested item and the decryption related context. public abstract Task<(T, DecryptionContext)> GetItemAsync(); + + /// + /// Decrypts and deserializes the content. + /// + /// Cancellation token. + /// The type of item to be returned. + /// The requested item and the decryption related context. + public abstract Task<(T, DecryptionContext)> GetItemAsync(CancellationToken cancellationToken); + + /// + /// Dispose unmanaged resources. + /// + public abstract void Dispose(); } } diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/DecryptableItemCore.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/DecryptableItemCore.cs index 4c1f9c4698..de247e8652 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/DecryptableItemCore.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/DecryptableItemCore.cs @@ -5,6 +5,7 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom { using System; + using System.Threading; using System.Threading.Tasks; using Newtonsoft.Json.Linq; @@ -27,7 +28,16 @@ public DecryptableItemCore( this.cosmosSerializer = cosmosSerializer ?? throw new ArgumentNullException(nameof(cosmosSerializer)); } - public override async Task<(T, DecryptionContext)> GetItemAsync() + public override void Dispose() + { + } + + public override Task<(T, DecryptionContext)> GetItemAsync() + { + return this.GetItemAsync(CancellationToken.None); + } + + public override async Task<(T, DecryptionContext)> GetItemAsync(CancellationToken cancellationToken) { if (this.decryptableContent is not JObject document) { @@ -40,7 +50,7 @@ public DecryptableItemCore( document, this.encryptor, new CosmosDiagnosticsContext(), - cancellationToken: default); + cancellationToken: cancellationToken); return (this.cosmosSerializer.FromStream(EncryptionProcessor.BaseSerializer.ToStream(decryptedItem)), decryptionContext); } diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItem.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItem.cs index 5e6e173cf7..174cf504e8 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItem.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItem.cs @@ -4,13 +4,17 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom { + using System; using System.IO; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Encryption.Custom.Transformation; using Newtonsoft.Json.Linq; /// /// Input type should implement this abstract class for lazy decryption and to retrieve the details in the write path. /// - public abstract class EncryptableItem + public abstract class EncryptableItem : IDisposable { /// /// Gets DecryptableItem @@ -24,6 +28,15 @@ public abstract class EncryptableItem /// Input payload in stream format protected internal abstract Stream ToStream(CosmosSerializer serializer); + /// + /// Gets the input payload in stream format. + /// + /// Cosmos Serializer + /// Output stream + /// CancellationToken + /// A representing the asynchronous operation. + protected internal abstract Task ToStreamAsync(CosmosSerializer serializer, Stream outputStream, CancellationToken cancellationToken); + /// /// Populates the DecryptableItem that can be used getting the decryption result. /// @@ -34,5 +47,27 @@ protected internal abstract void SetDecryptableItem( JToken decryptableContent, Encryptor encryptor, CosmosSerializer cosmosSerializer); + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + /// + /// Populates the DecryptableItem that can be used getting the decryption result. + /// + /// The encrypted content stream which is yet to be decrypted. + /// Encryptor instance which will be used for decryption. + /// Json processor for decryption. + /// Serializer instance which will be used for deserializing the content after decryption. + /// Stream manager providing output streams. + protected internal abstract void SetDecryptableStream( + Stream decryptableStream, + Encryptor encryptor, + JsonProcessor jsonProcessor, + CosmosSerializer cosmosSerializer, + StreamManager streamManager); +#endif + + /// + /// Release unmananaged resources + /// + public abstract void Dispose(); } } diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItemStream.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItemStream.cs index f96f4172d0..5a794d4817 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItemStream.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItemStream.cs @@ -6,6 +6,8 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom { using System; using System.IO; + using System.Threading; + using System.Threading.Tasks; using Newtonsoft.Json.Linq; /// @@ -69,10 +71,19 @@ protected internal override void SetDecryptableItem( cosmosSerializer); } + private void Dispose(bool disposing) + { + if (disposing) + { + this.StreamPayload?.Dispose(); + this.DecryptableItem?.Dispose(); + } + } + /// - public void Dispose() + public override void Dispose() { - this.StreamPayload.Dispose(); + this.Dispose(true); } /// @@ -80,5 +91,25 @@ protected internal override Stream ToStream(CosmosSerializer serializer) { return this.StreamPayload; } + + /// + /// This solution is not performant with Newtonsoft.Json. + protected internal override async Task ToStreamAsync(CosmosSerializer serializer, Stream outputStream, CancellationToken cancellationToken) + { +#if NET8_0_OR_GREATER + await this.StreamPayload.CopyToAsync(outputStream, cancellationToken); +#else + await this.StreamPayload.CopyToAsync(outputStream, 81920, cancellationToken); +#endif + } + +#if NET8_0_OR_GREATER + /// + /// Direct stream based item is not supported with Newtonsoft.Json. + protected internal override void SetDecryptableStream(Stream decryptableStream, Encryptor encryptor, JsonProcessor jsonProcessor, CosmosSerializer cosmosSerializer, StreamManager streamManager) + { + throw new NotImplementedException("Stream based item is only allowed for EncryptionContainerStream"); + } +#endif } } diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItem{T}.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItem{T}.cs index d6d6531e1c..7aef884c8a 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItem{T}.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptableItem{T}.cs @@ -6,6 +6,8 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom { using System; using System.IO; + using System.Threading; + using System.Threading.Tasks; using Newtonsoft.Json.Linq; /// @@ -85,5 +87,32 @@ protected internal override Stream ToStream(CosmosSerializer serializer) { return serializer.ToStream(this.Item); } + + /// + /// This solution is not performant with Newtonsoft.Json. + protected internal override async Task ToStreamAsync(CosmosSerializer serializer, Stream outputStream, CancellationToken cancellationToken) + { + Stream temp = serializer.ToStream(this.Item); +#if NET8_0_OR_GREATER + await temp.CopyToAsync(outputStream, cancellationToken); +#else + await temp.CopyToAsync(outputStream, 81920, cancellationToken); +#endif + } + +#if NET8_0_OR_GREATER + /// + /// Direct stream based item is not supported with Newtonsoft.Json. + protected internal override void SetDecryptableStream(Stream decryptableStream, Encryptor encryptor, JsonProcessor jsonProcessor, CosmosSerializer cosmosSerializer, StreamManager streamManager) + { + throw new NotImplementedException("Stream based item is only allowed for EncryptionContainerStream"); + } +#endif + + /// + /// Does nothing with Newtonsoft based EncryptableItem. + public override void Dispose() + { + } } } diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainerExtensions.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainerExtensions.cs index fdde796298..8654a07133 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainerExtensions.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainerExtensions.cs @@ -23,11 +23,40 @@ public static Container WithEncryptor( this Container container, Encryptor encryptor) { +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + if (container.Database.Client.ClientOptions.UseSystemTextJsonSerializerWithOptions is not null) + { + return new EncryptionContainerStream(container, encryptor); + } +#endif + return new EncryptionContainer( container, encryptor); } +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + /// + /// Get container with for performing operations using client-side encryption. + /// + /// Regular cosmos container. + /// Provider that allows encrypting and decrypting data. + /// Json Processor used for the container. + /// Container to perform operations supporting client-side encryption / decryption. + public static Container WithEncryptor( + this Container container, + Encryptor encryptor, + JsonProcessor jsonProcessor) + { + return jsonProcessor switch + { + JsonProcessor.Stream => new EncryptionContainerStream(container, encryptor), + JsonProcessor.Newtonsoft => new EncryptionContainer(container, encryptor), + _ => throw new NotSupportedException($"Json Processor {jsonProcessor} is not supported.") + }; + } +#endif + /// /// This method gets the FeedIterator from LINQ IQueryable to execute query asynchronously. /// This will create the fresh new FeedIterator when called which will support decryption. @@ -50,14 +79,20 @@ public static FeedIterator ToEncryptionFeedIterator( this Container container, IQueryable query) { - if (container is not EncryptionContainer encryptionContainer) + return container switch { - throw new ArgumentOutOfRangeException(nameof(query), $"{nameof(ToEncryptionFeedIterator)} is only supported with {nameof(EncryptionContainer)}."); - } +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + EncryptionContainerStream encryptionContainerStream => new EncryptionFeedIteratorStream( + (EncryptionFeedIteratorStream)encryptionContainerStream.ToEncryptionStreamIterator(query), + encryptionContainerStream.ResponseFactory), +#endif + EncryptionContainer encryptionContainer => new EncryptionFeedIterator( + (EncryptionFeedIterator)encryptionContainer.ToEncryptionStreamIterator(query), + encryptionContainer.ResponseFactory), + + _ => throw new ArgumentOutOfRangeException(nameof(container), $"Container type {container.GetType().Name} is not supported.") - return new EncryptionFeedIterator( - (EncryptionFeedIterator)encryptionContainer.ToEncryptionStreamIterator(query), - encryptionContainer.ResponseFactory); + }; } /// @@ -82,15 +117,21 @@ public static FeedIterator ToEncryptionStreamIterator( this Container container, IQueryable query) { - if (container is not EncryptionContainer encryptionContainer) + return container switch { - throw new ArgumentOutOfRangeException(nameof(query), $"{nameof(ToEncryptionStreamIterator)} is only supported with {nameof(EncryptionContainer)}."); - } - - return new EncryptionFeedIterator( - query.ToStreamIterator(), - encryptionContainer.Encryptor, - encryptionContainer.CosmosSerializer); +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + EncryptionContainerStream encryptionContainerStream => new EncryptionFeedIteratorStream( + query.ToStreamIterator(), + encryptionContainerStream.Encryptor, + encryptionContainerStream.CosmosSerializer, + new MemoryStreamManager()), +#endif + EncryptionContainer encryptionContainer => new EncryptionFeedIterator( + query.ToStreamIterator(), + encryptionContainer.Encryptor, + encryptionContainer.CosmosSerializer), + _ => throw new ArgumentOutOfRangeException(nameof(container), $"Container type {container.GetType().Name} is not supported.") + }; } } } diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionProcessor.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionProcessor.cs index 7bf05fb930..9091a88530 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionProcessor.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionProcessor.cs @@ -30,6 +30,7 @@ internal static class EncryptionProcessor #if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER private static readonly StreamProcessor StreamProcessor = new (); + private static readonly ArrayStreamProcessor ArrayStreamProcessor = new (); #endif private static readonly MdeEncryptionProcessor MdeEncryptionProcessor = new (); @@ -87,20 +88,38 @@ public static async Task EncryptAsync( if (!encryptionOptions.PathsToEncrypt.Any()) { await input.CopyToAsync(output, cancellationToken); + output.Position = 0; return; } - if (encryptionOptions.EncryptionAlgorithm != CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized) - { - throw new NotSupportedException($"Streaming mode is only allowed for {nameof(CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized)}"); - } - - if (encryptionOptions.JsonProcessor != JsonProcessor.Stream) + switch (encryptionOptions.EncryptionAlgorithm) { - throw new NotSupportedException($"Streaming mode is only allowed for {nameof(JsonProcessor.Stream)}"); +#pragma warning disable CS0618 // Type or member is obsolete + case CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized: + { + Stream result = await AeAesEncryptionProcessor.EncryptAsync(input, encryptor, encryptionOptions, cancellationToken); + await result.CopyToAsync(output, cancellationToken); + output.Position = 0; + return; + } +#pragma warning restore CS0618 // Type or member is obsolete + case CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized: + switch (encryptionOptions.JsonProcessor) + { + case JsonProcessor.Stream: + await EncryptionProcessor.StreamProcessor.EncryptStreamAsync(input, output, encryptor, encryptionOptions, cancellationToken); + break; + case JsonProcessor.Newtonsoft: + await EncryptionProcessor.MdeEncryptionProcessor.EncryptStreamAsync(input, output, encryptor, encryptionOptions, cancellationToken); + break; + default: + throw new NotSupportedException($"Streaming mode is not supported for {encryptionOptions.JsonProcessor}"); + } + + break; + default: + throw new NotSupportedException($"Encryption algorithm {encryptionOptions.EncryptionAlgorithm} not supported."); } - - await EncryptionProcessor.StreamProcessor.EncryptStreamAsync(input, output, encryptor, encryptionOptions, cancellationToken); } #endif @@ -173,11 +192,6 @@ public static async Task DecryptAsync( return null; } - if (jsonProcessor != JsonProcessor.Stream) - { - throw new NotSupportedException($"Streaming mode is only allowed for {nameof(JsonProcessor.Stream)}"); - } - Debug.Assert(input.CanSeek); Debug.Assert(output.CanWrite); Debug.Assert(output.CanSeek); @@ -190,6 +204,7 @@ public static async Task DecryptAsync( if (properties?.EncryptionProperties == null) { await input.CopyToAsync(output, cancellationToken: cancellationToken); + output.Position = 0; return null; } @@ -197,7 +212,19 @@ public static async Task DecryptAsync( #pragma warning disable CS0618 // Type or member is obsolete if (properties.EncryptionProperties.EncryptionAlgorithm == CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized) { - context = await StreamProcessor.DecryptStreamAsync(input, output, encryptor, properties.EncryptionProperties, diagnosticsContext, cancellationToken); + switch (jsonProcessor) + { + case JsonProcessor.Stream: + context = await StreamProcessor.DecryptStreamAsync(input, output, encryptor, properties.EncryptionProperties, diagnosticsContext, cancellationToken); + break; + case JsonProcessor.Newtonsoft: + (Stream ms, context) = await EncryptionProcessor.DecryptAsync(input, encryptor, diagnosticsContext, cancellationToken); + await ms.CopyToAsync(output, cancellationToken); + output.Position = 0; + break; + default: + throw new NotSupportedException($"Streaming mode is not supported for {jsonProcessor}"); + } } else if (properties.EncryptionProperties.EncryptionAlgorithm == CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized) { @@ -420,5 +447,23 @@ await DecryptAsync( // and corresponding decrypted properties are added back in the documents. return BaseSerializer.ToStream(contentJObj); } + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + internal static async Task DeserializeAndDecryptResponseAsync( + Stream inputStream, + Stream outputStream, + Encryptor encryptor, + StreamManager streamManager, + CancellationToken cancellationToken) + { + await ArrayStreamProcessor.DeserializeAndDecryptCollectionAsync( + inputStream, + outputStream, + encryptor, + streamManager, + cancellationToken); + } +#endif + } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/MemoryStreamManager.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/MemoryStreamManager.cs new file mode 100644 index 0000000000..0edbdf52b2 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/MemoryStreamManager.cs @@ -0,0 +1,41 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +#if NET8_0_OR_GREATER +namespace Microsoft.Azure.Cosmos.Encryption.Custom +{ + using System.IO; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror; + + /// + /// Memory Stream manager + /// + /// Placeholder + internal class MemoryStreamManager : StreamManager + { + private readonly RecyclableMemoryStreamManager streamManager = new (); + + /// + /// Create stream + /// + /// Desired minimal capacity of stream. + /// Instance of stream. + public override Stream CreateStream(int hintSize = 0) + { + return new RecyclableMemoryStream(this.streamManager, null, hintSize); + } + + /// + /// Dispose of used Stream (return to pool) + /// + /// Stream to dispose. + /// ValueTask.CompletedTask + public async override ValueTask ReturnStreamAsync(Stream stream) + { + await stream.DisposeAsync(); + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Microsoft.Azure.Cosmos.Encryption.Custom.csproj b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Microsoft.Azure.Cosmos.Encryption.Custom.csproj index 3b94db54c2..95014ab6d6 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Microsoft.Azure.Cosmos.Encryption.Custom.csproj +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Microsoft.Azure.Cosmos.Encryption.Custom.csproj @@ -26,11 +26,11 @@ - + - + diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/README.md b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/README.md new file mode 100644 index 0000000000..eed315c783 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/README.md @@ -0,0 +1,3 @@ +# Microsoft.IO.RecyclableMemoryStream 3.0.1 + +Mirrored from https://github.com/microsoft/Microsoft.IO.RecyclableMemoryStream/tree/master/src diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStream.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStream.cs new file mode 100644 index 0000000000..92b003d71e --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStream.cs @@ -0,0 +1,1585 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +// The MIT License (MIT) +// +// Copyright (c) 2015-2016 Microsoft +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +namespace Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror +{ + using System; + using System.Buffers; + using System.Collections.Generic; + using System.Diagnostics; + using System.IO; + using System.Runtime.CompilerServices; + using System.Threading; + using System.Threading.Tasks; + + /// + /// MemoryStream implementation that deals with pooling and managing memory streams which use potentially large + /// buffers. + /// + /// + /// This class works in tandem with the to supply MemoryStream-derived + /// objects to callers, while avoiding these specific problems: + /// + /// + /// LOH allocations + /// Since all large buffers are pooled, they will never incur a Gen2 GC + /// + /// + /// Memory wasteA standard memory stream doubles its size when it runs out of room. This + /// leads to continual memory growth as each stream approaches the maximum allowed size. + /// + /// + /// Memory copying + /// Each time a MemoryStream grows, all the bytes are copied into new buffers. + /// This implementation only copies the bytes when is called. + /// + /// + /// Memory fragmentation + /// By using homogeneous buffer sizes, it ensures that blocks of memory + /// can be easily reused. + /// + /// + /// + /// + /// The stream is implemented on top of a series of uniformly-sized blocks. As the stream's length grows, + /// additional blocks are retrieved from the memory manager. It is these blocks that are pooled, not the stream + /// object itself. + /// + /// + /// The biggest wrinkle in this implementation is when is called. This requires a single + /// contiguous buffer. If only a single block is in use, then that block is returned. If multiple blocks + /// are in use, we retrieve a larger buffer from the memory manager. These large buffers are also pooled, + /// split by size--they are multiples/exponentials of a chunk size (1 MB by default). + /// + /// + /// Once a large buffer is assigned to the stream the small blocks are NEVER again used for this stream. All operations take place on the + /// large buffer. The large buffer can be replaced by a larger buffer from the pool as needed. All blocks and large buffers + /// are maintained in the stream until the stream is disposed (unless AggressiveBufferReturn is enabled in the stream manager). + /// + /// + /// A further wrinkle is what happens when the stream is longer than the maximum allowable array length under .NET. This is allowed + /// when only blocks are in use, and only the Read/Write APIs are used. Once a stream grows to this size, any attempt to convert it + /// to a single buffer will result in an exception. Similarly, if a stream is already converted to use a single larger buffer, then + /// it cannot grow beyond the limits of the maximum allowable array size. + /// + /// + /// Any method that modifies the stream has the potential to throw an OutOfMemoryException, either because + /// the stream is beyond the limits set in RecyclableStreamManager, or it would result in a buffer larger than + /// the maximum array size supported by .NET. + /// + /// + public sealed class RecyclableMemoryStream : MemoryStream, IBufferWriter + { + /// + /// All of these blocks must be the same size. + /// + private readonly List blocks; + + private readonly Guid id; + + private readonly RecyclableMemoryStreamManager memoryManager; + + private readonly string tag; + + private readonly long creationTimestamp; + + /// + /// This list is used to store buffers once they're replaced by something larger. + /// This is for the cases where you have users of this class that may hold onto the buffers longer + /// than they should and you want to prevent race conditions which could corrupt the data. + /// + private List dirtyBuffers; + + private bool disposed; + + /// + /// This is only set by GetBuffer() if the necessary buffer is larger than a single block size, or on + /// construction if the caller immediately requests a single large buffer. + /// + /// If this field is non-null, it contains the concatenation of the bytes found in the individual + /// blocks. Once it is created, this (or a larger) largeBuffer will be used for the life of the stream. + /// + private byte[] largeBuffer; + + /// + /// Gets unique identifier for this stream across its entire lifetime. + /// + /// Object has been disposed. + internal Guid Id + { + get + { + this.CheckDisposed(); + return this.id; + } + } + + /// + /// Gets a temporary identifier for the current usage of this stream. + /// + /// Object has been disposed. + internal string Tag + { + get + { + this.CheckDisposed(); + return this.tag; + } + } + + /// + /// Gets the memory manager being used by this stream. + /// + /// Object has been disposed. + internal RecyclableMemoryStreamManager MemoryManager + { + get + { + this.CheckDisposed(); + return this.memoryManager; + } + } + + /// + /// Gets call stack of the constructor. It is only set if is true, + /// which should only be in debugging situations. + /// + internal string AllocationStack { get; } + + /// + /// Gets call stack of the call. It is only set if is true, + /// which should only be in debugging situations. + /// + internal string DisposeStack { get; private set; } + + /// + /// Initializes a new instance of the class. + /// + /// The memory manager. + public RecyclableMemoryStream(RecyclableMemoryStreamManager memoryManager) + : this(memoryManager, Guid.NewGuid(), null, 0, null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The memory manager. + /// A unique identifier which can be used to trace usages of the stream. + public RecyclableMemoryStream(RecyclableMemoryStreamManager memoryManager, Guid id) + : this(memoryManager, id, null, 0, null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The memory manager. + /// A string identifying this stream for logging and debugging purposes. + public RecyclableMemoryStream(RecyclableMemoryStreamManager memoryManager, string tag) + : this(memoryManager, Guid.NewGuid(), tag, 0, null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The memory manager. + /// A unique identifier which can be used to trace usages of the stream. + /// A string identifying this stream for logging and debugging purposes. + public RecyclableMemoryStream(RecyclableMemoryStreamManager memoryManager, Guid id, string tag) + : this(memoryManager, id, tag, 0, null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The memory manager. + /// A string identifying this stream for logging and debugging purposes. + /// The initial requested size to prevent future allocations. + public RecyclableMemoryStream(RecyclableMemoryStreamManager memoryManager, string tag, long requestedSize) + : this(memoryManager, Guid.NewGuid(), tag, requestedSize, null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The memory manager + /// A unique identifier which can be used to trace usages of the stream. + /// A string identifying this stream for logging and debugging purposes. + /// The initial requested size to prevent future allocations. + public RecyclableMemoryStream(RecyclableMemoryStreamManager memoryManager, Guid id, string tag, long requestedSize) + : this(memoryManager, id, tag, requestedSize, null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The memory manager. + /// A unique identifier which can be used to trace usages of the stream. + /// A string identifying this stream for logging and debugging purposes. + /// The initial requested size to prevent future allocations. + /// An initial buffer to use. This buffer will be owned by the stream and returned to the memory manager upon Dispose. + internal RecyclableMemoryStream(RecyclableMemoryStreamManager memoryManager, Guid id, string tag, long requestedSize, byte[] initialLargeBuffer) + : base(Array.Empty()) + { + this.memoryManager = memoryManager; + this.id = id; + this.tag = tag; + this.blocks = new List(); + this.creationTimestamp = Stopwatch.GetTimestamp(); + + long actualRequestedSize = Math.Max(requestedSize, this.memoryManager.OptionsValue.BlockSize); + + if (initialLargeBuffer == null) + { + this.EnsureCapacity(actualRequestedSize); + } + else + { + this.largeBuffer = initialLargeBuffer; + } + + if (this.memoryManager.OptionsValue.GenerateCallStacks) + { + this.AllocationStack = Environment.StackTrace; + } + + this.memoryManager.ReportStreamCreated(this.id, this.tag, requestedSize, actualRequestedSize); + this.memoryManager.ReportUsageReport(); + } + + /// + /// Finalizes an instance of the class. + /// + /// Failing to dispose indicates a bug in the code using streams. Care should be taken to properly account for stream lifetime. + ~RecyclableMemoryStream() + { + this.Dispose(false); + } + + /// + /// Returns the memory used by this stream back to the pool. + /// + /// Whether we're disposing (true), or being called by the finalizer (false). + protected override void Dispose(bool disposing) + { + if (this.disposed) + { + string doubleDisposeStack = null; + if (this.memoryManager.OptionsValue.GenerateCallStacks) + { + doubleDisposeStack = Environment.StackTrace; + } + + this.memoryManager.ReportStreamDoubleDisposed(this.id, this.tag, this.AllocationStack, this.DisposeStack, doubleDisposeStack); + return; + } + + this.disposed = true; + TimeSpan lifetime = TimeSpan.FromTicks((Stopwatch.GetTimestamp() - this.creationTimestamp) * TimeSpan.TicksPerSecond / Stopwatch.Frequency); + + if (this.memoryManager.OptionsValue.GenerateCallStacks) + { + this.DisposeStack = Environment.StackTrace; + } + + this.memoryManager.ReportStreamDisposed(this.id, this.tag, lifetime, this.AllocationStack, this.DisposeStack); + + if (disposing) + { + GC.SuppressFinalize(this); + } + else + { + // We're being finalized. + this.memoryManager.ReportStreamFinalized(this.id, this.tag, this.AllocationStack); + + if (AppDomain.CurrentDomain.IsFinalizingForUnload()) + { + // If we're being finalized because of a shutdown, don't go any further. + // We have no idea what's already been cleaned up. Triggering events may cause + // a crash. + base.Dispose(disposing); + return; + } + } + + this.memoryManager.ReportStreamLength(this.length); + + if (this.largeBuffer != null) + { + this.memoryManager.ReturnLargeBuffer(this.largeBuffer, this.id, this.tag); + } + + if (this.dirtyBuffers != null) + { + foreach (byte[] buffer in this.dirtyBuffers) + { + this.memoryManager.ReturnLargeBuffer(buffer, this.id, this.tag); + } + } + + this.memoryManager.ReturnBlocks(this.blocks, this.id, this.tag); + this.memoryManager.ReportUsageReport(); + this.blocks.Clear(); + + base.Dispose(disposing); + } + + /// + /// Equivalent to Dispose. + /// + public override void Close() + { + this.Dispose(true); + } + + /// + /// Gets or sets the capacity. + /// + /// + /// + /// Capacity is always in multiples of the memory manager's block size, unless + /// the large buffer is in use. Capacity never decreases during a stream's lifetime. + /// Explicitly setting the capacity to a lower value than the current value will have no effect. + /// This is because the buffers are all pooled by chunks and there's little reason to + /// allow stream truncation. + /// + /// + /// Writing past the current capacity will cause to automatically increase, until MaximumStreamCapacity is reached. + /// + /// + /// If the capacity is larger than int.MaxValue, then InvalidOperationException will be thrown. If you anticipate using + /// larger streams, use the property instead. + /// + /// + /// Object has been disposed. + /// Capacity is larger than int.MaxValue. + public override int Capacity + { + get + { + this.CheckDisposed(); + if (this.largeBuffer != null) + { + return this.largeBuffer.Length; + } + + long size = (long)this.blocks.Count * this.memoryManager.OptionsValue.BlockSize; + if (size > int.MaxValue) + { + throw new InvalidOperationException($"{nameof(this.Capacity)} is larger than int.MaxValue. Use {nameof(this.Capacity64)} instead."); + } + + return (int)size; + } + + set => this.Capacity64 = value; + } + + /// + /// Gets or sets returns a 64-bit version of capacity, for streams larger than int.MaxValue in length. + /// + public long Capacity64 + { + get + { + this.CheckDisposed(); + if (this.largeBuffer != null) + { + return this.largeBuffer.Length; + } + + long size = (long)this.blocks.Count * this.memoryManager.OptionsValue.BlockSize; + return size; + } + + set + { + this.CheckDisposed(); + this.EnsureCapacity(value); + } + } + + private long length; + + /// + /// Gets the number of bytes written to this stream. + /// + /// Object has been disposed. + /// If the buffer has already been converted to a large buffer, then the maximum length is limited by the maximum allowed array length in .NET. + public override long Length + { + get + { + this.CheckDisposed(); + return this.length; + } + } + + private long position; + + /// + /// Gets or sets the current position in the stream. + /// + /// Object has been disposed. + /// A negative value was passed. + /// Stream is in large-buffer mode, but an attempt was made to set the position past the maximum allowed array length. + /// If the buffer has already been converted to a large buffer, then the maximum length (and thus position) is limited by the maximum allowed array length in .NET. + public override long Position + { + get + { + this.CheckDisposed(); + return this.position; + } + + set + { + this.CheckDisposed(); + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(value)} must be non-negative."); + } + + if (this.largeBuffer != null && value > RecyclableMemoryStreamManager.MaxArrayLength) + { + throw new InvalidOperationException($"Once the stream is converted to a single large buffer, position cannot be set past {RecyclableMemoryStreamManager.MaxArrayLength}."); + } + + this.position = value; + } + } + + /// + /// Gets a value indicating whether whether the stream can currently read. + /// + public override bool CanRead => !this.disposed; + + /// + /// Gets a value indicating whether whether the stream can currently seek. + /// + public override bool CanSeek => !this.disposed; + + /// + /// Gets a value indicating whether the steram can timeout. + /// + /// Always false + public override bool CanTimeout => false; + + /// + /// Gets a value indicating whether whether the stream can currently write. + /// + public override bool CanWrite => !this.disposed; + + /// + /// Returns a single buffer containing the contents of the stream. + /// The buffer may be longer than the stream length. + /// + /// A byte[] buffer. + /// IMPORTANT: Doing a after calling GetBuffer invalidates the buffer. The old buffer is held onto + /// until is called, but the next time GetBuffer is called, a new buffer from the pool will be required. + /// Object has been disposed. + /// stream is too large for a contiguous buffer. + public override byte[] GetBuffer() + { + this.CheckDisposed(); + + if (this.largeBuffer != null) + { + return this.largeBuffer; + } + + if (this.blocks.Count == 1) + { + return this.blocks[0]; + } + + // Buffer needs to reflect the capacity, not the length, because + // it's possible that people will manipulate the buffer directly + // and set the length afterward. Capacity sets the expectation + // for the size of the buffer. + byte[] newBuffer = this.memoryManager.GetLargeBuffer(this.Capacity64, this.id, this.tag); + + // InternalRead will check for existence of largeBuffer, so make sure we + // don't set it until after we've copied the data. + this.AssertLengthIsSmall(); + this.InternalRead(newBuffer, 0, (int)this.length, 0); + this.largeBuffer = newBuffer; + + if (this.blocks.Count > 0 && this.memoryManager.OptionsValue.AggressiveBufferReturn) + { + this.memoryManager.ReturnBlocks(this.blocks, this.id, this.tag); + this.blocks.Clear(); + } + + return this.largeBuffer; + } + +#if NET6_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER + /// + public override void CopyTo(Stream destination, int bufferSize) + { + this.WriteTo(destination, this.position, this.length - this.position); + } +#endif + + /// Asynchronously reads all the bytes from the current position in this stream and writes them to another stream. + /// The stream to which the contents of the current stream will be copied. + /// This parameter is ignored. + /// The token to monitor for cancellation requests. + /// A task that represents the asynchronous copy operation. + /// + /// is . + /// Either the current stream or the destination stream is disposed. + /// The current stream does not support reading, or the destination stream does not support writing. + /// Similarly to MemoryStream's behavior, CopyToAsync will adjust the source stream's position by the number of bytes written to the destination stream, as a Read would do. + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { +#if NET8_0_OR_GREATER + ArgumentNullException.ThrowIfNull(destination); +#else + if (destination == null) + { + throw new ArgumentNullException(nameof(destination)); + } +#endif + + this.CheckDisposed(); + + if (this.length == 0) + { + return Task.CompletedTask; + } + + long startPos = this.position; + long count = this.length - startPos; + this.position += count; + + if (destination is MemoryStream destinationRMS) + { + this.WriteTo(destinationRMS, startPos, count); + return Task.CompletedTask; + } + else + { + if (this.largeBuffer == null) + { + if (this.blocks.Count == 1) + { + this.AssertLengthIsSmall(); + return destination.WriteAsync(this.blocks[0], (int)startPos, (int)count, cancellationToken); + } + else + { + return CopyToAsyncImpl(destination, this.GetBlockAndRelativeOffset(startPos), count, this.blocks, cancellationToken); + } + } + else + { + this.AssertLengthIsSmall(); + return destination.WriteAsync(this.largeBuffer, (int)startPos, (int)count, cancellationToken); + } + } + + static async Task CopyToAsyncImpl(Stream destination, BlockAndOffset blockAndOffset, long count, List blocks, CancellationToken cancellationToken) + { + long bytesRemaining = count; + int currentBlock = blockAndOffset.Block; + int currentOffset = blockAndOffset.Offset; + while (bytesRemaining > 0) + { + byte[] block = blocks[currentBlock]; + int amountToCopy = (int)Math.Min(block.Length - currentOffset, bytesRemaining); +#if NET8_0_OR_GREATER + await destination.WriteAsync(block.AsMemory(currentOffset, amountToCopy), cancellationToken); +#else + await destination.WriteAsync(block, currentOffset, amountToCopy, cancellationToken); +#endif + bytesRemaining -= amountToCopy; + ++currentBlock; + currentOffset = 0; + } + } + } + + private byte[] bufferWriterTempBuffer; + + /// + /// Notifies the stream that bytes were written to the buffer returned by or . + /// Seeks forward by bytes. + /// + /// + /// You must request a new buffer after calling Advance to continue writing more data and cannot write to a previously acquired buffer. + /// + /// How many bytes to advance. + /// Object has been disposed. + /// is negative. + /// is larger than the size of the previously requested buffer. + public void Advance(int count) + { + this.CheckDisposed(); + if (count < 0) + { + throw new ArgumentOutOfRangeException(nameof(count), $"{nameof(count)} must be non-negative."); + } + + byte[] buffer = this.bufferWriterTempBuffer; + if (buffer != null) + { + if (count > buffer.Length) + { + throw new InvalidOperationException($"Cannot advance past the end of the buffer, which has a size of {buffer.Length}."); + } + + this.Write(buffer, 0, count); + this.ReturnTempBuffer(buffer); + this.bufferWriterTempBuffer = null; + } + else + { + long bufferSize = this.largeBuffer == null + ? this.memoryManager.OptionsValue.BlockSize - this.GetBlockAndRelativeOffset(this.position).Offset + : this.largeBuffer.Length - this.position; + + if (count > bufferSize) + { + throw new InvalidOperationException($"Cannot advance past the end of the buffer, which has a size of {bufferSize}."); + } + + this.position += count; + this.length = Math.Max(this.position, this.length); + } + } + + private void ReturnTempBuffer(byte[] buffer) + { + if (buffer.Length == this.memoryManager.OptionsValue.BlockSize) + { + this.memoryManager.ReturnBlock(buffer, this.id, this.tag); + } + else + { + this.memoryManager.ReturnLargeBuffer(buffer, this.id, this.tag); + } + } + + /// + /// + /// IMPORTANT: Calling Write(), GetBuffer(), TryGetBuffer(), Seek(), GetLength(), Advance(), + /// or setting Position after calling GetMemory() invalidates the memory. + /// + public Memory GetMemory(int sizeHint = 0) + { + return this.GetWritableBuffer(sizeHint); + } + + /// + /// + /// IMPORTANT: Calling Write(), GetBuffer(), TryGetBuffer(), Seek(), GetLength(), Advance(), + /// or setting Position after calling GetSpan() invalidates the span. + /// + public Span GetSpan(int sizeHint = 0) + { + return this.GetWritableBuffer(sizeHint); + } + + /// + /// When callers to GetSpan() or GetMemory() request a buffer that is larger than the remaining size of the current block + /// this method return a temp buffer. When Advance() is called, that temp buffer is then copied into the stream. + /// + private ArraySegment GetWritableBuffer(int sizeHint) + { + this.CheckDisposed(); + if (sizeHint < 0) + { + throw new ArgumentOutOfRangeException(nameof(sizeHint), $"{nameof(sizeHint)} must be non-negative."); + } + + int minimumBufferSize = Math.Max(sizeHint, 1); + + this.EnsureCapacity(this.position + minimumBufferSize); + if (this.bufferWriterTempBuffer != null) + { + this.ReturnTempBuffer(this.bufferWriterTempBuffer); + this.bufferWriterTempBuffer = null; + } + + if (this.largeBuffer != null) + { + return new ArraySegment(this.largeBuffer, (int)this.position, this.largeBuffer.Length - (int)this.position); + } + + BlockAndOffset blockAndOffset = this.GetBlockAndRelativeOffset(this.position); + int remainingBytesInBlock = this.MemoryManager.OptionsValue.BlockSize - blockAndOffset.Offset; + if (remainingBytesInBlock >= minimumBufferSize) + { + return new ArraySegment(this.blocks[blockAndOffset.Block], blockAndOffset.Offset, this.MemoryManager.OptionsValue.BlockSize - blockAndOffset.Offset); + } + + this.bufferWriterTempBuffer = minimumBufferSize > this.memoryManager.OptionsValue.BlockSize ? + this.memoryManager.GetLargeBuffer(minimumBufferSize, this.id, this.tag) : + this.memoryManager.GetBlock(); + + return new ArraySegment(this.bufferWriterTempBuffer); + } + + /// + /// Returns a sequence containing the contents of the stream. + /// + /// A ReadOnlySequence of bytes. + /// IMPORTANT: Calling Write(), GetMemory(), GetSpan(), Dispose(), or Close() after calling GetReadOnlySequence() invalidates the sequence. + /// Object has been disposed. + public ReadOnlySequence GetReadOnlySequence() + { + this.CheckDisposed(); + + if (this.largeBuffer != null) + { + this.AssertLengthIsSmall(); + return new ReadOnlySequence(this.largeBuffer, 0, (int)this.length); + } + + if (this.blocks.Count == 1) + { + this.AssertLengthIsSmall(); + return new ReadOnlySequence(this.blocks[0], 0, (int)this.length); + } + + BlockSegment first = new (this.blocks[0]); + BlockSegment last = first; + + for (int blockIdx = 1; last.RunningIndex + last.Memory.Length < this.length; blockIdx++) + { + last = last.Append(this.blocks[blockIdx]); + } + + return new ReadOnlySequence(first, 0, last, (int)(this.length - last.RunningIndex)); + } + + private sealed class BlockSegment : ReadOnlySequenceSegment + { + public BlockSegment(Memory memory) + { + this.Memory = memory; + } + + public BlockSegment Append(Memory memory) + { + BlockSegment nextSegment = new (memory) { RunningIndex = this.RunningIndex + this.Memory.Length }; + this.Next = nextSegment; + return nextSegment; + } + } + + /// + /// Returns an ArraySegment that wraps a single buffer containing the contents of the stream. + /// + /// An ArraySegment containing a reference to the underlying bytes. + /// Returns if a buffer can be returned; otherwise, . + public override bool TryGetBuffer(out ArraySegment buffer) + { + this.CheckDisposed(); + + try + { + if (this.length <= RecyclableMemoryStreamManager.MaxArrayLength) + { + buffer = new ArraySegment(this.GetBuffer(), 0, (int)this.Length); + return true; + } + } + catch (OutOfMemoryException) + { + } + +#if NET6_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER + buffer = ArraySegment.Empty; +#else + buffer = default; +#endif + return false; + } + + /// + /// Returns a new array with a copy of the buffer's contents. You should almost certainly be using combined with the to + /// access the bytes in this stream. Calling ToArray will destroy the benefits of pooled buffers, but it is included + /// for the sake of completeness. + /// + /// Object has been disposed. + /// The current object disallows ToArray calls. + /// The length of the stream is too long for a contiguous array. + /// Array of bytes +#pragma warning disable CS0809 + [Obsolete("This method has degraded performance vs. GetBuffer and should be avoided.")] + public override byte[] ToArray() + { + this.CheckDisposed(); + + string stack = this.memoryManager.OptionsValue.GenerateCallStacks ? Environment.StackTrace : null; + this.memoryManager.ReportStreamToArray(this.id, this.tag, stack, this.length); + + if (this.memoryManager.OptionsValue.ThrowExceptionOnToArray) + { + throw new NotSupportedException("The underlying RecyclableMemoryStreamManager is configured to not allow calls to ToArray."); + } + + byte[] newBuffer = new byte[this.Length]; + + Debug.Assert(this.length <= int.MaxValue); + this.InternalRead(newBuffer, 0, (int)this.length, 0); + + return newBuffer; + } +#pragma warning restore CS0809 + + /// + /// Reads from the current position into the provided buffer. + /// + /// Destination buffer. + /// Offset into buffer at which to start placing the read bytes. + /// Number of bytes to read. + /// The number of bytes read. + /// buffer is null. + /// offset or count is less than 0. + /// offset subtracted from the buffer length is less than count. + /// Object has been disposed. + public override int Read(byte[] buffer, int offset, int count) + { + return this.SafeRead(buffer, offset, count, ref this.position); + } + + /// + /// Reads from the specified position into the provided buffer. + /// + /// Destination buffer. + /// Offset into buffer at which to start placing the read bytes. + /// Number of bytes to read. + /// Position in the stream to start reading from. + /// The number of bytes read. + /// is null. + /// or is less than 0. + /// subtracted from the buffer length is less than . + /// Object has been disposed. + public int SafeRead(byte[] buffer, int offset, int count, ref long streamPosition) + { + this.CheckDisposed(); +#if NET8_0_OR_GREATER + ArgumentNullException.ThrowIfNull(buffer); +#else + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } +#endif + + if (offset < 0) + { + throw new ArgumentOutOfRangeException(nameof(offset), $"{nameof(offset)} cannot be negative."); + } + + if (count < 0) + { + throw new ArgumentOutOfRangeException(nameof(count), $"{nameof(count)} cannot be negative."); + } + + if (offset + count > buffer.Length) + { + throw new ArgumentException($"{nameof(buffer)} length must be at least {nameof(offset)} + {nameof(count)}."); + } + + int amountRead = this.InternalRead(buffer, offset, count, streamPosition); + streamPosition += amountRead; + return amountRead; + } + + /// + /// Reads from the current position into the provided buffer. + /// + /// Destination buffer. + /// The number of bytes read. + /// Object has been disposed. +#if NETSTANDARD2_0 + public int Read(Span buffer) +#else + public override int Read(Span buffer) +#endif + { + return this.SafeRead(buffer, ref this.position); + } + + /// + /// Reads from the specified position into the provided buffer. + /// + /// Destination buffer. + /// Position in the stream to start reading from. + /// The number of bytes read. + /// Object has been disposed. + public int SafeRead(Span buffer, ref long streamPosition) + { + this.CheckDisposed(); + + int amountRead = this.InternalRead(buffer, streamPosition); + streamPosition += amountRead; + return amountRead; + } + + /// + /// Writes the buffer to the stream. + /// + /// Source buffer. + /// Start position. + /// Number of bytes to write. + /// buffer is null. + /// offset or count is negative. + /// buffer.Length - offset is not less than count. + /// Object has been disposed. + public override void Write(byte[] buffer, int offset, int count) + { + this.CheckDisposed(); +#if NET8_0_OR_GREATER + ArgumentNullException.ThrowIfNull(buffer); +#else + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } +#endif + + if (offset < 0) + { + throw new ArgumentOutOfRangeException( + nameof(offset), + offset, + $"{nameof(offset)} must be in the range of 0 - {nameof(buffer)}.{nameof(buffer.Length)}-1."); + } + + if (count < 0) + { + throw new ArgumentOutOfRangeException(nameof(count), count, $"{nameof(count)} must be non-negative."); + } + + if (count + offset > buffer.Length) + { + throw new ArgumentException($"{nameof(count)} must be greater than {nameof(buffer)}.{nameof(buffer.Length)} - {nameof(offset)}."); + } + + int blockSize = this.memoryManager.OptionsValue.BlockSize; + long end = this.position + count; + + this.EnsureCapacity(end); + + if (this.largeBuffer == null) + { + int bytesRemaining = count; + int bytesWritten = 0; + BlockAndOffset blockAndOffset = this.GetBlockAndRelativeOffset(this.position); + + while (bytesRemaining > 0) + { + byte[] currentBlock = this.blocks[blockAndOffset.Block]; + int remainingInBlock = blockSize - blockAndOffset.Offset; + int amountToWriteInBlock = Math.Min(remainingInBlock, bytesRemaining); + + Buffer.BlockCopy( + buffer, + offset + bytesWritten, + currentBlock, + blockAndOffset.Offset, + amountToWriteInBlock); + + bytesRemaining -= amountToWriteInBlock; + bytesWritten += amountToWriteInBlock; + + ++blockAndOffset.Block; + blockAndOffset.Offset = 0; + } + } + else + { + Buffer.BlockCopy(buffer, offset, this.largeBuffer, (int)this.position, count); + } + + this.position = end; + this.length = Math.Max(this.position, this.length); + } + + /// + /// Writes the buffer to the stream. + /// + /// Source buffer. + /// buffer is null. + /// Object has been disposed. +#if NETSTANDARD2_0 + public void Write(ReadOnlySpan source) +#else + public override void Write(ReadOnlySpan source) +#endif + { + this.CheckDisposed(); + + int blockSize = this.memoryManager.OptionsValue.BlockSize; + long end = this.position + source.Length; + + this.EnsureCapacity(end); + + if (this.largeBuffer == null) + { + BlockAndOffset blockAndOffset = this.GetBlockAndRelativeOffset(this.position); + + while (source.Length > 0) + { + byte[] currentBlock = this.blocks[blockAndOffset.Block]; + int remainingInBlock = blockSize - blockAndOffset.Offset; + int amountToWriteInBlock = Math.Min(remainingInBlock, source.Length); +#if NET8_0_OR_GREATER + source[..amountToWriteInBlock] + .CopyTo(currentBlock.AsSpan(blockAndOffset.Offset)); + + source = source[amountToWriteInBlock..]; +#else + source.Slice(0, amountToWriteInBlock) + .CopyTo(currentBlock.AsSpan(blockAndOffset.Offset)); + + source = source.Slice(amountToWriteInBlock); +#endif + + ++blockAndOffset.Block; + blockAndOffset.Offset = 0; + } + } + else + { + source.CopyTo(this.largeBuffer.AsSpan((int)this.position)); + } + + this.position = end; + this.length = Math.Max(this.position, this.length); + } + + /// + /// Returns a useful string for debugging. This should not normally be called in actual production code. + /// + /// String with debug data. + public override string ToString() + { + if (!this.disposed) + { + return $"Id = {this.Id}, Tag = {this.Tag}, Length = {this.Length:N0} bytes"; + } + else + { + // Avoid properties because of the dispose check, but the fields themselves are not cleared. + return $"Disposed: Id = {this.id}, Tag = {this.tag}, Final Length: {this.length:N0} bytes"; + } + } + + /// + /// Writes a single byte to the current position in the stream. + /// + /// byte value to write. + /// Object has been disposed. + public override void WriteByte(byte value) + { + this.CheckDisposed(); + + long end = this.position + 1; + + if (this.largeBuffer == null) + { + int blockSize = this.memoryManager.OptionsValue.BlockSize; + + int block = (int)Math.DivRem(this.position, blockSize, out long index); + + if (block >= this.blocks.Count) + { + this.EnsureCapacity(end); + } + + this.blocks[block][index] = value; + } + else + { + if (this.position >= this.largeBuffer.Length) + { + this.EnsureCapacity(end); + } + + this.largeBuffer[this.position] = value; + } + + this.position = end; + + if (this.position > this.length) + { + this.length = this.position; + } + } + + /// + /// Reads a single byte from the current position in the stream. + /// + /// The byte at the current position, or -1 if the position is at the end of the stream. + /// Object has been disposed. + public override int ReadByte() + { + return this.SafeReadByte(ref this.position); + } + + /// + /// Reads a single byte from the specified position in the stream. + /// + /// The position in the stream to read from. + /// The byte at the current position, or -1 if the position is at the end of the stream. + /// Object has been disposed. + public int SafeReadByte(ref long streamPosition) + { + this.CheckDisposed(); + if (streamPosition == this.length) + { + return -1; + } + + byte value; + if (this.largeBuffer == null) + { + BlockAndOffset blockAndOffset = this.GetBlockAndRelativeOffset(streamPosition); + value = this.blocks[blockAndOffset.Block][blockAndOffset.Offset]; + } + else + { + value = this.largeBuffer[streamPosition]; + } + + streamPosition++; + return value; + } + + /// + /// Sets the length of the stream. + /// + /// length of the stream + /// value is negative or larger than . + /// Object has been disposed. + public override void SetLength(long value) + { + this.CheckDisposed(); + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(value)} must be non-negative."); + } + + this.EnsureCapacity(value); + + this.length = value; + if (this.position > value) + { + this.position = value; + } + } + + /// + /// Sets the position to the offset from the seek location. + /// + /// How many bytes to move. + /// From where. + /// The new position. + /// Object has been disposed. + /// is larger than . + /// Invalid seek origin. + /// Attempt to set negative position. + public override long Seek(long offset, SeekOrigin loc) + { + this.CheckDisposed(); + long newPosition = loc switch + { + SeekOrigin.Begin => offset, + SeekOrigin.Current => offset + this.position, + SeekOrigin.End => offset + this.length, + _ => throw new ArgumentException("Invalid seek origin.", nameof(loc)), + }; + if (newPosition < 0) + { + throw new IOException("Seek before beginning."); + } + + this.position = newPosition; + return this.position; + } + + /// + /// Synchronously writes this stream's bytes to the argument stream. + /// + /// Destination stream. + /// Important: This does a synchronous write, which may not be desired in some situations. + /// is null. + /// Object has been disposed. + public override void WriteTo(Stream stream) + { + this.WriteTo(stream, 0, this.length); + } + + /// + /// Synchronously writes this stream's bytes, starting at offset, for count bytes, to the argument stream. + /// + /// Destination stream. + /// Offset in source. + /// Number of bytes to write. + /// is null. + /// + /// is less than 0, or + is beyond this 's length. + /// + /// Object has been disposed. + public void WriteTo(Stream stream, long offset, long count) + { + this.CheckDisposed(); +#if NET8_0_OR_GREATER + ArgumentNullException.ThrowIfNull(stream); +#else + if (stream == null) + { + throw new ArgumentNullException(nameof(stream)); + } +#endif + + if (offset < 0 || offset + count > this.length) + { + throw new ArgumentOutOfRangeException( + message: $"{nameof(offset)} must not be negative and {nameof(offset)} + {nameof(count)} must not exceed the length of the {nameof(stream)}.", + innerException: null); + } + + if (this.largeBuffer == null) + { + BlockAndOffset blockAndOffset = this.GetBlockAndRelativeOffset(offset); + long bytesRemaining = count; + int currentBlock = blockAndOffset.Block; + int currentOffset = blockAndOffset.Offset; + + while (bytesRemaining > 0) + { + byte[] block = this.blocks[currentBlock]; + int amountToCopy = (int)Math.Min((long)block.Length - currentOffset, bytesRemaining); + stream.Write(block, currentOffset, amountToCopy); + + bytesRemaining -= amountToCopy; + + ++currentBlock; + currentOffset = 0; + } + } + else + { + stream.Write(this.largeBuffer, (int)offset, (int)count); + } + } + + /// + /// Writes bytes from the current stream to a destination byte array. + /// + /// Target buffer. + /// The entire stream is written to the target array. + /// > is null. + /// Object has been disposed. + public void WriteTo(byte[] buffer) + { + this.WriteTo(buffer, 0, this.Length); + } + + /// + /// Writes bytes from the current stream to a destination byte array. + /// + /// Target buffer. + /// Offset in the source stream, from which to start. + /// Number of bytes to write. + /// > is null. + /// + /// is less than 0, or + is beyond this stream's length. + /// + /// Object has been disposed. + public void WriteTo(byte[] buffer, long offset, long count) + { + this.WriteTo(buffer, offset, count, 0); + } + + /// + /// Writes bytes from the current stream to a destination byte array. + /// + /// Target buffer. + /// Offset in the source stream, from which to start. + /// Number of bytes to write. + /// Offset in the target byte array to start writing + /// buffer is null + /// + /// is less than 0, or + is beyond this stream's length. + /// + /// + /// is less than 0, or + is beyond the target 's length. + /// + /// Object has been disposed. + public void WriteTo(byte[] buffer, long offset, long count, int targetOffset) + { + this.CheckDisposed(); +#if NET8_0_OR_GREATER + ArgumentNullException.ThrowIfNull(buffer); +#else + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } +#endif + + if (offset < 0 || offset + count > this.length) + { + throw new ArgumentOutOfRangeException( + message: $"{nameof(offset)} must not be negative and {nameof(offset)} + {nameof(count)} must not exceed the length of the stream.", + innerException: null); + } + + if (targetOffset < 0 || count + targetOffset > buffer.Length) + { + throw new ArgumentOutOfRangeException( + message: $"{nameof(targetOffset)} must not be negative and {nameof(targetOffset)} + {nameof(count)} must not exceed the length of the target {nameof(buffer)}.", + innerException: null); + } + + if (this.largeBuffer == null) + { + BlockAndOffset blockAndOffset = this.GetBlockAndRelativeOffset(offset); + long bytesRemaining = count; + int currentBlock = blockAndOffset.Block; + int currentOffset = blockAndOffset.Offset; + int currentTargetOffset = targetOffset; + + while (bytesRemaining > 0) + { + byte[] block = this.blocks[currentBlock]; + int amountToCopy = (int)Math.Min((long)block.Length - currentOffset, bytesRemaining); + Buffer.BlockCopy(block, currentOffset, buffer, currentTargetOffset, amountToCopy); + + bytesRemaining -= amountToCopy; + + ++currentBlock; + currentOffset = 0; + currentTargetOffset += amountToCopy; + } + } + else + { + this.AssertLengthIsSmall(); + Buffer.BlockCopy(this.largeBuffer, (int)offset, buffer, targetOffset, (int)count); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckDisposed() + { + if (this.disposed) + { + this.ThrowDisposedException(); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void ThrowDisposedException() + { + throw new ObjectDisposedException($"The stream with Id {this.id} and Tag {this.tag} is disposed."); + } + + private int InternalRead(byte[] buffer, int offset, int count, long fromPosition) + { + if (this.length - fromPosition <= 0) + { + return 0; + } + + int amountToCopy; + + if (this.largeBuffer == null) + { + BlockAndOffset blockAndOffset = this.GetBlockAndRelativeOffset(fromPosition); + int bytesWritten = 0; + int bytesRemaining = (int)Math.Min(count, this.length - fromPosition); + + while (bytesRemaining > 0) + { + byte[] block = this.blocks[blockAndOffset.Block]; + amountToCopy = Math.Min( + block.Length - blockAndOffset.Offset, + bytesRemaining); + Buffer.BlockCopy( + block, + blockAndOffset.Offset, + buffer, + bytesWritten + offset, + amountToCopy); + + bytesWritten += amountToCopy; + bytesRemaining -= amountToCopy; + + ++blockAndOffset.Block; + blockAndOffset.Offset = 0; + } + + return bytesWritten; + } + + amountToCopy = (int)Math.Min(count, this.length - fromPosition); + Buffer.BlockCopy(this.largeBuffer, (int)fromPosition, buffer, offset, amountToCopy); + return amountToCopy; + } + + private int InternalRead(Span buffer, long fromPosition) + { + if (this.length - fromPosition <= 0) + { + return 0; + } + + int amountToCopy; + + if (this.largeBuffer == null) + { + BlockAndOffset blockAndOffset = this.GetBlockAndRelativeOffset(fromPosition); + int bytesWritten = 0; + int bytesRemaining = (int)Math.Min(buffer.Length, this.length - fromPosition); + + while (bytesRemaining > 0) + { + byte[] block = this.blocks[blockAndOffset.Block]; + amountToCopy = Math.Min( + block.Length - blockAndOffset.Offset, + bytesRemaining); +#if NET8_0_OR_GREATER + block.AsSpan(blockAndOffset.Offset, amountToCopy) + .CopyTo(buffer[bytesWritten..]); +#else + block.AsSpan(blockAndOffset.Offset, amountToCopy) + .CopyTo(buffer.Slice(bytesWritten)); +#endif + + bytesWritten += amountToCopy; + bytesRemaining -= amountToCopy; + + ++blockAndOffset.Block; + blockAndOffset.Offset = 0; + } + + return bytesWritten; + } + + amountToCopy = (int)Math.Min(buffer.Length, this.length - fromPosition); + this.largeBuffer.AsSpan((int)fromPosition, amountToCopy).CopyTo(buffer); + return amountToCopy; + } + + private struct BlockAndOffset + { + public int Block; + public int Offset; + + public BlockAndOffset(int block, int offset) + { + this.Block = block; + this.Offset = offset; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private BlockAndOffset GetBlockAndRelativeOffset(long offset) + { + int blockSize = this.memoryManager.OptionsValue.BlockSize; + int blockIndex = (int)Math.DivRem(offset, blockSize, out long offsetIndex); + return new BlockAndOffset(blockIndex, (int)offsetIndex); + } + + private void EnsureCapacity(long newCapacity) + { + if (newCapacity > this.memoryManager.OptionsValue.MaximumStreamCapacity && this.memoryManager.OptionsValue.MaximumStreamCapacity > 0) + { + this.memoryManager.ReportStreamOverCapacity(this.id, this.tag, newCapacity, this.AllocationStack); + + throw new OutOfMemoryException($"Requested capacity is too large: {newCapacity}. Limit is {this.memoryManager.OptionsValue.MaximumStreamCapacity}."); + } + + if (this.largeBuffer != null) + { + if (newCapacity > this.largeBuffer.Length) + { + byte[] newBuffer = this.memoryManager.GetLargeBuffer(newCapacity, this.id, this.tag); + Debug.Assert(this.length <= int.MaxValue); + this.InternalRead(newBuffer, 0, (int)this.length, 0); + this.ReleaseLargeBuffer(); + this.largeBuffer = newBuffer; + } + } + else + { + // Let's save some re-allocation of the blocks list + long blocksRequired = (newCapacity / this.memoryManager.OptionsValue.BlockSize) + 1; + if (this.blocks.Capacity < blocksRequired) + { + this.blocks.Capacity = (int)blocksRequired; + } + + while (this.Capacity64 < newCapacity) + { + this.blocks.Add(this.memoryManager.GetBlock()); + } + } + } + + /// + /// Release the large buffer (either stores it for eventual release or returns it immediately). + /// + private void ReleaseLargeBuffer() + { + Debug.Assert(this.largeBuffer != null); + + if (this.memoryManager.OptionsValue.AggressiveBufferReturn) + { + this.memoryManager.ReturnLargeBuffer(this.largeBuffer!, this.id, this.tag); + } + else + { + // We most likely will only ever need space for one + this.dirtyBuffers ??= new List(1); + this.dirtyBuffers.Add(this.largeBuffer!); + } + + this.largeBuffer = null; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void AssertLengthIsSmall() + { + Debug.Assert(this.length <= int.MaxValue, "this.length was assumed to be <= Int32.MaxValue, but was larger."); + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.EventArgs.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.EventArgs.cs new file mode 100644 index 0000000000..7c450a418a --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.EventArgs.cs @@ -0,0 +1,456 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror +{ + using System; + + /// + /// Wrapper for EventArgs + /// + public sealed partial class RecyclableMemoryStreamManager + { + /// + /// Arguments for the event. + /// + public sealed class StreamCreatedEventArgs : EventArgs + { + /// + /// Gets unique ID for the stream. + /// + public Guid Id { get; } + + /// + /// Gets optional Tag for the event. + /// + public string Tag { get; } + + /// + /// Gets requested stream size. + /// + public long RequestedSize { get; } + + /// + /// Gets actual stream size. + /// + public long ActualSize { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Unique ID of the stream. + /// Tag of the stream. + /// The requested stream size. + /// The actual stream size. + public StreamCreatedEventArgs(Guid guid, string tag, long requestedSize, long actualSize) + { + this.Id = guid; + this.Tag = tag; + this.RequestedSize = requestedSize; + this.ActualSize = actualSize; + } + } + + /// + /// Arguments for the event. + /// + public sealed class StreamDisposedEventArgs : EventArgs + { + /// + /// Gets unique ID for the stream. + /// + public Guid Id { get; } + + /// + /// Gets optional Tag for the event. + /// + public string Tag { get; } + + /// + /// Gets stack where the stream was allocated. + /// + public string AllocationStack { get; } + + /// + /// Gets stack where stream was disposed. + /// + public string DisposeStack { get; } + + /// + /// Gets lifetime of the stream. + /// + public TimeSpan Lifetime { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Unique ID of the stream. + /// Tag of the stream. + /// Lifetime of the stream + /// Stack of original allocation. + /// Dispose stack. + public StreamDisposedEventArgs(Guid guid, string tag, TimeSpan lifetime, string allocationStack, string disposeStack) + { + this.Id = guid; + this.Tag = tag; + this.Lifetime = lifetime; + this.AllocationStack = allocationStack; + this.DisposeStack = disposeStack; + } + } + + /// + /// Arguments for the event. + /// + public sealed class StreamDoubleDisposedEventArgs : EventArgs + { + /// + /// Gets unique ID for the stream. + /// + public Guid Id { get; } + + /// + /// Gets optional Tag for the event. + /// + public string Tag { get; } + + /// + /// Gets stack where the stream was allocated. + /// + public string AllocationStack { get; } + + /// + /// Gets first dispose stack. + /// + public string DisposeStack1 { get; } + + /// + /// Gets second dispose stack. + /// + public string DisposeStack2 { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Unique ID of the stream. + /// Tag of the stream. + /// Stack of original allocation. + /// First dispose stack. + /// Second dispose stack. + public StreamDoubleDisposedEventArgs(Guid guid, string tag, string allocationStack, string disposeStack1, string disposeStack2) + { + this.Id = guid; + this.Tag = tag; + this.AllocationStack = allocationStack; + this.DisposeStack1 = disposeStack1; + this.DisposeStack2 = disposeStack2; + } + } + + /// + /// Arguments for the event. + /// + public sealed class StreamFinalizedEventArgs : EventArgs + { + /// + /// Gets unique ID for the stream. + /// + public Guid Id { get; } + + /// + /// Gets optional Tag for the event. + /// + public string Tag { get; } + + /// + /// Gets stack where the stream was allocated. + /// + public string AllocationStack { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Unique ID of the stream. + /// Tag of the stream. + /// Stack of original allocation. + public StreamFinalizedEventArgs(Guid guid, string tag, string allocationStack) + { + this.Id = guid; + this.Tag = tag; + this.AllocationStack = allocationStack; + } + } + + /// + /// Arguments for the event. + /// + public sealed class StreamConvertedToArrayEventArgs : EventArgs + { + /// + /// Gets unique ID for the stream. + /// + public Guid Id { get; } + + /// + /// Gets optional Tag for the event. + /// + public string Tag { get; } + + /// + /// Gets stack where ToArray was called. + /// + public string Stack { get; } + + /// + /// Gets length of stack. + /// + public long Length { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Unique ID of the stream. + /// Tag of the stream. + /// Stack of ToArray call. + /// Length of stream. + public StreamConvertedToArrayEventArgs(Guid guid, string tag, string stack, long length) + { + this.Id = guid; + this.Tag = tag; + this.Stack = stack; + this.Length = length; + } + } + + /// + /// Arguments for the event. + /// + public sealed class StreamOverCapacityEventArgs : EventArgs + { + /// + /// Gets unique ID for the stream. + /// + public Guid Id { get; } + + /// + /// Gets optional Tag for the event. + /// + public string Tag { get; } + + /// + /// Gets original allocation stack. + /// + public string AllocationStack { get; } + + /// + /// Gets requested capacity. + /// + public long RequestedCapacity { get; } + + /// + /// Gets maximum capacity. + /// + public long MaximumCapacity { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Unique ID of the stream. + /// Tag of the stream. + /// Requested capacity. + /// Maximum stream capacity of the manager. + /// Original allocation stack. + internal StreamOverCapacityEventArgs(Guid guid, string tag, long requestedCapacity, long maximumCapacity, string allocationStack) + { + this.Id = guid; + this.Tag = tag; + this.RequestedCapacity = requestedCapacity; + this.MaximumCapacity = maximumCapacity; + this.AllocationStack = allocationStack; + } + } + + /// + /// Arguments for the event. + /// + public sealed class BlockCreatedEventArgs : EventArgs + { + /// + /// Gets how many bytes are currently in use from the small pool. + /// + public long SmallPoolInUse { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Number of bytes currently in use from the small pool. + internal BlockCreatedEventArgs(long smallPoolInUse) + { + this.SmallPoolInUse = smallPoolInUse; + } + } + + /// + /// Arguments for the events. + /// + public sealed class LargeBufferCreatedEventArgs : EventArgs + { + /// + /// Gets unique ID for the stream. + /// + public Guid Id { get; } + + /// + /// Gets optional Tag for the event. + /// + public string Tag { get; } + + /// + /// Gets a value indicating whether whether the buffer was satisfied from the pool or not. + /// + public bool Pooled { get; } + + /// + /// Gets required buffer size. + /// + public long RequiredSize { get; } + + /// + /// Gets how many bytes are in use from the large pool. + /// + public long LargePoolInUse { get; } + + /// + /// Gets if the buffer was not satisfied from the pool, and is turned on, then. + /// this will contain the call stack of the allocation request. + /// + public string CallStack { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Unique ID of the stream. + /// Tag of the stream. + /// Required size of the new buffer. + /// How many bytes from the large pool are currently in use. + /// Whether the buffer was satisfied from the pool or not. + /// Call stack of the allocation, if it wasn't pooled. + internal LargeBufferCreatedEventArgs(Guid guid, string tag, long requiredSize, long largePoolInUse, bool pooled, string callStack) + { + this.RequiredSize = requiredSize; + this.LargePoolInUse = largePoolInUse; + this.Pooled = pooled; + this.Id = guid; + this.Tag = tag; + this.CallStack = callStack; + } + } + + /// + /// Arguments for the event. + /// + public sealed class BufferDiscardedEventArgs : EventArgs + { + /// + /// Gets unique ID for the stream. + /// + public Guid Id { get; } + + /// + /// Gets optional Tag for the event. + /// + public string Tag { get; } + + /// + /// Gets type of the buffer. + /// + public Events.MemoryStreamBufferType BufferType { get; } + + /// + /// Gets the reason this buffer was discarded. + /// + public Events.MemoryStreamDiscardReason Reason { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Unique ID of the stream. + /// Tag of the stream. + /// Type of buffer being discarded. + /// The reason for the discard. + internal BufferDiscardedEventArgs(Guid guid, string tag, Events.MemoryStreamBufferType bufferType, Events.MemoryStreamDiscardReason reason) + { + this.Id = guid; + this.Tag = tag; + this.BufferType = bufferType; + this.Reason = reason; + } + } + + /// + /// Arguments for the event. + /// + public sealed class StreamLengthEventArgs : EventArgs + { + /// + /// Gets length of the stream. + /// + public long Length { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Length of the strength. + public StreamLengthEventArgs(long length) + { + this.Length = length; + } + } + + /// + /// Arguments for the event. + /// + public sealed class UsageReportEventArgs : EventArgs + { + /// + /// Gets bytes from the small pool currently in use. + /// + public long SmallPoolInUseBytes { get; } + + /// + /// Gets bytes from the small pool currently available. + /// + public long SmallPoolFreeBytes { get; } + + /// + /// Gets bytes from the large pool currently in use. + /// + public long LargePoolInUseBytes { get; } + + /// + /// Gets bytes from the large pool currently available. + /// + public long LargePoolFreeBytes { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Bytes from the small pool currently in use. + /// Bytes from the small pool currently available. + /// Bytes from the large pool currently in use. + /// Bytes from the large pool currently available. + public UsageReportEventArgs( + long smallPoolInUseBytes, + long smallPoolFreeBytes, + long largePoolInUseBytes, + long largePoolFreeBytes) + { + this.SmallPoolInUseBytes = smallPoolInUseBytes; + this.SmallPoolFreeBytes = smallPoolFreeBytes; + this.LargePoolInUseBytes = largePoolInUseBytes; + this.LargePoolFreeBytes = largePoolFreeBytes; + } + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.Events.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.Events.cs new file mode 100644 index 0000000000..81a24f9de8 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.Events.cs @@ -0,0 +1,288 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +// --------------------------------------------------------------------- +// Copyright (c) 2015 Microsoft +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// --------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror +{ + using System; + using System.Diagnostics.Tracing; + + /// + /// Holder for Events + /// + public sealed partial class RecyclableMemoryStreamManager + { + /// + /// ETW events for RecyclableMemoryStream. + /// + [EventSource(Name = "Microsoft-IO-RecyclableMemoryStream", Guid = "{B80CD4E4-890E-468D-9CBA-90EB7C82DFC7}")] + public sealed class Events : EventSource + { + /// + /// Static log object, through which all events are written. + /// +#pragma warning disable SA1401 // Fields should be private +#pragma warning disable CA2211 // Non-constant fields should not be visible + public static Events Writer = new (); +#pragma warning restore CA2211 // Non-constant fields should not be visible +#pragma warning restore SA1401 // Fields should be private + + /// + /// Type of buffer. + /// + public enum MemoryStreamBufferType + { + /// + /// Small block buffer. + /// + Small, + + /// + /// Large pool buffer. + /// + Large, + } + + /// + /// The possible reasons for discarding a buffer. + /// + public enum MemoryStreamDiscardReason + { + /// + /// Buffer was too large to be re-pooled. + /// + TooLarge, + + /// + /// There are enough free bytes in the pool. + /// + EnoughFree, + } + + /// + /// Logged when a stream object is created. + /// + /// A unique ID for this stream. + /// A temporary ID for this stream, usually indicates current usage. + /// Requested size of the stream. + /// Actual size given to the stream from the pool. + [Event(1, Level = EventLevel.Verbose, Version = 2)] + public void MemoryStreamCreated(Guid guid, string tag, long requestedSize, long actualSize) + { + if (this.IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + this.WriteEvent(1, guid, tag ?? string.Empty, requestedSize, actualSize); + } + } + + /// + /// Logged when the stream is disposed. + /// + /// A unique ID for this stream. + /// A temporary ID for this stream, usually indicates current usage. + /// Lifetime in milliseconds of the stream + /// Call stack of initial allocation. + /// Call stack of the dispose. + [Event(2, Level = EventLevel.Verbose, Version = 3)] + public void MemoryStreamDisposed(Guid guid, string tag, long lifetimeMs, string allocationStack, string disposeStack) + { + if (this.IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + this.WriteEvent(2, guid, tag ?? string.Empty, lifetimeMs, allocationStack ?? string.Empty, disposeStack ?? string.Empty); + } + } + + /// + /// Logged when the stream is disposed for the second time. + /// + /// A unique ID for this stream. + /// A temporary ID for this stream, usually indicates current usage. + /// Call stack of initial allocation. + /// Call stack of the first dispose. + /// Call stack of the second dispose. + /// Note: Stacks will only be populated if RecyclableMemoryStreamManager.GenerateCallStacks is true. + [Event(3, Level = EventLevel.Critical)] + public void MemoryStreamDoubleDispose( + Guid guid, + string tag, + string allocationStack, + string disposeStack1, + string disposeStack2) + { + if (this.IsEnabled()) + { + this.WriteEvent( + 3, + guid, + tag ?? string.Empty, + allocationStack ?? string.Empty, + disposeStack1 ?? string.Empty, + disposeStack2 ?? string.Empty); + } + } + + /// + /// Logged when a stream is finalized. + /// + /// A unique ID for this stream. + /// A temporary ID for this stream, usually indicates current usage. + /// Call stack of initial allocation. + /// Note: Stacks will only be populated if RecyclableMemoryStreamManager.GenerateCallStacks is true. + [Event(4, Level = EventLevel.Error)] + public void MemoryStreamFinalized(Guid guid, string tag, string allocationStack) + { + if (this.IsEnabled()) + { + this.WriteEvent(4, guid, tag ?? string.Empty, allocationStack ?? string.Empty); + } + } + + /// + /// Logged when ToArray is called on a stream. + /// + /// A unique ID for this stream. + /// A temporary ID for this stream, usually indicates current usage. + /// Call stack of the ToArray call. + /// Length of stream. + /// Note: Stacks will only be populated if RecyclableMemoryStreamManager.GenerateCallStacks is true. + [Event(5, Level = EventLevel.Verbose, Version = 2)] + public void MemoryStreamToArray(Guid guid, string tag, string stack, long size) + { + if (this.IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + this.WriteEvent(5, guid, tag ?? string.Empty, stack ?? string.Empty, size); + } + } + + /// + /// Logged when the RecyclableMemoryStreamManager is initialized. + /// + /// Size of blocks, in bytes. + /// Size of the large buffer multiple, in bytes. + /// Maximum buffer size, in bytes. + [Event(6, Level = EventLevel.Informational)] + public void MemoryStreamManagerInitialized(int blockSize, int largeBufferMultiple, int maximumBufferSize) + { + if (this.IsEnabled()) + { + this.WriteEvent(6, blockSize, largeBufferMultiple, maximumBufferSize); + } + } + + /// + /// Logged when a new block is created. + /// + /// Number of bytes in the small pool currently in use. + [Event(7, Level = EventLevel.Warning, Version = 2)] + public void MemoryStreamNewBlockCreated(long smallPoolInUseBytes) + { + if (this.IsEnabled(EventLevel.Warning, EventKeywords.None)) + { + this.WriteEvent(7, smallPoolInUseBytes); + } + } + + /// + /// Logged when a new large buffer is created. + /// + /// Requested size. + /// Number of bytes in the large pool in use. + [Event(8, Level = EventLevel.Warning, Version = 3)] + public void MemoryStreamNewLargeBufferCreated(long requiredSize, long largePoolInUseBytes) + { + if (this.IsEnabled(EventLevel.Warning, EventKeywords.None)) + { + this.WriteEvent(8, requiredSize, largePoolInUseBytes); + } + } + + /// + /// Logged when a buffer is created that is too large to pool. + /// + /// Unique stream ID. + /// A temporary ID for this stream, usually indicates current usage. + /// Size requested by the caller. + /// Call stack of the requested stream. + /// Note: Stacks will only be populated if RecyclableMemoryStreamManager.GenerateCallStacks is true. + [Event(9, Level = EventLevel.Verbose, Version = 3)] + public void MemoryStreamNonPooledLargeBufferCreated(Guid guid, string tag, long requiredSize, string allocationStack) + { + if (this.IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + this.WriteEvent(9, guid, tag ?? string.Empty, requiredSize, allocationStack ?? string.Empty); + } + } + + /// + /// Logged when a buffer is discarded (not put back in the pool, but given to GC to clean up). + /// + /// Unique stream ID. + /// A temporary ID for this stream, usually indicates current usage. + /// Type of the buffer being discarded. + /// Reason for the discard. + /// Number of free small pool blocks. + /// Bytes free in the small pool. + /// Bytes in use from the small pool. + /// Number of free large pool blocks. + /// Bytes free in the large pool. + /// Bytes in use from the large pool. + [Event(10, Level = EventLevel.Warning, Version = 2)] + public void MemoryStreamDiscardBuffer( + Guid guid, + string tag, + MemoryStreamBufferType bufferType, + MemoryStreamDiscardReason reason, + long smallBlocksFree, + long smallPoolBytesFree, + long smallPoolBytesInUse, + long largeBlocksFree, + long largePoolBytesFree, + long largePoolBytesInUse) + { + if (this.IsEnabled(EventLevel.Warning, EventKeywords.None)) + { + this.WriteEvent(10, guid, tag ?? string.Empty, bufferType, reason, smallBlocksFree, smallPoolBytesFree, smallPoolBytesInUse, largeBlocksFree, largePoolBytesFree, largePoolBytesInUse); + } + } + + /// + /// Logged when a stream grows beyond the maximum capacity. + /// + /// Unique stream ID + /// A temporary ID for this stream, usually indicates current usage. + /// The requested capacity. + /// Maximum capacity, as configured by RecyclableMemoryStreamManager. + /// Call stack for the capacity request. + /// Note: Stacks will only be populated if RecyclableMemoryStreamManager.GenerateCallStacks is true. + [Event(11, Level = EventLevel.Error, Version = 3)] + public void MemoryStreamOverCapacity(Guid guid, string tag, long requestedCapacity, long maxCapacity, string allocationStack) + { + if (this.IsEnabled()) + { + this.WriteEvent(11, guid, tag ?? string.Empty, requestedCapacity, maxCapacity, allocationStack ?? string.Empty); + } + } + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.cs new file mode 100644 index 0000000000..8348ccc809 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RecyclableMemoryStreamMirror/RecyclableMemoryStreamManager.cs @@ -0,0 +1,989 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +// --------------------------------------------------------------------- +// Copyright (c) 2015-2016 Microsoft +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// --------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror +{ + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.Runtime.CompilerServices; + using System.Threading; + using Microsoft.Extensions.Logging; + + /// + /// Manages pools of objects. + /// + /// + /// + /// There are two pools managed in here. The small pool contains same-sized buffers that are handed to streams + /// as they write more data. + /// + /// + /// For scenarios that need to call , the large pool contains buffers of various sizes, all + /// multiples/exponentials of (1 MB by default). They are split by size to avoid overly-wasteful buffer + /// usage. There should be far fewer 8 MB buffers than 1 MB buffers, for example. + /// + /// + public partial class RecyclableMemoryStreamManager + { + /// + /// Maximum length of a single array. + /// + /// See documentation at https://docs.microsoft.com/dotnet/api/system.array?view=netcore-3.1 + /// + internal const int MaxArrayLength = 0X7FFFFFC7; + + /// + /// Default block size, in bytes. + /// + public const int DefaultBlockSize = 128 * 1024; + + /// + /// Default large buffer multiple, in bytes. + /// + public const int DefaultLargeBufferMultiple = 1024 * 1024; + + /// + /// Default maximum buffer size, in bytes. + /// + public const int DefaultMaximumBufferSize = 128 * 1024 * 1024; + + // 0 to indicate unbounded + private const long DefaultMaxSmallPoolFreeBytes = 0L; + private const long DefaultMaxLargePoolFreeBytes = 0L; + + private readonly long[] largeBufferFreeSize; + private readonly long[] largeBufferInUseSize; + + private readonly ConcurrentStack[] largePools; + + private readonly ConcurrentStack smallPool; + +#pragma warning disable SA1401 // Fields should be private - performance reasons + internal readonly Options OptionsValue; +#pragma warning restore SA1401 // Fields should be private + + private long smallPoolFreeSize; + private long smallPoolInUseSize; + + /// + /// Gets settings for controlling the behavior of RecyclableMemoryStream + /// + public Options Settings => this.OptionsValue; + + /// + /// Gets number of bytes in small pool not currently in use. + /// + public long SmallPoolFreeSize => this.smallPoolFreeSize; + + /// + /// Gets number of bytes currently in use by stream from the small pool. + /// + public long SmallPoolInUseSize => this.smallPoolInUseSize; + + /// + /// Gets number of bytes in large pool not currently in use. + /// + public long LargePoolFreeSize + { + get + { + long sum = 0; + foreach (long freeSize in this.largeBufferFreeSize) + { + sum += freeSize; + } + + return sum; + } + } + + /// + /// Gets number of bytes currently in use by streams from the large pool. + /// + public long LargePoolInUseSize + { + get + { + long sum = 0; + foreach (long inUseSize in this.largeBufferInUseSize) + { + sum += inUseSize; + } + + return sum; + } + } + + /// + /// Gets how many blocks are in the small pool. + /// + public long SmallBlocksFree => this.smallPool.Count; + + /// + /// Gets how many buffers are in the large pool. + /// + public long LargeBuffersFree + { + get + { + long free = 0; + foreach (ConcurrentStack pool in this.largePools) + { + free += pool.Count; + } + + return free; + } + } + + /// + /// Parameters for customizing the behavior of + /// + public class Options + { + /// + /// Gets or sets the size of the pooled blocks. This must be greater than 0. + /// + /// The default size 131,072 (128KB) + public int BlockSize { get; set; } = DefaultBlockSize; + + /// + /// Gets or sets each large buffer will be a multiple exponential of this value + /// + /// The default value is 1,048,576 (1MB) + public int LargeBufferMultiple { get; set; } = DefaultLargeBufferMultiple; + + /// + /// Gets or sets buffer beyond this length are not pooled. + /// + /// The default value is 134,217,728 (128MB) + public int MaximumBufferSize { get; set; } = DefaultMaximumBufferSize; + + /// + /// Gets or sets maximum number of bytes to keep available in the small pool. + /// + /// + /// Trying to return buffers to the pool beyond this limit will result in them being garbage collected. + /// The default value is 0, but all users should set a reasonable value depending on your application's memory requirements. + /// + public long MaximumSmallPoolFreeBytes { get; set; } + + /// + /// Gets or sets maximum number of bytes to keep available in the large pools. + /// + /// + /// Trying to return buffers to the pool beyond this limit will result in them being garbage collected. + /// The default value is 0, but all users should set a reasonable value depending on your application's memory requirements. + /// + public long MaximumLargePoolFreeBytes { get; set; } + + /// + /// Gets or sets a value indicating whether whether to use the exponential allocation strategy (see documentation). + /// + /// The default value is false. + public bool UseExponentialLargeBuffer { get; set; } = false; + + /// + /// Gets or sets maximum stream capacity in bytes. Attempts to set a larger capacity will + /// result in an exception. + /// + /// The default value of 0 indicates no limit. + public long MaximumStreamCapacity { get; set; } = 0; + + /// + /// Gets or sets a value indicating whether whether to save call stacks for stream allocations. This can help in debugging. + /// It should NEVER be turned on generally in production. + /// + public bool GenerateCallStacks { get; set; } = false; + + /// + /// Gets or sets a value indicating whether whether dirty buffers can be immediately returned to the buffer pool. + /// + /// + /// + /// When is called on a stream and creates a single large buffer, if this setting is enabled, the other blocks will be returned + /// to the buffer pool immediately. + /// + /// + /// Note when enabling this setting that the user is responsible for ensuring that any buffer previously + /// retrieved from a stream which is subsequently modified is not used after modification (as it may no longer + /// be valid). + /// + /// + public bool AggressiveBufferReturn { get; set; } = false; + + /// + /// Gets or sets a value indicating whether causes an exception to be thrown if is ever called. + /// + /// Calling defeats the purpose of a pooled buffer. Use this property to discover code that is calling . If this is + /// set and is called, a NotSupportedException will be thrown. + public bool ThrowExceptionOnToArray { get; set; } = false; + + /// + /// Gets or sets a value indicating whether zero out buffers on allocation and before returning them to the pool. + /// + /// Setting this to true causes a performance hit and should only be set if one wants to avoid accidental data leaks. + public bool ZeroOutBuffer { get; set; } = false; + + /// + /// Creates a new object. + /// + public Options() + { + } + + /// + /// Creates a new object with the most common options. + /// + /// Size of the blocks in the small pool. + /// Size of the large buffer multiple + /// Maximum poolable buffer size. + /// Maximum bytes to hold in the small pool. + /// Maximum bytes to hold in each of the large pools. + public Options(int blockSize, int largeBufferMultiple, int maximumBufferSize, long maximumSmallPoolFreeBytes, long maximumLargePoolFreeBytes) + { + this.BlockSize = blockSize; + this.LargeBufferMultiple = largeBufferMultiple; + this.MaximumBufferSize = maximumBufferSize; + this.MaximumSmallPoolFreeBytes = maximumSmallPoolFreeBytes; + this.MaximumLargePoolFreeBytes = maximumLargePoolFreeBytes; + } + } + + /// + /// Initializes the memory manager with the default block/buffer specifications. This pool may have unbounded growth unless you modify . + /// + public RecyclableMemoryStreamManager() + : this(new Options()) + { + } + + /// + /// Initializes the memory manager with the given block requiredSize. + /// + /// Object specifying options for stream behavior. + /// + /// is not a positive number, + /// or is not a positive number, + /// or is less than options.BlockSize, + /// or is negative, + /// or is negative, + /// or is not a multiple/exponential of . + /// + public RecyclableMemoryStreamManager(Options options) + { + if (options.BlockSize <= 0) + { + throw new InvalidOperationException($"{nameof(options.BlockSize)} must be a positive number"); + } + + if (options.LargeBufferMultiple <= 0) + { + throw new InvalidOperationException($"{nameof(options.LargeBufferMultiple)} must be a positive number"); + } + + if (options.MaximumBufferSize < options.BlockSize) + { + throw new InvalidOperationException($"{nameof(options.MaximumBufferSize)} must be at least {nameof(options.BlockSize)}"); + } + + if (options.MaximumSmallPoolFreeBytes < 0) + { + throw new InvalidOperationException($"{nameof(options.MaximumSmallPoolFreeBytes)} must be non-negative"); + } + + if (options.MaximumLargePoolFreeBytes < 0) + { + throw new InvalidOperationException($"{nameof(options.MaximumLargePoolFreeBytes)} must be non-negative"); + } + + this.OptionsValue = options; + + if (!this.IsLargeBufferSize(options.MaximumBufferSize)) + { + throw new InvalidOperationException( + $"{nameof(options.MaximumBufferSize)} is not {(options.UseExponentialLargeBuffer ? "an exponential" : "a multiple")} of {nameof(options.LargeBufferMultiple)}."); + } + + this.smallPool = new ConcurrentStack(); + int numLargePools = options.UseExponentialLargeBuffer + ? (int)Math.Log(options.MaximumBufferSize / options.LargeBufferMultiple, 2) + 1 + : options.MaximumBufferSize / options.LargeBufferMultiple; + + // +1 to store size of bytes in use that are too large to be pooled + this.largeBufferInUseSize = new long[numLargePools + 1]; + this.largeBufferFreeSize = new long[numLargePools]; + + this.largePools = new ConcurrentStack[numLargePools]; + + for (int i = 0; i < this.largePools.Length; ++i) + { + this.largePools[i] = new ConcurrentStack(); + } + + Events.Writer.MemoryStreamManagerInitialized(options.BlockSize, options.LargeBufferMultiple, options.MaximumBufferSize); + } + + /// + /// Removes and returns a single block from the pool. + /// + /// A byte[] array. + internal byte[] GetBlock() + { + Interlocked.Add(ref this.smallPoolInUseSize, this.OptionsValue.BlockSize); + + if (!this.smallPool.TryPop(out byte[] block)) + { + // We'll add this back to the pool when the stream is disposed + // (unless our free pool is too large) +#if NET6_0_OR_GREATER + block = this.OptionsValue.ZeroOutBuffer ? GC.AllocateArray(this.OptionsValue.BlockSize) : GC.AllocateUninitializedArray(this.OptionsValue.BlockSize); +#else + block = new byte[this.OptionsValue.BlockSize]; +#endif + this.ReportBlockCreated(); + } + else + { + Interlocked.Add(ref this.smallPoolFreeSize, -this.OptionsValue.BlockSize); + } + + return block; + } + + /// + /// Returns a buffer of arbitrary size from the large buffer pool. This buffer + /// will be at least the requiredSize and always be a multiple/exponential of largeBufferMultiple. + /// + /// The minimum length of the buffer. + /// Unique ID for the stream. + /// The tag of the stream returning this buffer, for logging if necessary. + /// A buffer of at least the required size. + /// Requested array size is larger than the maximum allowed. + internal byte[] GetLargeBuffer(long requiredSize, Guid id, string tag) + { + requiredSize = this.RoundToLargeBufferSize(requiredSize); + + if (requiredSize > MaxArrayLength) + { + throw new OutOfMemoryException($"Required buffer size exceeds maximum array length of {MaxArrayLength}."); + } + + int poolIndex = this.GetPoolIndex(requiredSize); + + bool createdNew = false; + bool pooled = true; + string callStack = null; + + byte[] buffer; + if (poolIndex < this.largePools.Length) + { + if (!this.largePools[poolIndex].TryPop(out buffer)) + { + buffer = AllocateArray(requiredSize, this.OptionsValue.ZeroOutBuffer); + createdNew = true; + } + else + { + Interlocked.Add(ref this.largeBufferFreeSize[poolIndex], -buffer.Length); + } + } + else + { + // Buffer is too large to pool. They get a new buffer. + + // We still want to track the size, though, and we've reserved a slot + // in the end of the in-use array for non-pooled bytes in use. + poolIndex = this.largeBufferInUseSize.Length - 1; + + // We still want to round up to reduce heap fragmentation. + buffer = AllocateArray(requiredSize, this.OptionsValue.ZeroOutBuffer); + if (this.OptionsValue.GenerateCallStacks) + { + // Grab the stack -- we want to know who requires such large buffers + callStack = Environment.StackTrace; + } + + createdNew = true; + pooled = false; + } + + Interlocked.Add(ref this.largeBufferInUseSize[poolIndex], buffer.Length); + if (createdNew) + { + this.ReportLargeBufferCreated(id, tag, requiredSize, pooled: pooled, callStack); + } + + return buffer; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static byte[] AllocateArray(long requiredSize, bool zeroInitializeArray) => +#if NET6_0_OR_GREATER + zeroInitializeArray ? GC.AllocateArray((int)requiredSize) : GC.AllocateUninitializedArray((int)requiredSize); +#else + new byte[requiredSize]; +#endif + } + + private long RoundToLargeBufferSize(long requiredSize) + { + if (this.OptionsValue.UseExponentialLargeBuffer) + { + long pow = 1; + while (this.OptionsValue.LargeBufferMultiple * pow < requiredSize) + { + pow <<= 1; + } + + return this.OptionsValue.LargeBufferMultiple * pow; + } + else + { + return (requiredSize + this.OptionsValue.LargeBufferMultiple - 1) / this.OptionsValue.LargeBufferMultiple * this.OptionsValue.LargeBufferMultiple; + } + } + + private bool IsLargeBufferSize(int value) + { + return value != 0 && (this.OptionsValue.UseExponentialLargeBuffer + ? value == this.RoundToLargeBufferSize(value) + : value % this.OptionsValue.LargeBufferMultiple == 0); + } + + private int GetPoolIndex(long length) + { + if (this.OptionsValue.UseExponentialLargeBuffer) + { + int index = 0; + while (this.OptionsValue.LargeBufferMultiple << index < length) + { + ++index; + } + + return index; + } + else + { + return (int)((length / this.OptionsValue.LargeBufferMultiple) - 1); + } + } + + /// + /// Returns the buffer to the large pool. + /// + /// The buffer to return. + /// Unique stream ID. + /// The tag of the stream returning this buffer, for logging if necessary. + /// is null. + /// buffer.Length is not a multiple/exponential of (it did not originate from this pool). + internal void ReturnLargeBuffer(byte[] buffer, Guid id, string tag) + { +#if NET8_0_OR_GREATER + ArgumentNullException.ThrowIfNull(buffer); +#else + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } +#endif + + if (!this.IsLargeBufferSize(buffer.Length)) + { + throw new ArgumentException($"{nameof(buffer)} did not originate from this memory manager. The size is not " + + $"{(this.OptionsValue.UseExponentialLargeBuffer ? "an exponential" : "a multiple")} of {this.OptionsValue.LargeBufferMultiple}."); + } + + this.ZeroOutMemoryIfEnabled(buffer); + int poolIndex = this.GetPoolIndex(buffer.Length); + if (poolIndex < this.largePools.Length) + { + if ((this.largePools[poolIndex].Count + 1) * buffer.Length <= this.OptionsValue.MaximumLargePoolFreeBytes || + this.OptionsValue.MaximumLargePoolFreeBytes == 0) + { + this.largePools[poolIndex].Push(buffer); + Interlocked.Add(ref this.largeBufferFreeSize[poolIndex], buffer.Length); + } + else + { + this.ReportBufferDiscarded(id, tag, Events.MemoryStreamBufferType.Large, Events.MemoryStreamDiscardReason.EnoughFree); + } + } + else + { + // This is a non-poolable buffer, but we still want to track its size for in-use + // analysis. We have space in the InUse array for this. + poolIndex = this.largeBufferInUseSize.Length - 1; + this.ReportBufferDiscarded(id, tag, Events.MemoryStreamBufferType.Large, Events.MemoryStreamDiscardReason.TooLarge); + } + + Interlocked.Add(ref this.largeBufferInUseSize[poolIndex], -buffer.Length); + } + + /// + /// Returns the blocks to the pool. + /// + /// Collection of blocks to return to the pool. + /// Unique Stream ID. + /// The tag of the stream returning these blocks, for logging if necessary. + /// is null. + /// contains buffers that are the wrong size (or null) for this memory manager. + internal void ReturnBlocks(List blocks, Guid id, string tag) + { +#if NET8_0_OR_GREATER + ArgumentNullException.ThrowIfNull(blocks); +#else + if (blocks == null) + { + throw new ArgumentNullException(nameof(blocks)); + } +#endif + + long bytesToReturn = blocks.Count * (long)this.OptionsValue.BlockSize; + Interlocked.Add(ref this.smallPoolInUseSize, -bytesToReturn); + + foreach (byte[] block in blocks) + { + if (block == null || block.Length != this.OptionsValue.BlockSize) + { + throw new ArgumentException($"{nameof(blocks)} contains buffers that are not {nameof(this.OptionsValue.BlockSize)} in length.", nameof(blocks)); + } + } + + foreach (byte[] block in blocks) + { + this.ZeroOutMemoryIfEnabled(block); + if (this.OptionsValue.MaximumSmallPoolFreeBytes == 0 || this.SmallPoolFreeSize < this.OptionsValue.MaximumSmallPoolFreeBytes) + { + Interlocked.Add(ref this.smallPoolFreeSize, this.OptionsValue.BlockSize); + this.smallPool.Push(block); + } + else + { + this.ReportBufferDiscarded(id, tag, Events.MemoryStreamBufferType.Small, Events.MemoryStreamDiscardReason.EnoughFree); + break; + } + } + } + + /// + /// Returns a block to the pool. + /// + /// Block to return to the pool. + /// Unique Stream ID. + /// The tag of the stream returning this, for logging if necessary. + /// is null. + /// is the wrong size for this memory manager. + internal void ReturnBlock(byte[] block, Guid id, string tag) + { + int bytesToReturn = this.OptionsValue.BlockSize; + Interlocked.Add(ref this.smallPoolInUseSize, -bytesToReturn); + +#if NET8_0_OR_GREATER + ArgumentNullException.ThrowIfNull(block); +#else + if (block == null) + { + throw new ArgumentNullException(nameof(block)); + } +#endif + + if (block.Length != this.OptionsValue.BlockSize) + { + throw new ArgumentException($"{nameof(block)} is not not {nameof(this.OptionsValue.BlockSize)} in length."); + } + + this.ZeroOutMemoryIfEnabled(block); + if (this.OptionsValue.MaximumSmallPoolFreeBytes == 0 || this.SmallPoolFreeSize < this.OptionsValue.MaximumSmallPoolFreeBytes) + { + Interlocked.Add(ref this.smallPoolFreeSize, this.OptionsValue.BlockSize); + this.smallPool.Push(block); + } + else + { + this.ReportBufferDiscarded(id, tag, Events.MemoryStreamBufferType.Small, Events.MemoryStreamDiscardReason.EnoughFree); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ZeroOutMemoryIfEnabled(byte[] buffer) + { + if (this.OptionsValue.ZeroOutBuffer) + { +#if NET6_0_OR_GREATER + Array.Clear(buffer); +#else + Array.Clear(buffer, 0, buffer.Length); +#endif + } + } + + internal void ReportBlockCreated() + { + Events.Writer.MemoryStreamNewBlockCreated(this.smallPoolInUseSize); + this.BlockCreated?.Invoke(this, new BlockCreatedEventArgs(this.smallPoolInUseSize)); + } + + internal void ReportLargeBufferCreated(Guid id, string tag, long requiredSize, bool pooled, string callStack) + { + if (pooled) + { + Events.Writer.MemoryStreamNewLargeBufferCreated(requiredSize, this.LargePoolInUseSize); + } + else + { + Events.Writer.MemoryStreamNonPooledLargeBufferCreated(id, tag, requiredSize, callStack); + } + + this.LargeBufferCreated?.Invoke(this, new LargeBufferCreatedEventArgs(id, tag, requiredSize, this.LargePoolInUseSize, pooled, callStack)); + } + + internal void ReportBufferDiscarded(Guid id, string tag, Events.MemoryStreamBufferType bufferType, Events.MemoryStreamDiscardReason reason) + { + Events.Writer.MemoryStreamDiscardBuffer( + id, + tag, + bufferType, + reason, + this.SmallBlocksFree, + this.smallPoolFreeSize, + this.smallPoolInUseSize, + this.LargeBuffersFree, + this.LargePoolFreeSize, + this.LargePoolInUseSize); + this.BufferDiscarded?.Invoke(this, new BufferDiscardedEventArgs(id, tag, bufferType, reason)); + } + + internal void ReportStreamCreated(Guid id, string tag, long requestedSize, long actualSize) + { + Events.Writer.MemoryStreamCreated(id, tag, requestedSize, actualSize); + this.StreamCreated?.Invoke(this, new StreamCreatedEventArgs(id, tag, requestedSize, actualSize)); + } + + internal void ReportStreamDisposed(Guid id, string tag, TimeSpan lifetime, string allocationStack, string disposeStack) + { + Events.Writer.MemoryStreamDisposed(id, tag, (long)lifetime.TotalMilliseconds, allocationStack, disposeStack); + this.StreamDisposed?.Invoke(this, new StreamDisposedEventArgs(id, tag, lifetime, allocationStack, disposeStack)); + } + + internal void ReportStreamDoubleDisposed(Guid id, string tag, string allocationStack, string disposeStack1, string disposeStack2) + { + Events.Writer.MemoryStreamDoubleDispose(id, tag, allocationStack, disposeStack1, disposeStack2); + this.StreamDoubleDisposed?.Invoke(this, new StreamDoubleDisposedEventArgs(id, tag, allocationStack, disposeStack1, disposeStack2)); + } + + internal void ReportStreamFinalized(Guid id, string tag, string allocationStack) + { + Events.Writer.MemoryStreamFinalized(id, tag, allocationStack); + this.StreamFinalized?.Invoke(this, new StreamFinalizedEventArgs(id, tag, allocationStack)); + } + + internal void ReportStreamLength(long bytes) + { + this.StreamLength?.Invoke(this, new StreamLengthEventArgs(bytes)); + } + + internal void ReportStreamToArray(Guid id, string tag, string stack, long length) + { + Events.Writer.MemoryStreamToArray(id, tag, stack, length); + this.StreamConvertedToArray?.Invoke(this, new StreamConvertedToArrayEventArgs(id, tag, stack, length)); + } + + internal void ReportStreamOverCapacity(Guid id, string tag, long requestedCapacity, string allocationStack) + { + Events.Writer.MemoryStreamOverCapacity(id, tag, requestedCapacity, this.OptionsValue.MaximumStreamCapacity, allocationStack); + this.StreamOverCapacity?.Invoke(this, new StreamOverCapacityEventArgs(id, tag, requestedCapacity, this.OptionsValue.MaximumStreamCapacity, allocationStack)); + } + + internal void ReportUsageReport() + { + this.UsageReport?.Invoke(this, new UsageReportEventArgs(this.smallPoolInUseSize, this.smallPoolFreeSize, this.LargePoolInUseSize, this.LargePoolFreeSize)); + } + + /// + /// Retrieve a new object with no tag and a default initial capacity. + /// + /// A . + public RecyclableMemoryStream GetStream() + { + return new RecyclableMemoryStream(this); + } + + /// + /// Retrieve a new object with no tag and a default initial capacity. + /// + /// A unique identifier which can be used to trace usages of the stream. + /// A . + public RecyclableMemoryStream GetStream(Guid id) + { + return new RecyclableMemoryStream(this, id); + } + + /// + /// Retrieve a new object with the given tag and a default initial capacity. + /// + /// A tag which can be used to track the source of the stream. + /// A . + public RecyclableMemoryStream GetStream(string tag) + { + return new RecyclableMemoryStream(this, tag); + } + + /// + /// Retrieve a new object with the given tag and a default initial capacity. + /// + /// A unique identifier which can be used to trace usages of the stream. + /// A tag which can be used to track the source of the stream. + /// A . + public RecyclableMemoryStream GetStream(Guid id, string tag) + { + return new RecyclableMemoryStream(this, id, tag); + } + + /// + /// Retrieve a new object with the given tag and at least the given capacity. + /// + /// A tag which can be used to track the source of the stream. + /// The minimum desired capacity for the stream. + /// A . + public RecyclableMemoryStream GetStream(string tag, long requiredSize) + { + return new RecyclableMemoryStream(this, tag, requiredSize); + } + + /// + /// Retrieve a new object with the given tag and at least the given capacity. + /// + /// A unique identifier which can be used to trace usages of the stream. + /// A tag which can be used to track the source of the stream. + /// The minimum desired capacity for the stream. + /// A . + public RecyclableMemoryStream GetStream(Guid id, string tag, long requiredSize) + { + return new RecyclableMemoryStream(this, id, tag, requiredSize); + } + + /// + /// Retrieve a new object with the given tag and at least the given capacity, possibly using + /// a single contiguous underlying buffer. + /// + /// Retrieving a which provides a single contiguous buffer can be useful in situations + /// where the initial size is known and it is desirable to avoid copying data between the smaller underlying + /// buffers to a single large one. This is most helpful when you know that you will always call + /// on the underlying stream. + /// A unique identifier which can be used to trace usages of the stream. + /// A tag which can be used to track the source of the stream. + /// The minimum desired capacity for the stream. + /// Whether to attempt to use a single contiguous buffer. + /// A . + public RecyclableMemoryStream GetStream(Guid id, string tag, long requiredSize, bool asContiguousBuffer) + { + if (!asContiguousBuffer || requiredSize <= this.OptionsValue.BlockSize) + { + return this.GetStream(id, tag, requiredSize); + } + + return new RecyclableMemoryStream(this, id, tag, requiredSize, this.GetLargeBuffer(requiredSize, id, tag)); + } + + /// + /// Retrieve a new object with the given tag and at least the given capacity, possibly using + /// a single contiguous underlying buffer. + /// + /// Retrieving a which provides a single contiguous buffer can be useful in situations + /// where the initial size is known and it is desirable to avoid copying data between the smaller underlying + /// buffers to a single large one. This is most helpful when you know that you will always call + /// on the underlying stream. + /// A tag which can be used to track the source of the stream. + /// The minimum desired capacity for the stream. + /// Whether to attempt to use a single contiguous buffer. + /// A . + public RecyclableMemoryStream GetStream(string tag, long requiredSize, bool asContiguousBuffer) + { + return this.GetStream(Guid.NewGuid(), tag, requiredSize, asContiguousBuffer); + } + + /// + /// Retrieve a new object with the given tag and with contents copied from the provided + /// buffer. The provided buffer is not wrapped or used after construction. + /// + /// The new stream's position is set to the beginning of the stream when returned. + /// A unique identifier which can be used to trace usages of the stream. + /// A tag which can be used to track the source of the stream. + /// The byte buffer to copy data from. + /// The offset from the start of the buffer to copy from. + /// The number of bytes to copy from the buffer. + /// A . + public RecyclableMemoryStream GetStream(Guid id, string tag, byte[] buffer, int offset, int count) + { + RecyclableMemoryStream stream = null; + try + { + stream = new RecyclableMemoryStream(this, id, tag, count); + stream.Write(buffer, offset, count); + stream.Position = 0; + return stream; + } + catch + { + stream?.Dispose(); + throw; + } + } + + /// + /// Retrieve a new object with the contents copied from the provided + /// buffer. The provided buffer is not wrapped or used after construction. + /// + /// The new stream's position is set to the beginning of the stream when returned. + /// The byte buffer to copy data from. + /// A . + public RecyclableMemoryStream GetStream(byte[] buffer) + { + return this.GetStream(null, buffer, 0, buffer.Length); + } + + /// + /// Retrieve a new object with the given tag and with contents copied from the provided + /// buffer. The provided buffer is not wrapped or used after construction. + /// + /// The new stream's position is set to the beginning of the stream when returned. + /// A tag which can be used to track the source of the stream. + /// The byte buffer to copy data from. + /// The offset from the start of the buffer to copy from. + /// The number of bytes to copy from the buffer. + /// A . + public RecyclableMemoryStream GetStream(string tag, byte[] buffer, int offset, int count) + { + return this.GetStream(Guid.NewGuid(), tag, buffer, offset, count); + } + + /// + /// Retrieve a new object with the given tag and with contents copied from the provided + /// buffer. The provided buffer is not wrapped or used after construction. + /// + /// The new stream's position is set to the beginning of the stream when returned. + /// A unique identifier which can be used to trace usages of the stream. + /// A tag which can be used to track the source of the stream. + /// The byte buffer to copy data from. + /// A . + public RecyclableMemoryStream GetStream(Guid id, string tag, ReadOnlySpan buffer) + { + RecyclableMemoryStream stream = null; + try + { + stream = new RecyclableMemoryStream(this, id, tag, buffer.Length); + stream.Write(buffer); + stream.Position = 0; + return stream; + } + catch + { + stream?.Dispose(); + throw; + } + } + + /// + /// Retrieve a new object with the contents copied from the provided + /// buffer. The provided buffer is not wrapped or used after construction. + /// + /// The new stream's position is set to the beginning of the stream when returned. + /// The byte buffer to copy data from. + /// A . + public RecyclableMemoryStream GetStream(ReadOnlySpan buffer) + { + return this.GetStream(null, buffer); + } + + /// + /// Retrieve a new object with the given tag and with contents copied from the provided + /// buffer. The provided buffer is not wrapped or used after construction. + /// + /// The new stream's position is set to the beginning of the stream when returned. + /// A tag which can be used to track the source of the stream. + /// The byte buffer to copy data from. + /// A . + public RecyclableMemoryStream GetStream(string tag, ReadOnlySpan buffer) + { + return this.GetStream(Guid.NewGuid(), tag, buffer); + } + + /// + /// Triggered when a new block is created. + /// + public event EventHandler BlockCreated; + + /// + /// Triggered when a new large buffer is created. + /// + public event EventHandler LargeBufferCreated; + + /// + /// Triggered when a new stream is created. + /// + public event EventHandler StreamCreated; + + /// + /// Triggered when a stream is disposed. + /// + public event EventHandler StreamDisposed; + + /// + /// Triggered when a stream is disposed of twice (an error). + /// + public event EventHandler StreamDoubleDisposed; + + /// + /// Triggered when a stream is finalized. + /// + public event EventHandler StreamFinalized; + + /// + /// Triggered when a stream is disposed to report the stream's length. + /// + public event EventHandler StreamLength; + + /// + /// Triggered when a user converts a stream to array. + /// + public event EventHandler StreamConvertedToArray; + + /// + /// Triggered when a stream is requested to expand beyond the maximum length specified by the responsible RecyclableMemoryStreamManager. + /// + public event EventHandler StreamOverCapacity; + + /// + /// Triggered when a buffer of either type is discarded, along with the reason for the discard. + /// + public event EventHandler BufferDiscarded; + + /// + /// Periodically triggered to report usage statistics. + /// + public event EventHandler UsageReport; + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/RentArrayBufferWriter.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/RentArrayBufferWriter.cs deleted file mode 100644 index 5b37ded4fb..0000000000 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/RentArrayBufferWriter.cs +++ /dev/null @@ -1,211 +0,0 @@ -// ------------------------------------------------------------ -// Copyright (c) Microsoft Corporation. All rights reserved. -// ------------------------------------------------------------ - -namespace Microsoft.Azure.Cosmos.Encryption.Custom; - -#if NET8_0_OR_GREATER - -using System; -using System.Buffers; -using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -/// -/// https://gist.github.com/ahsonkhan/c76a1cc4dc7107537c3fdc0079a68b35 -/// Standard ArrayBufferWriter is not using pooled memory -/// -internal class RentArrayBufferWriter : IBufferWriter, IDisposable -{ - private const int MinimumBufferSize = 256; - - private byte[] rentedBuffer; - private int written; - private long committed; - - public RentArrayBufferWriter(int initialCapacity = MinimumBufferSize) - { - if (initialCapacity <= 0) - { - throw new ArgumentException(null, nameof(initialCapacity)); - } - - this.rentedBuffer = ArrayPool.Shared.Rent(initialCapacity); - this.written = 0; - this.committed = 0; - } - - public (byte[], int) WrittenBuffer - { - get - { - this.CheckIfDisposed(); - - return (this.rentedBuffer, this.written); - } - } - - public Memory WrittenMemory - { - get - { - this.CheckIfDisposed(); - - return this.rentedBuffer.AsMemory(0, this.written); - } - } - - public Span WrittenSpan - { - get - { - this.CheckIfDisposed(); - - return this.rentedBuffer.AsSpan(0, this.written); - } - } - - public int BytesWritten - { - get - { - this.CheckIfDisposed(); - - return this.written; - } - } - - public long BytesCommitted - { - get - { - this.CheckIfDisposed(); - - return this.committed; - } - } - - public void Clear() - { - this.CheckIfDisposed(); - - this.ClearHelper(); - } - - private void ClearHelper() - { - this.rentedBuffer.AsSpan(0, this.written).Clear(); - this.written = 0; - } - - public async Task CopyToAsync(Stream stream, CancellationToken cancellationToken = default) - { - this.CheckIfDisposed(); - - ArgumentNullException.ThrowIfNull(stream); - - await stream.WriteAsync(new Memory(this.rentedBuffer, 0, this.written), cancellationToken).ConfigureAwait(false); - this.committed += this.written; - - this.ClearHelper(); - } - - public void CopyTo(Stream stream) - { - this.CheckIfDisposed(); - - ArgumentNullException.ThrowIfNull(stream); - - stream.Write(this.rentedBuffer, 0, this.written); - this.committed += this.written; - - this.ClearHelper(); - } - - public void Advance(int count) - { - this.CheckIfDisposed(); - - ArgumentOutOfRangeException.ThrowIfLessThan(count, 0); - - if (this.written > this.rentedBuffer.Length - count) - { - throw new InvalidOperationException("Cannot advance past the end of the buffer."); - } - - this.written += count; - } - - // Returns the rented buffer back to the pool - public void Dispose() - { - if (this.rentedBuffer == null) - { - return; - } - - ArrayPool.Shared.Return(this.rentedBuffer, clearArray: true); - this.rentedBuffer = null; - this.written = 0; - } - - private void CheckIfDisposed() - { - ObjectDisposedException.ThrowIf(this.rentedBuffer == null, this); - } - - public Memory GetMemory(int sizeHint = 0) - { - this.CheckIfDisposed(); - - ArgumentOutOfRangeException.ThrowIfLessThan(sizeHint, 0); - - this.CheckAndResizeBuffer(sizeHint); - return this.rentedBuffer.AsMemory(this.written); - } - - public Span GetSpan(int sizeHint = 0) - { - this.CheckIfDisposed(); - - ArgumentOutOfRangeException.ThrowIfLessThan(sizeHint, 0); - - this.CheckAndResizeBuffer(sizeHint); - return this.rentedBuffer.AsSpan(this.written); - } - - private void CheckAndResizeBuffer(int sizeHint) - { - Debug.Assert(sizeHint >= 0); - - if (sizeHint == 0) - { - sizeHint = MinimumBufferSize; - } - - int availableSpace = this.rentedBuffer.Length - this.written; - - if (sizeHint > availableSpace) - { - int growBy = sizeHint > this.rentedBuffer.Length ? sizeHint : this.rentedBuffer.Length; - - int newSize = checked(this.rentedBuffer.Length + growBy); - - byte[] oldBuffer = this.rentedBuffer; - - this.rentedBuffer = ArrayPool.Shared.Rent(newSize); - - Debug.Assert(oldBuffer.Length >= this.written); - Debug.Assert(this.rentedBuffer.Length >= this.written); - - oldBuffer.AsSpan(0, this.written).CopyTo(this.rentedBuffer); - ArrayPool.Shared.Return(oldBuffer, clearArray: true); - } - - Debug.Assert(this.rentedBuffer.Length - this.written > 0); - Debug.Assert(this.rentedBuffer.Length - this.written >= sizeHint); - } -} -#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamManager.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamManager.cs new file mode 100644 index 0000000000..ed7831b783 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamManager.cs @@ -0,0 +1,31 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +#if NET8_0_OR_GREATER +namespace Microsoft.Azure.Cosmos.Encryption.Custom +{ + using System.IO; + using System.Threading.Tasks; + + /// + /// Abstraction for pooling streams + /// + public abstract class StreamManager + { + /// + /// Create stream + /// + /// Desired minimal size of stream. + /// Instance of stream. + public abstract Stream CreateStream(int hintSize = 0); + + /// + /// Dispose of used Stream (return to pool) + /// + /// Stream to dispose. + /// ValueTask.CompletedTask + public abstract ValueTask ReturnStreamAsync(Stream stream); + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/DecryptableItemStream.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/DecryptableItemStream.cs new file mode 100644 index 0000000000..9e5d44c17a --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/DecryptableItemStream.cs @@ -0,0 +1,109 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER +namespace Microsoft.Azure.Cosmos.Encryption.Custom.StreamProcessing +{ + using System; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror; + + internal sealed class DecryptableItemStream : DecryptableItem + { + private readonly Encryptor encryptor; + private readonly JsonProcessor jsonProcessor; + private readonly CosmosSerializer cosmosSerializer; + private readonly StreamManager streamManager; + + private Stream encryptedStream; // this stream should be recyclable + private Stream decryptedStream; // this stream should be recyclable + private DecryptionContext decryptionContext; + + private bool isDisposed; + + public DecryptableItemStream( + Stream encryptedStream, + Encryptor encryptor, + JsonProcessor processor, + CosmosSerializer cosmosSerializer, + StreamManager streamManager) + { + this.encryptedStream = encryptedStream; + this.encryptor = encryptor; + this.jsonProcessor = processor; + this.cosmosSerializer = cosmosSerializer; + this.streamManager = streamManager; + } + + public override Task<(T, DecryptionContext)> GetItemAsync() + { + return this.GetItemAsync(CancellationToken.None); + } + + public override async Task<(T, DecryptionContext)> GetItemAsync(CancellationToken cancellationToken) + { + ObjectDisposedException.ThrowIf(this.isDisposed, this); + + if (this.decryptedStream == null) + { + this.decryptedStream = this.streamManager.CreateStream(); + + this.decryptionContext = await EncryptionProcessor.DecryptAsync( + this.encryptedStream, + this.decryptedStream, + this.encryptor, + new CosmosDiagnosticsContext(), + this.jsonProcessor, + cancellationToken); + + await this.encryptedStream.DisposeAsync(); + this.encryptedStream = null; + } + + T selector = default; + switch (selector) + { + case Stream: // consumer doesn't need payload deserialized + MemoryStream ms = new ((int)this.decryptedStream.Length); + await this.decryptedStream.CopyToAsync(ms, cancellationToken); + ms.Position = 0; + return ((T)(object)ms, this.decryptionContext); + default: +#if SDKPROJECTREF + return (await this.cosmosSerializer.FromStreamAsync(this.decryptedStream, cancellationToken), this.decryptionContext); +#else + // this API is missing Async => should not be used + return (this.cosmosSerializer.FromStream(this.decryptedStream), this.decryptionContext); +#endif + + } + } + + private void Dispose(bool disposing) + { + if (!this.isDisposed) + { + if (disposing) + { + this.encryptedStream?.Dispose(); + this.decryptedStream?.Dispose(); + this.encryptedStream = null; + this.decryptedStream = null; + } + + this.isDisposed = true; + } + } + + public override void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptableItemStream.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptableItemStream.cs new file mode 100644 index 0000000000..87af42bfcd --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptableItemStream.cs @@ -0,0 +1,118 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER +namespace Microsoft.Azure.Cosmos.Encryption.Custom.StreamProcessing +{ + using System; + using System.IO; + using System.Text.Json; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos; + using Newtonsoft.Json.Linq; + + /// + /// Input type that can be used to allow for lazy decryption in the write path. + /// + /// Type of item. + public sealed class EncryptableItemStream : EncryptableItem, IDisposable + { + private DecryptableItemStream decryptableItem = null; + private bool isDisposed; + + /// + /// Gets the input item + /// + public T Item { get; } + + /// + public override DecryptableItem DecryptableItem + { + get + { + ObjectDisposedException.ThrowIf(this.isDisposed, this); + return this.decryptableItem ?? throw new InvalidOperationException("Decryptable content is not initialized."); + } + } + + /// + /// Initializes a new instance of the class. + /// + /// Item to be written. + /// Thrown when input is null. + public EncryptableItemStream(T input) + { + this.Item = input ?? throw new ArgumentNullException(nameof(input)); + } + + /// + protected internal override void SetDecryptableItem(JToken decryptableContent, Encryptor encryptor, CosmosSerializer cosmosSerializer) + { + throw new NotImplementedException(); + } + + /// + protected internal override void SetDecryptableStream(Stream decryptableStream, Encryptor encryptor, JsonProcessor jsonProcessor, CosmosSerializer cosmosSerializer, StreamManager streamManager) + { + ArgumentNullException.ThrowIfNull(decryptableStream); + + this.decryptableItem = new DecryptableItemStream(decryptableStream, encryptor, jsonProcessor, cosmosSerializer, streamManager); + } + + /// + protected internal override Stream ToStream(CosmosSerializer serializer) + { + if (this.Item is Stream stream) + { + return stream; + } + + return serializer.ToStream(this.Item); + } + + /// + protected internal override async Task ToStreamAsync(CosmosSerializer serializer, Stream outputStream, CancellationToken cancellationToken) + { + if (this.Item is Stream stream) + { + await stream.CopyToAsync(outputStream, cancellationToken); + stream.Position = 0; + outputStream.Position = 0; + return; + } + +#if SDKPROJECTREF + await serializer.ToStreamAsync(this.Item, outputStream, cancellationToken); +#else + // TODO: CosmosSerializer is lacking suitable methods + Stream cosmosSerializerOutput = serializer.ToStream(this.Item); + await cosmosSerializerOutput.CopyToAsync(outputStream, cancellationToken); + outputStream.Position = 0; +#endif + } + + private void Dispose(bool disposing) + { + if (!this.isDisposed) + { + if (disposing) + { + this.DecryptableItem?.Dispose(); + } + + this.isDisposed = true; + } + } + + /// + public override void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionContainerStream.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionContainerStream.cs new file mode 100644 index 0000000000..3f9bb6d3be --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionContainerStream.cs @@ -0,0 +1,1110 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + +namespace Microsoft.Azure.Cosmos.Encryption.Custom +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Threading; + using System.Threading.Tasks; + + using Microsoft.Azure.Cosmos.Encryption.Custom.StreamProcessing; + + using Newtonsoft.Json.Linq; + + internal sealed class EncryptionContainerStream : Container + { + private readonly Container container; + + public CosmosSerializer CosmosSerializer { get; } + + public Encryptor Encryptor { get; } + + public CosmosResponseFactory ResponseFactory { get; } + + private readonly StreamManager streamManager; + + /// + /// All the operations / requests for exercising client-side encryption functionality need to be made using this EncryptionContainer instance. + /// + /// Regular cosmos container. + /// Provider that allows encrypting and decrypting data. + public EncryptionContainerStream(Container container, Encryptor encryptor) + : this(container, encryptor, new MemoryStreamManager()) + { + } + + /// + /// All the operations / requests for exercising client-side encryption functionality need to be made using this EncryptionContainer instance. + /// + /// Regular cosmos container. + /// Provider that allows encrypting and decrypting data. + /// Custom stream manager instance. + public EncryptionContainerStream( + Container container, + Encryptor encryptor, + StreamManager streamManager) + { + this.container = container ?? throw new ArgumentNullException(nameof(container)); + this.Encryptor = encryptor ?? throw new ArgumentNullException(nameof(encryptor)); + this.ResponseFactory = this.Database.Client.ResponseFactory; + this.CosmosSerializer = this.Database.Client.ClientOptions.Serializer; + this.streamManager = streamManager; + } + + public override string Id => this.container.Id; + + public override Conflicts Conflicts => this.container.Conflicts; + + public override Scripts.Scripts Scripts => this.container.Scripts; + + public override Database Database => this.container.Database; + + public override async Task> CreateItemAsync( + T item, + PartitionKey? partitionKey = null, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + if (item == null) + { + throw new ArgumentNullException(nameof(item)); + } + + if (requestOptions is not EncryptionItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + return await this.container.CreateItemAsync( + item, + partitionKey, + requestOptions, + cancellationToken); + } + + if (partitionKey == null) + { + throw new NotSupportedException($"{nameof(partitionKey)} cannot be null for operations using {nameof(EncryptionContainer)}."); + } + + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("CreateItem")) + { + ResponseMessage responseMessage; + + if (item is EncryptableItem encryptableItemStream) + { + using Stream rms = this.streamManager.CreateStream(); + await encryptableItemStream.ToStreamAsync(this.CosmosSerializer, rms, cancellationToken); + responseMessage = await this.CreateItemHelperAsync( + rms, + partitionKey.Value, + requestOptions, + decryptResponse: false, + diagnosticsContext, + cancellationToken); + + encryptableItemStream.SetDecryptableStream(responseMessage.Content, this.Encryptor, encryptionItemRequestOptions.EncryptionOptions.JsonProcessor, this.CosmosSerializer, this.streamManager); + + return new EncryptionItemResponse(responseMessage, item); + } + else + { + using Stream itemStream = this.CosmosSerializer.ToStream(item); + responseMessage = await this.CreateItemHelperAsync( + itemStream, + partitionKey.Value, + requestOptions, + decryptResponse: true, + diagnosticsContext, + cancellationToken); + + return this.ResponseFactory.CreateItemResponse(responseMessage); + } + } + } + + public override async Task CreateItemStreamAsync( + Stream streamPayload, + PartitionKey partitionKey, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(streamPayload); + + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("CreateItemStream")) + { + return await this.CreateItemHelperAsync( + streamPayload, + partitionKey, + requestOptions, + decryptResponse: true, + diagnosticsContext, + cancellationToken); + } + } + + private async Task CreateItemHelperAsync( + Stream streamPayload, + PartitionKey partitionKey, + ItemRequestOptions requestOptions, + bool decryptResponse, + CosmosDiagnosticsContext diagnosticsContext, + CancellationToken cancellationToken) + { + if (requestOptions is not EncryptionItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + return await this.container.CreateItemStreamAsync( + streamPayload, + partitionKey, + requestOptions, + cancellationToken); + } + + using Stream encryptedStream = this.streamManager.CreateStream(); + await EncryptionProcessor.EncryptAsync( + streamPayload, + encryptedStream, + this.Encryptor, + encryptionItemRequestOptions.EncryptionOptions, + diagnosticsContext, + cancellationToken); + + ResponseMessage responseMessage = await this.container.CreateItemStreamAsync( + encryptedStream, + partitionKey, + requestOptions, + cancellationToken); + + if (decryptResponse) + { + Stream decryptedStream = this.streamManager.CreateStream(); + _ = await EncryptionProcessor.DecryptAsync( + responseMessage.Content, + decryptedStream, + this.Encryptor, + diagnosticsContext, + encryptionItemRequestOptions.EncryptionOptions.JsonProcessor, + cancellationToken); + responseMessage.Content = decryptedStream; + } + + return responseMessage; + } + + public override Task> DeleteItemAsync( + string id, + PartitionKey partitionKey, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.DeleteItemAsync( + id, + partitionKey, + requestOptions, + cancellationToken); + } + + public override Task DeleteItemStreamAsync( + string id, + PartitionKey partitionKey, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.DeleteItemStreamAsync( + id, + partitionKey, + requestOptions, + cancellationToken); + } + + public override async Task> ReadItemAsync( + string id, + PartitionKey partitionKey, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("ReadItem")) + { + ResponseMessage responseMessage; + + if (typeof(T) == typeof(DecryptableItem)) + { + responseMessage = await this.ReadItemHelperAsync( + id, + partitionKey, + requestOptions, + decryptResponse: false, + diagnosticsContext, + cancellationToken); + + EncryptionItemRequestOptions options = requestOptions as EncryptionItemRequestOptions; + DecryptableItem decryptableItem = new DecryptableItemStream( + responseMessage.Content, + this.Encryptor, + options?.EncryptionOptions?.JsonProcessor ?? JsonProcessor.Newtonsoft, + this.CosmosSerializer, + this.streamManager); + + return new EncryptionItemResponse( + responseMessage, + (T)(object)decryptableItem); + } + + responseMessage = await this.ReadItemHelperAsync( + id, + partitionKey, + requestOptions, + decryptResponse: true, + diagnosticsContext, + cancellationToken); + + return this.ResponseFactory.CreateItemResponse(responseMessage); + } + } + + public override async Task ReadItemStreamAsync( + string id, + PartitionKey partitionKey, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("ReadItemStream")) + { + return await this.ReadItemHelperAsync( + id, + partitionKey, + requestOptions, + decryptResponse: true, + diagnosticsContext, + cancellationToken); + } + } + + private async Task ReadItemHelperAsync( + string id, + PartitionKey partitionKey, + ItemRequestOptions requestOptions, + bool decryptResponse, + CosmosDiagnosticsContext diagnosticsContext, + CancellationToken cancellationToken) + { + ResponseMessage responseMessage = await this.container.ReadItemStreamAsync( + id, + partitionKey, + requestOptions, + cancellationToken); + + if (decryptResponse) + { + JsonProcessor processor = (requestOptions as EncryptionItemRequestOptions)?.EncryptionOptions?.JsonProcessor ?? JsonProcessor.Newtonsoft; + Stream rms = this.streamManager.CreateStream(); + _ = await EncryptionProcessor.DecryptAsync(responseMessage.Content, rms, this.Encryptor, diagnosticsContext, processor, cancellationToken); + responseMessage.Content = rms; + } + + return responseMessage; + } + + public override async Task> ReplaceItemAsync( + T item, + string id, + PartitionKey? partitionKey = null, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(id); + ArgumentNullException.ThrowIfNull(item); + + if (requestOptions is not EncryptionItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + return await this.container.ReplaceItemAsync( + item, + id, + partitionKey, + requestOptions, + cancellationToken); + } + + if (partitionKey == null) + { + throw new NotSupportedException($"{nameof(partitionKey)} cannot be null for operations using {nameof(EncryptionContainer)}."); + } + + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("ReplaceItem")) + { + ResponseMessage responseMessage; + + if (item is EncryptableItem encryptableItemStream) + { + using Stream rms = this.streamManager.CreateStream(); + await encryptableItemStream.ToStreamAsync(this.CosmosSerializer, rms, cancellationToken); + responseMessage = await this.ReplaceItemHelperAsync( + rms, + id, + partitionKey.Value, + requestOptions, + decryptResponse: false, + diagnosticsContext, + cancellationToken); + + encryptableItemStream.SetDecryptableStream(responseMessage.Content, this.Encryptor, encryptionItemRequestOptions.EncryptionOptions.JsonProcessor, this.CosmosSerializer, this.streamManager); + + return new EncryptionItemResponse(responseMessage, item); + } + else + { + using Stream itemStream = this.CosmosSerializer.ToStream(item); + responseMessage = await this.ReplaceItemHelperAsync( + itemStream, + id, + partitionKey.Value, + requestOptions, + decryptResponse: true, + diagnosticsContext, + cancellationToken); + + return this.ResponseFactory.CreateItemResponse(responseMessage); + } + } + } + + public override async Task ReplaceItemStreamAsync( + Stream streamPayload, + string id, + PartitionKey partitionKey, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(id); + ArgumentNullException.ThrowIfNull(streamPayload); + + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("ReplaceItemStream")) + { + return await this.ReplaceItemHelperAsync( + streamPayload, + id, + partitionKey, + requestOptions, + decryptResponse: true, + diagnosticsContext, + cancellationToken); + } + } + + private async Task ReplaceItemHelperAsync( + Stream streamPayload, + string id, + PartitionKey partitionKey, + ItemRequestOptions requestOptions, + bool decryptResponse, + CosmosDiagnosticsContext diagnosticsContext, + CancellationToken cancellationToken) + { + if (requestOptions is not EncryptionItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + return await this.container.ReplaceItemStreamAsync( + streamPayload, + id, + partitionKey, + requestOptions, + cancellationToken); + } + + using Stream encryptedStream = this.streamManager.CreateStream(); + await EncryptionProcessor.EncryptAsync( + streamPayload, + encryptedStream, + this.Encryptor, + encryptionItemRequestOptions.EncryptionOptions, + diagnosticsContext, + cancellationToken); + + ResponseMessage responseMessage = await this.container.ReplaceItemStreamAsync( + encryptedStream, + id, + partitionKey, + requestOptions, + cancellationToken); + + if (decryptResponse) + { + Stream decryptedStream = this.streamManager.CreateStream(); + _ = await EncryptionProcessor.DecryptAsync( + responseMessage.Content, + decryptedStream, + this.Encryptor, + diagnosticsContext, + encryptionItemRequestOptions.EncryptionOptions.JsonProcessor, + cancellationToken); + responseMessage.Content = decryptedStream; + } + + return responseMessage; + } + + public override async Task> UpsertItemAsync( + T item, + PartitionKey? partitionKey = null, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + if (item == null) + { + throw new ArgumentNullException(nameof(item)); + } + + if (requestOptions is not EncryptionItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + return await this.container.UpsertItemAsync( + item, + partitionKey, + requestOptions, + cancellationToken); + } + + if (partitionKey == null) + { + throw new NotSupportedException($"{nameof(partitionKey)} cannot be null for operations using {nameof(EncryptionContainer)}."); + } + + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("UpsertItem")) + { + ResponseMessage responseMessage; + + if (item is EncryptableItem encryptableItemStream) + { + using Stream rms = this.streamManager.CreateStream(); + await encryptableItemStream.ToStreamAsync(this.CosmosSerializer, rms, cancellationToken); + responseMessage = await this.UpsertItemHelperAsync( + rms, + partitionKey.Value, + requestOptions, + decryptResponse: false, + diagnosticsContext, + cancellationToken); + + encryptableItemStream.SetDecryptableStream(responseMessage.Content, this.Encryptor, encryptionItemRequestOptions.EncryptionOptions.JsonProcessor, this.CosmosSerializer, this.streamManager); + + return new EncryptionItemResponse(responseMessage, item); + } + else + { + using (Stream itemStream = this.CosmosSerializer.ToStream(item)) + { + responseMessage = await this.UpsertItemHelperAsync( + itemStream, + partitionKey.Value, + requestOptions, + decryptResponse: true, + diagnosticsContext, + cancellationToken); + } + + return this.ResponseFactory.CreateItemResponse(responseMessage); + } + } + } + + public override async Task UpsertItemStreamAsync( + Stream streamPayload, + PartitionKey partitionKey, + ItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(streamPayload); + + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("UpsertItemStream")) + { + return await this.UpsertItemHelperAsync( + streamPayload, + partitionKey, + requestOptions, + decryptResponse: true, + diagnosticsContext, + cancellationToken); + } + } + + private async Task UpsertItemHelperAsync( + Stream streamPayload, + PartitionKey partitionKey, + ItemRequestOptions requestOptions, + bool decryptResponse, + CosmosDiagnosticsContext diagnosticsContext, + CancellationToken cancellationToken) + { + if (requestOptions is not EncryptionItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + return await this.container.UpsertItemStreamAsync( + streamPayload, + partitionKey, + requestOptions, + cancellationToken); + } + + Stream encryptedStream = this.streamManager.CreateStream(); + await EncryptionProcessor.EncryptAsync( + streamPayload, + encryptedStream, + this.Encryptor, + encryptionItemRequestOptions.EncryptionOptions, + diagnosticsContext, + cancellationToken); + + ResponseMessage responseMessage = await this.container.UpsertItemStreamAsync( + encryptedStream, + partitionKey, + requestOptions, + cancellationToken); + + if (decryptResponse) + { + Stream decryptStream = this.streamManager.CreateStream(); + _ = await EncryptionProcessor.DecryptAsync( + responseMessage.Content, + decryptStream, + this.Encryptor, + diagnosticsContext, + encryptionItemRequestOptions.EncryptionOptions.JsonProcessor, + cancellationToken); + responseMessage.Content = decryptStream; + } + + return responseMessage; + } + + public override TransactionalBatch CreateTransactionalBatch( + PartitionKey partitionKey) + { + return new EncryptionTransactionalBatchStream( + this.container.CreateTransactionalBatch(partitionKey), + this.Encryptor, + this.CosmosSerializer, + this.streamManager); + } + + public override Task DeleteContainerAsync( + ContainerRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.DeleteContainerAsync( + requestOptions, + cancellationToken); + } + + public override Task DeleteContainerStreamAsync( + ContainerRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.DeleteContainerStreamAsync( + requestOptions, + cancellationToken); + } + + public override ChangeFeedProcessorBuilder GetChangeFeedEstimatorBuilder( + string processorName, + ChangesEstimationHandler estimationDelegate, + TimeSpan? estimationPeriod = null) + { + return this.container.GetChangeFeedEstimatorBuilder( + processorName, + estimationDelegate, + estimationPeriod); + } + + public override IOrderedQueryable GetItemLinqQueryable( + bool allowSynchronousQueryExecution = false, + string continuationToken = null, + QueryRequestOptions requestOptions = null, + CosmosLinqSerializerOptions linqSerializerOptions = null) + { + return this.container.GetItemLinqQueryable( + allowSynchronousQueryExecution, + continuationToken, + requestOptions, + linqSerializerOptions); + } + + public override FeedIterator GetItemQueryIterator( + QueryDefinition queryDefinition, + string continuationToken = null, + QueryRequestOptions requestOptions = null) + { + return new EncryptionFeedIteratorStream( + (EncryptionFeedIteratorStream)this.GetItemQueryStreamIterator( + queryDefinition, + continuationToken, + requestOptions), + this.ResponseFactory); + } + + public override FeedIterator GetItemQueryIterator( + string queryText = null, + string continuationToken = null, + QueryRequestOptions requestOptions = null) + { + return new EncryptionFeedIteratorStream( + (EncryptionFeedIteratorStream)this.GetItemQueryStreamIterator( + queryText, + continuationToken, + requestOptions), + this.ResponseFactory); + } + + public override Task ReadContainerAsync( + ContainerRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.ReadContainerAsync( + requestOptions, + cancellationToken); + } + + public override Task ReadContainerStreamAsync( + ContainerRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.ReadContainerStreamAsync( + requestOptions, + cancellationToken); + } + + public override Task ReadThroughputAsync( + CancellationToken cancellationToken = default) + { + return this.container.ReadThroughputAsync(cancellationToken); + } + + public override Task ReadThroughputAsync( + RequestOptions requestOptions, + CancellationToken cancellationToken = default) + { + return this.container.ReadThroughputAsync( + requestOptions, + cancellationToken); + } + + public override Task ReplaceContainerAsync( + ContainerProperties containerProperties, + ContainerRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.ReplaceContainerAsync( + containerProperties, + requestOptions, + cancellationToken); + } + + public override Task ReplaceContainerStreamAsync( + ContainerProperties containerProperties, + ContainerRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.ReplaceContainerStreamAsync( + containerProperties, + requestOptions, + cancellationToken); + } + + public override Task ReplaceThroughputAsync( + int throughput, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.ReplaceThroughputAsync( + throughput, + requestOptions, + cancellationToken); + } + + public override FeedIterator GetItemQueryStreamIterator( + QueryDefinition queryDefinition, + string continuationToken = null, + QueryRequestOptions requestOptions = null) + { + return new EncryptionFeedIteratorStream( + this.container.GetItemQueryStreamIterator( + queryDefinition, + continuationToken, + requestOptions), + this.Encryptor, + this.CosmosSerializer, + this.streamManager); + } + + public override FeedIterator GetItemQueryStreamIterator( + string queryText = null, + string continuationToken = null, + QueryRequestOptions requestOptions = null) + { + return new EncryptionFeedIteratorStream( + this.container.GetItemQueryStreamIterator( + queryText, + continuationToken, + requestOptions), + this.Encryptor, + this.CosmosSerializer, + this.streamManager); + } + + public override Task ReplaceThroughputAsync( + ThroughputProperties throughputProperties, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.ReplaceThroughputAsync( + throughputProperties, + requestOptions, + cancellationToken); + } + + public override Task> GetFeedRangesAsync( + CancellationToken cancellationToken = default) + { + return this.container.GetFeedRangesAsync(cancellationToken); + } + + public override FeedIterator GetItemQueryStreamIterator( + FeedRange feedRange, + QueryDefinition queryDefinition, + string continuationToken, + QueryRequestOptions requestOptions = null) + { + return new EncryptionFeedIteratorStream( + this.container.GetItemQueryStreamIterator( + feedRange, + queryDefinition, + continuationToken, + requestOptions), + this.Encryptor, + this.CosmosSerializer, + this.streamManager); + } + + public override FeedIterator GetItemQueryIterator( + FeedRange feedRange, + QueryDefinition queryDefinition, + string continuationToken = null, + QueryRequestOptions requestOptions = null) + { + return new EncryptionFeedIteratorStream( + (EncryptionFeedIteratorStream)this.GetItemQueryStreamIterator( + feedRange, + queryDefinition, + continuationToken, + requestOptions), + this.ResponseFactory); + } + + public override ChangeFeedEstimator GetChangeFeedEstimator( + string processorName, + Container leaseContainer) + { + return this.container.GetChangeFeedEstimator(processorName, leaseContainer); + } + + public override FeedIterator GetChangeFeedStreamIterator( + ChangeFeedStartFrom changeFeedStartFrom, + ChangeFeedMode changeFeedMode, + ChangeFeedRequestOptions changeFeedRequestOptions = null) + { + return new EncryptionFeedIteratorStream( + this.container.GetChangeFeedStreamIterator( + changeFeedStartFrom, + changeFeedMode, + changeFeedRequestOptions), + this.Encryptor, + this.CosmosSerializer, + this.streamManager); + } + + public override FeedIterator GetChangeFeedIterator( + ChangeFeedStartFrom changeFeedStartFrom, + ChangeFeedMode changeFeedMode, + ChangeFeedRequestOptions changeFeedRequestOptions = null) + { + return new EncryptionFeedIteratorStream( + (EncryptionFeedIteratorStream)this.GetChangeFeedStreamIterator( + changeFeedStartFrom, + changeFeedMode, + changeFeedRequestOptions), + this.ResponseFactory); + } + + public override Task> PatchItemAsync( + string id, + PartitionKey partitionKey, + IReadOnlyList patchOperations, + PatchItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task PatchItemStreamAsync( + string id, + PartitionKey partitionKey, + IReadOnlyList patchOperations, + PatchItemRequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilder( + string processorName, + ChangesHandler onChangesDelegate) + { + return this.container.GetChangeFeedProcessorBuilder( + processorName, + async ( + IReadOnlyCollection documents, + CancellationToken cancellationToken) => + { + List decryptItems = await this.DecryptChangeFeedDocumentsAsync( + documents, + cancellationToken); + + // Call the original passed in delegate + await onChangesDelegate(decryptItems, cancellationToken); + }); + } + + public override ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilder( + string processorName, + ChangeFeedHandler onChangesDelegate) + { + return this.container.GetChangeFeedProcessorBuilder( + processorName, + async ( + ChangeFeedProcessorContext context, + IReadOnlyCollection documents, + CancellationToken cancellationToken) => + { + List decryptItems = await this.DecryptChangeFeedDocumentsAsync( + documents, + cancellationToken); + + // Call the original passed in delegate + await onChangesDelegate(context, decryptItems, cancellationToken); + }); + } + + public override ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilderWithManualCheckpoint( + string processorName, + ChangeFeedHandlerWithManualCheckpoint onChangesDelegate) + { + return this.container.GetChangeFeedProcessorBuilderWithManualCheckpoint( + processorName, + async ( + ChangeFeedProcessorContext context, + IReadOnlyCollection documents, + Func tryCheckpointAsync, + CancellationToken cancellationToken) => + { + List decryptItems = await this.DecryptChangeFeedDocumentsAsync( + documents, + cancellationToken); + + // Call the original passed in delegate + await onChangesDelegate(context, decryptItems, tryCheckpointAsync, cancellationToken); + }); + } + + public override ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilder( + string processorName, + ChangeFeedStreamHandler onChangesDelegate) + { + return this.container.GetChangeFeedProcessorBuilder( + processorName, + async ( + ChangeFeedProcessorContext context, + Stream changes, + CancellationToken cancellationToken) => + { + using Stream decryptedChanges = this.streamManager.CreateStream(); + await EncryptionProcessor.DeserializeAndDecryptResponseAsync( + changes, + decryptedChanges, + this.Encryptor, + this.streamManager, + cancellationToken); + + // Call the original passed in delegate + await onChangesDelegate(context, decryptedChanges, cancellationToken); + }); + } + + public override ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilderWithManualCheckpoint( + string processorName, + ChangeFeedStreamHandlerWithManualCheckpoint onChangesDelegate) + { + return this.container.GetChangeFeedProcessorBuilderWithManualCheckpoint( + processorName, + async ( + ChangeFeedProcessorContext context, + Stream changes, + Func tryCheckpointAsync, + CancellationToken cancellationToken) => + { + using Stream decryptedChanges = this.streamManager.CreateStream(); + await EncryptionProcessor.DeserializeAndDecryptResponseAsync( + changes, + decryptedChanges, + this.Encryptor, + this.streamManager, + cancellationToken); + + // Call the original passed in delegate + await onChangesDelegate(context, decryptedChanges, tryCheckpointAsync, cancellationToken); + }); + } + + public override Task ReadManyItemsStreamAsync( + IReadOnlyList<(string id, PartitionKey partitionKey)> items, + ReadManyRequestOptions readManyRequestOptions = null, + CancellationToken cancellationToken = default) + { + return this.ReadManyItemsHelperAsync( + items, + readManyRequestOptions, + cancellationToken); + } + + public override async Task> ReadManyItemsAsync( + IReadOnlyList<(string id, PartitionKey partitionKey)> items, + ReadManyRequestOptions readManyRequestOptions = null, + CancellationToken cancellationToken = default) + { + ResponseMessage responseMessage = await this.ReadManyItemsHelperAsync( + items, + readManyRequestOptions, + cancellationToken); + + return this.ResponseFactory.CreateItemFeedResponse(responseMessage); + } + +#if ENCRYPTIONPREVIEW + public override Task> GetPartitionKeyRangesAsync( + FeedRange feedRange, + CancellationToken cancellationToken = default) + { + return this.container.GetPartitionKeyRangesAsync(feedRange, cancellationToken); + } + + public override Task DeleteAllItemsByPartitionKeyStreamAsync( + Cosmos.PartitionKey partitionKey, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + { + return this.container.DeleteAllItemsByPartitionKeyStreamAsync( + partitionKey, + requestOptions, + cancellationToken); + } + + public override ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilderWithAllVersionsAndDeletes(string processorName, ChangeFeedHandler> onChangesDelegate) + { + return this.container.GetChangeFeedProcessorBuilderWithAllVersionsAndDeletes( + processorName, + onChangesDelegate); + } +#endif + +#if SDKPROJECTREF + public override Task IsFeedRangePartOfAsync( + Cosmos.FeedRange x, + Cosmos.FeedRange y, + CancellationToken cancellationToken = default) + { + return this.container.IsFeedRangePartOfAsync( + x, + y, + cancellationToken); + } +#endif + + private async Task ReadManyItemsHelperAsync( + IReadOnlyList<(string id, PartitionKey partitionKey)> items, + ReadManyRequestOptions readManyRequestOptions = null, + CancellationToken cancellationToken = default) + { + ResponseMessage responseMessage = await this.container.ReadManyItemsStreamAsync( + items, + readManyRequestOptions, + cancellationToken); + + Stream decryptedStream = this.streamManager.CreateStream(); + await EncryptionProcessor.DeserializeAndDecryptResponseAsync(responseMessage.Content, decryptedStream, this.Encryptor, this.streamManager, cancellationToken); + + return new DecryptedResponseMessage(responseMessage, decryptedStream); + } + + private async Task> DecryptChangeFeedDocumentsAsync( + IReadOnlyCollection documents, + CancellationToken cancellationToken) + { + List decryptItems = new (documents.Count); + if (typeof(T) == typeof(DecryptableItem)) + { + foreach (Stream documentStream in documents) + { + DecryptableItemStream item = new ( + documentStream, + this.Encryptor, + JsonProcessor.Stream, + this.CosmosSerializer, + this.streamManager); + + decryptItems.Add((T)(object)item); + } + } + else + { + foreach (Stream document in documents) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(null); + using (diagnosticsContext.CreateScope("DecryptChangeFeedDocumentsAsync<")) + { + using Stream decryptedStream = this.streamManager.CreateStream(); + _ = await EncryptionProcessor.DecryptAsync( + document, + decryptedStream, + this.Encryptor, + diagnosticsContext, + JsonProcessor.Stream, + cancellationToken); + +#if SDKPROJECTREF + decryptItems.Add(await this.CosmosSerializer.FromStreamAsync(decryptedStream, cancellationToken)); +#else + decryptItems.Add(this.CosmosSerializer.FromStream(decryptedStream)); +#endif + + await decryptedStream.DisposeAsync(); + } + } + } + + return decryptItems; + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionFeedIteratorStream.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionFeedIteratorStream.cs new file mode 100644 index 0000000000..720d753995 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionFeedIteratorStream.cs @@ -0,0 +1,106 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + +namespace Microsoft.Azure.Cosmos.Encryption.Custom +{ + using System.Collections.Generic; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Encryption.Custom.StreamProcessing; + using Microsoft.Azure.Cosmos.Encryption.Custom.Transformation; + + internal sealed class EncryptionFeedIteratorStream : FeedIterator + { + private readonly FeedIterator feedIterator; + private readonly Encryptor encryptor; + private readonly CosmosSerializer cosmosSerializer; + private readonly StreamManager streamManager; + + private static readonly ArrayStreamSplitter StreamSplitter = new (); + + public EncryptionFeedIteratorStream( + FeedIterator feedIterator, + Encryptor encryptor, + CosmosSerializer cosmosSerializer, + StreamManager streamManager) + { + this.feedIterator = feedIterator; + this.encryptor = encryptor; + this.cosmosSerializer = cosmosSerializer; + this.streamManager = streamManager; + } + + public override bool HasMoreResults => this.feedIterator.HasMoreResults; + + public override async Task ReadNextAsync(CancellationToken cancellationToken = default) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(options: null); + using (diagnosticsContext.CreateScope("FeedIterator.ReadNext")) + { + ResponseMessage responseMessage = await this.feedIterator.ReadNextAsync(cancellationToken); + + if (responseMessage.IsSuccessStatusCode && responseMessage.Content != null) + { + Stream decryptedContent = this.streamManager.CreateStream(); + await EncryptionProcessor.DeserializeAndDecryptResponseAsync( + responseMessage.Content, + decryptedContent, + this.encryptor, + this.streamManager, + cancellationToken); + + return new DecryptedResponseMessage(responseMessage, decryptedContent); + } + + return responseMessage; + } + } + + public async Task<(ResponseMessage, List)> ReadNextWithoutDecryptionAsync(CancellationToken cancellationToken = default) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(options: null); + using (diagnosticsContext.CreateScope("FeedIterator.ReadNextWithoutDecryption")) + { + ResponseMessage responseMessage = await this.feedIterator.ReadNextAsync(cancellationToken); + List decryptableContent = null; + + if (responseMessage.IsSuccessStatusCode && responseMessage.Content != null) + { + decryptableContent = await this.ConvertResponseToDecryptableItemsAsync( + responseMessage.Content, + cancellationToken); + + return (responseMessage, decryptableContent); + } + + return (responseMessage, decryptableContent); + } + } + + private async Task> ConvertResponseToDecryptableItemsAsync( + Stream content, + CancellationToken token) + { + List decryptableStreams = await StreamSplitter.SplitCollectionAsync(content, this.streamManager, token); + List decryptableItems = new (); + + foreach (Stream item in decryptableStreams) + { + decryptableItems.Add( + (T)(object)new DecryptableItemStream( + item, + this.encryptor, + JsonProcessor.Stream, + this.cosmosSerializer, + this.streamManager)); + } + + return decryptableItems; + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionFeedIteratorStream{T}.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionFeedIteratorStream{T}.cs new file mode 100644 index 0000000000..a2b050e529 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionFeedIteratorStream{T}.cs @@ -0,0 +1,50 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER +namespace Microsoft.Azure.Cosmos.Encryption.Custom +{ + using System; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + + internal sealed class EncryptionFeedIteratorStream : FeedIterator + { + private readonly EncryptionFeedIteratorStream feedIterator; + private readonly CosmosResponseFactory responseFactory; + + public EncryptionFeedIteratorStream( + EncryptionFeedIteratorStream feedIterator, + CosmosResponseFactory responseFactory) + { + this.feedIterator = feedIterator ?? throw new ArgumentNullException(nameof(feedIterator)); + this.responseFactory = responseFactory ?? throw new ArgumentNullException(nameof(responseFactory)); + } + + public override bool HasMoreResults => this.feedIterator.HasMoreResults; + + public override async Task> ReadNextAsync(CancellationToken cancellationToken = default) + { + ResponseMessage responseMessage; + + if (typeof(T) == typeof(DecryptableItem)) + { + IReadOnlyCollection resource; + (responseMessage, resource) = await this.feedIterator.ReadNextWithoutDecryptionAsync(cancellationToken); + + return DecryptableFeedResponse.CreateResponse( + responseMessage, + resource); + } + else + { + responseMessage = await this.feedIterator.ReadNextAsync(cancellationToken); + } + + return this.responseFactory.CreateItemFeedResponse(responseMessage); + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionTransactionalBatchStream.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionTransactionalBatchStream.cs new file mode 100644 index 0000000000..0f6df9cb0c --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/StreamProcessing/EncryptionTransactionalBatchStream.cs @@ -0,0 +1,293 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER +namespace Microsoft.Azure.Cosmos.Encryption.Custom +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos; + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "VSTHRD002:Avoid problematic synchronous waits", Justification = "To be fixed, tracked in issue #1575")] + internal sealed class EncryptionTransactionalBatchStream : TransactionalBatch + { + private readonly Encryptor encryptor; + private readonly CosmosSerializer cosmosSerializer; + private readonly StreamManager streamManager; + + private TransactionalBatch transactionalBatch; + + public EncryptionTransactionalBatchStream( + TransactionalBatch transactionalBatch, + Encryptor encryptor, + CosmosSerializer cosmosSerializer, + StreamManager streamManager) + { + this.transactionalBatch = transactionalBatch ?? throw new ArgumentNullException(nameof(transactionalBatch)); + this.encryptor = encryptor ?? throw new ArgumentNullException(nameof(encryptor)); + this.cosmosSerializer = cosmosSerializer ?? throw new ArgumentNullException(nameof(cosmosSerializer)); + this.streamManager = streamManager ?? throw new ArgumentNullException(nameof(streamManager)); + } + + public override TransactionalBatch CreateItem( + T item, + TransactionalBatchItemRequestOptions requestOptions = null) + { + if (requestOptions is not EncryptionTransactionalBatchItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + this.transactionalBatch = this.transactionalBatch.CreateItem( + item, + requestOptions); + + return this; + } + +#if SDKPROJECTREF + using Stream itemStream = this.streamManager.CreateStream(); + this.cosmosSerializer.ToStreamAsync(item, itemStream, CancellationToken.None).GetAwaiter().GetResult(); +#else + Stream itemStream = this.cosmosSerializer.ToStream(item); +#endif + return this.CreateItemStream( + itemStream, + requestOptions); + } + + public override TransactionalBatch CreateItemStream( + Stream streamPayload, + TransactionalBatchItemRequestOptions requestOptions = null) + { + if (requestOptions is EncryptionTransactionalBatchItemRequestOptions encryptionItemRequestOptions && + encryptionItemRequestOptions.EncryptionOptions != null) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("EncryptItemStream")) + { + Stream temp = this.streamManager.CreateStream(); + EncryptionProcessor.EncryptAsync( + streamPayload, + temp, + this.encryptor, + encryptionItemRequestOptions.EncryptionOptions, + diagnosticsContext, + cancellationToken: default).GetAwaiter().GetResult(); + streamPayload = temp; + } + } + + this.transactionalBatch = this.transactionalBatch.CreateItemStream( + streamPayload, + requestOptions); + + return this; + } + + public override TransactionalBatch DeleteItem( + string id, + TransactionalBatchItemRequestOptions requestOptions = null) + { + this.transactionalBatch = this.transactionalBatch.DeleteItem( + id, + requestOptions); + + return this; + } + + public override TransactionalBatch ReadItem( + string id, + TransactionalBatchItemRequestOptions requestOptions = null) + { + this.transactionalBatch = this.transactionalBatch.ReadItem( + id, + requestOptions); + + return this; + } + + public override TransactionalBatch ReplaceItem( + string id, + T item, + TransactionalBatchItemRequestOptions requestOptions = null) + { + if (requestOptions is not EncryptionTransactionalBatchItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + this.transactionalBatch = this.transactionalBatch.ReplaceItem( + id, + item, + requestOptions); + + return this; + } +#if SDKPROJECTREF + using Stream itemStream = this.streamManager.CreateStream(); + this.cosmosSerializer.ToStreamAsync(item, itemStream, CancellationToken.None).GetAwaiter().GetResult(); +#else + Stream itemStream = this.cosmosSerializer.ToStream(item); +#endif + return this.ReplaceItemStream( + id, + itemStream, + requestOptions); + } + + public override TransactionalBatch ReplaceItemStream( + string id, + Stream streamPayload, + TransactionalBatchItemRequestOptions requestOptions = null) + { + if (requestOptions is EncryptionTransactionalBatchItemRequestOptions encryptionItemRequestOptions && + encryptionItemRequestOptions.EncryptionOptions != null) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + Stream temp = this.streamManager.CreateStream(); + EncryptionProcessor.EncryptAsync( + streamPayload, + temp, + this.encryptor, + encryptionItemRequestOptions.EncryptionOptions, + diagnosticsContext, + cancellationToken: default).GetAwaiter().GetResult(); + streamPayload = temp; + } + + this.transactionalBatch = this.transactionalBatch.ReplaceItemStream( + id, + streamPayload, + requestOptions); + + return this; + } + + public override TransactionalBatch UpsertItem( + T item, + TransactionalBatchItemRequestOptions requestOptions = null) + { + if (requestOptions is not EncryptionTransactionalBatchItemRequestOptions encryptionItemRequestOptions || + encryptionItemRequestOptions.EncryptionOptions == null) + { + this.transactionalBatch = this.transactionalBatch.UpsertItem( + item, + requestOptions); + + return this; + } + +#if SDKPROJECTREF + using Stream itemStream = this.streamManager.CreateStream(); + this.cosmosSerializer.ToStreamAsync(item, itemStream, CancellationToken.None).GetAwaiter().GetResult(); +#else + Stream itemStream = this.cosmosSerializer.ToStream(item); +#endif + return this.UpsertItemStream( + itemStream, + requestOptions); + } + + public override TransactionalBatch UpsertItemStream( + Stream streamPayload, + TransactionalBatchItemRequestOptions requestOptions = null) + { + if (requestOptions is EncryptionTransactionalBatchItemRequestOptions encryptionItemRequestOptions && + encryptionItemRequestOptions.EncryptionOptions != null) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(requestOptions); + using (diagnosticsContext.CreateScope("EncryptItemStream")) + { + Stream temp = this.streamManager.CreateStream(); + EncryptionProcessor.EncryptAsync( + streamPayload, + temp, + this.encryptor, + encryptionItemRequestOptions.EncryptionOptions, + diagnosticsContext, + cancellationToken: default).GetAwaiter().GetResult(); + streamPayload = temp; + } + } + + this.transactionalBatch = this.transactionalBatch.UpsertItemStream( + streamPayload, + requestOptions); + + return this; + } + + public override async Task ExecuteAsync( + CancellationToken cancellationToken = default) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(options: null); + using (diagnosticsContext.CreateScope("TransactionalBatch.ExecuteAsync")) + { + TransactionalBatchResponse response = await this.transactionalBatch.ExecuteAsync(cancellationToken); + return await this.DecryptTransactionalBatchResponseAsync( + response, + diagnosticsContext, + cancellationToken); + } + } + + public override async Task ExecuteAsync( + TransactionalBatchRequestOptions requestOptions, + CancellationToken cancellationToken = default) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(options: null); + using (diagnosticsContext.CreateScope("TransactionalBatch.ExecuteAsync.WithRequestOptions")) + { + TransactionalBatchResponse response = await this.transactionalBatch.ExecuteAsync(requestOptions, cancellationToken); + return await this.DecryptTransactionalBatchResponseAsync( + response, + diagnosticsContext, + cancellationToken); + } + } + + private async Task DecryptTransactionalBatchResponseAsync( + TransactionalBatchResponse response, + CosmosDiagnosticsContext diagnosticsContext, + CancellationToken cancellationToken) + { + List decryptedTransactionalBatchOperationResults = new (); + + foreach (TransactionalBatchOperationResult result in response) + { + if (response.IsSuccessStatusCode && result.ResourceStream != null) + { + Stream decryptedStream = this.streamManager.CreateStream(); + _ = await EncryptionProcessor.DecryptAsync( + result.ResourceStream, + decryptedStream, + this.encryptor, + diagnosticsContext, + JsonProcessor.Stream, + cancellationToken); + + decryptedTransactionalBatchOperationResults.Add(new EncryptionTransactionalBatchOperationResult(result, decryptedStream)); + } + else + { + decryptedTransactionalBatchOperationResults.Add(result); + } + } + + return new EncryptionTransactionalBatchResponse( + decryptedTransactionalBatchOperationResults, + response, + this.cosmosSerializer); + } + + public override TransactionalBatch PatchItem( + string id, + IReadOnlyList patchOperations, + TransactionalBatchPatchItemRequestOptions requestOptions = null) + { + throw new NotImplementedException(); + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/ArrayStreamProcessor.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/ArrayStreamProcessor.cs new file mode 100644 index 0000000000..8c017ed0ce --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/ArrayStreamProcessor.cs @@ -0,0 +1,220 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER +namespace Microsoft.Azure.Cosmos.Encryption.Custom.Transformation +{ + using System; + using System.Buffers; + using System.IO; + using System.Text; + using System.Text.Json; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror; + + internal class ArrayStreamProcessor + { + internal int InitialBufferSize { get; set; } = 16384; + + private static readonly JsonReaderOptions JsonReaderOptions = new () { AllowTrailingCommas = true, CommentHandling = JsonCommentHandling.Skip }; + + private static readonly ReadOnlyMemory DocumentsPropertyUtf8Bytes; + + static ArrayStreamProcessor() + { + DocumentsPropertyUtf8Bytes = new Memory(Encoding.UTF8.GetBytes(Constants.DocumentsResourcePropertyName)); + } + + internal async Task DeserializeAndDecryptCollectionAsync( + Stream input, + Stream output, + Encryptor encryptor, + StreamManager manager, + CancellationToken cancellationToken) + { + Stream readStream = input; + if (!input.CanSeek) + { + Stream temp = manager.CreateStream(); + await input.CopyToAsync(temp, cancellationToken); + temp.Position = 0; + readStream = temp; + } + + using ArrayPoolManager arrayPoolManager = new (); + using Utf8JsonWriter writer = new (output); + + byte[] buffer = arrayPoolManager.Rent(this.InitialBufferSize); + + Utf8JsonWriter chunkWriter = null; + + int leftOver = 0; + bool isFinalBlock = false; + bool isDocumentsArray = false; + RecyclableMemoryStream bufferWriter = null; + bool isDocumentsProperty = false; + + RecyclableMemoryStreamManager recyclableMemoryStreamManager = new (); + + JsonReaderState state = new (ArrayStreamProcessor.JsonReaderOptions); + + while (!isFinalBlock) + { + int dataLength = await readStream.ReadAsync(buffer.AsMemory(leftOver, buffer.Length - leftOver), cancellationToken); + int dataSize = dataLength + leftOver; + isFinalBlock = dataSize == 0; + + if (isFinalBlock) + { + break; + } + + long bytesConsumed = 0; + + bytesConsumed = this.TransformBuffer( + buffer.AsSpan(0, dataSize), + isFinalBlock, + writer, + ref bufferWriter, + ref chunkWriter, + ref state, + ref isDocumentsProperty, + ref isDocumentsArray, + arrayPoolManager, + encryptor, + manager, + recyclableMemoryStreamManager); + + leftOver = dataSize - (int)bytesConsumed; + + // we need to scale out buffer + if (leftOver == dataSize) + { + byte[] newBuffer = arrayPoolManager.Rent(buffer.Length * 2); + buffer.AsSpan().CopyTo(newBuffer); + buffer = newBuffer; + } + else if (leftOver != 0) + { + buffer.AsSpan(dataSize - leftOver, leftOver).CopyTo(buffer); + } + } + + writer.Flush(); + + await readStream.DisposeAsync(); + output.Position = 0; + } + + private long TransformBuffer(Span buffer, bool isFinalBlock, Utf8JsonWriter writer, ref RecyclableMemoryStream bufferWriter, ref Utf8JsonWriter chunkWriter, ref JsonReaderState state, ref bool isDocumentsProperty, ref bool isDocumentsArray, ArrayPoolManager arrayPoolManager, Encryptor encryptor, StreamManager streamManager, RecyclableMemoryStreamManager manager) + { + Utf8JsonReader reader = new Utf8JsonReader(buffer, isFinalBlock, state); + + while (reader.Read()) + { + Utf8JsonWriter currentWriter = chunkWriter ?? writer; + + JsonTokenType tokenType = reader.TokenType; + + switch (tokenType) + { + case JsonTokenType.None: + break; + case JsonTokenType.StartObject: + if (isDocumentsArray && chunkWriter == null) + { + bufferWriter = new RecyclableMemoryStream(manager); + chunkWriter = new Utf8JsonWriter((IBufferWriter)bufferWriter); + chunkWriter.WriteStartObject(); + } + else + { + currentWriter.WriteStartObject(); + } + + break; + case JsonTokenType.EndObject: + currentWriter.WriteEndObject(); + if (reader.CurrentDepth == 2 && chunkWriter != null) + { + currentWriter.Flush(); + Stream transformStream = streamManager.CreateStream(); + bufferWriter.Position = 0; +#pragma warning disable VSTHRD002 // Avoid problematic synchronous waits - we cannot make this call async + _ = EncryptionProcessor.DecryptAsync(bufferWriter, transformStream, encryptor, new CosmosDiagnosticsContext(), JsonProcessor.Stream, CancellationToken.None).GetAwaiter().GetResult(); +#pragma warning restore VSTHRD002 // Avoid problematic synchronous waits + + byte[] copyBuffer = arrayPoolManager.Rent(16384); + Span copySpan = copyBuffer.AsSpan(); + int readBytes = 16384; + while (readBytes == 16384) + { + readBytes = transformStream.Read(copySpan); + + if (readBytes > 0) + { + writer.WriteRawValue(copySpan[..readBytes], false); + } + } + + transformStream.Dispose(); + chunkWriter.Dispose(); + bufferWriter.Dispose(); + chunkWriter = null; + } + + break; + case JsonTokenType.StartArray: + if (isDocumentsProperty && reader.CurrentDepth == 1) + { + isDocumentsArray = true; + } + + currentWriter.WriteStartArray(); + break; + + case JsonTokenType.EndArray: + currentWriter.WriteEndArray(); + if (isDocumentsArray && reader.CurrentDepth == 1) + { + isDocumentsArray = false; + isDocumentsProperty = false; + } + + break; + + case JsonTokenType.PropertyName: + if (chunkWriter == null && reader.ValueTextEquals(DocumentsPropertyUtf8Bytes.Span)) + { + isDocumentsProperty = true; + } + + currentWriter.WritePropertyName(reader.ValueSpan); + break; + case JsonTokenType.String: + if (!reader.ValueIsEscaped) + { + currentWriter.WriteStringValue(reader.ValueSpan); + } + else + { + byte[] temp = arrayPoolManager.Rent(reader.ValueSpan.Length); + int tempBytes = reader.CopyString(temp); + currentWriter.WriteStringValue(temp.AsSpan(0, tempBytes)); + } + + break; + default: + currentWriter.WriteRawValue(reader.ValueSpan, true); + break; + } + } + + state = reader.CurrentState; + return reader.BytesConsumed; + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/ArrayStreamSplitter.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/ArrayStreamSplitter.cs new file mode 100644 index 0000000000..adc697f9d2 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/ArrayStreamSplitter.cs @@ -0,0 +1,197 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER +namespace Microsoft.Azure.Cosmos.Encryption.Custom.Transformation +{ + using System; + using System.Buffers; + using System.Collections.Generic; + using System.IO; + using System.Text; + using System.Text.Json; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror; + + internal class ArrayStreamSplitter + { + internal int InitialBufferSize { get; set; } = 16384; + + private static readonly JsonReaderOptions JsonReaderOptions = new () { AllowTrailingCommas = true, CommentHandling = JsonCommentHandling.Skip }; + + private static readonly ReadOnlyMemory DocumentsPropertyUtf8Bytes; + + static ArrayStreamSplitter() + { + DocumentsPropertyUtf8Bytes = new Memory(Encoding.UTF8.GetBytes(Constants.DocumentsResourcePropertyName)); + } + + internal async Task> SplitCollectionAsync( + Stream input, + StreamManager manager, + CancellationToken cancellationToken) + { + Stream readStream = input; + if (!input.CanSeek) + { + Stream temp = manager.CreateStream(); + await input.CopyToAsync(temp, cancellationToken); + temp.Position = 0; + readStream = temp; + } + + using ArrayPoolManager arrayPoolManager = new (); + + byte[] buffer = arrayPoolManager.Rent(this.InitialBufferSize); + + Utf8JsonWriter chunkWriter = null; + + int leftOver = 0; + bool isFinalBlock = false; + bool isDocumentsArray = false; + RecyclableMemoryStream bufferWriter = null; + bool isDocumentsProperty = false; + + RecyclableMemoryStreamManager recyclableMemoryStreamManager = new (); + + JsonReaderState state = new (ArrayStreamSplitter.JsonReaderOptions); + List outputList = new List(); + + while (!isFinalBlock) + { + int dataLength = await readStream.ReadAsync(buffer.AsMemory(leftOver, buffer.Length - leftOver), cancellationToken); + int dataSize = dataLength + leftOver; + isFinalBlock = dataSize == 0; + if (isFinalBlock) + { + break; + } + + long bytesConsumed = 0; + + bytesConsumed = this.TransformBuffer( + buffer.AsSpan(0, dataSize), + outputList, + isFinalBlock, + ref bufferWriter, + ref chunkWriter, + ref state, + ref isDocumentsProperty, + ref isDocumentsArray, + recyclableMemoryStreamManager, + arrayPoolManager); + + leftOver = dataSize - (int)bytesConsumed; + + // we need to scale out buffer + if (leftOver == dataSize) + { + byte[] newBuffer = arrayPoolManager.Rent(buffer.Length * 2); + buffer.AsSpan().CopyTo(newBuffer); + buffer = newBuffer; + } + else if (leftOver != 0) + { + buffer.AsSpan(dataSize - leftOver, leftOver).CopyTo(buffer); + } + } + + await readStream.DisposeAsync(); + + return outputList; + } + + private long TransformBuffer(Span buffer, List outputList, bool isFinalBlock, ref RecyclableMemoryStream bufferWriter, ref Utf8JsonWriter chunkWriter, ref JsonReaderState state, ref bool isDocumentsProperty, ref bool isDocumentsArray, RecyclableMemoryStreamManager manager, ArrayPoolManager arrayPoolManager) + { + Utf8JsonReader reader = new Utf8JsonReader(buffer, isFinalBlock, state); + + while (reader.Read()) + { + JsonTokenType tokenType = reader.TokenType; + + switch (tokenType) + { + case JsonTokenType.None: + break; + case JsonTokenType.StartObject: + if (isDocumentsArray && chunkWriter == null) + { + bufferWriter = new RecyclableMemoryStream(manager); + chunkWriter = new Utf8JsonWriter((IBufferWriter)bufferWriter); + } + + chunkWriter?.WriteStartObject(); + + break; + case JsonTokenType.EndObject: + chunkWriter?.WriteEndObject(); + if (reader.CurrentDepth == 2 && chunkWriter != null) + { + chunkWriter.Flush(); + bufferWriter.Position = 0; + outputList.Add(bufferWriter); + + bufferWriter = null; + + chunkWriter.Dispose(); + chunkWriter = null; + } + + break; + case JsonTokenType.StartArray: + if (isDocumentsProperty && reader.CurrentDepth == 1) + { + isDocumentsArray = true; + } + + chunkWriter?.WriteStartArray(); + break; + + case JsonTokenType.EndArray: + chunkWriter?.WriteEndArray(); + if (isDocumentsArray && reader.CurrentDepth == 1) + { + isDocumentsArray = false; + isDocumentsProperty = false; + } + + break; + + case JsonTokenType.PropertyName: + if (reader.ValueTextEquals(DocumentsPropertyUtf8Bytes.Span)) + { + isDocumentsProperty = true; + } + else + { + chunkWriter?.WritePropertyName(reader.ValueSpan); + } + + break; + case JsonTokenType.String: + if (!reader.ValueIsEscaped) + { + chunkWriter?.WriteStringValue(reader.ValueSpan); + } + else + { + byte[] temp = arrayPoolManager.Rent(reader.ValueSpan.Length); + int tempBytes = reader.CopyString(temp); + chunkWriter?.WriteStringValue(temp.AsSpan(0, tempBytes)); + } + + break; + default: + chunkWriter?.WriteRawValue(reader.ValueSpan, true); + break; + } + } + + state = reader.CurrentState; + return reader.BytesConsumed; + } + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/MdeEncryptionProcessor.Preview.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/MdeEncryptionProcessor.Preview.cs index 389a709207..ef4041dde1 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/MdeEncryptionProcessor.Preview.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/MdeEncryptionProcessor.Preview.cs @@ -51,6 +51,28 @@ public async Task EncryptAsync( #endif } +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + public async Task EncryptStreamAsync( + Stream input, + Stream output, + Encryptor encryptor, + EncryptionOptions encryptionOptions, + CancellationToken token) + { + switch (encryptionOptions.JsonProcessor) + { + case JsonProcessor.Newtonsoft: + await this.JObjectEncryptionProcessor.EncryptStreamAsync(input, output, encryptor, encryptionOptions, token); + break; + case JsonProcessor.Stream: + await this.StreamProcessor.EncryptStreamAsync(input, output, encryptor, encryptionOptions, token); + break; + default: + throw new InvalidOperationException($"Unsupported JsonProcessor {encryptionOptions.JsonProcessor}"); + } + } +#endif + internal async Task DecryptObjectAsync( JObject document, Encryptor encryptor, diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/MdeJObjectEncryptionProcessor.Preview.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/MdeJObjectEncryptionProcessor.Preview.cs index a3b60dee57..6ae80136eb 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/MdeJObjectEncryptionProcessor.Preview.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/MdeJObjectEncryptionProcessor.Preview.cs @@ -26,9 +26,8 @@ public async Task EncryptAsync( EncryptionOptions encryptionOptions, CancellationToken token) { - JObject itemJObj = EncryptionProcessor.BaseSerializer.FromStream(input); - - Stream result = await this.EncryptAsync(itemJObj, encryptor, encryptionOptions, token); + MemoryStream result = new (); + await this.EncryptStreamAsync(input, result, encryptor, encryptionOptions, token); #if NET8_0_OR_GREATER await input.DisposeAsync(); @@ -39,8 +38,21 @@ public async Task EncryptAsync( return result; } - public async Task EncryptAsync( + public async Task EncryptStreamAsync( + Stream input, + Stream output, + Encryptor encryptor, + EncryptionOptions encryptionOptions, + CancellationToken token) + { + JObject itemJObj = EncryptionProcessor.BaseSerializer.FromStream(input, leaveOpen: true); + + await this.EncryptAsync(itemJObj, output, encryptor, encryptionOptions, token); + } + + public async Task EncryptAsync( JObject input, + Stream output, Encryptor encryptor, EncryptionOptions encryptionOptions, CancellationToken token) @@ -115,7 +127,7 @@ public async Task EncryptAsync( input.Add(Constants.EncryptedInfo, JObject.FromObject(encryptionProperties)); - return EncryptionProcessor.BaseSerializer.ToStream(input); + await EncryptionProcessor.BaseSerializer.ToStreamAsync(input, output, token); } internal async Task DecryptObjectAsync( diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/StreamProcessor.Decryptor.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/StreamProcessor.Decryptor.cs index d4ca1ccfdf..82d5e713b2 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/StreamProcessor.Decryptor.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/StreamProcessor.Decryptor.cs @@ -77,6 +77,12 @@ internal async Task DecryptStreamAsync( int dataLength = await inputStream.ReadAsync(buffer.AsMemory(leftOver, buffer.Length - leftOver), cancellationToken); int dataSize = dataLength + leftOver; isFinalBlock = dataSize == 0; + + if (isFinalBlock) + { + break; + } + long bytesConsumed = 0; // processing itself here @@ -125,7 +131,16 @@ long TransformDecryptBuffer(ReadOnlySpan buffer) case JsonTokenType.String: if (decryptPropertyName == null) { - writer.WriteStringValue(reader.ValueSpan); + if (!reader.ValueIsEscaped) + { + writer.WriteStringValue(reader.ValueSpan); + } + else + { + byte[] temp = arrayPoolManager.Rent(reader.ValueSpan.Length); + int tempBytes = reader.CopyString(temp); + writer.WriteStringValue(temp.AsSpan(0, tempBytes)); + } } else { diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/StreamProcessor.Encryptor.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/StreamProcessor.Encryptor.cs index a3980ae56a..1643c26ed0 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/StreamProcessor.Encryptor.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/src/Transformation/StreamProcessor.Encryptor.cs @@ -6,17 +6,23 @@ namespace Microsoft.Azure.Cosmos.Encryption.Custom.Transformation { using System; + using System.Buffers; using System.Collections.Generic; using System.IO; using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Encryption.Custom.RecyclableMemoryStreamMirror; internal partial class StreamProcessor { private readonly byte[] encryptionPropertiesNameBytes = Encoding.UTF8.GetBytes(Constants.EncryptedInfo); + private static ReadOnlySpan Utf8Bom => new byte[] { 0xEF, 0xBB, 0xBF }; + + private readonly RecyclableMemoryStreamManager streamManager = new (); + internal async Task EncryptStreamAsync( Stream inputStream, Stream outputStream, @@ -51,18 +57,32 @@ internal async Task EncryptStreamAsync( Utf8JsonWriter encryptionPayloadWriter = null; string encryptPropertyName = null; - RentArrayBufferWriter bufferWriter = null; + + RecyclableMemoryStream bufferWriter = null; + bool firstRead = true; while (!isFinalBlock) { + int offset = 0; int dataLength = await inputStream.ReadAsync(buffer.AsMemory(leftOver, buffer.Length - leftOver), cancellationToken); + if (firstRead && buffer.AsSpan(0, Utf8Bom.Length).StartsWith(Utf8Bom)) + { + offset = Utf8Bom.Length; + } + int dataSize = dataLength + leftOver; isFinalBlock = dataSize == 0; + + if (isFinalBlock) + { + break; + } + long bytesConsumed = 0; - bytesConsumed = TransformEncryptBuffer(buffer.AsSpan(0, dataSize)); + bytesConsumed = TransformEncryptBuffer(buffer.AsSpan(0 + offset, dataSize - offset)); - leftOver = dataSize - (int)bytesConsumed; + leftOver = dataSize - ((int)bytesConsumed + offset); // we need to scale out buffer if (leftOver == dataSize) @@ -74,11 +94,10 @@ internal async Task EncryptStreamAsync( else if (leftOver != 0) { buffer.AsSpan(dataSize - leftOver, leftOver).CopyTo(buffer); + firstRead = false; } } - await inputStream.DisposeAsync(); - EncryptionProperties encryptionProperties = new ( encryptionFormatVersion: compressionEnabled ? 4 : 3, encryptionOptions.EncryptionAlgorithm, @@ -93,6 +112,7 @@ internal async Task EncryptStreamAsync( writer.WriteEndObject(); writer.Flush(); + inputStream.Position = 0; outputStream.Position = 0; long TransformEncryptBuffer(ReadOnlySpan buffer) @@ -112,8 +132,8 @@ long TransformEncryptBuffer(ReadOnlySpan buffer) case JsonTokenType.StartObject: if (encryptPropertyName != null && encryptionPayloadWriter == null) { - bufferWriter = new RentArrayBufferWriter(); - encryptionPayloadWriter = new Utf8JsonWriter(bufferWriter); + bufferWriter = new RecyclableMemoryStream(this.streamManager); + encryptionPayloadWriter = new Utf8JsonWriter((IBufferWriter)bufferWriter); encryptionPayloadWriter.WriteStartObject(); } else @@ -132,16 +152,17 @@ long TransformEncryptBuffer(ReadOnlySpan buffer) if (reader.CurrentDepth == 1 && encryptionPayloadWriter != null) { currentWriter.Flush(); - (byte[] bytes, int length) = bufferWriter.WrittenBuffer; + byte[] bytes = bufferWriter.GetBuffer(); + int length = (int)bufferWriter.Length; ReadOnlySpan encryptedBytes = TransformEncryptPayload(bytes, length, TypeMarker.Object); writer.WriteBase64StringValue(encryptedBytes); encryptPropertyName = null; #pragma warning disable VSTHRD103 // Call async methods when in an async method - this method cannot be async, Utf8JsonReader is ref struct encryptionPayloadWriter.Dispose(); + bufferWriter.Dispose(); #pragma warning restore VSTHRD103 // Call async methods when in an async method encryptionPayloadWriter = null; - bufferWriter.Dispose(); bufferWriter = null; } @@ -149,8 +170,8 @@ long TransformEncryptBuffer(ReadOnlySpan buffer) case JsonTokenType.StartArray: if (encryptPropertyName != null && encryptionPayloadWriter == null) { - bufferWriter = new RentArrayBufferWriter(); - encryptionPayloadWriter = new Utf8JsonWriter(bufferWriter); + bufferWriter = new RecyclableMemoryStream(this.streamManager); + encryptionPayloadWriter = new Utf8JsonWriter((IBufferWriter)bufferWriter); encryptionPayloadWriter.WriteStartArray(); } else @@ -164,16 +185,17 @@ long TransformEncryptBuffer(ReadOnlySpan buffer) if (reader.CurrentDepth == 1 && encryptionPayloadWriter != null) { currentWriter.Flush(); - (byte[] bytes, int length) = bufferWriter.WrittenBuffer; + byte[] bytes = bufferWriter.GetBuffer(); + int length = (int)bufferWriter.Length; ReadOnlySpan encryptedBytes = TransformEncryptPayload(bytes, length, TypeMarker.Array); writer.WriteBase64StringValue(encryptedBytes); encryptPropertyName = null; #pragma warning disable VSTHRD103 // Call async methods when in an async method - this method cannot be async, Utf8JsonReader is ref struct encryptionPayloadWriter.Dispose(); + bufferWriter.Dispose(); #pragma warning restore VSTHRD103 // Call async methods when in an async method encryptionPayloadWriter = null; - bufferWriter.Dispose(); bufferWriter = null; } diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/MdeCustomEncryptionTests.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/MdeCustomEncryptionTests.cs index 14b1915abb..e6b1b69966 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/MdeCustomEncryptionTests.cs +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/MdeCustomEncryptionTests.cs @@ -2177,14 +2177,7 @@ public TestEncryptionKeyStoreProvider() public override byte[] UnwrapKey(string masterKeyPath, KeyEncryptionKeyAlgorithm encryptionAlgorithm, byte[] encryptedKey) { - if (!this.UnWrapKeyCallsCount.ContainsKey(masterKeyPath)) - { - this.UnWrapKeyCallsCount[masterKeyPath] = 1; - } - else - { - this.UnWrapKeyCallsCount[masterKeyPath]++; - } + this.UnWrapKeyCallsCount[masterKeyPath] = !this.UnWrapKeyCallsCount.TryGetValue(masterKeyPath, out int value) ? 1 : ++value; this.keyinfo.TryGetValue(masterKeyPath, out int moveBy); byte[] plainkey = encryptedKey.Select(b => (byte)(b - moveBy)).ToArray(); @@ -2193,14 +2186,7 @@ public override byte[] UnwrapKey(string masterKeyPath, KeyEncryptionKeyAlgorithm public override byte[] WrapKey(string masterKeyPath, KeyEncryptionKeyAlgorithm encryptionAlgorithm, byte[] key) { - if (!this.WrapKeyCallsCount.ContainsKey(masterKeyPath)) - { - this.WrapKeyCallsCount[masterKeyPath] = 1; - } - else - { - this.WrapKeyCallsCount[masterKeyPath]++; - } + this.WrapKeyCallsCount[masterKeyPath] = !this.WrapKeyCallsCount.TryGetValue(masterKeyPath, out int value) ? 1 : ++value; this.keyinfo.TryGetValue(masterKeyPath, out int moveBy); byte[] encryptedkey = key.Select(b => (byte)(b + moveBy)).ToArray(); @@ -2648,14 +2634,7 @@ public override Task UnwrapKeyAsync(byte[] wrappedKey public override Task WrapKeyAsync(byte[] key, EncryptionKeyWrapMetadata metadata, CancellationToken cancellationToken) { - if (!this.WrapKeyCallsCount.ContainsKey(metadata.Value)) - { - this.WrapKeyCallsCount[metadata.Value] = 1; - } - else - { - this.WrapKeyCallsCount[metadata.Value]++; - } + this.WrapKeyCallsCount[metadata.Value] = !this.WrapKeyCallsCount.TryGetValue(metadata.Value, out int value) ? 1 : ++value; EncryptionKeyWrapMetadata responseMetadata = new(metadata.Value + metadataUpdateSuffix); int moveBy = metadata.Value == metadata1.Value ? 1 : 2; diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/MdeCustomEncryptionTestsWithSystemText.cs b/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/MdeCustomEncryptionTestsWithSystemText.cs new file mode 100644 index 0000000000..0b5e0cf4d9 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/MdeCustomEncryptionTestsWithSystemText.cs @@ -0,0 +1,2724 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +#if ENCRYPTION_CUSTOM_PREVIEW && NET8_0_OR_GREATER + +namespace Microsoft.Azure.Cosmos.Encryption.Custom.EmulatorTests +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Net; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos; + using Microsoft.Azure.Cosmos.Encryption.Custom; + using Microsoft.Data.Encryption.Cryptography; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Newtonsoft.Json; + using static Microsoft.Azure.Cosmos.Encryption.Custom.EmulatorTests.LegacyEncryptionTests; + using EncryptionKeyWrapMetadata = EncryptionKeyWrapMetadata; + using DataEncryptionKey = DataEncryptionKey; + using Newtonsoft.Json.Linq; + using Microsoft.Azure.Cosmos.Encryption.Custom.EmulatorTests.Utils; + using System.Text.Json.Serialization; + using Microsoft.Azure.Cosmos.Encryption.Custom.StreamProcessing; + + [TestClass] + public class MdeCustomEncryptionTestsWithSystemText + { + private static readonly Uri masterKeyUri1 = new("https://demo.keyvault.net/keys/samplekey1/03ded886623sss09bzc60351e536a111"); + private static readonly Uri masterKeyUri2 = new("https://demo.keyvault.net/keys/samplekey2/47d306aeaaeyyyaabs9467235460dc22"); + private static readonly EncryptionKeyWrapMetadata metadata1 = new(name: "metadata1", value: masterKeyUri1.ToString()); + private static readonly EncryptionKeyWrapMetadata metadata2 = new(name: "metadata2", value: masterKeyUri2.ToString()); + private const string dekId = "mydek"; + private const string legacydekId = "mylegacydek"; + private static CosmosClient client; + private static Database database; + private static DataEncryptionKeyProperties dekProperties; + private static Container itemContainer; + private static Container encryptionContainer; + private static Container itemContainerForChangeFeed; + private static Container encryptionContainerForChangeFeed; + private static Container keyContainer; + private static TestEncryptionKeyStoreProvider testKeyStoreProvider; + private static CosmosDataEncryptionKeyProvider dekProvider; + private static TestEncryptor encryptor; + + + private static TestKeyWrapProvider legacytestKeyWrapProvider; + private static CosmosDataEncryptionKeyProvider dualDekProvider; + private const string metadataUpdateSuffix = "updated"; + private static readonly TimeSpan cacheTTL = TimeSpan.FromDays(1); + private static TestEncryptor encryptorWithDualWrapProvider; + + + [ClassInitialize] + public static async Task ClassInitialize(TestContext context) + { + _ = context; + + client = TestCommon.CreateCosmosClient(builder => builder.WithSystemTextJsonSerializerOptions(new System.Text.Json.JsonSerializerOptions())); + database = await client.CreateDatabaseAsync(Guid.NewGuid().ToString()); + keyContainer = await database.CreateContainerAsync(Guid.NewGuid().ToString(), "/id", 400); + itemContainer = await database.CreateContainerAsync(Guid.NewGuid().ToString(), "/PK", 400); + itemContainerForChangeFeed = await database.CreateContainerAsync(Guid.NewGuid().ToString(), "/PK", 400); + + testKeyStoreProvider = new TestEncryptionKeyStoreProvider(); + await LegacyClassInitializeAsync(); + + MdeCustomEncryptionTestsWithSystemText.encryptor = new TestEncryptor(MdeCustomEncryptionTestsWithSystemText.dekProvider); + MdeCustomEncryptionTestsWithSystemText.encryptionContainer = MdeCustomEncryptionTestsWithSystemText.itemContainer.WithEncryptor(encryptor); + MdeCustomEncryptionTestsWithSystemText.encryptionContainerForChangeFeed = MdeCustomEncryptionTestsWithSystemText.itemContainerForChangeFeed.WithEncryptor(encryptor); + + await MdeCustomEncryptionTestsWithSystemText.dekProvider.InitializeAsync(MdeCustomEncryptionTestsWithSystemText.database, MdeCustomEncryptionTestsWithSystemText.keyContainer.Id); + MdeCustomEncryptionTestsWithSystemText.dekProperties = await MdeCustomEncryptionTestsWithSystemText.CreateDekAsync(MdeCustomEncryptionTestsWithSystemText.dekProvider, MdeCustomEncryptionTestsWithSystemText.dekId); + } + + [ClassCleanup] + public static async Task ClassCleanup() + { + if (database != null) + { + using (await database.DeleteStreamAsync()) { } + } + + client?.Dispose(); + } + + [TestMethod] + public async Task EncryptionCreateDek() + { + string dekId = "anotherDek"; + DataEncryptionKeyProperties dekProperties = await CreateDekAsync(MdeCustomEncryptionTestsWithSystemText.dekProvider, dekId); + Assert.AreEqual( + new EncryptionKeyWrapMetadata(name: "metadata1", value: metadata1.Value), + dekProperties.EncryptionKeyWrapMetadata); + + // Use different DEK provider to avoid (unintentional) cache impact + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKeyProperties readProperties = await dekProvider.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dekProperties, readProperties); + } + + [TestMethod] + public async Task FetchDataEncryptionKeyWithRawKey() + { + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKey k = await dekProvider.FetchDataEncryptionKeyAsync(dekProperties.Id, dekProperties.EncryptionAlgorithm, CancellationToken.None); + Assert.IsNotNull(k.RawKey); + } + + [TestMethod] + public async Task FetchDataEncryptionKeyWithoutRawKey() + { + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKey k = await dekProvider.FetchDataEncryptionKeyWithoutRawKeyAsync(dekProperties.Id, dekProperties.EncryptionAlgorithm, CancellationToken.None); + Assert.IsNull(k.RawKey); + } + + [TestMethod] + [Obsolete("Obsoleted algorithm")] + public async Task FetchDataEncryptionKeyMdeDEKAndLegacyBasedAlgorithm() + { + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKey k = await dekProvider.FetchDataEncryptionKeyAsync(dekProperties.Id, CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized, CancellationToken.None); + Assert.IsNotNull(k.RawKey); + } + + [TestMethod] + [Obsolete("Obsoleted algorithm")] + public async Task FetchDataEncryptionKeyLegacyDEKAndMdeBasedAlgorithm() + { + string dekId = "legacyDEK"; + DataEncryptionKeyProperties dekProperties = await CreateDekAsync(MdeCustomEncryptionTestsWithSystemText.dekProvider, dekId, CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized); + // Use different DEK provider to avoid (unintentional) cache impact + CosmosDataEncryptionKeyProvider dekProvider = new(new TestKeyWrapProvider(), new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKey k = await dekProvider.FetchDataEncryptionKeyAsync(dekProperties.Id, CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized, CancellationToken.None); + Assert.IsNotNull(k.RawKey); + } + + [TestMethod] + public async Task EncryptionRewrapDek() + { + string dekId = "randomDek"; + + DataEncryptionKeyProperties dekProperties = await CreateDekAsync(MdeCustomEncryptionTestsWithSystemText.dekProvider, dekId); + Assert.AreEqual( + metadata1, + dekProperties.EncryptionKeyWrapMetadata); + + ItemResponse dekResponse = await MdeCustomEncryptionTestsWithSystemText.dekProvider.DataEncryptionKeyContainer.RewrapDataEncryptionKeyAsync( + dekId, + metadata2); + + Assert.AreEqual(HttpStatusCode.OK, dekResponse.StatusCode); + dekProperties = VerifyDekResponse( + dekResponse, + dekId); + Assert.AreEqual( + metadata2, + dekProperties.EncryptionKeyWrapMetadata); + + // Use different DEK provider to avoid (unintentional) cache impact + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKeyProperties readProperties = await dekProvider.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dekProperties, readProperties); + } + + [TestMethod] + public async Task EncryptionRewrapDekEtagMismatch() + { + string dekId = "dummyDek"; + EncryptionKeyWrapMetadata newMetadata = new(name: "newMetadata", value: "newMetadataValue"); + + DataEncryptionKeyProperties dekProperties = await CreateDekAsync(MdeCustomEncryptionTestsWithSystemText.dekProvider, dekId); + Assert.AreEqual( + metadata1, + dekProperties.EncryptionKeyWrapMetadata); + + // modify dekProperties directly, which would lead to etag change + DataEncryptionKeyProperties updatedDekProperties = new( + dekProperties.Id, + dekProperties.EncryptionAlgorithm, + dekProperties.WrappedDataEncryptionKey, + dekProperties.EncryptionKeyWrapMetadata, + DateTime.UtcNow); + await keyContainer.ReplaceItemAsync( + updatedDekProperties, + dekProperties.Id, + new PartitionKey(dekProperties.Id)); + + // rewrap should succeed, despite difference in cached value + ItemResponse dekResponse = await MdeCustomEncryptionTestsWithSystemText.dekProvider.DataEncryptionKeyContainer.RewrapDataEncryptionKeyAsync( + dekId, + newMetadata); + + Assert.AreEqual(HttpStatusCode.OK, dekResponse.StatusCode); + dekProperties = VerifyDekResponse( + dekResponse, + dekId); + Assert.AreEqual( + newMetadata, + dekProperties.EncryptionKeyWrapMetadata); + + Assert.AreEqual(2, testKeyStoreProvider.WrapKeyCallsCount[newMetadata.Value]); + + // Use different DEK provider to avoid (unintentional) cache impact + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKeyProperties readProperties = await dekProvider.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dekProperties, readProperties); + } + + [TestMethod] + public async Task EncryptionDekReadFeed() + { + Container newKeyContainer = await database.CreateContainerAsync(Guid.NewGuid().ToString(), "/id", 400); + try + { + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, newKeyContainer.Id); + + string contosoV1 = "Contoso_v001"; + string contosoV2 = "Contoso_v002"; + string fabrikamV1 = "Fabrikam_v001"; + string fabrikamV2 = "Fabrikam_v002"; + + await CreateDekAsync(dekProvider, contosoV1); + await CreateDekAsync(dekProvider, contosoV2); + await CreateDekAsync(dekProvider, fabrikamV1); + await CreateDekAsync(dekProvider, fabrikamV2); + + // Test getting all keys + await IterateDekFeedAsync( + dekProvider, + new List { contosoV1, contosoV2, fabrikamV1, fabrikamV2 }, + isExpectedDeksCompleteSetForRequest: true, + isResultOrderExpected: false, + "SELECT * from c"); + + // Test getting specific subset of keys + await IterateDekFeedAsync( + dekProvider, + new List { contosoV2 }, + isExpectedDeksCompleteSetForRequest: false, + isResultOrderExpected: true, + "SELECT TOP 1 * from c where c.id >= 'Contoso_v000' and c.id <= 'Contoso_v999' ORDER BY c.id DESC"); + + // Ensure only required results are returned + await IterateDekFeedAsync( + dekProvider, + new List { contosoV1, contosoV2 }, + isExpectedDeksCompleteSetForRequest: true, + isResultOrderExpected: true, + "SELECT * from c where c.id >= 'Contoso_v000' and c.id <= 'Contoso_v999' ORDER BY c.id ASC"); + + // Test pagination + await IterateDekFeedAsync( + dekProvider, + new List { contosoV1, contosoV2, fabrikamV1, fabrikamV2 }, + isExpectedDeksCompleteSetForRequest: true, + isResultOrderExpected: false, + "SELECT * from c", + itemCountInPage: 3); + } + finally + { + await newKeyContainer.DeleteContainerStreamAsync(); + } + } + + [TestMethod] + public async Task EncryptionCreateItemWithoutEncryptionOptions() + { + TestDoc testDoc = TestDoc.Create(); + ItemResponse createResponse = await encryptionContainer.CreateItemAsync( + testDoc, + new PartitionKey(testDoc.PK)); + Assert.AreEqual(HttpStatusCode.Created, createResponse.StatusCode); + VerifyExpectedDocResponse(testDoc, createResponse.Resource); + } + + [TestMethod] + public async Task EncryptionCreateItemWithNullEncryptionOptions() + { + TestDoc testDoc = TestDoc.Create(); + ItemResponse createResponse = await encryptionContainer.CreateItemAsync( + testDoc, + new PartitionKey(testDoc.PK), + new EncryptionItemRequestOptions() { }); + Assert.AreEqual(HttpStatusCode.Created, createResponse.StatusCode); + VerifyExpectedDocResponse(testDoc, createResponse.Resource); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionCreateItemWithoutPartitionKey(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc = TestDoc.Create(); + try + { + await encryptionContainer.CreateItemAsync( + testDoc, + requestOptions: GetRequestOptions(dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)); + Assert.Fail("CreateItem should've failed because PartitionKey was not provided."); + } + catch (NotSupportedException ex) + { + Assert.AreEqual("partitionKey cannot be null for operations using EncryptionContainer.", ex.Message); + } + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionFailsWithUnknownDek(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + string unknownDek = "unknownDek"; + + try + { + await CreateItemAsync(encryptionContainer, unknownDek, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + } + catch (ArgumentException ex) + { + Assert.AreEqual($"Failed to retrieve Data Encryption Key with id: '{unknownDek}'.", ex.Message); + Assert.IsTrue(ex.InnerException is CosmosException); + } + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task ValidateCachingOfProtectedDataEncryptionKey(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestEncryptionKeyStoreProvider testEncryptionKeyStoreProvider = new() + { + DataEncryptionKeyCacheTimeToLive = TimeSpan.FromMinutes(30) + }; + + string dekId = "pDekCache"; + DataEncryptionKeyProperties dekProperties = await CreateDekAsync(dualDekProvider, dekId); + Assert.AreEqual( + new EncryptionKeyWrapMetadata(name: "metadata1", value: metadata1.Value), + dekProperties.EncryptionKeyWrapMetadata); + + // Caching for 30 min. + CosmosDataEncryptionKeyProvider dekProvider = new(testEncryptionKeyStoreProvider); + await dekProvider.InitializeAsync(database, keyContainer.Id); + + TestEncryptor encryptor = new(dekProvider); + Container encryptionContainer = itemContainer.WithEncryptor(encryptor); + for (int i = 0; i < 2; i++) + await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + testEncryptionKeyStoreProvider.UnWrapKeyCallsCount.TryGetValue(masterKeyUri1.ToString(), out int unwrapcount); + Assert.AreEqual(1, unwrapcount); + + testEncryptionKeyStoreProvider = new TestEncryptionKeyStoreProvider + { + DataEncryptionKeyCacheTimeToLive = TimeSpan.Zero + }; + + // No caching + dekProvider = new CosmosDataEncryptionKeyProvider(testEncryptionKeyStoreProvider); + await dekProvider.InitializeAsync(database, keyContainer.Id); + + encryptor = new TestEncryptor(dekProvider); + encryptionContainer = itemContainer.WithEncryptor(encryptor); + for (int i = 0; i < 2; i++) + await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + testEncryptionKeyStoreProvider.UnWrapKeyCallsCount.TryGetValue(masterKeyUri1.ToString(), out unwrapcount); + Assert.AreEqual(4, unwrapcount); + + // 2 hours default + testEncryptionKeyStoreProvider = new TestEncryptionKeyStoreProvider(); + + dekProvider = new CosmosDataEncryptionKeyProvider(testEncryptionKeyStoreProvider); + await dekProvider.InitializeAsync(database, keyContainer.Id); + + encryptor = new TestEncryptor(dekProvider); + encryptionContainer = itemContainer.WithEncryptor(encryptor); + for (int i = 0; i < 2; i++) + await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + testEncryptionKeyStoreProvider.UnWrapKeyCallsCount.TryGetValue(masterKeyUri1.ToString(), out unwrapcount); + Assert.AreEqual(1, unwrapcount); + + await dekProvider.Container.DeleteItemAsync< DataEncryptionKeyProperties>(dekId, new PartitionKey(dekId)); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionReadManyItemAsync(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + TestDoc testDoc2 = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + List<(string, PartitionKey)> itemList = new() + { + (testDoc.Id, new PartitionKey(testDoc.PK)), + (testDoc2.Id, new PartitionKey(testDoc2.PK)) + }; + + FeedResponse response = await encryptionContainer.ReadManyItemsAsync(itemList); + + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + Assert.AreEqual(2, response.Count); + VerifyExpectedDocResponse(testDoc, response.Resource.ElementAt(0)); + VerifyExpectedDocResponse(testDoc2, response.Resource.ElementAt(1)); + + // stream test. + ResponseMessage responseStream = await encryptionContainer.ReadManyItemsStreamAsync(itemList); + + Assert.IsTrue(responseStream.IsSuccessStatusCode); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + + JObject contentJObjects = TestCommon.FromStream(responseStream.Content); + + if (contentJObjects.SelectToken(Constants.DocumentsResourcePropertyName) is JArray documents) + { + VerifyExpectedDocResponse(testDoc, documents.ElementAt(0).ToObject()); + VerifyExpectedDocResponse(testDoc2, documents.ElementAt(1).ToObject()); + } + else + { + Assert.Fail("ResponseMessage from ReadManyItemsStreamAsync did not have a valid response. "); + } + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionCreateItem(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + await VerifyItemByReadAsync(encryptionContainer, testDoc); + + await VerifyItemByReadStreamAsync(encryptionContainer, testDoc); + + TestDoc expectedDoc = new(testDoc); + + // Read feed (null query) + await MdeCustomEncryptionTestsWithSystemText.ValidateQueryResultsAsync( + MdeCustomEncryptionTestsWithSystemText.encryptionContainer, + query: null, + expectedDoc); + + await ValidateQueryResultsAsync( + encryptionContainer, + "SELECT * FROM c", + expectedDoc); + + await ValidateQueryResultsAsync( + encryptionContainer, + string.Format( + "SELECT * FROM c where c.PK = '{0}' and c.id = '{1}' and c.NonSensitive = '{2}'", + expectedDoc.PK, + expectedDoc.Id, + expectedDoc.NonSensitive), + expectedDoc); + + await ValidateQueryResultsAsync( + encryptionContainer, + string.Format("SELECT * FROM c where c.Sensitive_IntFormat = '{0}'", testDoc.Sensitive_IntFormat), + expectedDoc: null); + + await ValidateQueryResultsAsync( + encryptionContainer, + queryDefinition: new QueryDefinition( + "select * from c where c.id = @theId and c.PK = @thePK") + .WithParameter("@theId", expectedDoc.Id) + .WithParameter("@thePK", expectedDoc.PK), + expectedDoc: expectedDoc); + + expectedDoc.Sensitive_NestedObjectFormatL1 = null; + expectedDoc.Sensitive_ArrayFormat = null; + expectedDoc.Sensitive_DecimalFormat = 0; + expectedDoc.Sensitive_IntFormat = 0; + expectedDoc.Sensitive_FloatFormat = 0; + expectedDoc.Sensitive_BoolFormat = false; + expectedDoc.Sensitive_StringFormat = null; + expectedDoc.Sensitive_DateFormat = new DateTime(); + + await ValidateQueryResultsAsync( + encryptionContainer, + "SELECT c.id, c.PK, c.NonSensitive FROM c", + expectedDoc); + } + + [TestMethod] + [ExpectedException(typeof(InvalidOperationException), "Decryptable content is not initialized.")] + public void ValidateDecryptableContent() + { + TestDoc testDoc = TestDoc.Create(); + EncryptableItem encryptableItem = new(testDoc); + encryptableItem.DecryptableItem.GetItemAsync(); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionCreateItemWithLazyDecryption(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc = TestDoc.Create(); + ItemResponse> createResponse = await encryptionContainer.CreateItemAsync( + new EncryptableItemStream(testDoc), + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)); + + Assert.AreEqual(HttpStatusCode.Created, createResponse.StatusCode); + Assert.IsNotNull(createResponse.Resource); + + await ValidateDecryptableItem(createResponse.Resource.DecryptableItem, testDoc); + + // stream + TestDoc testDoc1 = TestDoc.Create(); + ItemResponse> createResponseStream = await encryptionContainer.CreateItemAsync( + new EncryptableItemStream(TestCommon.ToStream(testDoc1)), + new PartitionKey(testDoc1.PK), + GetRequestOptions(dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)); + + Assert.AreEqual(HttpStatusCode.Created, createResponseStream.StatusCode); + Assert.IsNotNull(createResponseStream.Resource); + + await ValidateDecryptableItem(createResponseStream.Resource.DecryptableItem, testDoc1); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionChangeFeedDecryptionSuccessful(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + string dek2 = "dek2ForChangeFeed"; + await CreateDekAsync(dekProvider, dek2); + + TestDoc testDoc1 = await CreateItemAsync(encryptionContainerForChangeFeed, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + TestDoc testDoc2 = await CreateItemAsync(encryptionContainerForChangeFeed, dek2, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + // change feed iterator + await ValidateChangeFeedIteratorResponse(encryptionContainerForChangeFeed, testDoc1, testDoc2); + + // change feed processor + await ValidateChangeFeedProcessorResponse(encryptionContainerForChangeFeed, testDoc1, testDoc2); + + // change feed processor with feed handler + await ValidateChangeFeedProcessorWithFeedHandlerResponse(encryptionContainerForChangeFeed, testDoc1, testDoc2); + + // change feed processor with manual checkpoint + await ValidateChangeFeedProcessorWithManualCheckpointResponse(encryptionContainerForChangeFeed, testDoc1, testDoc2); + + // change feed processor with feed stream handler + await ValidateChangeFeedProcessorWithFeedStreamHandlerResponse(encryptionContainerForChangeFeed, testDoc1, testDoc2); + + // change feed processor manual checkpoint with feed stream handler + await ValidateChangeFeedProcessorStreamWithManualCheckpointResponse(encryptionContainerForChangeFeed, testDoc1, testDoc2); + + await dekProvider.Container.DeleteItemAsync(dek2, new PartitionKey(dek2)); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionHandleDecryptionFailure(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + string dek2 = "failDek"; + await CreateDekAsync(dekProvider, dek2); + + TestDoc testDoc1 = await CreateItemAsync(encryptionContainer, dek2, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + TestDoc testDoc2 = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + string query = $"SELECT * FROM c WHERE c.PK in ('{testDoc1.PK}', '{testDoc2.PK}')"; + + // success + await ValidateQueryResultsMultipleDocumentsAsync(encryptionContainer, testDoc1, testDoc2, query); + + // induce failure + encryptor.FailDecryption = true; + + FeedIterator queryResponseIterator = encryptionContainer.GetItemQueryIterator(query); + FeedResponse readDocsLazily = await queryResponseIterator.ReadNextAsync(); + await ValidateLazyDecryptionResponse(readDocsLazily.GetEnumerator(), dek2); + + // validate changeFeed handling + FeedIterator changeIterator = encryptionContainer.GetChangeFeedIterator( + ChangeFeedStartFrom.Beginning(), + ChangeFeedMode.Incremental); + + while (changeIterator.HasMoreResults) + { + readDocsLazily = await changeIterator.ReadNextAsync(); + if (readDocsLazily.StatusCode == HttpStatusCode.NotModified) + { + break; + } + + if (readDocsLazily.Resource != null) + { + await ValidateLazyDecryptionResponse(readDocsLazily.GetEnumerator(), dek2); + } + } + + // validate changeFeedProcessor handling + Container leaseContainer = await database.CreateContainerIfNotExistsAsync( + new ContainerProperties(id: "leasesContainer", partitionKeyPath: "/id")); + + List changeFeedReturnedDocs = new(); + ChangeFeedProcessor cfp = encryptionContainer.GetChangeFeedProcessorBuilder( + "testCFPFailure", + (IReadOnlyCollection changes, CancellationToken cancellationToken) => + { + changeFeedReturnedDocs.AddRange(changes); + return Task.CompletedTask; + }) + .WithInstanceName("dummy") + .WithLeaseContainer(leaseContainer) + .WithStartTime(DateTime.MinValue.ToUniversalTime()) + .Build(); + + await cfp.StartAsync(); + await Task.Delay(2000); + await cfp.StopAsync(); + + Assert.IsTrue(changeFeedReturnedDocs.Count >= 2); + await ValidateLazyDecryptionResponse(changeFeedReturnedDocs.GetEnumerator(), dek2); + + encryptor.FailDecryption = false; + + await dekProvider.Container.DeleteItemAsync(dek2, new PartitionKey(dek2)); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionDecryptQueryResultMultipleDocs(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc1 = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + TestDoc testDoc2 = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + // test GetItemLinqQueryable + await ValidateQueryResultsMultipleDocumentsAsync(encryptionContainer, testDoc1, testDoc2, null); + + string query = $"SELECT * FROM c WHERE c.PK in ('{testDoc1.PK}', '{testDoc2.PK}')"; + await ValidateQueryResultsMultipleDocumentsAsync(encryptionContainer, testDoc1, testDoc2, query); + + // ORDER BY query + query += " ORDER BY c._ts"; + await ValidateQueryResultsMultipleDocumentsAsync(encryptionContainer, testDoc1, testDoc2, query); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionDecryptQueryResultMultipleEncryptedProperties(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + List pathsEncrypted = new() { "/Sensitive_StringFormat", "/NonSensitive" }; + TestDoc testDoc = await CreateItemAsync( + encryptionContainer, + dekId, + pathsEncrypted, + jsonProcessor, + compressionAlgorithm); + + TestDoc expectedDoc = new(testDoc); + + await ValidateQueryResultsAsync( + encryptionContainer, + "SELECT * FROM c", + expectedDoc, + pathsEncrypted: pathsEncrypted); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionDecryptQueryValueResponse(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + string query = "SELECT VALUE COUNT(1) FROM c"; + + await ValidateQueryResponseAsync(encryptionContainer, query); + await ValidateQueryResponseWithLazyDecryptionAsync(encryptionContainer, query); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionDecryptGroupByQueryResultTest(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + string partitionKey = Guid.NewGuid().ToString(); + + await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + + string query = $"SELECT COUNT(c.Id), c.PK " + + $"FROM c WHERE c.PK = '{partitionKey}' " + + $"GROUP BY c.PK "; + + await ValidateQueryResponseAsync(encryptionContainer, query); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionStreamIteratorValidation(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + // test GetItemLinqQueryable with ToEncryptionStreamIterator extension + await ValidateQueryResponseAsync(encryptionContainer); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionRudItem(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc = await UpsertItemAsync( + encryptionContainer, + TestDoc.Create(), + dekId, + TestDoc.PathsToEncrypt, + HttpStatusCode.Created, + jsonProcessor, + compressionAlgorithm); + + await VerifyItemByReadAsync(encryptionContainer, testDoc); + + testDoc.NonSensitive = Guid.NewGuid().ToString(); + testDoc.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + ItemResponse upsertResponse = await UpsertItemAsync( + encryptionContainer, + testDoc, + dekId, + TestDoc.PathsToEncrypt, + HttpStatusCode.OK, + jsonProcessor, + compressionAlgorithm); + TestDoc updatedDoc = upsertResponse.Resource; + + await VerifyItemByReadAsync(encryptionContainer, updatedDoc); + + updatedDoc.NonSensitive = Guid.NewGuid().ToString(); + updatedDoc.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TestDoc replacedDoc = await ReplaceItemAsync( + encryptionContainer, + updatedDoc, + dekId, + TestDoc.PathsToEncrypt, + jsonProcessor, + compressionAlgorithm, + upsertResponse.ETag); + + await VerifyItemByReadAsync(encryptionContainer, replacedDoc); + + await DeleteItemAsync(encryptionContainer, replacedDoc); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionRudItemLazyDecryption(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc = TestDoc.Create(); + // Upsert (item doesn't exist) + ItemResponse> upsertResponse = await encryptionContainer.UpsertItemAsync( + new EncryptableItemStream(testDoc), + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)); + + Assert.AreEqual(HttpStatusCode.Created, upsertResponse.StatusCode); + Assert.IsNotNull(upsertResponse.Resource); + + await ValidateDecryptableItem(upsertResponse.Resource.DecryptableItem, testDoc); + await VerifyItemByReadAsync(encryptionContainer, testDoc); + + // Upsert with stream (item exists) + testDoc.NonSensitive = Guid.NewGuid().ToString(); + testDoc.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + ItemResponse> upsertResponseStream = await encryptionContainer.UpsertItemAsync( + new EncryptableItemStream(TestCommon.ToStream(testDoc)), + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)); + + Assert.AreEqual(HttpStatusCode.OK, upsertResponseStream.StatusCode); + Assert.IsNotNull(upsertResponseStream.Resource); + + await ValidateDecryptableItem(upsertResponseStream.Resource.DecryptableItem, testDoc); + await VerifyItemByReadAsync(encryptionContainer, testDoc); + + // replace + testDoc.NonSensitive = Guid.NewGuid().ToString(); + testDoc.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + ItemResponse> replaceResponseStream = await encryptionContainer.ReplaceItemAsync( + new EncryptableItemStream(TestCommon.ToStream(testDoc)), + testDoc.Id, + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, upsertResponseStream.ETag)); + + Assert.AreEqual(HttpStatusCode.OK, replaceResponseStream.StatusCode); + Assert.IsNotNull(replaceResponseStream.Resource); + + await ValidateDecryptableItem(replaceResponseStream.Resource.DecryptableItem, testDoc); + await VerifyItemByReadAsync(encryptionContainer, testDoc); + + await DeleteItemAsync(encryptionContainer, testDoc); + } + + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionResourceTokenAuthRestricted(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + User restrictedUser = database.GetUser(Guid.NewGuid().ToString()); + await database.CreateUserAsync(restrictedUser.Id); + + PermissionProperties restrictedUserPermission = await restrictedUser.CreatePermissionAsync( + new PermissionProperties(Guid.NewGuid().ToString(), PermissionMode.All, itemContainer)); + + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + TestEncryptor encryptor = new(dekProvider); + + CosmosClient clientForRestrictedUser = TestCommon.CreateCosmosClient( + restrictedUserPermission.Token); + + Database databaseForRestrictedUser = clientForRestrictedUser.GetDatabase(database.Id); + Container containerForRestrictedUser = databaseForRestrictedUser.GetContainer(itemContainer.Id); + + Container encryptionContainerForRestrictedUser = containerForRestrictedUser.WithEncryptor(encryptor); + + await PerformForbiddenOperationAsync(() => + dekProvider.InitializeAsync(databaseForRestrictedUser, keyContainer.Id), "CosmosDekProvider.InitializeAsync"); + + await PerformOperationOnUninitializedDekProviderAsync(() => + dekProvider.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId), "DEK.ReadAsync"); + + try + { + await encryptionContainerForRestrictedUser.ReadItemAsync(testDoc.Id, new PartitionKey(testDoc.PK)); + } + catch (InvalidOperationException ex) + { + Assert.AreEqual(ex.Message, "The CosmosDataEncryptionKeyProvider was not initialized."); + } + + try + { + await encryptionContainerForRestrictedUser.ReadItemStreamAsync(testDoc.Id, new PartitionKey(testDoc.PK)); + } + catch (InvalidOperationException ex) + { + Assert.AreEqual(ex.Message, "The CosmosDataEncryptionKeyProvider was not initialized."); + } + } + + [TestMethod] + public async Task EncryptionResourceTokenAuthAllowed() + { + User keyManagerUser = database.GetUser(Guid.NewGuid().ToString()); + await database.CreateUserAsync(keyManagerUser.Id); + + PermissionProperties keyManagerUserPermission = await keyManagerUser.CreatePermissionAsync( + new PermissionProperties(Guid.NewGuid().ToString(), PermissionMode.All, keyContainer)); + + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + TestEncryptor encryptor = new(dekProvider); + CosmosClient clientForKeyManagerUser = TestCommon.CreateCosmosClient(keyManagerUserPermission.Token); + + Database databaseForKeyManagerUser = clientForKeyManagerUser.GetDatabase(database.Id); + + await dekProvider.InitializeAsync(databaseForKeyManagerUser, keyContainer.Id); + + DataEncryptionKeyProperties readDekProperties = await dekProvider.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dekProperties, readDekProperties); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionRestrictedProperties(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + try + { + await CreateItemAsync(encryptionContainer, dekId, new List() { "/id" }, jsonProcessor, compressionAlgorithm); + Assert.Fail("Expected item creation with id specified to be encrypted to fail."); + } + catch (InvalidOperationException ex) + { + Assert.AreEqual("PathsToEncrypt includes a invalid path: '/id'.", ex.Message); + } + + try + { + await CreateItemAsync(encryptionContainer, dekId, new List() { "/PK" }, jsonProcessor, compressionAlgorithm); + Assert.Fail("Expected item creation with PK specified to be encrypted to fail."); + } + catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.BadRequest) + { + } + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionBulkCrud(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc docToReplace = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + docToReplace.NonSensitive = Guid.NewGuid().ToString(); + docToReplace.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TestDoc docToUpsert = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + docToUpsert.NonSensitive = Guid.NewGuid().ToString(); + docToUpsert.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TestDoc docToDelete = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + CosmosClient clientWithBulk = TestCommon.CreateCosmosClient(builder => builder + .WithBulkExecution(true) + .Build()); + + Database databaseWithBulk = clientWithBulk.GetDatabase(database.Id); + Container containerWithBulk = databaseWithBulk.GetContainer(itemContainer.Id); + Container encryptionContainerWithBulk = containerWithBulk.WithEncryptor(encryptor); + + List tasks = new() + { + CreateItemAsync(encryptionContainerWithBulk, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm), + UpsertItemAsync(encryptionContainerWithBulk, TestDoc.Create(), dekId, TestDoc.PathsToEncrypt, HttpStatusCode.Created, jsonProcessor, compressionAlgorithm), + ReplaceItemAsync(encryptionContainerWithBulk, docToReplace, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm), + UpsertItemAsync(encryptionContainerWithBulk, docToUpsert, dekId, TestDoc.PathsToEncrypt, HttpStatusCode.OK, jsonProcessor, compressionAlgorithm), + DeleteItemAsync(encryptionContainerWithBulk, docToDelete) + }; + + await Task.WhenAll(tasks); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionTransactionBatchCrud(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + string partitionKey = "thePK"; + string dek1 = dekId; + string dek2 = "dek2Forbatch"; + await CreateDekAsync(dekProvider, dek2); + + TestDoc doc1ToCreate = TestDoc.Create(partitionKey); + TestDoc doc2ToCreate = TestDoc.Create(partitionKey); + TestDoc doc3ToCreate = TestDoc.Create(partitionKey); + TestDoc doc4ToCreate = TestDoc.Create(partitionKey); + + ItemResponse doc1ToReplaceCreateResponse = await CreateItemAsync(encryptionContainer, dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + TestDoc doc1ToReplace = doc1ToReplaceCreateResponse.Resource; + doc1ToReplace.NonSensitive = Guid.NewGuid().ToString(); + doc1ToReplace.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TestDoc doc2ToReplace = await CreateItemAsync(encryptionContainer, dek2, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + doc2ToReplace.NonSensitive = Guid.NewGuid().ToString(); + doc2ToReplace.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TestDoc doc1ToUpsert = await CreateItemAsync(encryptionContainer, dek2, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + doc1ToUpsert.NonSensitive = Guid.NewGuid().ToString(); + doc1ToUpsert.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TestDoc doc2ToUpsert = await CreateItemAsync(encryptionContainer, dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + doc2ToUpsert.NonSensitive = Guid.NewGuid().ToString(); + doc2ToUpsert.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TestDoc docToDelete = await CreateItemAsync(encryptionContainer, dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + + TransactionalBatchResponse batchResponse = await encryptionContainer.CreateTransactionalBatch(new PartitionKey(partitionKey)) + .CreateItem(doc1ToCreate, GetBatchItemRequestOptions(dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)) + .CreateItemStream(doc2ToCreate.ToStream(), GetBatchItemRequestOptions(dek2, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)) + .ReplaceItem(doc1ToReplace.Id, doc1ToReplace, GetBatchItemRequestOptions(dek2, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, doc1ToReplaceCreateResponse.ETag)) + .CreateItem(doc3ToCreate) + .CreateItem(doc4ToCreate, GetBatchItemRequestOptions(dek1, new List(), jsonProcessor, compressionAlgorithm)) // empty PathsToEncrypt list + .ReplaceItemStream(doc2ToReplace.Id, doc2ToReplace.ToStream(), GetBatchItemRequestOptions(dek2, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)) + .UpsertItem(doc1ToUpsert, GetBatchItemRequestOptions(dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)) + .DeleteItem(docToDelete.Id) + .UpsertItemStream(doc2ToUpsert.ToStream(), GetBatchItemRequestOptions(dek2, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)) + .ExecuteAsync(); + + Assert.AreEqual(HttpStatusCode.OK, batchResponse.StatusCode); + + TransactionalBatchOperationResult doc1 = batchResponse.GetOperationResultAtIndex(0); + VerifyExpectedDocResponse(doc1ToCreate, doc1.Resource); + + TransactionalBatchOperationResult doc2 = batchResponse.GetOperationResultAtIndex(1); + VerifyExpectedDocResponse(doc2ToCreate, doc2.Resource); + + TransactionalBatchOperationResult doc3 = batchResponse.GetOperationResultAtIndex(2); + VerifyExpectedDocResponse(doc1ToReplace, doc3.Resource); + + TransactionalBatchOperationResult doc4 = batchResponse.GetOperationResultAtIndex(3); + VerifyExpectedDocResponse(doc3ToCreate, doc4.Resource); + + TransactionalBatchOperationResult doc5 = batchResponse.GetOperationResultAtIndex(4); + VerifyExpectedDocResponse(doc4ToCreate, doc5.Resource); + + TransactionalBatchOperationResult doc6 = batchResponse.GetOperationResultAtIndex(5); + VerifyExpectedDocResponse(doc2ToReplace, doc6.Resource); + + TransactionalBatchOperationResult doc7 = batchResponse.GetOperationResultAtIndex(6); + VerifyExpectedDocResponse(doc1ToUpsert, doc7.Resource); + + TransactionalBatchOperationResult doc8 = batchResponse.GetOperationResultAtIndex(8); + VerifyExpectedDocResponse(doc2ToUpsert, doc8.Resource); + + await VerifyItemByReadAsync(encryptionContainer, doc1ToCreate); + await VerifyItemByReadAsync(encryptionContainer, doc2ToCreate, dekId: dek2); + await VerifyItemByReadAsync(encryptionContainer, doc3ToCreate, isDocDecrypted: false); + await VerifyItemByReadAsync(encryptionContainer, doc4ToCreate, isDocDecrypted: false); + await VerifyItemByReadAsync(encryptionContainer, doc1ToReplace, dekId: dek2); + await VerifyItemByReadAsync(encryptionContainer, doc2ToReplace, dekId: dek2); + await VerifyItemByReadAsync(encryptionContainer, doc1ToUpsert); + await VerifyItemByReadAsync(encryptionContainer, doc2ToUpsert, dekId: dek2); + + ResponseMessage readResponseMessage = await encryptionContainer.ReadItemStreamAsync(docToDelete.Id, new PartitionKey(docToDelete.PK)); + Assert.AreEqual(HttpStatusCode.NotFound, readResponseMessage.StatusCode); + + // doc3ToCreate, doc4ToCreate wasn't encrypted + await VerifyItemByReadAsync(itemContainer, doc3ToCreate); + await VerifyItemByReadAsync(itemContainer, doc4ToCreate); + + await dekProvider.Container.DeleteItemAsync(dek2, new PartitionKey(dek2)); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionTransactionalBatchWithCustomSerializer(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + CustomSerializer customSerializer = new(); + CosmosClient clientWithCustomSerializer = TestCommon.CreateCosmosClient(builder => builder + .WithCustomSerializer(customSerializer) + .Build()); + + Database databaseWithCustomSerializer = clientWithCustomSerializer.GetDatabase(database.Id); + Container containerWithCustomSerializer = databaseWithCustomSerializer.GetContainer(itemContainer.Id); + Container encryptionContainerWithCustomSerializer = containerWithCustomSerializer.WithEncryptor(encryptor, JsonProcessor.Stream); + + string partitionKey = "thePK"; + string dek1 = dekId; + + TestDoc doc1ToCreate = TestDoc.Create(partitionKey); + + ItemResponse doc1ToReplaceCreateResponse = await CreateItemAsync(encryptionContainerWithCustomSerializer, dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + TestDoc doc1ToReplace = doc1ToReplaceCreateResponse.Resource; + doc1ToReplace.NonSensitive = Guid.NewGuid().ToString(); + doc1ToReplace.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TransactionalBatchResponse batchResponse = await encryptionContainerWithCustomSerializer.CreateTransactionalBatch(new PartitionKey(partitionKey)) + .CreateItem(doc1ToCreate, GetBatchItemRequestOptions(dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)) + .ReplaceItem(doc1ToReplace.Id, doc1ToReplace, GetBatchItemRequestOptions(dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, doc1ToReplaceCreateResponse.ETag)) + .ExecuteAsync(); + + Assert.AreEqual(HttpStatusCode.OK, batchResponse.StatusCode); + // FromStream is called as part of CreateItem request + Assert.AreEqual(1, customSerializer.FromStreamCalled); + + TransactionalBatchOperationResult doc1 = batchResponse.GetOperationResultAtIndex(0); + VerifyExpectedDocResponse(doc1ToCreate, doc1.Resource); + Assert.AreEqual(2, customSerializer.FromStreamCalled); + + TransactionalBatchOperationResult doc2 = batchResponse.GetOperationResultAtIndex(1); + VerifyExpectedDocResponse(doc1ToReplace, doc2.Resource); + Assert.AreEqual(3, customSerializer.FromStreamCalled); + + await VerifyItemByReadAsync(encryptionContainerWithCustomSerializer, doc1ToCreate); + await VerifyItemByReadAsync(encryptionContainerWithCustomSerializer, doc1ToReplace); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task VerifyDekOperationWithSystemTextSerializer(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + System.Text.Json.JsonSerializerOptions jsonSerializerOptions = new() + { + DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull + }; + + CosmosSystemTextJsonSerializer cosmosSystemTextJsonSerializer = new(jsonSerializerOptions); + + CosmosClient clientWithCosmosSystemTextJsonSerializer = TestCommon.CreateCosmosClient(builder => builder + .WithCustomSerializer(cosmosSystemTextJsonSerializer) + .Build()); + + // get database and container + Database databaseWithCosmosSystemTextJsonSerializer = clientWithCosmosSystemTextJsonSerializer.GetDatabase(database.Id); + Container containerWithCosmosSystemTextJsonSerializer = databaseWithCosmosSystemTextJsonSerializer.GetContainer(itemContainer.Id); + + // create the Dek container + Container dekContainerWithCosmosSystemTextJsonSerializer = await databaseWithCosmosSystemTextJsonSerializer.CreateContainerAsync(Guid.NewGuid().ToString(), "/id", 400); + + CosmosDataEncryptionKeyProvider dekProviderWithCosmosSystemTextJsonSerializer = new(new TestEncryptionKeyStoreProvider()); + await dekProviderWithCosmosSystemTextJsonSerializer.InitializeAsync(databaseWithCosmosSystemTextJsonSerializer, dekContainerWithCosmosSystemTextJsonSerializer.Id); + + TestEncryptor encryptorWithCosmosSystemTextJsonSerializer = new(dekProviderWithCosmosSystemTextJsonSerializer); + + // enable encryption on container + Container encryptionContainerWithCosmosSystemTextJsonSerializer = containerWithCosmosSystemTextJsonSerializer.WithEncryptor(encryptorWithCosmosSystemTextJsonSerializer); + + string dekId = "dekWithSystemTextJson"; + DataEncryptionKeyProperties dekProperties = await CreateDekAsync(dekProviderWithCosmosSystemTextJsonSerializer, dekId); + Assert.AreEqual( + new EncryptionKeyWrapMetadata(name: "metadata1", value: metadata1.Value), + dekProperties.EncryptionKeyWrapMetadata); + + // Use different DEK provider to avoid (unintentional) cache impact + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(databaseWithCosmosSystemTextJsonSerializer, dekContainerWithCosmosSystemTextJsonSerializer.Id); + DataEncryptionKeyProperties readProperties = await dekProviderWithCosmosSystemTextJsonSerializer.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dekProperties, readProperties); + + // rewrap + ItemResponse dekResponse = await dekProviderWithCosmosSystemTextJsonSerializer.DataEncryptionKeyContainer.RewrapDataEncryptionKeyAsync( + dekId, + metadata2); + + Assert.AreEqual(HttpStatusCode.OK, dekResponse.StatusCode); + dekProperties = VerifyDekResponse( + dekResponse, + dekId); + Assert.AreEqual( + metadata2, + dekProperties.EncryptionKeyWrapMetadata); + + readProperties = await dekProviderWithCosmosSystemTextJsonSerializer.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dekProperties, readProperties); + + TestDocSystemText testDocSystemText = new() + { + Id = Guid.NewGuid().ToString(), + ActivityId = Guid.NewGuid().ToString(), + PartitionKey = "myPartitionKey", + Status = "Active" + }; + + // Create items that use System.Text.Json serialization attributes + ItemResponse createTestDoc = await encryptionContainerWithCosmosSystemTextJsonSerializer.CreateItemAsync( + testDocSystemText, + new PartitionKey(testDocSystemText.PartitionKey), + GetRequestOptions(dekId, new List() { "/status" }, jsonProcessor, compressionAlgorithm, legacyAlgo: false)); + + Assert.AreEqual(HttpStatusCode.Created, createTestDoc.StatusCode); + + string contosoV1 = "Contoso_v001"; + string contosoV2 = "Contoso_v002"; + string fabrikamV1 = "Fabrikam_v001"; + string fabrikamV2 = "Fabrikam_v002"; + + await CreateDekAsync(dekProviderWithCosmosSystemTextJsonSerializer, contosoV1); + await CreateDekAsync(dekProviderWithCosmosSystemTextJsonSerializer, contosoV2); + await CreateDekAsync(dekProviderWithCosmosSystemTextJsonSerializer, fabrikamV1); + await CreateDekAsync(dekProviderWithCosmosSystemTextJsonSerializer, fabrikamV2); + + // Test getting all keys + await IterateDekFeedAsync( + dekProviderWithCosmosSystemTextJsonSerializer, + new List { dekId, contosoV1, contosoV2, fabrikamV1, fabrikamV2 }, + isExpectedDeksCompleteSetForRequest: true, + isResultOrderExpected: false, + "SELECT * from c"); + + // Test getting specific subset of keys + await IterateDekFeedAsync( + dekProviderWithCosmosSystemTextJsonSerializer, + new List { contosoV2 }, + isExpectedDeksCompleteSetForRequest: false, + isResultOrderExpected: true, + "SELECT TOP 1 * from c where c.id >= 'Contoso_v000' and c.id <= 'Contoso_v999' ORDER BY c.id DESC"); + + // Ensure only required results are returned + await IterateDekFeedAsync( + dekProviderWithCosmosSystemTextJsonSerializer, + new List { contosoV1, contosoV2 }, + isExpectedDeksCompleteSetForRequest: true, + isResultOrderExpected: true, + "SELECT * from c where c.id >= 'Contoso_v000' and c.id <= 'Contoso_v999' ORDER BY c.id ASC"); + + // Test pagination + await IterateDekFeedAsync( + dekProviderWithCosmosSystemTextJsonSerializer, + new List { dekId, contosoV1, contosoV2, fabrikamV1, fabrikamV2 }, + isExpectedDeksCompleteSetForRequest: true, + isResultOrderExpected: false, + "SELECT * from c", + itemCountInPage: 3); + + // cleanup + FeedIterator iterator = containerWithCosmosSystemTextJsonSerializer.GetItemQueryIterator(); + + while (iterator.HasMoreResults) + { + FeedResponse feedResponse = await iterator.ReadNextAsync(); + foreach (TestDocSystemText testDoc in feedResponse) + { + if (testDoc.Id == null) + { + continue; + } + await containerWithCosmosSystemTextJsonSerializer.DeleteItemAsync(testDoc.Id, new PartitionKey(testDoc.PartitionKey)); + } + } + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionTransactionalBatchConflictResponse(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + string partitionKey = "thePK"; + string dek1 = dekId; + + ItemResponse doc1CreatedResponse = await CreateItemAsync(encryptionContainer, dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, partitionKey); + TestDoc doc1ToCreateAgain = doc1CreatedResponse.Resource; + doc1ToCreateAgain.NonSensitive = Guid.NewGuid().ToString(); + doc1ToCreateAgain.Sensitive_StringFormat = Guid.NewGuid().ToString(); + + TransactionalBatchResponse batchResponse = await encryptionContainer.CreateTransactionalBatch(new PartitionKey(partitionKey)) + .CreateItem(doc1ToCreateAgain, GetBatchItemRequestOptions(dek1, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm)) + .ExecuteAsync(); + + Assert.AreEqual(HttpStatusCode.Conflict, batchResponse.StatusCode); + Assert.AreEqual(1, batchResponse.Count); + } + + // One of query or queryDefinition is to be passed in non-null + private static async Task ValidateQueryResultsAsync( + Container container, + string query = null, + TestDoc expectedDoc = null, + QueryDefinition queryDefinition = null, + List pathsEncrypted = null, + bool legacyAlgo = false) + { + QueryRequestOptions requestOptions = expectedDoc != null + ? new QueryRequestOptions() + { + PartitionKey = new PartitionKey(expectedDoc.PK), + } + : null; + + FeedIterator queryResponseIterator; + FeedIterator queryResponseIteratorForLazyDecryption; + if (query != null) + { + queryResponseIterator = container.GetItemQueryIterator(query, requestOptions: requestOptions); + queryResponseIteratorForLazyDecryption = container.GetItemQueryIterator(query, requestOptions: requestOptions); + } + else + { + queryResponseIterator = container.GetItemQueryIterator(queryDefinition, requestOptions: requestOptions); + queryResponseIteratorForLazyDecryption = container.GetItemQueryIterator(queryDefinition, requestOptions: requestOptions); + } + FeedResponse readDocs = await queryResponseIterator.ReadNextAsync(); + Assert.AreEqual(null, readDocs.ContinuationToken); + + FeedResponse readDocsLazily = await queryResponseIteratorForLazyDecryption.ReadNextAsync(); + Assert.AreEqual(null, readDocsLazily.ContinuationToken); + + if (expectedDoc != null) + { + Assert.AreEqual(1, readDocs.Count); + TestDoc readDoc = readDocs.Single(); + VerifyExpectedDocResponse(expectedDoc, readDoc); + + Assert.AreEqual(1, readDocsLazily.Count); + if (!legacyAlgo) + { + await ValidateDecryptableItem(readDocsLazily.First(), expectedDoc, pathsEncrypted: pathsEncrypted); + } + else + { + await ValidateDecryptableItem(readDocsLazily.First(), expectedDoc, dekId: legacydekId, pathsEncrypted: pathsEncrypted); + } + } + else + { + Assert.AreEqual(0, readDocs.Count); + } + } + + private static async Task ValidateQueryResultsMultipleDocumentsAsync( + Container container, + TestDoc testDoc1, + TestDoc testDoc2, + string query, + bool compareEncryptedProperty = true) + { + FeedIterator queryResponseIterator; + FeedIterator queryResponseIteratorForLazyDecryption; + + if (query == null) + { + IOrderedQueryable linqQueryable = container.GetItemLinqQueryable(); + queryResponseIterator = container.ToEncryptionFeedIterator(linqQueryable); + + IOrderedQueryable linqQueryableDecryptableItem = container.GetItemLinqQueryable(); + queryResponseIteratorForLazyDecryption = container.ToEncryptionFeedIterator(linqQueryableDecryptableItem); + } + else + { + queryResponseIterator = container.GetItemQueryIterator(query); + queryResponseIteratorForLazyDecryption = container.GetItemQueryIterator(query); + } + + FeedResponse readDocs = await queryResponseIterator.ReadNextAsync(); + Assert.AreEqual(null, readDocs.ContinuationToken); + + FeedResponse readDocsLazily = await queryResponseIteratorForLazyDecryption.ReadNextAsync(); + Assert.AreEqual(null, readDocsLazily.ContinuationToken); + + if (query == null) + { + Assert.IsTrue(readDocs.Count >= 2); + Assert.IsTrue(readDocsLazily.Count >= 2); + } + else + { + Assert.AreEqual(2, readDocs.Count); + Assert.AreEqual(2, readDocsLazily.Count); + } + + for (int index = 0; index < readDocs.Count; index++) + { + if (readDocs.ElementAt(index).Id.Equals(testDoc1.Id)) + { + if (compareEncryptedProperty) + { + VerifyExpectedDocResponse(readDocs.ElementAt(index), testDoc1); + } + else + { + testDoc1.EqualsExceptEncryptedProperty(readDocs.ElementAt(index)); + } + } + else if (readDocs.ElementAt(index).Id.Equals(testDoc2.Id)) + { + if (compareEncryptedProperty) + { + VerifyExpectedDocResponse(readDocs.ElementAt(index), testDoc2); + } + else + { + testDoc2.EqualsExceptEncryptedProperty(readDocs.ElementAt(index)); + } + } + } + } + + private static async Task ValidateQueryResponseAsync( + Container container, + string query = null) + { + FeedIterator feedIterator; + if (query == null) + { + IOrderedQueryable linqQueryable = container.GetItemLinqQueryable(); + feedIterator = container.ToEncryptionStreamIterator(linqQueryable); + } + else + { + feedIterator = container.GetItemQueryStreamIterator(query); + } + + while (feedIterator.HasMoreResults) + { + ResponseMessage response = await feedIterator.ReadNextAsync(); + Assert.IsTrue(response.IsSuccessStatusCode); + Assert.IsNull(response.ErrorMessage); + } + } + + private static async Task ValidateQueryResponseWithLazyDecryptionAsync(Container container, + string query = null) + { + FeedIterator queryResponseIteratorForLazyDecryption = container.GetItemQueryIterator(query); + FeedResponse readDocsLazily = await queryResponseIteratorForLazyDecryption.ReadNextAsync(); + Assert.AreEqual(null, readDocsLazily.ContinuationToken); + Assert.AreEqual(1, readDocsLazily.Count); + (dynamic readDoc, DecryptionContext decryptionContext) = await readDocsLazily.First().GetItemAsync(); + Assert.IsTrue((long)readDoc >= 1); + Assert.IsNull(decryptionContext); + } + + private static async Task ValidateChangeFeedIteratorResponse( + Container container, + TestDoc testDoc1, + TestDoc testDoc2) + { + FeedIterator changeIterator = container.GetChangeFeedIterator( + ChangeFeedStartFrom.Beginning(), + ChangeFeedMode.Incremental); + + while (changeIterator.HasMoreResults) + { + FeedResponse testDocs = await changeIterator.ReadNextAsync(); + if (testDocs.StatusCode == HttpStatusCode.NotModified) + { + break; + } + + Assert.AreEqual(testDocs.Count, 2); + + VerifyExpectedDocResponse(testDoc1, testDocs.Resource.ElementAt(0)); + VerifyExpectedDocResponse(testDoc2, testDocs.Resource.ElementAt(1)); + } + } + + private static async Task ValidateChangeFeedProcessorResponse( + Container container, + TestDoc testDoc1, + TestDoc testDoc2) + { + Database leaseDatabase = await client.CreateDatabaseAsync(Guid.NewGuid().ToString()); + Container leaseContainer = await leaseDatabase.CreateContainerIfNotExistsAsync( + new ContainerProperties(id: "leases", partitionKeyPath: "/id")); + ManualResetEvent allDocsProcessed = new(false); + int processedDocCount = 0; + + List changeFeedReturnedDocs = new(); + ChangeFeedProcessor cfp = container.GetChangeFeedProcessorBuilder( + "testCFP", + (IReadOnlyCollection changes, CancellationToken cancellationToken) => + { + changeFeedReturnedDocs.AddRange(changes); + processedDocCount += changes.Count; + if (processedDocCount == 2) + { + allDocsProcessed.Set(); + } + + return Task.CompletedTask; + }) + .WithInstanceName("random") + .WithLeaseContainer(leaseContainer) + .WithStartTime(DateTime.MinValue.ToUniversalTime()) + .Build(); + + await cfp.StartAsync(); + bool isStartOk = allDocsProcessed.WaitOne(60000); + await cfp.StopAsync(); + + Assert.AreEqual(2, changeFeedReturnedDocs.Count); + + VerifyExpectedDocResponse(testDoc1, changeFeedReturnedDocs[^2]); + VerifyExpectedDocResponse(testDoc2, changeFeedReturnedDocs[^1]); + + if (leaseDatabase != null) + { + using (await leaseDatabase.DeleteStreamAsync()) { } + } + } + + private static async Task ValidateChangeFeedProcessorWithFeedHandlerResponse( + Container container, + TestDoc testDoc1, + TestDoc testDoc2) + { + Database leaseDatabase = await client.CreateDatabaseAsync(Guid.NewGuid().ToString()); + Container leaseContainer = await leaseDatabase.CreateContainerIfNotExistsAsync( + new ContainerProperties(id: "leases", partitionKeyPath: "/id")); + ManualResetEvent allDocsProcessed = new(false); + int processedDocCount = 0; + + List changeFeedReturnedDocs = new(); + ChangeFeedProcessor cfp = container.GetChangeFeedProcessorBuilder( + "testCFPWithFeedHandler", + ( + ChangeFeedProcessorContext context, + IReadOnlyCollection changes, + CancellationToken cancellationToken) => + { + changeFeedReturnedDocs.AddRange(changes); + processedDocCount += changes.Count; + if (processedDocCount == 2) + { + allDocsProcessed.Set(); + } + + return Task.CompletedTask; + }) + .WithInstanceName("random") + .WithLeaseContainer(leaseContainer) + .WithStartTime(DateTime.MinValue.ToUniversalTime()) + .Build(); + + await cfp.StartAsync(); + bool isStartOk = allDocsProcessed.WaitOne(60000); + await cfp.StopAsync(); + + Assert.AreEqual(changeFeedReturnedDocs.Count, 2); + + VerifyExpectedDocResponse(testDoc1, changeFeedReturnedDocs[^2]); + VerifyExpectedDocResponse(testDoc2, changeFeedReturnedDocs[^1]); + + if (leaseDatabase != null) + { + using (await leaseDatabase.DeleteStreamAsync()) { } + } + } + + private static async Task ValidateChangeFeedProcessorWithManualCheckpointResponse( + Container container, + TestDoc testDoc1, + TestDoc testDoc2) + { + Database leaseDatabase = await client.CreateDatabaseAsync(Guid.NewGuid().ToString()); + Container leaseContainer = await leaseDatabase.CreateContainerIfNotExistsAsync( + new ContainerProperties(id: "leases", partitionKeyPath: "/id")); + ManualResetEvent allDocsProcessed = new(false); + int processedDocCount = 0; + + List changeFeedReturnedDocs = new(); + ChangeFeedProcessor cfp = container.GetChangeFeedProcessorBuilderWithManualCheckpoint( + "testCFPWithManualCheckpoint", + ( + ChangeFeedProcessorContext context, + IReadOnlyCollection changes, + Func tryCheckpointAsync, + CancellationToken cancellationToken) => + { + changeFeedReturnedDocs.AddRange(changes); + processedDocCount += changes.Count; + if (processedDocCount == 2) + { + allDocsProcessed.Set(); + } + + return Task.CompletedTask; + }) + .WithInstanceName("random") + .WithLeaseContainer(leaseContainer) + .WithStartTime(DateTime.MinValue.ToUniversalTime()) + .Build(); + + await cfp.StartAsync(); + bool isStartOk = allDocsProcessed.WaitOne(60000); + await cfp.StopAsync(); + + Assert.AreEqual(changeFeedReturnedDocs.Count, 2); + + VerifyExpectedDocResponse(testDoc1, changeFeedReturnedDocs[^2]); + VerifyExpectedDocResponse(testDoc2, changeFeedReturnedDocs[^1]); + + if (leaseDatabase != null) + { + using (await leaseDatabase.DeleteStreamAsync()) { } + } + } + + private static async Task ValidateChangeFeedProcessorWithFeedStreamHandlerResponse( + Container container, + TestDoc testDoc1, + TestDoc testDoc2) + { + Database leaseDatabase = await client.CreateDatabaseAsync(Guid.NewGuid().ToString()); + Container leaseContainer = await leaseDatabase.CreateContainerIfNotExistsAsync( + new ContainerProperties(id: "leases", partitionKeyPath: "/id")); + ManualResetEvent allDocsProcessed = new(false); + int processedDocCount = 0; + + ChangeFeedProcessor cfp = container.GetChangeFeedProcessorBuilder( + "testCFPWithFeedStreamHandler", + ( +context, +changes, +cancellationToken) => + { + string changeFeed = string.Empty; + using (StreamReader streamReader = new(changes)) + { + changeFeed = streamReader.ReadToEnd(); + } + + if (changeFeed.Contains(testDoc1.Id)) + { + processedDocCount++; + } + + if (changeFeed.Contains(testDoc2.Id)) + { + processedDocCount++; + } + + if (processedDocCount == 2) + { + allDocsProcessed.Set(); + } + + return Task.CompletedTask; + }) + .WithInstanceName("random") + .WithLeaseContainer(leaseContainer) + .WithStartTime(DateTime.MinValue.ToUniversalTime()) + .Build(); + + await cfp.StartAsync(); + bool isStartOk = allDocsProcessed.WaitOne(60000); + await cfp.StopAsync(); + + if (leaseDatabase != null) + { + using (await leaseDatabase.DeleteStreamAsync()) { } + } + } + + private static async Task ValidateChangeFeedProcessorStreamWithManualCheckpointResponse( + Container container, + TestDoc testDoc1, + TestDoc testDoc2) + { + Database leaseDatabase = await client.CreateDatabaseAsync(Guid.NewGuid().ToString()); + Container leaseContainer = await leaseDatabase.CreateContainerIfNotExistsAsync( + new ContainerProperties(id: "leases", partitionKeyPath: "/id")); + ManualResetEvent allDocsProcessed = new(false); + int processedDocCount = 0; + + ChangeFeedProcessor cfp = container.GetChangeFeedProcessorBuilderWithManualCheckpoint( + "testCFPStreamWithManualCheckpoint", + ( +context, +changes, +tryCheckpointAsync, +cancellationToken) => + { + string changeFeed = string.Empty; + using (StreamReader streamReader = new(changes)) + { + changeFeed = streamReader.ReadToEnd(); + } + + if (changeFeed.Contains(testDoc1.Id)) + { + processedDocCount++; + } + + if (changeFeed.Contains(testDoc2.Id)) + { + processedDocCount++; + } + + if (processedDocCount == 2) + { + allDocsProcessed.Set(); + } + + return Task.CompletedTask; + }) + .WithInstanceName("random") + .WithLeaseContainer(leaseContainer) + .WithStartTime(DateTime.MinValue.ToUniversalTime()) + .Build(); + + await cfp.StartAsync(); + bool isStartOk = allDocsProcessed.WaitOne(60000); + await cfp.StopAsync(); + + if (leaseDatabase != null) + { + using (await leaseDatabase.DeleteStreamAsync()) { } + } + } + + private static async Task ValidateLazyDecryptionResponse( + IEnumerator readDocsLazily, + string failureDek) + { + int decryptedDoc = 0; + int failedDoc = 0; + + while (readDocsLazily.MoveNext()) + { + try + { + (_, _) = await readDocsLazily.Current.GetItemAsync(); + decryptedDoc++; + } + catch (EncryptionException encryptionException) + { + failedDoc++; + ValidateEncryptionException(encryptionException, failureDek); + } + } + + Assert.IsTrue(decryptedDoc >= 1); + Assert.AreEqual(1, failedDoc); + } + + private static void ValidateEncryptionException( + EncryptionException encryptionException, + string failureDek) + { + Assert.AreEqual(failureDek, encryptionException.DataEncryptionKeyId); + Assert.IsNotNull(encryptionException.EncryptedContent); + Assert.IsNotNull(encryptionException.InnerException); + Assert.IsTrue(encryptionException.InnerException is InvalidOperationException); + Assert.AreEqual(encryptionException.InnerException.Message, "Null DataEncryptionKey returned."); + } + + private static async Task IterateDekFeedAsync( + CosmosDataEncryptionKeyProvider dekProvider, + List expectedDekIds, + bool isExpectedDeksCompleteSetForRequest, + bool isResultOrderExpected, + string query, + int? itemCountInPage = null, + QueryDefinition queryDefinition = null) + { + int remainingItemCount = expectedDekIds.Count; + QueryRequestOptions requestOptions = null; + if (itemCountInPage.HasValue) + { + requestOptions = new QueryRequestOptions() + { + MaxItemCount = itemCountInPage + }; + } + + FeedIterator dekIterator = queryDefinition != null + ? dekProvider.DataEncryptionKeyContainer.GetDataEncryptionKeyQueryIterator( + queryDefinition, + requestOptions: requestOptions) + : dekProvider.DataEncryptionKeyContainer.GetDataEncryptionKeyQueryIterator( + query, + requestOptions: requestOptions); + + Assert.IsTrue(dekIterator.HasMoreResults); + + List readDekIds = new(); + while (remainingItemCount > 0) + { + FeedResponse page = await dekIterator.ReadNextAsync(); + if (itemCountInPage.HasValue) + { + // last page + if (remainingItemCount < itemCountInPage.Value) + { + Assert.AreEqual(remainingItemCount, page.Count); + } + else + { + Assert.AreEqual(itemCountInPage.Value, page.Count); + } + } + else + { + Assert.AreEqual(expectedDekIds.Count, page.Count); + } + + remainingItemCount -= page.Count; + if (isExpectedDeksCompleteSetForRequest) + { + Assert.AreEqual(remainingItemCount > 0, dekIterator.HasMoreResults); + } + + foreach (DataEncryptionKeyProperties dek in page.Resource) + { + readDekIds.Add(dek.Id); + } + } + + if (isResultOrderExpected) + { + Assert.IsTrue(expectedDekIds.SequenceEqual(readDekIds)); + } + else + { + Assert.IsTrue(expectedDekIds.ToHashSet().SetEquals(readDekIds)); + } + } + + + private static async Task> UpsertItemAsync( + Container container, + TestDoc testDoc, + string dekId, + List pathsToEncrypt, + HttpStatusCode expectedStatusCode, + JsonProcessor jsonProcessor, + CompressionOptions.CompressionAlgorithm compressionAlgorithm + ) + { + ItemResponse upsertResponse = await container.UpsertItemAsync( + testDoc, + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, pathsToEncrypt, jsonProcessor, compressionAlgorithm)); + Assert.AreEqual(expectedStatusCode, upsertResponse.StatusCode); + VerifyExpectedDocResponse(testDoc, upsertResponse.Resource); + return upsertResponse; + } + + private static async Task> CreateItemAsync( + Container container, + string dekId, + List pathsToEncrypt, + JsonProcessor jsonProcessor, + CompressionOptions.CompressionAlgorithm compressionAlgorithm, + string partitionKey = null, + bool legacyAlgo = false) + { + TestDoc testDoc = TestDoc.Create(partitionKey); + ItemResponse createResponse = await container.CreateItemAsync( + testDoc, + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, pathsToEncrypt, jsonProcessor, compressionAlgorithm, legacyAlgo: legacyAlgo)); + Assert.AreEqual(HttpStatusCode.Created, createResponse.StatusCode); + VerifyExpectedDocResponse(testDoc, createResponse.Resource); + return createResponse; + } + + private static async Task> ReplaceItemAsync( + Container encryptedContainer, + TestDoc testDoc, + string dekId, + List pathsToEncrypt, + JsonProcessor jsonProcessor, + CompressionOptions.CompressionAlgorithm compressionAlgorithm, + string etag = null) + { + ItemResponse replaceResponse = await encryptedContainer.ReplaceItemAsync( + testDoc, + testDoc.Id, + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, pathsToEncrypt, jsonProcessor, compressionAlgorithm, etag)); + + Assert.AreEqual(HttpStatusCode.OK, replaceResponse.StatusCode); + + VerifyExpectedDocResponse(testDoc, replaceResponse.Resource); + + return replaceResponse; + } + + private static async Task> DeleteItemAsync( + Container encryptedContainer, + TestDoc testDoc) + { + ItemResponse deleteResponse = await encryptedContainer.DeleteItemAsync( + testDoc.Id, + new PartitionKey(testDoc.PK)); + + Assert.AreEqual(HttpStatusCode.NoContent, deleteResponse.StatusCode); + Assert.IsNull(deleteResponse.Resource); + return deleteResponse; + } + + private static EncryptionItemRequestOptions GetRequestOptions( + string dekId, + List pathsToEncrypt, + JsonProcessor jsonProcessor, + CompressionOptions.CompressionAlgorithm compressionAlgorithm, + string ifMatchEtag = null, + bool legacyAlgo = false) + { + if (!legacyAlgo) + { + return new EncryptionItemRequestOptions + { + EncryptionOptions = GetEncryptionOptions(dekId, pathsToEncrypt, jsonProcessor, compressionAlgorithm), + IfMatchEtag = ifMatchEtag + }; + } + else + { + return new EncryptionItemRequestOptions + { + EncryptionOptions = GetLegacyEncryptionOptions(dekId, pathsToEncrypt, jsonProcessor, compressionAlgorithm), + IfMatchEtag = ifMatchEtag + }; + } + } + + private static EncryptionTransactionalBatchItemRequestOptions GetBatchItemRequestOptions( + string dekId, + List pathsToEncrypt, + JsonProcessor jsonProcessor, + CompressionOptions.CompressionAlgorithm compressionAlgorithm, + string ifMatchEtag = null) + { + return new EncryptionTransactionalBatchItemRequestOptions + { + EncryptionOptions = GetEncryptionOptions(dekId, pathsToEncrypt, jsonProcessor, compressionAlgorithm), + IfMatchEtag = ifMatchEtag + }; + } + + private static EncryptionOptions GetEncryptionOptions( + string dekId, + List pathsToEncrypt, + JsonProcessor jsonProcessor, + CompressionOptions.CompressionAlgorithm compressionAlgorithm + ) + { + return new EncryptionOptions() + { + DataEncryptionKeyId = dekId, + EncryptionAlgorithm = CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized, + PathsToEncrypt = pathsToEncrypt, + JsonProcessor = jsonProcessor, + CompressionOptions = new CompressionOptions { Algorithm = compressionAlgorithm} + }; + } + + private static async Task ValidateDecryptableItem( + DecryptableItem decryptableItem, + TestDoc testDoc, + string dekId = null, + List pathsEncrypted = null, + bool isDocDecrypted = true) + { + (TestDoc readDoc, DecryptionContext decryptionContext) = await decryptableItem.GetItemAsync(); + VerifyExpectedDocResponse(testDoc, readDoc); + + if (isDocDecrypted && testDoc.Sensitive_StringFormat != null) + { + ValidateDecryptionContext(decryptionContext, dekId, pathsEncrypted); + } + else + { + Assert.IsNull(decryptionContext); + } + } + + private static void ValidateDecryptionContext( + DecryptionContext decryptionContext, + string dekId = null, + List pathsEncrypted = null) + { + Assert.IsNotNull(decryptionContext.DecryptionInfoList); + Assert.AreEqual(1, decryptionContext.DecryptionInfoList.Count); + DecryptionInfo decryptionInfo = decryptionContext.DecryptionInfoList[0]; + Assert.AreEqual(dekId ?? MdeCustomEncryptionTestsWithSystemText.dekId, decryptionInfo.DataEncryptionKeyId); + + pathsEncrypted ??= TestDoc.PathsToEncrypt; + + Assert.AreEqual(pathsEncrypted.Count, decryptionInfo.PathsDecrypted.Count); + Assert.IsFalse(pathsEncrypted.Exists(path => !decryptionInfo.PathsDecrypted.Contains(path))); + } + + + private static async Task VerifyItemByReadStreamAsync(Container container, TestDoc testDoc, ItemRequestOptions requestOptions = null, bool compareEncryptedProperty = true) + { + ResponseMessage readResponseMessage = await container.ReadItemStreamAsync(testDoc.Id, new PartitionKey(testDoc.PK), requestOptions); + Assert.AreEqual(HttpStatusCode.OK, readResponseMessage.StatusCode); + Assert.IsNotNull(readResponseMessage.Content); + TestDoc readDoc = TestCommon.FromStream(readResponseMessage.Content); + if (compareEncryptedProperty) + { + VerifyExpectedDocResponse(testDoc, readDoc); + } + else + { + testDoc.EqualsExceptEncryptedProperty(readDoc); + } + } + + private static async Task VerifyItemByReadAsync(Container container, TestDoc testDoc, ItemRequestOptions requestOptions = null, string dekId = null, bool isDocDecrypted = true, bool compareEncryptedProperty = true) + { + ItemResponse readResponse = await container.ReadItemAsync(testDoc.Id, new PartitionKey(testDoc.PK), requestOptions); + Assert.AreEqual(HttpStatusCode.OK, readResponse.StatusCode); + if (compareEncryptedProperty) + { + VerifyExpectedDocResponse(testDoc, readResponse.Resource); + } + else + { + testDoc.EqualsExceptEncryptedProperty(readResponse.Resource); + } + + // ignore for reads via regular container.. + if (container == encryptionContainer) + { + ItemResponse readResponseDecryptableItem = await container.ReadItemAsync(testDoc.Id, new PartitionKey(testDoc.PK), requestOptions); + Assert.AreEqual(HttpStatusCode.OK, readResponse.StatusCode); + await ValidateDecryptableItem(readResponseDecryptableItem.Resource, testDoc, dekId, isDocDecrypted: isDocDecrypted); + } + } + + private static async Task CreateDekAsync(CosmosDataEncryptionKeyProvider dekProvider, string dekId, string algorithm = null) + { + ItemResponse dekResponse = await dekProvider.DataEncryptionKeyContainer.CreateDataEncryptionKeyAsync( + dekId, + algorithm ?? CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized, + metadata1); + + Assert.AreEqual(HttpStatusCode.Created, dekResponse.StatusCode); + + return VerifyDekResponse(dekResponse, + dekId); + } + + + private static DataEncryptionKeyProperties VerifyDekResponse( + ItemResponse dekResponse, + string dekId) + { + Assert.IsTrue(dekResponse.RequestCharge > 0); + Assert.IsNotNull(dekResponse.ETag); + + DataEncryptionKeyProperties dekProperties = dekResponse.Resource; + Assert.IsNotNull(dekProperties); + Assert.AreEqual(dekResponse.ETag, dekProperties.ETag); + Assert.AreEqual(dekId, dekProperties.Id); + Assert.IsNotNull(dekProperties.SelfLink); + Assert.IsNotNull(dekProperties.CreatedTime); + Assert.IsNotNull(dekProperties.LastModified); + + return dekProperties; + } + + private static async Task PerformForbiddenOperationAsync(Func func, string operationName) + { + try + { + await func(); + Assert.Fail($"Expected resource token based client to not be able to perform {operationName}"); + } + catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.Forbidden) + { + } + } + + private static async Task PerformOperationOnUninitializedDekProviderAsync(Func func, string operationName) + { + try + { + await func(); + Assert.Fail($"Expected {operationName} to not work on uninitialized CosmosDataEncryptionKeyProvider."); + } + catch (InvalidOperationException ex) + { + Assert.IsTrue(ex.Message.Contains("The CosmosDataEncryptionKeyProvider was not initialized.")); + } + } + + private static void VerifyExpectedDocResponse(TestDoc expectedDoc, TestDoc verifyDoc) + { + Assert.AreEqual(expectedDoc.Id, verifyDoc.Id); + Assert.AreEqual(expectedDoc.Sensitive_StringFormat, verifyDoc.Sensitive_StringFormat); + if (expectedDoc.Sensitive_ArrayFormat != null) + { + Assert.AreEqual(expectedDoc.Sensitive_ArrayFormat[0].Sensitive_ArrayDecimalFormat, verifyDoc.Sensitive_ArrayFormat[0].Sensitive_ArrayDecimalFormat); + Assert.AreEqual(expectedDoc.Sensitive_ArrayFormat[0].Sensitive_ArrayIntFormat, verifyDoc.Sensitive_ArrayFormat[0].Sensitive_ArrayIntFormat); + Assert.AreEqual(expectedDoc.Sensitive_NestedObjectFormatL1.Sensitive_IntFormatL1, verifyDoc.Sensitive_NestedObjectFormatL1.Sensitive_IntFormatL1); + Assert.AreEqual( + expectedDoc.Sensitive_NestedObjectFormatL1.Sensitive_NestedObjectFormatL2.Sensitive_IntFormatL2, + verifyDoc.Sensitive_NestedObjectFormatL1.Sensitive_NestedObjectFormatL2.Sensitive_IntFormatL2); + } + else + { + Assert.AreEqual(expectedDoc.Sensitive_ArrayFormat, verifyDoc.Sensitive_ArrayFormat); + Assert.AreEqual(expectedDoc.Sensitive_NestedObjectFormatL1, verifyDoc.Sensitive_NestedObjectFormatL1); + } + Assert.AreEqual(expectedDoc.Sensitive_DateFormat, verifyDoc.Sensitive_DateFormat); + Assert.AreEqual(expectedDoc.Sensitive_DecimalFormat, verifyDoc.Sensitive_DecimalFormat); + Assert.AreEqual(expectedDoc.Sensitive_IntFormat, verifyDoc.Sensitive_IntFormat); + Assert.AreEqual(expectedDoc.Sensitive_FloatFormat, verifyDoc.Sensitive_FloatFormat); + Assert.AreEqual(expectedDoc.Sensitive_BoolFormat, verifyDoc.Sensitive_BoolFormat); + Assert.AreEqual(expectedDoc.NonSensitive, verifyDoc.NonSensitive); + } + + public class TestDoc + { + public static List PathsToEncrypt { get; } = + new List() { + "/Sensitive_StringFormat", + "/Sensitive_ArrayFormat", + "/Sensitive_DecimalFormat", + "/Sensitive_IntFormat", + "/Sensitive_DateFormat", + "/Sensitive_BoolFormat", + "/Sensitive_FloatFormat", + "/Sensitive_NestedObjectFormatL1" + }; + + [JsonProperty("id")] + [JsonPropertyName("id")] + public string Id { get; set; } + + public string PK { get; set; } + + public string NonSensitive { get; set; } + + public string Sensitive_StringFormat { get; set; } + + public DateTime Sensitive_DateFormat { get; set; } + + public decimal Sensitive_DecimalFormat { get; set; } + + public bool Sensitive_BoolFormat { get; set; } + + public int Sensitive_IntFormat { get; set; } + + public float Sensitive_FloatFormat { get; set; } + + public Sensitive_ArrayData[] Sensitive_ArrayFormat { get; set; } + + public Sensitive_NestedObjectL1 Sensitive_NestedObjectFormatL1 { get; set; } + + public TestDoc() + { + } + + public class Sensitive_ArrayData + { + public int Sensitive_ArrayIntFormat { get; set; } + public decimal Sensitive_ArrayDecimalFormat { get; set; } + } + + public class Sensitive_NestedObjectL1 + { + public int Sensitive_IntFormatL1 { get; set; } + public Sensitive_NestedObjectL2 Sensitive_NestedObjectFormatL2 { get; set; } + } + + public class Sensitive_NestedObjectL2 + { + public int Sensitive_IntFormatL2 { get; set; } + } + + public TestDoc(TestDoc other) + { + this.Id = other.Id; + this.PK = other.PK; + this.NonSensitive = other.NonSensitive; + this.Sensitive_StringFormat = other.Sensitive_StringFormat; + this.Sensitive_DateFormat = other.Sensitive_DateFormat; + this.Sensitive_DecimalFormat = other.Sensitive_DecimalFormat; + this.Sensitive_IntFormat = other.Sensitive_IntFormat; + this.Sensitive_ArrayFormat = other.Sensitive_ArrayFormat; + this.Sensitive_BoolFormat = other.Sensitive_BoolFormat; + this.Sensitive_FloatFormat = other.Sensitive_FloatFormat; + this.Sensitive_NestedObjectFormatL1 = other.Sensitive_NestedObjectFormatL1; + } + + public override bool Equals(object obj) + { + return obj is TestDoc doc + && this.Id == doc.Id + && this.PK == doc.PK + && this.NonSensitive == doc.NonSensitive + && this.Sensitive_StringFormat == doc.Sensitive_StringFormat + && this.Sensitive_DateFormat == doc.Sensitive_DateFormat + && this.Sensitive_DecimalFormat == doc.Sensitive_DecimalFormat + && this.Sensitive_IntFormat == doc.Sensitive_IntFormat + && this.Sensitive_ArrayFormat == doc.Sensitive_ArrayFormat + && this.Sensitive_BoolFormat == doc.Sensitive_BoolFormat + && this.Sensitive_FloatFormat == doc.Sensitive_FloatFormat + && this.Sensitive_NestedObjectFormatL1 != doc.Sensitive_NestedObjectFormatL1; + } + + public bool EqualsExceptEncryptedProperty(object obj) + { + return obj is TestDoc doc + && this.Id == doc.Id + && this.PK == doc.PK + && this.NonSensitive == doc.NonSensitive + && this.Sensitive_StringFormat != doc.Sensitive_StringFormat; + } + + public override int GetHashCode() + { + int hashCode = 1652434776; + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Id); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.PK); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.NonSensitive); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Sensitive_StringFormat); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Sensitive_DateFormat); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Sensitive_DecimalFormat); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Sensitive_IntFormat); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Sensitive_ArrayFormat); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Sensitive_BoolFormat); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Sensitive_FloatFormat); + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(this.Sensitive_NestedObjectFormatL1); + return hashCode; + } + + public static TestDoc Create(string partitionKey = null) + { + return new TestDoc() + { + Id = Guid.NewGuid().ToString(), + PK = partitionKey ?? Guid.NewGuid().ToString(), + NonSensitive = Guid.NewGuid().ToString(), + Sensitive_StringFormat = Guid.NewGuid().ToString(), + Sensitive_DateFormat = new DateTime(1987, 12, 25), + Sensitive_DecimalFormat = 472.3108m, + Sensitive_IntFormat = 1965, + Sensitive_BoolFormat = true, + Sensitive_FloatFormat = 8923.124f, + Sensitive_ArrayFormat = new Sensitive_ArrayData[] + { + new() { + Sensitive_ArrayIntFormat = 1999, + Sensitive_ArrayDecimalFormat = 472.3199m + } + }, + Sensitive_NestedObjectFormatL1 = new Sensitive_NestedObjectL1() + { + Sensitive_IntFormatL1 = 1999, + Sensitive_NestedObjectFormatL2 = new Sensitive_NestedObjectL2() + { + Sensitive_IntFormatL2 = 2000, + } + } + }; + } + + public Stream ToStream() + { + return TestCommon.ToStream(this); + } + } + + private class TestEncryptionKeyStoreProvider : EncryptionKeyStoreProvider + { + readonly Dictionary keyinfo = new() + { + {masterKeyUri1.ToString(), 1}, + {masterKeyUri2.ToString(), 2}, + }; + + public Dictionary WrapKeyCallsCount { get; set; } + public Dictionary UnWrapKeyCallsCount { get; set; } + + public TestEncryptionKeyStoreProvider() + { + this.WrapKeyCallsCount = new Dictionary(); + this.UnWrapKeyCallsCount = new Dictionary(); + } + + public override string ProviderName => "TESTKEYSTORE_VAULT"; + + public override byte[] UnwrapKey(string masterKeyPath, KeyEncryptionKeyAlgorithm encryptionAlgorithm, byte[] encryptedKey) + { + this.UnWrapKeyCallsCount[masterKeyPath] = !this.UnWrapKeyCallsCount.TryGetValue(masterKeyPath, out int value) ? 1 : ++value; + + this.keyinfo.TryGetValue(masterKeyPath, out int moveBy); + byte[] plainkey = encryptedKey.Select(b => (byte)(b - moveBy)).ToArray(); + return plainkey; + } + + public override byte[] WrapKey(string masterKeyPath, KeyEncryptionKeyAlgorithm encryptionAlgorithm, byte[] key) + { + this.WrapKeyCallsCount[masterKeyPath] = !this.WrapKeyCallsCount.TryGetValue(masterKeyPath, out int value) ? 1 : ++value; + + this.keyinfo.TryGetValue(masterKeyPath, out int moveBy); + byte[] encryptedkey = key.Select(b => (byte)(b + moveBy)).ToArray(); + return encryptedkey; + } + + public override byte[] Sign(string masterKeyPath, bool allowEnclaveComputations) + { + byte[] rawKey = new byte[32]; + SecurityUtility.GenerateRandomBytes(rawKey); + return rawKey; + } + + public override bool Verify(string masterKeyPath, bool allowEnclaveComputations, byte[] signature) + { + return true; + } + } + + // This class is same as CosmosEncryptor but copied so as to induce decryption failure easily for testing. + private class TestEncryptor : Encryptor + { + public DataEncryptionKeyProvider DataEncryptionKeyProvider { get; } + public bool FailDecryption { get; set; } + + private readonly CosmosEncryptor encryptor; + + public TestEncryptor(DataEncryptionKeyProvider dataEncryptionKeyProvider) + { + this.encryptor = new CosmosEncryptor(dataEncryptionKeyProvider); + this.FailDecryption = false; + } + + private void ThrowIfFail(string dataEncryptionKeyId) + { + if (this.FailDecryption && dataEncryptionKeyId.Equals("failDek")) + { + throw new InvalidOperationException($"Null {nameof(DataEncryptionKey)} returned."); + } + } + + public override async Task DecryptAsync( + byte[] cipherText, + string dataEncryptionKeyId, + string encryptionAlgorithm, + CancellationToken cancellationToken = default) + { + this.ThrowIfFail(dataEncryptionKeyId); + return await this.encryptor.DecryptAsync(cipherText, dataEncryptionKeyId, encryptionAlgorithm, cancellationToken); + } + + public override async Task EncryptAsync( + byte[] plainText, + string dataEncryptionKeyId, + string encryptionAlgorithm, + CancellationToken cancellationToken = default) + { + this.ThrowIfFail(dataEncryptionKeyId); + return await this.encryptor.EncryptAsync(plainText, dataEncryptionKeyId, encryptionAlgorithm, cancellationToken); + } + + public override async Task GetEncryptionKeyAsync(string dataEncryptionKeyId, string encryptionAlgorithm, CancellationToken cancellationToken = default) + { + this.ThrowIfFail(dataEncryptionKeyId); + return await this.encryptor.GetEncryptionKeyAsync(dataEncryptionKeyId, encryptionAlgorithm, cancellationToken); + } + } + + public static IEnumerable ProcessorAndCompressorCombinations => new[] { + new object[] { JsonProcessor.Newtonsoft, CompressionOptions.CompressionAlgorithm.None }, + new object[] { JsonProcessor.Stream, CompressionOptions.CompressionAlgorithm.None }, + new object[] { JsonProcessor.Newtonsoft, CompressionOptions.CompressionAlgorithm.Brotli }, + new object[] { JsonProcessor.Stream, CompressionOptions.CompressionAlgorithm.Brotli }, + }; + + #region Legacy +#pragma warning disable CS0618 // Type or member is obsolete + [TestMethod] + public async Task EncryptionCreateDekWithDualDekProvider() + { + string dekId = "dekWithDualDekProviderNewAlgo"; + DataEncryptionKeyProperties dekProperties = await CreateDekAsync(dualDekProvider, dekId); + Assert.AreEqual( + new EncryptionKeyWrapMetadata(name: "metadata1", value: metadata1.Value), + dekProperties.EncryptionKeyWrapMetadata); + + // Use different DEK provider to avoid (unintentional) cache impact + CosmosDataEncryptionKeyProvider dekProvider = new(new TestKeyWrapProvider(), new TestEncryptionKeyStoreProvider(), TimeSpan.FromMinutes(30)); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKeyProperties readProperties = await dekProvider.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dekProperties, readProperties); + + dekId = "dekWithDualDekProviderLegacyAlgo"; + dekProperties = await CreateLegacyDekAsync(dualDekProvider, dekId); + Assert.AreEqual( + new EncryptionKeyWrapMetadata(metadata1.Value + metadataUpdateSuffix), + dekProperties.EncryptionKeyWrapMetadata); + + readProperties = await dekProvider.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dekProperties, readProperties); + } + + [TestMethod] + public async Task EncryptionCreateDekWithNonMdeAlgorithmFails() + { + string dekId = "oldDek"; + TestEncryptionKeyStoreProvider testKeyStoreProvider = new() + { + DataEncryptionKeyCacheTimeToLive = TimeSpan.FromSeconds(3600) + }; + + CosmosDataEncryptionKeyProvider dekProvider = new(testKeyStoreProvider); + try + { + await CreateDekAsync(dekProvider, dekId, CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized); + Assert.Fail("CreateDataEncryptionKeyAsync should not have succeeded. "); + } + catch (InvalidOperationException ex) + { + Assert.AreEqual("For use of 'AEAes256CbcHmacSha256Randomized' algorithm, Encryptor or CosmosDataEncryptionKeyProvider needs to be initialized with EncryptionKeyWrapProvider.", ex.Message); + } + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionCreateItemWithIncompatibleWrapProvider(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + Container legacyEncryptionContainer; + CosmosDataEncryptionKeyProvider legacydekProvider = new(new TestKeyWrapProvider()); + await legacydekProvider.InitializeAsync(database, keyContainer.Id); + TestEncryptor legacyEncryptor = new(legacydekProvider); + legacyEncryptionContainer = itemContainer.WithEncryptor(legacyEncryptor); + TestDoc testDoc = TestDoc.Create(null); + + try + { + ItemResponse createResponse = await legacyEncryptionContainer.CreateItemAsync( + testDoc, + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, legacyAlgo: true)); + Assert.Fail("CreateItemAsync should not have succeeded. "); + } + catch (InvalidOperationException ex) + { + Assert.AreEqual("For use of 'MdeAeadAes256CbcHmac256Randomized' algorithm based DEK, Encryptor or CosmosDataEncryptionKeyProvider needs to be initialized with EncryptionKeyStoreProvider.", ex.Message); + } + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionCreateItemUsingLegacyAlgoWithMdeDek(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + TestDoc testDoc = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, legacyAlgo: true); + await VerifyItemByReadAsync(encryptionContainer, testDoc, dekId: dekId); + } + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionCreateItemUsingMDEAlgoWithLegacyDek(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + CosmosDataEncryptionKeyProvider legacydekProvider = new(new TestKeyWrapProvider()); + await legacydekProvider.InitializeAsync(database, keyContainer.Id); + + TestDoc testDoc = TestDoc.Create(null); + + ItemResponse createResponse = await encryptionContainer.CreateItemAsync( + testDoc, + new PartitionKey(testDoc.PK), + GetRequestOptions(legacydekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, legacyAlgo: false)); + + VerifyExpectedDocResponse(testDoc, createResponse); + + await VerifyItemByReadAsync(encryptionContainer, testDoc, dekId: legacydekId); + } + + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task EncryptionRewrapLegacyDekToMdeWrap(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + string dekId = "rewrapLegacyAlgoDektoMdeAlgoDek"; + DataEncryptionKeyProperties dataEncryptionKeyProperties; + + dataEncryptionKeyProperties = await CreateLegacyDekAsync(dualDekProvider, dekId); + + Assert.AreEqual( + metadata1.Value + metadataUpdateSuffix, + dataEncryptionKeyProperties.EncryptionKeyWrapMetadata.Value); + + Assert.AreEqual(CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized, dataEncryptionKeyProperties.EncryptionAlgorithm); + + // use it to create item with Legacy Algo + TestDoc testDoc = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm, legacyAlgo: true); + + await VerifyItemByReadAsync(encryptionContainer, testDoc, dekId: dekId); + + // validate key with new Algo + testDoc = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + await VerifyItemByReadAsync(encryptionContainer, testDoc, dekId: dekId); + + ItemResponse dekResponse = await MdeCustomEncryptionTestsWithSystemText.dekProvider.DataEncryptionKeyContainer.RewrapDataEncryptionKeyAsync( + dekId, + metadata2, + CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized); + + Assert.AreEqual(HttpStatusCode.OK, dekResponse.StatusCode); + + dataEncryptionKeyProperties = VerifyDekResponse( + dekResponse, + dekId); + + Assert.AreEqual(CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized, dataEncryptionKeyProperties.EncryptionAlgorithm); + + Assert.AreEqual( + metadata2, + dataEncryptionKeyProperties.EncryptionKeyWrapMetadata); + + // Use different DEK provider to avoid (unintentional) cache impact + CosmosDataEncryptionKeyProvider dekProvider = new(new TestEncryptionKeyStoreProvider()); + await dekProvider.InitializeAsync(database, keyContainer.Id); + DataEncryptionKeyProperties readProperties = await dekProvider.DataEncryptionKeyContainer.ReadDataEncryptionKeyAsync(dekId); + Assert.AreEqual(dataEncryptionKeyProperties, readProperties); + + // validate key + testDoc = await CreateItemAsync(encryptionContainer, dekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + await VerifyItemByReadAsync(encryptionContainer, testDoc, dekId: dekId); + + await dekProvider.Container.DeleteItemAsync(dekId, new PartitionKey(dekId)); + + // rewrap from Mde Algo to Legacy algo should fail + dekId = "rewrapMdeAlgoDekToLegacyAlgoDek"; + + DataEncryptionKeyProperties dekProperties = await CreateDekAsync(MdeCustomEncryptionTestsWithSystemText.dekProvider, dekId); + Assert.AreEqual( + metadata1, + dekProperties.EncryptionKeyWrapMetadata); + + try + { + await MdeCustomEncryptionTestsWithSystemText.dekProvider.DataEncryptionKeyContainer.RewrapDataEncryptionKeyAsync( + dekId, + metadata2, + CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized); + + Assert.Fail("RewrapDataEncryptionKeyAsync should not have succeeded. "); + } + catch (InvalidOperationException ex) + { + Assert.AreEqual("Rewrap operation with EncryptionAlgorithm 'AEAes256CbcHmacSha256Randomized' is not supported on Data Encryption Keys which are configured with 'MdeAeadAes256CbcHmac256Randomized'. ", ex.Message); + } + + await dekProvider.Container.DeleteItemAsync(dekId, new PartitionKey(dekId)); + // rewrap Mde to Mde with Option + + // rewrap from Mde Algo to Legacy algo should fail + dekId = "rewrapMdeAlgoDekToMdeAlgoDek"; + + dekProperties = await CreateDekAsync(MdeCustomEncryptionTestsWithSystemText.dekProvider, dekId); + Assert.AreEqual( + metadata1, + dekProperties.EncryptionKeyWrapMetadata); + + dekResponse = await MdeCustomEncryptionTestsWithSystemText.dekProvider.DataEncryptionKeyContainer.RewrapDataEncryptionKeyAsync( + dekId, + metadata2, + CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized); + + Assert.AreEqual(HttpStatusCode.OK, dekResponse.StatusCode); + + dataEncryptionKeyProperties = VerifyDekResponse( + dekResponse, + dekId); + + Assert.AreEqual(CosmosEncryptionAlgorithm.MdeAeadAes256CbcHmac256Randomized, dataEncryptionKeyProperties.EncryptionAlgorithm); + + Assert.AreEqual( + metadata2, + dataEncryptionKeyProperties.EncryptionKeyWrapMetadata); + + await dekProvider.Container.DeleteItemAsync(dekId, new PartitionKey(dekId)); + } + + + [TestMethod] + [DynamicData(nameof(ProcessorAndCompressorCombinations))] + public async Task ReadLegacyEncryptedDataWithMdeProcessor(JsonProcessor jsonProcessor, CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + // Setup the Container with a Dual Wrap Provider Container. + encryptionContainer = itemContainer.WithEncryptor(encryptorWithDualWrapProvider); + + TestDoc testDoc = await CreateItemAsyncUsingLegacyAlgorithm(encryptionContainer, legacydekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + await VerifyItemByReadAsync(encryptionContainer, testDoc, dekId: legacydekId); + + await VerifyItemByReadStreamAsync(encryptionContainer, testDoc); + + TestDoc expectedDoc = new(testDoc); + + // Read feed (null query) + await MdeCustomEncryptionTestsWithSystemText.ValidateQueryResultsAsync( + MdeCustomEncryptionTestsWithSystemText.encryptionContainer, + query: null, + expectedDoc, + legacyAlgo: true); + + await ValidateQueryResultsAsync( + encryptionContainer, + "SELECT * FROM c", + expectedDoc, + legacyAlgo: true); + + await ValidateQueryResultsAsync( + encryptionContainer, + string.Format( + "SELECT * FROM c where c.PK = '{0}' and c.id = '{1}' and c.NonSensitive = '{2}'", + expectedDoc.PK, + expectedDoc.Id, + expectedDoc.NonSensitive), + expectedDoc, + legacyAlgo: true); + + await ValidateQueryResultsAsync( + encryptionContainer, + string.Format("SELECT * FROM c where c.Sensitive_IntFormat = '{0}'", testDoc.Sensitive_StringFormat), + expectedDoc: null, + legacyAlgo: true); + + await ValidateQueryResultsAsync( + encryptionContainer, + queryDefinition: new QueryDefinition( + "select * from c where c.id = @theId and c.PK = @thePK") + .WithParameter("@theId", expectedDoc.Id) + .WithParameter("@thePK", expectedDoc.PK), + expectedDoc: expectedDoc, + legacyAlgo: true); + + expectedDoc.Sensitive_NestedObjectFormatL1 = null; + expectedDoc.Sensitive_ArrayFormat = null; + expectedDoc.Sensitive_DecimalFormat = 0; + expectedDoc.Sensitive_IntFormat = 0; + expectedDoc.Sensitive_FloatFormat = 0; + expectedDoc.Sensitive_BoolFormat = false; + expectedDoc.Sensitive_StringFormat = null; + expectedDoc.Sensitive_DateFormat = new DateTime(); + + await ValidateQueryResultsAsync( + encryptionContainer, + "SELECT c.id, c.PK, c.NonSensitive FROM c", + expectedDoc); + + // create Items with New Algorithm + await this.EncryptionCreateItem(jsonProcessor, compressionAlgorithm); + + // read back Data Items encrypted with Old Algorithm + await VerifyItemByReadAsync(encryptionContainer, testDoc, dekId: legacydekId); + + await VerifyItemByReadStreamAsync(encryptionContainer, testDoc); + + // Create and read back Data Items encrypted with Old Algorithm + TestDoc testDoc2 = await CreateItemAsyncUsingLegacyAlgorithm(encryptionContainer, legacydekId, TestDoc.PathsToEncrypt, jsonProcessor, compressionAlgorithm); + + await VerifyItemByReadAsync(encryptionContainer, testDoc2, dekId: legacydekId); + + await VerifyItemByReadStreamAsync(encryptionContainer, testDoc2); + + // create Items with New Algorithm + await this.EncryptionCreateItem(jsonProcessor, compressionAlgorithm); + + // read back Data Items encrypted with Old Algorithm + await VerifyItemByReadAsync(encryptionContainer, testDoc2, dekId: legacydekId); + + await VerifyItemByReadStreamAsync(encryptionContainer, testDoc2); + + // Reset the Container for Other Tests to be carried on regular Encryptor with Single Dek Provider. + encryptionContainer = itemContainer.WithEncryptor(encryptor); + } + + + private static async Task> CreateItemAsyncUsingLegacyAlgorithm( + Container container, + string dekId, + List pathsToEncrypt, + JsonProcessor jsonProcessor, + CompressionOptions.CompressionAlgorithm compressionAlgorithm, + string partitionKey = null) + { + TestDoc testDoc = TestDoc.Create(partitionKey); + ItemResponse createResponse = await container.CreateItemAsync( + testDoc, + new PartitionKey(testDoc.PK), + GetRequestOptions(dekId, pathsToEncrypt, jsonProcessor, compressionAlgorithm, legacyAlgo: true)); + Assert.AreEqual(HttpStatusCode.Created, createResponse.StatusCode); + + VerifyExpectedDocResponse(testDoc, createResponse.Resource); + + return createResponse; + } + + private static async Task LegacyClassInitializeAsync() + { + MdeCustomEncryptionTestsWithSystemText.testKeyStoreProvider.DataEncryptionKeyCacheTimeToLive = TimeSpan.FromSeconds(3600); + + dekProvider = new CosmosDataEncryptionKeyProvider(new TestKeyWrapProvider(), MdeCustomEncryptionTestsWithSystemText.testKeyStoreProvider); + legacytestKeyWrapProvider = new TestKeyWrapProvider(); + + TestEncryptionKeyStoreProvider testKeyStoreProvider = new() + { + DataEncryptionKeyCacheTimeToLive = TimeSpan.Zero + }; + dualDekProvider = new CosmosDataEncryptionKeyProvider(legacytestKeyWrapProvider, testKeyStoreProvider); + await dualDekProvider.InitializeAsync(database, keyContainer.Id); + + _ = await CreateLegacyDekAsync(MdeCustomEncryptionTestsWithSystemText.dualDekProvider, MdeCustomEncryptionTestsWithSystemText.legacydekId); + encryptorWithDualWrapProvider = new TestEncryptor(dualDekProvider); + } + + private static EncryptionOptions GetLegacyEncryptionOptions( + string dekId, + List pathsToEncrypt, + JsonProcessor jsonProcessor, + CompressionOptions.CompressionAlgorithm compressionAlgorithm) + { + return new EncryptionOptions() + { + DataEncryptionKeyId = dekId, + EncryptionAlgorithm = CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized, + PathsToEncrypt = pathsToEncrypt, + JsonProcessor = jsonProcessor, + CompressionOptions = new CompressionOptions() { Algorithm = compressionAlgorithm } + }; + } + + private static async Task CreateLegacyDekAsync(CosmosDataEncryptionKeyProvider dekProvider, string dekId, string algorithm = null) + { + ItemResponse dekResponse = await dekProvider.DataEncryptionKeyContainer.CreateDataEncryptionKeyAsync( + dekId, + algorithm ?? CosmosEncryptionAlgorithm.AEAes256CbcHmacSha256Randomized, + metadata1); + + Assert.AreEqual(HttpStatusCode.Created, dekResponse.StatusCode); + + return VerifyDekResponse(dekResponse, + dekId); + } + + + private class TestKeyWrapProvider : EncryptionKeyWrapProvider + { + public Dictionary WrapKeyCallsCount { get; private set; } + + public TestKeyWrapProvider() + { + this.WrapKeyCallsCount = new Dictionary(); + } + + public override Task UnwrapKeyAsync(byte[] wrappedKey, EncryptionKeyWrapMetadata metadata, CancellationToken cancellationToken) + { + int moveBy = metadata.Value == metadata1.Value + metadataUpdateSuffix ? 1 : 2; + return Task.FromResult(new EncryptionKeyUnwrapResult(wrappedKey.Select(b => (byte)(b - moveBy)).ToArray(), cacheTTL)); + } + + public override Task WrapKeyAsync(byte[] key, EncryptionKeyWrapMetadata metadata, CancellationToken cancellationToken) + { + this.WrapKeyCallsCount[metadata.Value] = !this.WrapKeyCallsCount.TryGetValue(metadata.Value, out int value) ? 1 : ++value; + + EncryptionKeyWrapMetadata responseMetadata = new(metadata.Value + metadataUpdateSuffix); + int moveBy = metadata.Value == metadata1.Value ? 1 : 2; + return Task.FromResult(new EncryptionKeyWrapResult(key.Select(b => (byte)(b + moveBy)).ToArray(), responseMetadata)); + } + } + +#pragma warning restore CS0618 // Type or member is obsolete + #endregion + } +} +#endif \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/Microsoft.Azure.Cosmos.Encryption.Custom.EmulatorTests.csproj b/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/Microsoft.Azure.Cosmos.Encryption.Custom.EmulatorTests.csproj index bab1697c57..c2e5ee53b5 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/Microsoft.Azure.Cosmos.Encryption.Custom.EmulatorTests.csproj +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/tests/EmulatorTests/Microsoft.Azure.Cosmos.Encryption.Custom.EmulatorTests.csproj @@ -12,6 +12,7 @@ master True $(LangVersion) + $(DefineConstants);ENCRYPTION_CUSTOM_PREVIEW diff --git a/Microsoft.Azure.Cosmos.Encryption.Custom/tests/Microsoft.Azure.Cosmos.Encryption.Custom.Performance.Tests/Readme.md b/Microsoft.Azure.Cosmos.Encryption.Custom/tests/Microsoft.Azure.Cosmos.Encryption.Custom.Performance.Tests/Readme.md index 691dbf0a5d..a944d8cfa5 100644 --- a/Microsoft.Azure.Cosmos.Encryption.Custom/tests/Microsoft.Azure.Cosmos.Encryption.Custom.Performance.Tests/Readme.md +++ b/Microsoft.Azure.Cosmos.Encryption.Custom/tests/Microsoft.Azure.Cosmos.Encryption.Custom.Performance.Tests/Readme.md @@ -9,53 +9,53 @@ Job=MediumRun Toolchain=InProcessEmitToolchain IterationCount=15 LaunchCount=2 WarmupCount=10 ``` -| Method | DocumentSizeInKb | CompressionAlgorithm | JsonProcessor | Mean | Error | StdDev | Median | Gen0 | Gen1 | Gen2 | Allocated | -|------------------------ |----------------- |--------------------- |--------------- |------------:|----------:|----------:|------------:|--------:|--------:|--------:|----------:| -| **Encrypt** | **1** | **None** | **Newtonsoft** | **22.53 μs** | **0.511 μs** | **0.733 μs** | **22.29 μs** | **0.1526** | **0.0305** | **-** | **41784 B** | -| EncryptToProvidedStream | 1 | None | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| Decrypt | 1 | None | Newtonsoft | 26.31 μs | 0.224 μs | 0.322 μs | 26.23 μs | 0.1526 | 0.0305 | - | 41440 B | -| DecryptToProvidedStream | 1 | None | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| **Encrypt** | **1** | **None** | **Stream** | **12.85 μs** | **0.095 μs** | **0.143 μs** | **12.84 μs** | **0.0610** | **0.0153** | **-** | **17528 B** | -| EncryptToProvidedStream | 1 | None | Stream | 13.00 μs | 0.096 μs | 0.141 μs | 12.98 μs | 0.0458 | 0.0153 | - | 11392 B | -| Decrypt | 1 | None | Stream | 13.01 μs | 0.152 μs | 0.228 μs | 13.05 μs | 0.0458 | 0.0153 | - | 12672 B | -| DecryptToProvidedStream | 1 | None | Stream | 13.48 μs | 0.132 μs | 0.197 μs | 13.45 μs | 0.0458 | 0.0153 | - | 11504 B | -| **Encrypt** | **1** | **Brotli** | **Newtonsoft** | **27.94 μs** | **0.226 μs** | **0.338 μs** | **27.96 μs** | **0.1526** | **0.0305** | **-** | **38064 B** | -| EncryptToProvidedStream | 1 | Brotli | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| Decrypt | 1 | Brotli | Newtonsoft | 33.49 μs | 0.910 μs | 1.335 μs | 33.99 μs | 0.1221 | - | - | 41064 B | -| DecryptToProvidedStream | 1 | Brotli | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| **Encrypt** | **1** | **Brotli** | **Stream** | **21.15 μs** | **1.037 μs** | **1.521 μs** | **20.52 μs** | **0.0610** | **0.0305** | **-** | **16584 B** | -| EncryptToProvidedStream | 1 | Brotli | Stream | 20.57 μs | 0.213 μs | 0.292 μs | 20.57 μs | 0.0305 | - | - | 11672 B | -| Decrypt | 1 | Brotli | Stream | 21.14 μs | 2.212 μs | 3.311 μs | 19.46 μs | 0.0305 | - | - | 13216 B | -| DecryptToProvidedStream | 1 | Brotli | Stream | 19.60 μs | 0.439 μs | 0.600 μs | 19.52 μs | 0.0305 | - | - | 12048 B | -| **Encrypt** | **10** | **None** | **Newtonsoft** | **84.82 μs** | **3.002 μs** | **4.208 μs** | **83.32 μs** | **0.6104** | **0.1221** | **-** | **170993 B** | -| EncryptToProvidedStream | 10 | None | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| Decrypt | 10 | None | Newtonsoft | 112.98 μs | 15.294 μs | 21.934 μs | 100.38 μs | 0.6104 | 0.1221 | - | 157425 B | -| DecryptToProvidedStream | 10 | None | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| **Encrypt** | **10** | **None** | **Stream** | **39.63 μs** | **0.658 μs** | **0.923 μs** | **39.41 μs** | **0.3052** | **0.0610** | **-** | **82928 B** | -| EncryptToProvidedStream | 10 | None | Stream | 36.59 μs | 0.272 μs | 0.399 μs | 36.57 μs | 0.1221 | - | - | 37048 B | -| Decrypt | 10 | None | Stream | 28.64 μs | 0.378 μs | 0.517 μs | 28.59 μs | 0.1221 | 0.0305 | - | 29520 B | -| DecryptToProvidedStream | 10 | None | Stream | 27.61 μs | 0.237 μs | 0.332 μs | 27.64 μs | 0.0610 | 0.0305 | - | 18416 B | -| **Encrypt** | **10** | **Brotli** | **Newtonsoft** | **115.28 μs** | **3.336 μs** | **4.677 μs** | **113.71 μs** | **0.6104** | **0.1221** | **-** | **168065 B** | -| EncryptToProvidedStream | 10 | Brotli | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| Decrypt | 10 | Brotli | Newtonsoft | 118.98 μs | 1.530 μs | 2.195 μs | 118.76 μs | 0.4883 | - | - | 144849 B | -| DecryptToProvidedStream | 10 | Brotli | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| **Encrypt** | **10** | **Brotli** | **Stream** | **90.10 μs** | **3.136 μs** | **4.693 μs** | **88.92 μs** | **0.2441** | **-** | **-** | **63809 B** | -| EncryptToProvidedStream | 10 | Brotli | Stream | 97.27 μs | 1.885 μs | 2.703 μs | 97.35 μs | 0.1221 | - | - | 32465 B | -| Decrypt | 10 | Brotli | Stream | 58.48 μs | 0.956 μs | 1.372 μs | 58.59 μs | 0.1221 | 0.0610 | - | 30064 B | -| DecryptToProvidedStream | 10 | Brotli | Stream | 59.12 μs | 1.160 μs | 1.664 μs | 59.14 μs | 0.0610 | - | - | 18960 B | -| **Encrypt** | **100** | **None** | **Newtonsoft** | **1,199.74 μs** | **42.805 μs** | **64.069 μs** | **1,206.48 μs** | **23.4375** | **21.4844** | **21.4844** | **1677978 B** | -| EncryptToProvidedStream | 100 | None | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| Decrypt | 100 | None | Newtonsoft | 1,177.48 μs | 25.746 μs | 38.535 μs | 1,172.04 μs | 17.5781 | 15.6250 | 15.6250 | 1260228 B | -| DecryptToProvidedStream | 100 | None | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| **Encrypt** | **100** | **None** | **Stream** | **636.72 μs** | **31.468 μs** | **47.099 μs** | **630.15 μs** | **16.6016** | **16.6016** | **16.6016** | **678066 B** | -| EncryptToProvidedStream | 100 | None | Stream | 383.33 μs | 7.441 μs | 10.671 μs | 384.69 μs | 4.3945 | 4.3945 | 4.3945 | 230133 B | -| Decrypt | 100 | None | Stream | 384.93 μs | 12.519 μs | 18.738 μs | 383.59 μs | 5.8594 | 5.8594 | 5.8594 | 230753 B | -| DecryptToProvidedStream | 100 | None | Stream | 295.19 μs | 7.094 μs | 10.618 μs | 296.11 μs | 3.4180 | 3.4180 | 3.4180 | 119116 B | -| **Encrypt** | **100** | **Brotli** | **Newtonsoft** | **1,178.06 μs** | **63.246 μs** | **94.664 μs** | **1,152.03 μs** | **13.6719** | **11.7188** | **9.7656** | **1379183 B** | -| EncryptToProvidedStream | 100 | Brotli | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| Decrypt | 100 | Brotli | Newtonsoft | 1,175.01 μs | 41.917 μs | 61.441 μs | 1,156.01 μs | 11.7188 | 9.7656 | 9.7656 | 1124274 B | -| DecryptToProvidedStream | 100 | Brotli | Newtonsoft | NA | NA | NA | NA | - | - | - | - | -| **Encrypt** | **100** | **Brotli** | **Stream** | **757.11 μs** | **19.549 μs** | **29.260 μs** | **754.55 μs** | **10.7422** | **10.7422** | **10.7422** | **479493 B** | -| EncryptToProvidedStream | 100 | Brotli | Stream | 563.46 μs | 9.960 μs | 14.284 μs | 561.60 μs | 2.9297 | 2.9297 | 2.9297 | 180637 B | -| Decrypt | 100 | Brotli | Stream | 542.34 μs | 14.514 μs | 21.724 μs | 542.04 μs | 6.8359 | 6.8359 | 6.8359 | 231162 B | -| DecryptToProvidedStream | 100 | Brotli | Stream | 463.69 μs | 9.130 μs | 12.800 μs | 460.71 μs | 3.4180 | 3.4180 | 3.4180 | 119506 B | +| Method | DocumentSizeInKb | CompressionAlgorithm | JsonProcessor | Mean | Error | StdDev | Gen0 | Gen1 | Gen2 | Allocated | +|------------------------ |----------------- |--------------------- |-------------- |------------:|----------:|----------:|--------:|--------:|--------:|-----------:| +| **Encrypt** | **1** | **None** | **Newtonsoft** | **22.97 μs** | **0.545 μs** | **0.816 μs** | **0.1526** | **0.0305** | **-** | **41.2 KB** | +| EncryptToProvidedStream | 1 | None | Newtonsoft | 21.46 μs | 0.105 μs | 0.153 μs | 0.1221 | 0.0305 | - | 34.13 KB | +| Decrypt | 1 | None | Newtonsoft | 27.28 μs | 0.149 μs | 0.219 μs | 0.1526 | 0.0305 | - | 40.84 KB | +| DecryptToProvidedStream | 1 | None | Newtonsoft | 34.38 μs | 1.292 μs | 1.934 μs | 0.1221 | - | - | 42.3 KB | +| **Encrypt** | **1** | **None** | **Stream** | **12.10 μs** | **0.081 μs** | **0.116 μs** | **0.0610** | **0.0153** | **-** | **17.35 KB** | +| EncryptToProvidedStream | 1 | None | Stream | 12.72 μs | 0.829 μs | 1.215 μs | 0.0458 | 0.0153 | - | 11.36 KB | +| Decrypt | 1 | None | Stream | 12.21 μs | 0.180 μs | 0.264 μs | 0.0458 | 0.0153 | - | 12.38 KB | +| DecryptToProvidedStream | 1 | None | Stream | 12.55 μs | 0.152 μs | 0.213 μs | 0.0458 | 0.0153 | - | 11.23 KB | +| **Encrypt** | **1** | **Brotli** | **Newtonsoft** | **28.73 μs** | **0.579 μs** | **0.830 μs** | **0.1526** | **0.0305** | **-** | **37.59 KB** | +| EncryptToProvidedStream | 1 | Brotli | Newtonsoft | 28.58 μs | 0.293 μs | 0.411 μs | 0.1221 | 0.0305 | - | 34.54 KB | +| Decrypt | 1 | Brotli | Newtonsoft | 35.35 μs | 0.894 μs | 1.337 μs | 0.1221 | - | - | 40.49 KB | +| DecryptToProvidedStream | 1 | Brotli | Newtonsoft | 38.17 μs | 0.409 μs | 0.574 μs | 0.1221 | - | - | 42.14 KB | +| **Encrypt** | **1** | **Brotli** | **Stream** | **19.77 μs** | **0.275 μs** | **0.395 μs** | **0.0610** | **0.0305** | **-** | **16.43 KB** | +| EncryptToProvidedStream | 1 | Brotli | Stream | 19.40 μs | 0.188 μs | 0.264 μs | 0.0305 | - | - | 11.63 KB | +| Decrypt | 1 | Brotli | Stream | 17.73 μs | 0.138 μs | 0.206 μs | 0.0305 | - | - | 12.65 KB | +| DecryptToProvidedStream | 1 | Brotli | Stream | 18.05 μs | 0.120 μs | 0.180 μs | 0.0305 | - | - | 11.51 KB | +| **Encrypt** | **10** | **None** | **Newtonsoft** | **84.60 μs** | **0.488 μs** | **0.699 μs** | **0.6104** | **0.1221** | **-** | **168.82 KB** | +| EncryptToProvidedStream | 10 | None | Newtonsoft | 82.21 μs | 0.199 μs | 0.272 μs | 0.4883 | - | - | 137.7 KB | +| Decrypt | 10 | None | Newtonsoft | 101.88 μs | 0.452 μs | 0.676 μs | 0.6104 | 0.1221 | - | 155.55 KB | +| DecryptToProvidedStream | 10 | None | Newtonsoft | 107.81 μs | 0.595 μs | 0.890 μs | 0.6104 | 0.1221 | - | 157.01 KB | +| **Encrypt** | **10** | **None** | **Stream** | **37.80 μs** | **0.181 μs** | **0.266 μs** | **0.3052** | **0.0610** | **-** | **81.22 KB** | +| EncryptToProvidedStream | 10 | None | Stream | 34.84 μs | 0.326 μs | 0.488 μs | 0.1221 | - | - | 36.41 KB | +| Decrypt | 10 | None | Stream | 26.40 μs | 0.164 μs | 0.245 μs | 0.1221 | 0.0305 | - | 28.83 KB | +| DecryptToProvidedStream | 10 | None | Stream | 25.85 μs | 0.175 μs | 0.262 μs | 0.0610 | 0.0305 | - | 17.98 KB | +| **Encrypt** | **10** | **Brotli** | **Newtonsoft** | **113.23 μs** | **0.688 μs** | **0.986 μs** | **0.6104** | **0.1221** | **-** | **165.98 KB** | +| EncryptToProvidedStream | 10 | Brotli | Newtonsoft | 111.05 μs | 0.535 μs | 0.801 μs | 0.4883 | - | - | 134.86 KB | +| Decrypt | 10 | Brotli | Newtonsoft | 122.44 μs | 1.023 μs | 1.499 μs | 0.4883 | - | - | 143.28 KB | +| DecryptToProvidedStream | 10 | Brotli | Newtonsoft | 127.32 μs | 0.892 μs | 1.308 μs | 0.4883 | - | - | 144.93 KB | +| **Encrypt** | **10** | **Brotli** | **Stream** | **84.20 μs** | **2.861 μs** | **4.193 μs** | **0.2441** | **-** | **-** | **62.55 KB** | +| EncryptToProvidedStream | 10 | Brotli | Stream | 92.70 μs | 1.253 μs | 1.876 μs | 0.1221 | - | - | 31.94 KB | +| Decrypt | 10 | Brotli | Stream | 54.23 μs | 0.528 μs | 0.775 μs | 0.1221 | - | - | 29.1 KB | +| DecryptToProvidedStream | 10 | Brotli | Stream | 54.34 μs | 0.505 μs | 0.756 μs | 0.0610 | - | - | 18.26 KB | +| **Encrypt** | **100** | **None** | **Newtonsoft** | **1,074.94 μs** | **17.781 μs** | **26.614 μs** | **21.4844** | **19.5313** | **19.5313** | **1654.89 KB** | +| EncryptToProvidedStream | 100 | None | Newtonsoft | 908.83 μs | 44.365 μs | 62.193 μs | 11.7188 | 9.7656 | 9.7656 | 1143.65 KB | +| Decrypt | 100 | None | Newtonsoft | 1,126.75 μs | 21.460 μs | 32.120 μs | 17.5781 | 15.6250 | 15.6250 | 1246.93 KB | +| DecryptToProvidedStream | 100 | None | Newtonsoft | 1,183.13 μs | 19.585 μs | 29.314 μs | 15.6250 | 13.6719 | 13.6719 | 1248.4 KB | +| **Encrypt** | **100** | **None** | **Stream** | **513.68 μs** | **11.309 μs** | **16.927 μs** | **16.6016** | **16.6016** | **16.6016** | **662.42 KB** | +| EncryptToProvidedStream | 100 | None | Stream | 335.51 μs | 7.015 μs | 10.500 μs | 4.3945 | 4.3945 | 4.3945 | 224.97 KB | +| Decrypt | 100 | None | Stream | 310.34 μs | 5.028 μs | 7.525 μs | 6.3477 | 6.3477 | 6.3477 | 225.35 KB | +| DecryptToProvidedStream | 100 | None | Stream | 264.40 μs | 3.169 μs | 4.545 μs | 3.4180 | 3.4180 | 3.4180 | 116.32 KB | +| **Encrypt** | **100** | **Brotli** | **Newtonsoft** | **1,098.17 μs** | **10.860 μs** | **16.255 μs** | **13.6719** | **9.7656** | **9.7656** | **1363.12 KB** | +| EncryptToProvidedStream | 100 | Brotli | Newtonsoft | 1,012.03 μs | 9.265 μs | 13.581 μs | 7.8125 | 5.8594 | 5.8594 | 1107.87 KB | +| Decrypt | 100 | Brotli | Newtonsoft | 1,137.56 μs | 8.877 μs | 13.012 μs | 11.7188 | 9.7656 | 9.7656 | 1114.15 KB | +| DecryptToProvidedStream | 100 | Brotli | Newtonsoft | 1,160.69 μs | 9.399 μs | 13.777 μs | 11.7188 | 9.7656 | 9.7656 | 1115.79 KB | +| **Encrypt** | **100** | **Brotli** | **Stream** | **726.91 μs** | **10.086 μs** | **15.097 μs** | **11.7188** | **11.7188** | **11.7188** | **468.53 KB** | +| EncryptToProvidedStream | 100 | Brotli | Stream | 551.89 μs | 6.359 μs | 9.518 μs | 2.9297 | 2.9297 | 2.9297 | 176.64 KB | +| Decrypt | 100 | Brotli | Stream | 517.81 μs | 7.945 μs | 11.891 μs | 6.3477 | 6.3477 | 6.3477 | 225.62 KB | +| DecryptToProvidedStream | 100 | Brotli | Stream | 440.63 μs | 4.781 μs | 7.007 μs | 3.4180 | 3.4180 | 3.4180 | 116.6 KB | diff --git a/Microsoft.Azure.Cosmos/src/Serializer/CosmosSerializer.cs b/Microsoft.Azure.Cosmos/src/Serializer/CosmosSerializer.cs index e6bd34d1f5..01494e4b4d 100644 --- a/Microsoft.Azure.Cosmos/src/Serializer/CosmosSerializer.cs +++ b/Microsoft.Azure.Cosmos/src/Serializer/CosmosSerializer.cs @@ -5,6 +5,8 @@ namespace Microsoft.Azure.Cosmos { using System.IO; + using System.Threading; + using System.Threading.Tasks; /// /// This abstract class can be implemented to allow a custom serializer to be used by the CosmosClient. @@ -32,5 +34,37 @@ public abstract class CosmosSerializer /// Any type passed to . /// A readable Stream containing JSON of the serialized object. public abstract Stream ToStream(T input); + + /// + /// Convert a Stream of JSON to an object. + /// The implementation is responsible for Disposing of the stream, + /// including when an exception is thrown, to avoid memory leaks. + /// + /// Any type passed to . + /// The Stream response containing JSON from Cosmos DB. + /// Cancellation token. + /// The object deserialized from the stream. + public virtual Task FromStreamAsync(Stream stream, CancellationToken cancellationToken) + { + _ = cancellationToken; + return Task.FromResult(this.FromStream(stream)); + } + + /// + /// Convert the object to a Stream. + /// The caller provides the Stream and has full control over it. + /// Stream.CanRead must be true. + /// + /// Any type passed to . + /// Output stream. + /// Cancellation token. + /// A readable Stream containing JSON of the serialized object. + public virtual async Task ToStreamAsync(T input, Stream output, CancellationToken cancellationToken) + { + Stream temp = this.ToStream(input); + + //80kiB is default value + await temp.CopyToAsync(output, 81920, cancellationToken); + } } } diff --git a/Microsoft.Azure.Cosmos/src/Serializer/CosmosSystemTextJsonSerializer.cs b/Microsoft.Azure.Cosmos/src/Serializer/CosmosSystemTextJsonSerializer.cs index 6fcf5ee2c9..419c152a5d 100644 --- a/Microsoft.Azure.Cosmos/src/Serializer/CosmosSystemTextJsonSerializer.cs +++ b/Microsoft.Azure.Cosmos/src/Serializer/CosmosSystemTextJsonSerializer.cs @@ -9,6 +9,8 @@ namespace Microsoft.Azure.Cosmos using System.Reflection; using System.Text.Json; using System.Text.Json.Serialization; + using System.Threading; + using System.Threading.Tasks; using Microsoft.Azure.Cosmos.CosmosElements; using Microsoft.Azure.Cosmos.Json; using Microsoft.Azure.Cosmos.Serializer; @@ -77,6 +79,28 @@ public override T FromStream(Stream stream) } } + /// + public override async Task FromStreamAsync(Stream stream, CancellationToken cancellationToken) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + return (T)(object)stream; + } + + if (stream.CanSeek && stream.Length == 0) + { + return default; + } + + using (stream) + { + return await System.Text.Json.JsonSerializer.DeserializeAsync(stream, this.jsonSerializerOptions, cancellationToken); + } + } + /// public override Stream ToStream(T input) { @@ -89,6 +113,13 @@ public override Stream ToStream(T input) return streamPayload; } + /// + public override async Task ToStreamAsync(T input, Stream output, CancellationToken cancellationToken) + { + await System.Text.Json.JsonSerializer.SerializeAsync(output, input, this.jsonSerializerOptions, cancellationToken); + output.Position = 0; + } + /// /// Convert a MemberInfo to a string for use in LINQ query translation. ///