diff --git a/OpenMcdf3/CfbBinaryReader.cs b/OpenMcdf3/CfbBinaryReader.cs index 7dc58662..ad365038 100644 --- a/OpenMcdf3/CfbBinaryReader.cs +++ b/OpenMcdf3/CfbBinaryReader.cs @@ -7,6 +7,7 @@ namespace OpenMcdf3; /// internal sealed class CfbBinaryReader : BinaryReader { + readonly byte[] guidBuffer = new byte[16]; readonly byte[] buffer = new byte[DirectoryEntry.NameFieldLength]; public CfbBinaryReader(Stream input) @@ -14,7 +15,19 @@ public CfbBinaryReader(Stream input) { } - public Guid ReadGuid() => new(ReadBytes(16)); + public Guid ReadGuid() + { + int bytesRead = 0; + do + { + int n = Read(guidBuffer, bytesRead, guidBuffer.Length - bytesRead); + if (n == 0) + throw new EndOfStreamException(); + bytesRead += n; + } while (bytesRead < guidBuffer.Length); + + return new Guid(guidBuffer); + } public DateTime ReadFileTime() { @@ -22,8 +35,6 @@ public DateTime ReadFileTime() return DateTime.FromFileTimeUtc(fileTime); } - private void ReadBytes(byte[] buffer) => Read(buffer, 0, buffer.Length); - public Header ReadHeader() { Header header = new(); @@ -90,8 +101,9 @@ public DirectoryEntry ReadDirectoryEntry(Version version) DirectoryEntry entry = new(); Read(buffer, 0, DirectoryEntry.NameFieldLength); - int nameLength = Math.Max(0, ReadUInt16() - 2); - entry.Name = Encoding.Unicode.GetString(buffer, 0, nameLength); + ushort nameLength = ReadUInt16(); + int clampedNameLength = Math.Max(0, Math.Min(ushort.MaxValue, nameLength - 2)); + entry.Name = Encoding.Unicode.GetString(buffer, 0, clampedNameLength); entry.Type = ReadStorageType(); entry.Color = ReadColor(); entry.LeftSiblingId = ReadUInt32(); diff --git a/OpenMcdf3/FatSectorChainEnumerator.cs b/OpenMcdf3/FatSectorChainEnumerator.cs index f06757a1..09117b00 100644 --- a/OpenMcdf3/FatSectorChainEnumerator.cs +++ b/OpenMcdf3/FatSectorChainEnumerator.cs @@ -56,7 +56,7 @@ public bool MoveNext() } else if (!current.IsEndOfChain) { - uint sectorId = fatEnumerator.GetNextFatSectorId(current.Id); + uint sectorId = GetNextFatSectorId(current.Id); current = new(sectorId, ioContext.Header.SectorSize); Index++; } @@ -98,4 +98,23 @@ public void Reset() current = Sector.EndOfChain; Index = uint.MaxValue; } + + /// + /// Gets the next sector ID in the FAT chain. + /// + uint GetNextFatSectorId(uint id) + { + if (id > SectorType.Maximum) + throw new ArgumentException("Invalid sector ID", nameof(id)); + + int elementCount = ioContext.Header.SectorSize / sizeof(uint); + uint sectorId = (uint)Math.DivRem(id, elementCount, out long sectorOffset); + if (!fatEnumerator.MoveTo(sectorId)) + throw new ArgumentException("Invalid sector ID", nameof(id)); + + long position = fatEnumerator.Current.Position + sectorOffset * sizeof(uint); + ioContext.Reader.Seek(position); + uint nextId = ioContext.Reader.ReadUInt32(); + return nextId; + } } diff --git a/OpenMcdf3/FatSectorEnumerator.cs b/OpenMcdf3/FatSectorEnumerator.cs index e98e8d12..21bbf81b 100644 --- a/OpenMcdf3/FatSectorEnumerator.cs +++ b/OpenMcdf3/FatSectorEnumerator.cs @@ -83,7 +83,9 @@ public bool MoveNext() return true; } - /// + /// + /// Moves the enumerator to the specified sector. + /// public bool MoveTo(uint sectorId) { if (sectorId < id) @@ -107,21 +109,4 @@ public void Reset() difatSectorElementIndex = 0; current = Sector.EndOfChain; } - - /// - public uint GetNextFatSectorId(uint id) - { - if (id > SectorType.Maximum) - throw new ArgumentException("Invalid sector ID", nameof(id)); - - int elementCount = ioContext.Header.SectorSize / sizeof(uint); - uint sectorId = (uint)Math.DivRem(id, elementCount, out long sectorOffset); - if (!MoveTo(sectorId)) - throw new ArgumentException("Invalid sector ID", nameof(id)); - - long position = Current.Position + sectorOffset * sizeof(uint); - ioContext.Reader.Seek(position); - uint nextId = ioContext.Reader.ReadUInt32(); - return nextId; - } } diff --git a/OpenMcdf3/FatStream.cs b/OpenMcdf3/FatStream.cs index f0270898..7c86cda8 100644 --- a/OpenMcdf3/FatStream.cs +++ b/OpenMcdf3/FatStream.cs @@ -73,14 +73,14 @@ public override int Read(byte[] buffer, int offset, int count) if (count == 0) return 0; - uint chainIndex = (uint)Math.DivRem(position, ioContext.Header.SectorSize, out long sectorOffset); - if (!chain.MoveTo(chainIndex)) - return 0; - int maxCount = (int)Math.Min(Math.Max(length - position, 0), int.MaxValue); if (maxCount == 0) return 0; + uint chainIndex = (uint)Math.DivRem(position, ioContext.Header.SectorSize, out long sectorOffset); + if (!chain.MoveTo(chainIndex)) + return 0; + int realCount = Math.Min(count, maxCount); int readCount = 0; do diff --git a/OpenMcdf3/MiniFatStream.cs b/OpenMcdf3/MiniFatStream.cs index d51dd5c5..5044e953 100644 --- a/OpenMcdf3/MiniFatStream.cs +++ b/OpenMcdf3/MiniFatStream.cs @@ -71,14 +71,14 @@ public override int Read(byte[] buffer, int offset, int count) if (count == 0) return 0; - uint chainIndex = (uint)Math.DivRem(position, ioContext.Header.SectorSize, out long sectorOffset); - if (!chain.MoveTo(chainIndex)) - return 0; - int maxCount = (int)Math.Min(Math.Max(length - position, 0), int.MaxValue); if (maxCount == 0) return 0; + uint chainIndex = (uint)Math.DivRem(position, ioContext.Header.SectorSize, out long sectorOffset); + if (!chain.MoveTo(chainIndex)) + return 0; + int realCount = Math.Min(count, maxCount); int readCount = 0; do