Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize validity buffer concat. #2626

Open
wants to merge 1 commit into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 60 additions & 89 deletions src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,36 @@

package com.nvidia.spark.rapids.jni.kudo;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

import ai.rapids.cudf.HostMemoryBuffer;
import ai.rapids.cudf.Schema;
import com.nvidia.spark.rapids.jni.Arms;
import com.nvidia.spark.rapids.jni.schema.Visitors;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalInt;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

/**
* This class is used to merge multiple KudoTables into a single contiguous buffer, e.g. {@link KudoHostMergeResult},
* which could be easily converted to a {@link ai.rapids.cudf.ContiguousTable}.
*/
class KudoTableMerger extends MultiKudoTableVisitor<Void, Void, KudoHostMergeResult> {
// Number of 1s in a byte
private static final int[] NUMBER_OF_ONES = new int[256];

static {
for (int i = 0; i < NUMBER_OF_ONES.length; i += 1) {
int count = 0;
for (int j = 0; j < 8; j += 1) {
if ((i & (1 << j)) != 0) {
count += 1;
}
}
NUMBER_OF_ONES[i] = count;
}
}

private final List<ColumnOffsetInfo> columnOffsets;
private final HostMemoryBuffer buffer;
private final List<ColumnViewInfo> colViewInfoList;

public KudoTableMerger(List<KudoTable> tables, HostMemoryBuffer buffer, List<ColumnOffsetInfo> columnOffsets) {
public KudoTableMerger(List<KudoTable> tables, HostMemoryBuffer buffer,
List<ColumnOffsetInfo> columnOffsets) {
super(tables);
requireNonNull(buffer, "buffer can't be null!");
ensure(columnOffsets != null, "column offsets cannot be null");
Expand Down Expand Up @@ -155,80 +141,64 @@ private static int copyValidityBuffer(HostMemoryBuffer dest, int startBit,
HostMemoryBuffer src, int srcOffset,
SliceInfo sliceInfo) {
int nullCount = 0;
int totalRowCount = sliceInfo.getRowCount();
int curIdx = 0;
int curSrcByteIdx = srcOffset;
int curSrcBitIdx = sliceInfo.getValidityBufferInfo().getBeginBit();
int curDestByteIdx = startBit / 8;
int curDestBitIdx = startBit % 8;

while (curIdx < totalRowCount) {
int leftRowCount = totalRowCount - curIdx;
int appendCount;
if (curDestBitIdx == 0) {
appendCount = min(8, leftRowCount);
} else {
appendCount = min(8 - curDestBitIdx, leftRowCount);
}

int leftBitsInCurSrcByte = 8 - curSrcBitIdx;
byte srcByte = src.getByte(curSrcByteIdx);
if (leftBitsInCurSrcByte >= appendCount) {
// Extract appendCount bits from srcByte, starting from curSrcBitIdx
byte mask = (byte) (((1 << appendCount) - 1) & 0xFF);
srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask);
int totalRowCount = toIntExact(sliceInfo.getRowCount() + sliceInfo.getValidityBufferInfo().getBeginBit());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of this variable makes reading this confusing. It's not a total row count as the name implies. It's the end index or ending row, IIUC.

int curSrcIdx = sliceInfo.getValidityBufferInfo().getBeginBit();
int curDestIdx = startBit;

nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]);

// Sets the bits in destination buffer starting from curDestBitIdx to 0
byte destByte = dest.getByte(curDestByteIdx);
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1) & 0xFF);
while (curSrcIdx < totalRowCount) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does this concatenate the validity buffers, one word (32 bits) at a time, in a serial manner, in Java?

Can't we use C++/CUDA for accelerating this?

int leftRowCount = totalRowCount - curSrcIdx;

// Update destination byte with the bits from source byte
destByte = (byte) ((destByte | (srcByte << curDestBitIdx)) & 0xFF);
dest.setByte(curDestByteIdx, destByte);
int curDestOffset = (curDestIdx / 32) * Integer.BYTES;
int curDestBitIdx = curDestIdx % 32;

curSrcBitIdx += appendCount;
if (curSrcBitIdx == 8) {
curSrcBitIdx = 0;
curSrcByteIdx += 1;
}
} else {
// Extract appendCount bits from srcByte, starting from curSrcBitIdx
byte mask = (byte) (((1 << leftBitsInCurSrcByte) - 1) & 0xFF);
srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask);

byte nextSrcByte = src.getByte(curSrcByteIdx + 1);
byte nextSrcByteMask = (byte) ((1 << (appendCount - leftBitsInCurSrcByte)) - 1);
nextSrcByte = (byte) (nextSrcByte & nextSrcByteMask);
nextSrcByte = (byte) (nextSrcByte << leftBitsInCurSrcByte);
srcByte = (byte) (srcByte | nextSrcByte);
int curSrcOffset = srcOffset + (curSrcIdx / 32) * Integer.BYTES;
int curSrcBitIdx = curSrcIdx % 32;

nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]);
// This is safe since we always have validity buffer 4 bytes padded
int srcInt = src.getInt(curSrcOffset);
srcInt = srcInt >>> curSrcBitIdx;

// Sets the bits in destination buffer starting from curDestBitIdx to 0
byte destByte = dest.getByte(curDestByteIdx);
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1));
if (dest.getLength() >= (curDestOffset + Integer.BYTES)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conditionals should not be in the body of the while loop. The while loop should be as simple as possible, since that's expected to be the hotspot. IMO the code should be structured into three parts similar to the following:

if (curSrcIdx % 8 != 0) {
  // read an int from the buffer
  // mask off the unused bits
  // count the bits
  // shift and store the bits
}
while (whole_ints_left_in_buffer) {
  // read int from buffer
  // count bits
  // shift and store the bits
}
if (leftover bits) {
  // read an int from the buffer (leverage padded buffer here)
  // mask off the unused bits
  // count the bits
  // shift and store the bits
}

// We have enough room to get an int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we always have enough space to get an integer from the destination buffer because it's at least 4-byte padded?

int destInt = dest.getInt(curDestOffset);
destInt &= (1 << curDestBitIdx) - 1;
destInt |= srcInt << curDestBitIdx;
dest.setInt(curDestOffset, destInt);

// Update destination byte with the bits from source byte
destByte = (byte) (destByte | (srcByte << curDestBitIdx));
dest.setByte(curDestByteIdx, destByte);

// Update the source byte index and bit index
curSrcByteIdx += 1;
curSrcBitIdx = appendCount - leftBitsInCurSrcByte;
}
int appendCount = min(leftRowCount, 32 - Math.max(curSrcBitIdx, curDestBitIdx));

curIdx += appendCount;

// Update the destination byte index and bit index
curDestBitIdx += appendCount;
if (curDestBitIdx == 8) {
curDestBitIdx = 0;
curDestByteIdx += 1;
curDestIdx += appendCount;
curSrcIdx += appendCount;
if (appendCount == 32) {
nullCount += 32 - Integer.bitCount(srcInt);
} else {
int mask = (1 << appendCount) - 1;
nullCount += (appendCount - Integer.bitCount(srcInt & mask));
}
} else {
int destBufRemBytes = toIntExact(dest.getLength() - curDestOffset);
byte[] destBytes = new byte[4];
dest.getBytes(destBytes, 0, curDestOffset, destBufRemBytes);
int destInt = ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).getInt();
Comment on lines +181 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we're doing byte-at-time here and endian stuff when we don't up above? This is still grabbing 4 bytes like the above code. Byte-at-a-time is usually a lot slower, might be faster to read the int from the buffer (we're loading 4 bytes anyway) and call Integer.reverseBytes (a HotSpot intrinsic candidate) if ByteOrder.nativeOrder == ByteOrder.LITTLE_ENDIAN.

destInt &= (1 << curDestBitIdx) - 1;
destInt |= srcInt << curDestBitIdx;

ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).putInt(destInt);
dest.setBytes(curDestOffset, destBytes, 0, destBufRemBytes);
Comment on lines +187 to +188
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, can call setInt here and leverage ByteOrder to determine when we need to swap bytes or not.


int appendCount = min(leftRowCount, destBufRemBytes * 8 - Math.max(curSrcBitIdx, curDestBitIdx));

curDestIdx += appendCount;
curSrcIdx += appendCount;
int mask = (1 << appendCount) - 1;
nullCount += (appendCount - Integer.bitCount(srcInt & mask));
}
}

int srcIdx = curSrcIdx;
ensure(curSrcIdx == totalRowCount, () -> "Did not copy all of the validity buffer, total row count: " + totalRowCount +
" current src idx: " + srcIdx);
return nullCount;
}

Expand Down Expand Up @@ -325,7 +295,8 @@ static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) {
List<KudoTable> serializedTables = mergedInfo.getTables();
return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()),
buffer -> {
KudoTableMerger merger = new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets());
KudoTableMerger merger =
new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets());
return Visitors.visitSchema(schema, merger);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public class KudoSerializerTest {
public void testSerializeAndDeserializeTable() {
try(Table expected = buildTestTable()) {
int rowCount = toIntExact(expected.getRowCount());
for (int sliceSize = 1; sliceSize <= rowCount; sliceSize++) {
IntStream sliceSizes = IntStream.range(1, rowCount + 1);
for (int sliceSize: sliceSizes.toArray()) {
List<TableSlice> tableSlices = new ArrayList<>();
for (int startRow = 0; startRow < rowCount; startRow += sliceSize) {
tableSlices.add(new TableSlice(startRow, Math.min(sliceSize, rowCount - startRow), expected));
Expand Down