diff --git a/OpenMcdf.Tests/RootStorageTests.cs b/OpenMcdf.Tests/RootStorageTests.cs index 190c570..521eca8 100644 --- a/OpenMcdf.Tests/RootStorageTests.cs +++ b/OpenMcdf.Tests/RootStorageTests.cs @@ -207,4 +207,28 @@ public void SwitchTransactedStream(Version version, int subStorageCount) rootStorage.OpenStorage($"Test{i}"); } } + + [TestMethod] + [DataRow(false)] + [DataRow(true)] + public void DeleteTrimsBaseStream(bool consolidate) + { + using var rootStorage = RootStorage.CreateInMemory(Version.V3); + using (CfbStream stream = rootStorage.CreateStream("Test")) + { + byte[] buffer = TestData.CreateByteArray(4096); + stream.Write(buffer, 0, buffer.Length); + } + + rootStorage.Flush(consolidate); + + long originalLength = rootStorage.BaseStream.Length; + + rootStorage.Delete("Test"); + rootStorage.Flush(consolidate); + + long newLength = rootStorage.BaseStream.Length; + + Assert.IsTrue(originalLength > newLength); + } } diff --git a/OpenMcdf/Fat.cs b/OpenMcdf/Fat.cs index 1450920..92c8a94 100644 --- a/OpenMcdf/Fat.cs +++ b/OpenMcdf/Fat.cs @@ -141,6 +141,18 @@ public uint Add(FatEnumerator fatEnumerator, uint startIndex) return entry.Index; } + public Sector GetLastUsedSector() + { + uint lastUsedSectorIndex = uint.MaxValue; + foreach (FatEntry entry in this) + { + if (!entry.IsFree) + lastUsedSectorIndex = entry.Index; + } + + return new(lastUsedSectorIndex, Context.SectorSize); + } + public IEnumerator GetEnumerator() => new FatEnumerator(Context.Fat); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); diff --git a/OpenMcdf/RootContext.cs b/OpenMcdf/RootContext.cs index 0e88f63..b066534 100644 --- a/OpenMcdf/RootContext.cs +++ b/OpenMcdf/RootContext.cs @@ -87,8 +87,6 @@ public FatStream MiniStream public uint SectorCount => (uint)Math.Max(0, (Length - SectorSize) / SectorSize); // TODO: Check - bool isDirty; - public RootContext(RootContextSite rootContextSite, Stream stream, Version version, IOContextFlags contextFlags = IOContextFlags.None) : base(rootContextSite) { @@ -178,12 +176,10 @@ public void Flush() { Fat.Flush(); - if (isDirty && writer is not null && transactedStream is null) + if (writer is not null && transactedStream is null) { - // Ensure the stream is as long as expected - BaseStream.SetLength(Length); + TrimBaseStream(); WriteHeader(); - isDirty = false; } } @@ -191,7 +187,16 @@ public void ExtendStreamLength(long length) { if (Length < length) Length = length; - isDirty = true; + } + + void TrimBaseStream() + { + Sector lastUsedSector = Fat.GetLastUsedSector(); + if (!lastUsedSector.IsValid) + throw new FileFormatException("Last used sector is invalid"); + + Length = lastUsedSector.EndPosition; + BaseStream.SetLength(Length); } public void WriteHeader()