diff --git a/OpenMcdf.Tests/RootStorageTests.cs b/OpenMcdf.Tests/RootStorageTests.cs index 4599542..c729f58 100644 --- a/OpenMcdf.Tests/RootStorageTests.cs +++ b/OpenMcdf.Tests/RootStorageTests.cs @@ -28,6 +28,83 @@ public void Open(string fileName) Assert.ThrowsException(() => rootStorage.StateBits = 0); } + [TestMethod] + [DataRow(Version.V3)] + [DataRow(Version.V4)] + public void ConsolidateMemoryStream(Version version) + { + byte[] buffer = new byte[4096]; + + using MemoryStream memoryStream = new(); + using (var rootStorage = RootStorage.Create(memoryStream, version, StorageModeFlags.LeaveOpen)) + { + using (CfbStream stream = rootStorage.CreateStream("Test")) + stream.Write(buffer, 0, buffer.Length); + + Assert.AreEqual(1, rootStorage.EnumerateEntries().Count()); + + rootStorage.Flush(true); + + int originalMemoryStreamLength = (int)memoryStream.Length; + + rootStorage.Delete("Test"); + + rootStorage.Flush(true); + + Assert.IsTrue(originalMemoryStreamLength > memoryStream.Length); + } + + using (var rootStorage = RootStorage.Create(memoryStream, version, StorageModeFlags.LeaveOpen)) + { + Assert.AreEqual(0, rootStorage.EnumerateEntries().Count()); + } + } + + [TestMethod] + [DataRow(Version.V3, StorageModeFlags.None)] + [DataRow(Version.V4, StorageModeFlags.Transacted)] + public void ConsolidateFile(Version version, StorageModeFlags flags) + { + byte[] buffer = new byte[4096]; + + string fileName = Path.GetTempFileName(); + + try + { + using (var rootStorage = RootStorage.Create(fileName, version, flags)) + { + using (CfbStream stream = rootStorage.CreateStream("Test")) + stream.Write(buffer, 0, buffer.Length); + + Assert.AreEqual(1, rootStorage.EnumerateEntries().Count()); + + if (flags.HasFlag(StorageModeFlags.Transacted)) + rootStorage.Commit(); + rootStorage.Flush(true); + + long originalLength = new FileInfo(fileName).Length; + + rootStorage.Delete("Test"); + + if (flags.HasFlag(StorageModeFlags.Transacted)) + rootStorage.Commit(); + rootStorage.Flush(true); + + long consolidatedLength = new FileInfo(fileName).Length; + Assert.IsTrue(originalLength > consolidatedLength); + } + + using (var rootStorage = RootStorage.OpenRead(fileName)) + { + Assert.AreEqual(0, rootStorage.EnumerateEntries().Count()); + } + } + finally + { + File.Delete(fileName); + } + } + [TestMethod] [DataRow(Version.V3, 0)] [DataRow(Version.V3, 1)] diff --git a/OpenMcdf.Tests/StorageTests.cs b/OpenMcdf.Tests/StorageTests.cs index c617d59..e5e3d6d 100644 --- a/OpenMcdf.Tests/StorageTests.cs +++ b/OpenMcdf.Tests/StorageTests.cs @@ -274,79 +274,6 @@ public void DeleteStream(Version version) } } - [TestMethod] - [DataRow(Version.V3)] - [DataRow(Version.V4)] - public void ConsolidateMemoryStream(Version version) - { - byte[] buffer = new byte[4096]; - - using MemoryStream memoryStream = new(); - using (var rootStorage = RootStorage.Create(memoryStream, version, StorageModeFlags.LeaveOpen)) - { - using (CfbStream stream = rootStorage.CreateStream("Test")) - stream.Write(buffer, 0, buffer.Length); - - Assert.AreEqual(1, rootStorage.EnumerateEntries().Count()); - - rootStorage.Flush(true); - - int originalMemoryStreamLength = (int)memoryStream.Length; - - rootStorage.Delete("Test"); - - rootStorage.Flush(true); - - Assert.IsTrue(originalMemoryStreamLength > memoryStream.Length); - } - - using (var rootStorage = RootStorage.Create(memoryStream, version, StorageModeFlags.LeaveOpen)) - { - Assert.AreEqual(0, rootStorage.EnumerateEntries().Count()); - } - } - - [TestMethod] - [DataRow(Version.V3)] - [DataRow(Version.V4)] - public void ConsolidateFile(Version version) - { - byte[] buffer = new byte[4096]; - - string fileName = Path.GetTempFileName(); - - try - { - using (var rootStorage = RootStorage.Create(fileName, version)) - { - using (CfbStream stream = rootStorage.CreateStream("Test")) - stream.Write(buffer, 0, buffer.Length); - - Assert.AreEqual(1, rootStorage.EnumerateEntries().Count()); - - rootStorage.Flush(true); - - long originalLength = new FileInfo(fileName).Length; - - rootStorage.Delete("Test"); - - rootStorage.Flush(true); - - long consolidatedLength = new FileInfo(fileName).Length; - Assert.IsTrue(originalLength > consolidatedLength); - } - - using (var rootStorage = RootStorage.OpenRead(fileName)) - { - Assert.AreEqual(0, rootStorage.EnumerateEntries().Count()); - } - } - finally - { - File.Delete(fileName); - } - } - [TestMethod] [DataRow(Version.V3)] [DataRow(Version.V4)] diff --git a/OpenMcdf/RootContext.cs b/OpenMcdf/RootContext.cs index e9508cf..b19732a 100644 --- a/OpenMcdf/RootContext.cs +++ b/OpenMcdf/RootContext.cs @@ -2,6 +2,7 @@ namespace OpenMcdf; +[Flags] enum IOContextFlags { None = 0, @@ -188,12 +189,6 @@ public void ExtendStreamLength(long length) isDirty = true; } - public void Consolidate(long length) - { - BaseStream.SetLength(length); - Length = length; - } - public void WriteHeader() { CfbBinaryWriter writer = Writer; diff --git a/OpenMcdf/RootStorage.cs b/OpenMcdf/RootStorage.cs index 15229be..498189b 100644 --- a/OpenMcdf/RootStorage.cs +++ b/OpenMcdf/RootStorage.cs @@ -154,25 +154,25 @@ void Consolidate() { // TODO: Consolidate by defragmentation instead of copy + Stream baseStream = Context.BaseStream; Stream? destinationStream = null; try { - if (Context.BaseStream is MemoryStream) - destinationStream = new MemoryStream((int)Context.BaseStream.Length); - else if (Context.BaseStream is FileStream) + if (baseStream is MemoryStream) + destinationStream = new MemoryStream((int)baseStream.Length); + else if (baseStream is FileStream) destinationStream = File.Create(Path.GetTempFileName()); else throw new NotSupportedException("Unsupported stream type for consolidation."); - using (RootStorage destinationStorage = Create(destinationStream, Context.Version, storageModeFlags | StorageModeFlags.LeaveOpen)) + using (RootStorage destinationStorage = Create(destinationStream, Context.Version, StorageModeFlags.LeaveOpen)) CopyTo(destinationStorage); - Context.BaseStream.Position = 0; - destinationStream.Position = 0; + destinationStream.CopyAllTo(baseStream); - destinationStream.CopyTo(Context.BaseStream); - Context.Consolidate(destinationStream.Length); + IOContextFlags contextFlags = ToIOContextFlags(storageModeFlags); + _ = new RootContext(ContextSite, baseStream, Version.Unknown, contextFlags); } catch {