-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize validity buffer concat. #2626
base: branch-25.02
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,50 +16,36 @@ | |
|
||
package com.nvidia.spark.rapids.jni.kudo; | ||
|
||
import static com.nvidia.spark.rapids.jni.Preconditions.ensure; | ||
import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; | ||
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; | ||
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; | ||
import static java.lang.Math.min; | ||
import static java.lang.Math.toIntExact; | ||
import static java.util.Objects.requireNonNull; | ||
|
||
import ai.rapids.cudf.HostMemoryBuffer; | ||
import ai.rapids.cudf.Schema; | ||
import com.nvidia.spark.rapids.jni.Arms; | ||
import com.nvidia.spark.rapids.jni.schema.Visitors; | ||
|
||
import java.nio.ByteBuffer; | ||
import java.nio.ByteOrder; | ||
import java.nio.IntBuffer; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.OptionalInt; | ||
|
||
import static com.nvidia.spark.rapids.jni.Preconditions.ensure; | ||
import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; | ||
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; | ||
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; | ||
import static java.lang.Math.min; | ||
import static java.lang.Math.toIntExact; | ||
import static java.util.Objects.requireNonNull; | ||
|
||
/** | ||
* This class is used to merge multiple KudoTables into a single contiguous buffer, e.g. {@link KudoHostMergeResult}, | ||
* which could be easily converted to a {@link ai.rapids.cudf.ContiguousTable}. | ||
*/ | ||
class KudoTableMerger extends MultiKudoTableVisitor<Void, Void, KudoHostMergeResult> { | ||
// Number of 1s in a byte | ||
private static final int[] NUMBER_OF_ONES = new int[256]; | ||
|
||
static { | ||
for (int i = 0; i < NUMBER_OF_ONES.length; i += 1) { | ||
int count = 0; | ||
for (int j = 0; j < 8; j += 1) { | ||
if ((i & (1 << j)) != 0) { | ||
count += 1; | ||
} | ||
} | ||
NUMBER_OF_ONES[i] = count; | ||
} | ||
} | ||
|
||
private final List<ColumnOffsetInfo> columnOffsets; | ||
private final HostMemoryBuffer buffer; | ||
private final List<ColumnViewInfo> colViewInfoList; | ||
|
||
public KudoTableMerger(List<KudoTable> tables, HostMemoryBuffer buffer, List<ColumnOffsetInfo> columnOffsets) { | ||
public KudoTableMerger(List<KudoTable> tables, HostMemoryBuffer buffer, | ||
List<ColumnOffsetInfo> columnOffsets) { | ||
super(tables); | ||
requireNonNull(buffer, "buffer can't be null!"); | ||
ensure(columnOffsets != null, "column offsets cannot be null"); | ||
|
@@ -155,80 +141,64 @@ private static int copyValidityBuffer(HostMemoryBuffer dest, int startBit, | |
HostMemoryBuffer src, int srcOffset, | ||
SliceInfo sliceInfo) { | ||
int nullCount = 0; | ||
int totalRowCount = sliceInfo.getRowCount(); | ||
int curIdx = 0; | ||
int curSrcByteIdx = srcOffset; | ||
int curSrcBitIdx = sliceInfo.getValidityBufferInfo().getBeginBit(); | ||
int curDestByteIdx = startBit / 8; | ||
int curDestBitIdx = startBit % 8; | ||
|
||
while (curIdx < totalRowCount) { | ||
int leftRowCount = totalRowCount - curIdx; | ||
int appendCount; | ||
if (curDestBitIdx == 0) { | ||
appendCount = min(8, leftRowCount); | ||
} else { | ||
appendCount = min(8 - curDestBitIdx, leftRowCount); | ||
} | ||
|
||
int leftBitsInCurSrcByte = 8 - curSrcBitIdx; | ||
byte srcByte = src.getByte(curSrcByteIdx); | ||
if (leftBitsInCurSrcByte >= appendCount) { | ||
// Extract appendCount bits from srcByte, starting from curSrcBitIdx | ||
byte mask = (byte) (((1 << appendCount) - 1) & 0xFF); | ||
srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); | ||
int totalRowCount = toIntExact(sliceInfo.getRowCount() + sliceInfo.getValidityBufferInfo().getBeginBit()); | ||
int curSrcIdx = sliceInfo.getValidityBufferInfo().getBeginBit(); | ||
int curDestIdx = startBit; | ||
|
||
nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]); | ||
|
||
// Sets the bits in destination buffer starting from curDestBitIdx to 0 | ||
byte destByte = dest.getByte(curDestByteIdx); | ||
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1) & 0xFF); | ||
while (curSrcIdx < totalRowCount) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So does this concatenate the validity buffers, one word (32 bits) at a time, in a serial manner, in Java? Can't we use C++/CUDA for accelerating this? |
||
int leftRowCount = totalRowCount - curSrcIdx; | ||
|
||
// Update destination byte with the bits from source byte | ||
destByte = (byte) ((destByte | (srcByte << curDestBitIdx)) & 0xFF); | ||
dest.setByte(curDestByteIdx, destByte); | ||
int curDestOffset = (curDestIdx / 32) * Integer.BYTES; | ||
int curDestBitIdx = curDestIdx % 32; | ||
|
||
curSrcBitIdx += appendCount; | ||
if (curSrcBitIdx == 8) { | ||
curSrcBitIdx = 0; | ||
curSrcByteIdx += 1; | ||
} | ||
} else { | ||
// Extract appendCount bits from srcByte, starting from curSrcBitIdx | ||
byte mask = (byte) (((1 << leftBitsInCurSrcByte) - 1) & 0xFF); | ||
srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); | ||
|
||
byte nextSrcByte = src.getByte(curSrcByteIdx + 1); | ||
byte nextSrcByteMask = (byte) ((1 << (appendCount - leftBitsInCurSrcByte)) - 1); | ||
nextSrcByte = (byte) (nextSrcByte & nextSrcByteMask); | ||
nextSrcByte = (byte) (nextSrcByte << leftBitsInCurSrcByte); | ||
srcByte = (byte) (srcByte | nextSrcByte); | ||
int curSrcOffset = srcOffset + (curSrcIdx / 32) * Integer.BYTES; | ||
int curSrcBitIdx = curSrcIdx % 32; | ||
|
||
nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]); | ||
// This is safe since we always have validity buffer 4 bytes padded | ||
int srcInt = src.getInt(curSrcOffset); | ||
srcInt = srcInt >>> curSrcBitIdx; | ||
|
||
// Sets the bits in destination buffer starting from curDestBitIdx to 0 | ||
byte destByte = dest.getByte(curDestByteIdx); | ||
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1)); | ||
if (dest.getLength() >= (curDestOffset + Integer.BYTES)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conditionals should not be in the body of the while loop. The while loop should be as simple as possible, since that's expected to be the hotspot. IMO the code should be structured into three parts similar to the following:
|
||
// We have enough room to get an int | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we always have enough space to get an integer from the destination buffer because it's at least 4-byte padded? |
||
int destInt = dest.getInt(curDestOffset); | ||
destInt &= (1 << curDestBitIdx) - 1; | ||
destInt |= srcInt << curDestBitIdx; | ||
dest.setInt(curDestOffset, destInt); | ||
|
||
// Update destination byte with the bits from source byte | ||
destByte = (byte) (destByte | (srcByte << curDestBitIdx)); | ||
dest.setByte(curDestByteIdx, destByte); | ||
|
||
// Update the source byte index and bit index | ||
curSrcByteIdx += 1; | ||
curSrcBitIdx = appendCount - leftBitsInCurSrcByte; | ||
} | ||
int appendCount = min(leftRowCount, 32 - Math.max(curSrcBitIdx, curDestBitIdx)); | ||
|
||
curIdx += appendCount; | ||
|
||
// Update the destination byte index and bit index | ||
curDestBitIdx += appendCount; | ||
if (curDestBitIdx == 8) { | ||
curDestBitIdx = 0; | ||
curDestByteIdx += 1; | ||
curDestIdx += appendCount; | ||
curSrcIdx += appendCount; | ||
if (appendCount == 32) { | ||
nullCount += 32 - Integer.bitCount(srcInt); | ||
} else { | ||
int mask = (1 << appendCount) - 1; | ||
nullCount += (appendCount - Integer.bitCount(srcInt & mask)); | ||
} | ||
} else { | ||
int destBufRemBytes = toIntExact(dest.getLength() - curDestOffset); | ||
byte[] destBytes = new byte[4]; | ||
dest.getBytes(destBytes, 0, curDestOffset, destBufRemBytes); | ||
int destInt = ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).getInt(); | ||
Comment on lines
+181
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious why we're doing byte-at-time here and endian stuff when we don't up above? This is still grabbing 4 bytes like the above code. Byte-at-a-time is usually a lot slower, might be faster to read the int from the buffer (we're loading 4 bytes anyway) and call Integer.reverseBytes (a HotSpot intrinsic candidate) if ByteOrder.nativeOrder == ByteOrder.LITTLE_ENDIAN. |
||
destInt &= (1 << curDestBitIdx) - 1; | ||
destInt |= srcInt << curDestBitIdx; | ||
|
||
ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).putInt(destInt); | ||
dest.setBytes(curDestOffset, destBytes, 0, destBufRemBytes); | ||
Comment on lines
+187
to
+188
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to above, can call setInt here and leverage ByteOrder to determine when we need to swap bytes or not. |
||
|
||
int appendCount = min(leftRowCount, destBufRemBytes * 8 - Math.max(curSrcBitIdx, curDestBitIdx)); | ||
|
||
curDestIdx += appendCount; | ||
curSrcIdx += appendCount; | ||
int mask = (1 << appendCount) - 1; | ||
nullCount += (appendCount - Integer.bitCount(srcInt & mask)); | ||
} | ||
} | ||
|
||
int srcIdx = curSrcIdx; | ||
ensure(curSrcIdx == totalRowCount, () -> "Did not copy all of the validity buffer, total row count: " + totalRowCount + | ||
" current src idx: " + srcIdx); | ||
return nullCount; | ||
} | ||
|
||
|
@@ -325,7 +295,8 @@ static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) { | |
List<KudoTable> serializedTables = mergedInfo.getTables(); | ||
return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()), | ||
buffer -> { | ||
KudoTableMerger merger = new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets()); | ||
KudoTableMerger merger = | ||
new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets()); | ||
return Visitors.visitSchema(schema, merger); | ||
}); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name of this variable makes reading this confusing. It's not a total row count as the name implies. It's the end index or ending row, IIUC.