Skip to content

Commit

Permalink
Optimize BucketingInputSource for performance
Browse files Browse the repository at this point in the history
The BucketingInputSource has a bytePositionToIndicies function that
returns a tuple containing the bucket index and index within that bucket
to find a given byte position. This function should be inlined and in
theory the tuple allocation could be optimized out since we immediately
take it apart into separate variables, but that doesn't seem to be the
case, and leads to noticeable Overhead.

This removes the tuple allocation by replacing the one function with two
separate functions. This means there is now an extra function call but
it avoids the tuple allocation, which appears to be the main overhead.

This also is more careful about which variables are Int's and Long's to
minimize the number of toInt calls. This is unlikely to make a
performance difference, but does make the code cleaner. This also
switches from integer/modular division to shifts and masks which should
also be more efficient. This does now require the bucket size to be
specified as a power of two.

In basic testing, these changes reduced the overhead of the
BucketingInputSource compared to the ByteBufferInputSource from about
15% to 5%.

DAFFODIL-2920
  • Loading branch information
stevedlawrence committed Aug 5, 2024
1 parent a13501a commit 8735ed1
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 37 deletions.
80 changes: 50 additions & 30 deletions daffodil-io/src/main/scala/org/apache/daffodil/io/InputSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,18 @@ abstract class InputSource {
* can be freed when the reference count goes to zero.
*
* @param inputStream the java.io.Inputstream to read data from
* @param bucketSize the size of each individual bucket
* @param bucketSizeExponent the exponent used to calculate the size of a single bucket, defined as 2^buckeSizeExponent bytes
* @param maxCacheSizeInBytes the max memory allowed to be used for bucket storage (num buckets * bucketSize)
*/
class BucketingInputSource(
inputStream: java.io.InputStream,
bucketSize: Int = 1 << 13,
bucketSizeExponent: Int = 13,
maxCacheSizeInBytes: Int = 256 * (1 << 20)
) extends InputSource {

private final val bucketSize: Int = 1 << bucketSizeExponent
private final val bucketMask: Int = bucketSize - 1

private class Bucket {
var refCount = 0
val bytes = new Array[Byte](bucketSize)
Expand All @@ -235,6 +238,10 @@ class BucketingInputSource(
* size of the buckets array to grow beyond this number. So this does not
* represent the maximum size of the buckets array, but instead represents the
* maximum number of non-null elements in the buckets array.
*
* We alway require at least two non-null buckets even if the max cache size is
* quite small--if the byte position is at the very beginning of a bucket, we
* always want the ability to backtrack to at least the previous bucket
*/
private val maxNumberOfNonNullBuckets = Math.max(maxCacheSizeInBytes / bucketSize, 2)

Expand All @@ -254,7 +261,7 @@ class BucketingInputSource(
* quickly know which buckets have something in them. All indices in the
* buckets ArrayBuffer less than this index should have a value of null
*/
private var oldestBucketIndex = 0
private var oldestBucketIndex: Int = 0

/**
* Stores the current byte position. When a call to get is made, the
Expand Down Expand Up @@ -298,7 +305,7 @@ class BucketingInputSource(
*
* @return false if EOF was hit before filling in the necessary bytes. true otherwise.
*/
private def fillBucketsToIndex(goalBucketIndex: Long, bytesNeededInBucket: Long): Boolean = {
private def fillBucketsToIndex(goalBucketIndex: Int, bytesNeededInBucket: Int): Boolean = {
var lastBucketIndex = buckets.length - 1

var needsMoreData =
Expand Down Expand Up @@ -344,7 +351,7 @@ class BucketingInputSource(
lastBucketIndex += 1
if ((lastBucketIndex - oldestBucketIndex) >= maxNumberOfNonNullBuckets) {
// This frees the oldest bucket, allowing it to be garbage collected.
buckets(oldestBucketIndex.toInt) = null
buckets(oldestBucketIndex) = null
oldestBucketIndex += 1
}
}
Expand Down Expand Up @@ -373,12 +380,20 @@ class BucketingInputSource(
!needsMoreData
}

/**
* Get the bucket index where bytePos0b is stored
*/
@inline
final private def bytePositionToIndicies(bytePos0b: Long): (Long, Long) = {
val offsetBytePosition0b = bytePos0b - headBucketBytePosition0b
val bucketIndex = offsetBytePosition0b / bucketSize
val byteIndex = offsetBytePosition0b % bucketSize
(bucketIndex, byteIndex)
final private def getBucketIndex(bytePos0b: Long): Int = {
((bytePos0b - headBucketBytePosition0b) >>> bucketSizeExponent).toInt
}

/**
* Get the index in a bucket where bytePos0b is stored
*/
@inline
final private def getByteIndex(bytePos0b: Long): Int = {
((bytePos0b - headBucketBytePosition0b) & bucketMask).toInt
}

/**
Expand All @@ -392,7 +407,8 @@ class BucketingInputSource(
if (finalBytePosition0b <= totalBytesBucketed) {
true
} else {
val (bucketIndex, byteIndex) = bytePositionToIndicies(finalBytePosition0b)
val bucketIndex = getBucketIndex(finalBytePosition0b)
val byteIndex = getByteIndex(finalBytePosition0b)
Assert.invariant(bucketIndex >= oldestBucketIndex)
val filled = fillBucketsToIndex(bucketIndex, byteIndex)
filled
Expand All @@ -405,7 +421,8 @@ class BucketingInputSource(
*/
def knownBytesAvailable(): Long = {
var available = 0L
val (curBucketIndex, curByteIndex) = bytePositionToIndicies(curBytePosition0b)
val curBucketIndex = getBucketIndex(curBytePosition0b)
val curByteIndex = getByteIndex(curBytePosition0b)

var i = curBucketIndex
while (i < buckets.length) {
Expand All @@ -429,12 +446,13 @@ class BucketingInputSource(
if (!hasByte) {
-1
} else {
val (bucketIndex, byteIndex) = bytePositionToIndicies(curBytePosition0b)
val bucketIndex = getBucketIndex(curBytePosition0b)
val byteIndex = getByteIndex(curBytePosition0b)

if ((bucketIndex < 0) || (buckets(bucketIndex.toInt) == null))
if ((bucketIndex < 0) || (buckets(bucketIndex) == null))
throw new BacktrackingException(curBytePosition0b, maxCacheSizeInBytes)

val byte = buckets(bucketIndex.toInt).bytes(byteIndex.toInt)
val byte = buckets(bucketIndex).bytes(byteIndex)
curBytePosition0b += 1
byte & 0xff
}
Expand All @@ -452,19 +470,20 @@ class BucketingInputSource(
if (!hasBytes) {
false
} else {
var (bucketIndex, byteIndex) = bytePositionToIndicies(curBytePosition0b)
var bucketIndex = getBucketIndex(curBytePosition0b)
var byteIndex = getByteIndex(curBytePosition0b)
var bytesStillToGet = len
var destOffset = off
while (bytesStillToGet > 0) {
val bytesToGetFromCurrentBucket =
Math.min(bucketSize - byteIndex, bytesStillToGet).toInt
Math.min(bucketSize - byteIndex, bytesStillToGet)

if ((bucketIndex < 0) || (buckets(bucketIndex.toInt) == null))
if ((bucketIndex < 0) || (buckets(bucketIndex) == null))
throw new BacktrackingException(curBytePosition0b, maxCacheSizeInBytes)

Array.copy(
buckets(bucketIndex.toInt).bytes,
byteIndex.toInt,
buckets(bucketIndex).bytes,
byteIndex,
dest,
destOffset,
bytesToGetFromCurrentBucket
Expand All @@ -484,27 +503,27 @@ class BucketingInputSource(
def position(): Long = curBytePosition0b

def position(bytePos0b: Long): Unit = {
val (bucketIndex, _) = bytePositionToIndicies(bytePos0b)
val bucketIndex = getBucketIndex(bytePos0b)
Assert.invariant(bucketIndex < buckets.length)
curBytePosition0b = bytePos0b
}

def lockPosition(bytePos0b: Long): Unit = {
val (bucketIndex, _) = bytePositionToIndicies(bytePos0b)
val bucketIndex = getBucketIndex(bytePos0b)
Assert.invariant(bucketIndex < buckets.length)
if (buckets(bucketIndex.toInt) != null)
buckets(bucketIndex.toInt).refCount += 1
if (buckets(bucketIndex) != null)
buckets(bucketIndex).refCount += 1
}

def releasePosition(bytePos0b: Long): Unit = {
val (bucketIndex, _) = bytePositionToIndicies(bytePos0b)
val bucketIndex = getBucketIndex(bytePos0b)

if (buckets(bucketIndex.toInt) != null) {
if (buckets(bucketIndex) != null) {
Assert.invariant(bucketIndex >= oldestBucketIndex && bucketIndex < buckets.length)
buckets(bucketIndex.toInt).refCount -= 1
buckets(bucketIndex).refCount -= 1
}

if (buckets(oldestBucketIndex.toInt).refCount == 0) {
if (buckets(oldestBucketIndex).refCount == 0) {
// We just freed the last reference to the oldest bucket (or the oldest
// bucket happened to have no references). So try to release as many
// buckets as possible. Note that this might still not release anything
Expand All @@ -521,7 +540,8 @@ class BucketingInputSource(
// refCount is zero), set them to null so they are garbage collected. Make
// sure not to remove whatever bucket holds the current byte position,
// even if there are no marks--we need to still read from that bucket.
val (curBucketIndex, _) = bytePositionToIndicies(curBytePosition0b)
val curBucketIndex = getBucketIndex(curBytePosition0b)

while (oldestBucketIndex < curBucketIndex && buckets(oldestBucketIndex).refCount == 0) {
buckets(oldestBucketIndex) = null
oldestBucketIndex += 1
Expand All @@ -539,7 +559,7 @@ class BucketingInputSource(
*/
def compact(): Unit = {
releaseBuckets()
buckets.remove(0, oldestBucketIndex.toInt)
buckets.remove(0, oldestBucketIndex)
val bytesRemoved = oldestBucketIndex * bucketSize
headBucketBytePosition0b += bytesRemoved
oldestBucketIndex = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class TestBucketingInputSource {

@Test def testBucketingInputSource1(): Unit = {
val tis = new TestInputStream
val bis = new BucketingInputSource(tis, 17)
val bis = new BucketingInputSource(tis, 4)
var i = 0
while (i < 100) {
assertEquals(i, bis.position())
Expand All @@ -145,7 +145,7 @@ class TestBucketingInputSource {

@Test def testBucketingInputSource2(): Unit = {
val tis = new TestInputStream
val bis = new BucketingInputSource(tis, 7)
val bis = new BucketingInputSource(tis, 3)
val b = new Array[Byte](10)
assertEquals(true, bis.get(b, 0, 10))
var i = 0
Expand All @@ -164,7 +164,7 @@ class TestBucketingInputSource {

@Test def testBucketingInputSource3(): Unit = {
val tis = new TestInputStream
val bis = new BucketingInputSource(tis, 7)
val bis = new BucketingInputSource(tis, 3)
val b = new Array[Byte](10)
assertEquals(0, bis.get)
assertEquals(1, bis.get)
Expand All @@ -181,7 +181,7 @@ class TestBucketingInputSource {

@Test def testBucketingInputSource4(): Unit = {
val tis = new TestInputStream
val bis = new BucketingInputSource(tis, 7)
val bis = new BucketingInputSource(tis, 3)
tis.setEOF(4)
assertEquals(0, bis.get)
assertEquals(1, bis.get)
Expand All @@ -192,7 +192,7 @@ class TestBucketingInputSource {

@Test def testBucketingInputSource5(): Unit = {
val tis = new TestInputStream
val bis = new BucketingInputSource(tis, 7)
val bis = new BucketingInputSource(tis, 3)
val b = new Array[Byte](10)
tis.setEOF(4)
assertEquals(false, bis.get(b, 0, 10))
Expand All @@ -207,7 +207,7 @@ class TestBucketingInputSource {

@Test def testBucketingInputSource6(): Unit = {
val tis = new TestInputStream
val bis = new BucketingInputSource(tis, 7)
val bis = new BucketingInputSource(tis, 3)
tis.setEOF(17)
var i = 0
while (i < 8) {
Expand All @@ -226,7 +226,7 @@ class TestBucketingInputSource {

@Test def testBucketingInputSource7(): Unit = {
val tis = new TestInputStream
val bis = new BucketingInputSource(tis, 7)
val bis = new BucketingInputSource(tis, 3)
tis.setEOF(17)
var i = 0
while (i < 2) {
Expand Down

0 comments on commit 8735ed1

Please sign in to comment.