diff --git a/OpenMcdf.Tests/RootStorageTests.cs b/OpenMcdf.Tests/RootStorageTests.cs index c8a3c81..190c570 100644 --- a/OpenMcdf.Tests/RootStorageTests.cs +++ b/OpenMcdf.Tests/RootStorageTests.cs @@ -174,4 +174,37 @@ public void SwitchFile(Version version, int subStorageCount) try { File.Delete(fileName); } catch { } } } + + [TestMethod] + [DataRow(Version.V3, 0)] + [DataRow(Version.V3, 1)] + [DataRow(Version.V3, 2)] + [DataRow(Version.V3, 4)] // Required 2 sectors including root + [DataRow(Version.V4, 0)] + [DataRow(Version.V4, 1)] + [DataRow(Version.V4, 2)] + [DataRow(Version.V4, 32)] // Required 2 sectors including root + public void SwitchTransactedStream(Version version, int subStorageCount) + { + using MemoryStream originalMemoryStream = new(); + using MemoryStream switchedMemoryStream = new(); + + using (var rootStorage = RootStorage.Create(originalMemoryStream, version, StorageModeFlags.Transacted | StorageModeFlags.LeaveOpen)) + { + for (int i = 0; i < subStorageCount; i++) + rootStorage.CreateStorage($"Test{i}"); + + rootStorage.SwitchTo(switchedMemoryStream); + rootStorage.Commit(); + } + + using (var rootStorage = RootStorage.Open(switchedMemoryStream)) + { + IEnumerable entries = rootStorage.EnumerateEntries(); + Assert.AreEqual(subStorageCount, entries.Count()); + + for (int i = 0; i < subStorageCount; i++) + rootStorage.OpenStorage($"Test{i}"); + } + } } diff --git a/OpenMcdf/RootContext.cs b/OpenMcdf/RootContext.cs index b19732a..0e88f63 100644 --- a/OpenMcdf/RootContext.cs +++ b/OpenMcdf/RootContext.cs @@ -26,6 +26,8 @@ internal sealed class RootContext : ContextBase, IDisposable public Stream BaseStream { get; } + public Stream Stream { get; } + public CfbBinaryReader Reader { get; } public CfbBinaryWriter Writer @@ -104,17 +106,20 @@ public RootContext(RootContextSite rootContextSite, Stream stream, Version versi DirectoryEntriesPerSector = SectorSize / DirectoryEntry.Length; Length = stream.Length; - Stream actualStream = stream; if (contextFlags.HasFlag(IOContextFlags.Transacted)) { Stream overlayStream = stream is MemoryStream ? new MemoryStream() : File.Create(Path.GetTempFileName()); transactedStream = new TransactedStream(ContextSite, stream, overlayStream); - actualStream = new BufferedStream(transactedStream, SectorSize); + Stream = new BufferedStream(transactedStream, SectorSize); + } + else + { + Stream = stream; } - Reader = new(actualStream); + Reader = new(Stream); if (stream.CanWrite) - writer = new(actualStream); + writer = new(Stream); Fat = new(ContextSite); DirectoryEntries = new(ContextSite); diff --git a/OpenMcdf/RootStorage.cs b/OpenMcdf/RootStorage.cs index 498189b..8ded764 100644 --- a/OpenMcdf/RootStorage.cs +++ b/OpenMcdf/RootStorage.cs @@ -78,7 +78,7 @@ public static RootStorage Create(Stream stream, Version version = Version.V3, St return new RootStorage(rootContextSite, flags); } - public static RootStorage CreateInMemory(Version version = Version.V3) => Create(new MemoryStream(), version); + public static RootStorage CreateInMemory(Version version = Version.V3, StorageModeFlags flags = StorageModeFlags.None) => Create(new MemoryStream(), version, flags); public static RootStorage Open(string fileName, FileMode mode, StorageModeFlags flags = StorageModeFlags.None) { @@ -205,15 +205,7 @@ public void Revert() private void SwitchToCore(Stream stream, bool allowLeaveOpen) { Flush(); - - stream.SetLength(Context.BaseStream.Length); - - stream.Position = 0; - Context.BaseStream.Position = 0; - - Context.BaseStream.CopyTo(stream); - stream.Position = 0; - + Context.Stream.CopyAllTo(stream); Context.Dispose(); IOContextFlags contextFlags = ToIOContextFlags(storageModeFlags);