diff --git a/OpenMcdf.Tests/RootStorageTests.cs b/OpenMcdf.Tests/RootStorageTests.cs index 9a328cd..87dafc0 100644 --- a/OpenMcdf.Tests/RootStorageTests.cs +++ b/OpenMcdf.Tests/RootStorageTests.cs @@ -34,4 +34,42 @@ public void SwitchStream(Version version, int subStorageCount) rootStorage.OpenStorage($"Test{i}"); } } + + [TestMethod] + [DataRow(Version.V3, 0)] + [DataRow(Version.V3, 1)] + [DataRow(Version.V3, 2)] + [DataRow(Version.V3, 4)] // Required 2 sectors including root + [DataRow(Version.V4, 0)] + [DataRow(Version.V4, 1)] + [DataRow(Version.V4, 2)] + [DataRow(Version.V4, 32)] // Required 2 sectors including root + public void SwitchFile(Version version, int subStorageCount) + { + string fileName = Path.GetTempFileName(); + + try + { + using (var rootStorage = RootStorage.CreateInMemory(version)) + { + for (int i = 0; i < subStorageCount; i++) + rootStorage.CreateStorage($"Test{i}"); + + rootStorage.SwitchTo(fileName); + } + + using (var rootStorage = RootStorage.OpenRead(fileName)) + { + IEnumerable entries = rootStorage.EnumerateEntries(); + Assert.AreEqual(subStorageCount, entries.Count()); + + for (int i = 0; i < subStorageCount; i++) + rootStorage.OpenStorage($"Test{i}"); + } + } + finally + { + File.Delete(fileName); + } + } } diff --git a/OpenMcdf/RootStorage.cs b/OpenMcdf/RootStorage.cs index 85c713d..ab12773 100644 --- a/OpenMcdf/RootStorage.cs +++ b/OpenMcdf/RootStorage.cs @@ -158,7 +158,7 @@ public void Revert() Context.Revert(); } - public void SwitchTo(Stream stream) + private void SwitchToCore(Stream stream, bool allowLeaveOpen) { Flush(); @@ -173,9 +173,22 @@ public void SwitchTo(Stream stream) Context.Dispose(); IOContextFlags contextFlags = ToIOContextFlags(storageModeFlags); + if (!allowLeaveOpen) + contextFlags &= ~IOContextFlags.LeaveOpen; _ = new RootContext(ContextSite, stream, Version.Unknown, contextFlags); } + public void SwitchTo(Stream stream) + { + SwitchToCore(stream, true); + } + + public void SwitchTo(string fileName) + { + FileStream stream = File.Create(fileName); + SwitchToCore(stream, false); + } + [ExcludeFromCodeCoverage] internal void Trace(TextWriter writer) {