Skip to content

Commit

Permalink
Fix consolidation with transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremy-visionaid committed Nov 29, 2024
1 parent d55b7e5 commit 08c0844
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 87 deletions.
77 changes: 77 additions & 0 deletions OpenMcdf.Tests/RootStorageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,83 @@ public void Open(string fileName)
Assert.ThrowsException<NotSupportedException>(() => rootStorage.StateBits = 0);
}

[TestMethod]
[DataRow(Version.V3)]
[DataRow(Version.V4)]
public void ConsolidateMemoryStream(Version version)
{
byte[] buffer = new byte[4096];

using MemoryStream memoryStream = new();
using (var rootStorage = RootStorage.Create(memoryStream, version, StorageModeFlags.LeaveOpen))
{
using (CfbStream stream = rootStorage.CreateStream("Test"))
stream.Write(buffer, 0, buffer.Length);

Assert.AreEqual(1, rootStorage.EnumerateEntries().Count());

rootStorage.Flush(true);

int originalMemoryStreamLength = (int)memoryStream.Length;

rootStorage.Delete("Test");

rootStorage.Flush(true);

Assert.IsTrue(originalMemoryStreamLength > memoryStream.Length);
}

using (var rootStorage = RootStorage.Create(memoryStream, version, StorageModeFlags.LeaveOpen))
{
Assert.AreEqual(0, rootStorage.EnumerateEntries().Count());
}
}

[TestMethod]
[DataRow(Version.V3, StorageModeFlags.None)]
[DataRow(Version.V4, StorageModeFlags.Transacted)]
public void ConsolidateFile(Version version, StorageModeFlags flags)
{
byte[] buffer = new byte[4096];

string fileName = Path.GetTempFileName();

try
{
using (var rootStorage = RootStorage.Create(fileName, version, flags))
{
using (CfbStream stream = rootStorage.CreateStream("Test"))
stream.Write(buffer, 0, buffer.Length);

Assert.AreEqual(1, rootStorage.EnumerateEntries().Count());

if (flags.HasFlag(StorageModeFlags.Transacted))
rootStorage.Commit();
rootStorage.Flush(true);

long originalLength = new FileInfo(fileName).Length;

rootStorage.Delete("Test");

if (flags.HasFlag(StorageModeFlags.Transacted))
rootStorage.Commit();
rootStorage.Flush(true);

long consolidatedLength = new FileInfo(fileName).Length;
Assert.IsTrue(originalLength > consolidatedLength);
}

using (var rootStorage = RootStorage.OpenRead(fileName))
{
Assert.AreEqual(0, rootStorage.EnumerateEntries().Count());
}
}
finally
{
File.Delete(fileName);
}
}

[TestMethod]
[DataRow(Version.V3, 0)]
[DataRow(Version.V3, 1)]
Expand Down
73 changes: 0 additions & 73 deletions OpenMcdf.Tests/StorageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -274,79 +274,6 @@ public void DeleteStream(Version version)
}
}

[TestMethod]
[DataRow(Version.V3)]
[DataRow(Version.V4)]
public void ConsolidateMemoryStream(Version version)
{
byte[] buffer = new byte[4096];

using MemoryStream memoryStream = new();
using (var rootStorage = RootStorage.Create(memoryStream, version, StorageModeFlags.LeaveOpen))
{
using (CfbStream stream = rootStorage.CreateStream("Test"))
stream.Write(buffer, 0, buffer.Length);

Assert.AreEqual(1, rootStorage.EnumerateEntries().Count());

rootStorage.Flush(true);

int originalMemoryStreamLength = (int)memoryStream.Length;

rootStorage.Delete("Test");

rootStorage.Flush(true);

Assert.IsTrue(originalMemoryStreamLength > memoryStream.Length);
}

using (var rootStorage = RootStorage.Create(memoryStream, version, StorageModeFlags.LeaveOpen))
{
Assert.AreEqual(0, rootStorage.EnumerateEntries().Count());
}
}

[TestMethod]
[DataRow(Version.V3)]
[DataRow(Version.V4)]
public void ConsolidateFile(Version version)
{
byte[] buffer = new byte[4096];

string fileName = Path.GetTempFileName();

try
{
using (var rootStorage = RootStorage.Create(fileName, version))
{
using (CfbStream stream = rootStorage.CreateStream("Test"))
stream.Write(buffer, 0, buffer.Length);

Assert.AreEqual(1, rootStorage.EnumerateEntries().Count());

rootStorage.Flush(true);

long originalLength = new FileInfo(fileName).Length;

rootStorage.Delete("Test");

rootStorage.Flush(true);

long consolidatedLength = new FileInfo(fileName).Length;
Assert.IsTrue(originalLength > consolidatedLength);
}

using (var rootStorage = RootStorage.OpenRead(fileName))
{
Assert.AreEqual(0, rootStorage.EnumerateEntries().Count());
}
}
finally
{
File.Delete(fileName);
}
}

[TestMethod]
[DataRow(Version.V3)]
[DataRow(Version.V4)]
Expand Down
7 changes: 1 addition & 6 deletions OpenMcdf/RootContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace OpenMcdf;

[Flags]
enum IOContextFlags
{
None = 0,
Expand Down Expand Up @@ -188,12 +189,6 @@ public void ExtendStreamLength(long length)
isDirty = true;
}

public void Consolidate(long length)
{
BaseStream.SetLength(length);
Length = length;
}

public void WriteHeader()
{
CfbBinaryWriter writer = Writer;
Expand Down
16 changes: 8 additions & 8 deletions OpenMcdf/RootStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,25 @@ void Consolidate()
{
// TODO: Consolidate by defragmentation instead of copy

Stream baseStream = Context.BaseStream;
Stream? destinationStream = null;

try
{
if (Context.BaseStream is MemoryStream)
destinationStream = new MemoryStream((int)Context.BaseStream.Length);
else if (Context.BaseStream is FileStream)
if (baseStream is MemoryStream)
destinationStream = new MemoryStream((int)baseStream.Length);
else if (baseStream is FileStream)
destinationStream = File.Create(Path.GetTempFileName());
else
throw new NotSupportedException("Unsupported stream type for consolidation.");

using (RootStorage destinationStorage = Create(destinationStream, Context.Version, storageModeFlags | StorageModeFlags.LeaveOpen))
using (RootStorage destinationStorage = Create(destinationStream, Context.Version, StorageModeFlags.LeaveOpen))
CopyTo(destinationStorage);

Context.BaseStream.Position = 0;
destinationStream.Position = 0;
destinationStream.CopyAllTo(baseStream);

destinationStream.CopyTo(Context.BaseStream);
Context.Consolidate(destinationStream.Length);
IOContextFlags contextFlags = ToIOContextFlags(storageModeFlags);
_ = new RootContext(ContextSite, baseStream, Version.Unknown, contextFlags);
}
catch
{
Expand Down

0 comments on commit 08c0844

Please sign in to comment.