Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion vortex-cuda/kernels/src/dynamic_dispatch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

#include "bit_unpack.cuh"
#include "dynamic_dispatch.h"
#include "patches.cuh"
#include "types.cuh"

// ═══════════════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -183,6 +184,79 @@ __device__ inline void bitunpack(const T *__restrict packed,
}
}

// ═══════════════════════════════════════════════════════════════════════════
// Source-patch helpers
// ═══════════════════════════════════════════════════════════════════════════

/// Reconstruct a GPUPatches struct from a packed device pointer.
/// The packed layout is: [lane_offsets (uint32 × lo_count)] [indices (uint16 × num_patches)]
/// [values (T × num_patches, aligned)]. Returns null pointers when patches_ptr == 0.
template <typename T>
__device__ inline GPUPatches
unpack_source_patches(uint64_t patches_ptr, uint32_t stage_len, uint32_t element_offset) {
static_assert((sizeof(T) & (sizeof(T) - 1)) == 0, "sizeof(T) must be a power of two");
uint8_t *base = reinterpret_cast<uint8_t *>(patches_ptr);
constexpr uint32_t FL_CHUNK = 1024;
constexpr uint32_t N_LANES = (sizeof(T) < 8) ? 32 : 16;
const uint32_t n_chunks = (stage_len + (element_offset % FL_CHUNK) + FL_CHUNK - 1) / FL_CHUNK;
const uint32_t lo_count = n_chunks * N_LANES + 1;
uint32_t *lane_offsets = reinterpret_cast<uint32_t *>(base);
const uint32_t num_patches = lane_offsets[lo_count - 1];
const uint32_t indices_byte_start = lo_count * sizeof(uint32_t);
uint16_t *indices = reinterpret_cast<uint16_t *>(base + indices_byte_start);
uint32_t values_byte_start = indices_byte_start + num_patches * sizeof(uint16_t);
values_byte_start = (values_byte_start + sizeof(T) - 1) & ~(sizeof(T) - 1);
void *values = base + values_byte_start;
return {lane_offsets, indices, values};
}

/// Apply source patches for a single FL chunk (used in the output stage).
/// Overwrites patched positions in `scratch` and issues __syncthreads().
template <typename T>
__device__ inline void apply_source_patches_chunk(uint64_t patches_ptr,
T *__restrict scratch,
uint32_t stage_len,
uint32_t element_offset,
uint32_t fl_chunk) {
const GPUPatches patches = unpack_source_patches<T>(patches_ptr, stage_len, element_offset);
constexpr uint32_t N_LANES = (sizeof(T) < 8) ? 32 : 16;
for (uint32_t lane = threadIdx.x; lane < N_LANES; lane += blockDim.x) {
PatchesCursor<T> cursor(patches, fl_chunk, lane, N_LANES);
auto p = cursor.next();
while (p.index != 1024) {
scratch[p.index] = p.value;
p = cursor.next();
}
}
__syncthreads();
}

/// Apply source patches for all FL chunks (used in the input stage).
/// Overwrites patched positions in `smem_out` and issues __syncthreads().
template <typename T>
__device__ inline void apply_source_patches_all(uint64_t patches_ptr,
T *__restrict smem_out,
uint32_t stage_len,
uint32_t element_offset) {
const GPUPatches patches = unpack_source_patches<T>(patches_ptr, stage_len, element_offset);
constexpr uint32_t FL_CHUNK = 1024;
constexpr uint32_t N_LANES = (sizeof(T) < 8) ? 32 : 16;
const uint32_t first_chunk = element_offset / FL_CHUNK;
const uint32_t n_chunks = (stage_len + (element_offset % FL_CHUNK) + FL_CHUNK - 1) / FL_CHUNK;
for (uint32_t c = 0; c < n_chunks; ++c) {
T *chunk_base = smem_out + c * FL_CHUNK;
for (uint32_t lane = threadIdx.x; lane < N_LANES; lane += blockDim.x) {
PatchesCursor<T> cursor(patches, first_chunk + c, lane, N_LANES);
auto p = cursor.next();
while (p.index != 1024) {
chunk_base[p.index] = p.value;
p = cursor.next();
}
}
}
__syncthreads();
}

/// Read N values from a source op into `out`.
///
/// Dispatches on `src.op_code` to handle each encoding:
Expand Down Expand Up @@ -317,6 +391,17 @@ __device__ void execute_output_stage(T *__restrict output,
smem_src = scratch + align;
// Write barrier: all threads finished bitunpack, safe to read from scratch.
__syncthreads();

// Merge source patches for this FL chunk into smem scratch.
if (stage.patches_ptr != 0) {
const uint32_t fl_chunk = static_cast<uint32_t>(
(block_start + elem_idx + src.params.bitunpack.element_offset) / 1024);
apply_source_patches_chunk<T>(stage.patches_ptr,
scratch,
stage.len,
src.params.bitunpack.element_offset,
fl_chunk);
}
} else {
chunk_len = block_len;
}
Expand Down Expand Up @@ -391,12 +476,22 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
const auto &src = stage.source;

if (src.op_code == SourceOp::BITUNPACK) {
T *raw_smem = smem_out;
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src);
smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;
// Write barrier: cooperative bitunpack finished, safe to read
// decoded elements in the scalar-op loop below.
__syncthreads();

// Merge source patches into the decoded smem region.
if (stage.patches_ptr != 0) {
apply_source_patches_all<T>(stage.patches_ptr,
raw_smem,
stage.len,
src.params.bitunpack.element_offset);
}

smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;

if (stage.num_scalar_ops > 0) {
for (uint32_t i = threadIdx.x; i < stage.len; i += blockDim.x) {
T val = smem_out[i];
Expand Down
40 changes: 21 additions & 19 deletions vortex-cuda/kernels/src/dynamic_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ union SourceParams {
/// Unpack FastLanes bit-packed data.
struct BitunpackParams {
uint8_t bit_width;
uint32_t element_offset; // Sub-byte offset
uint16_t element_offset; // Sub-block offset (0..1023)
} bitunpack;

/// Copy from global to shared memory.
Expand All @@ -120,10 +120,10 @@ union SourceParams {
/// The smem offsets are byte offsets so that ends and values can have
/// different element widths.
struct RunEndParams {
uint32_t ends_smem_byte_offset; // byte offset to decoded ends in smem
uint32_t values_smem_byte_offset; // byte offset to decoded values in smem
uint64_t num_runs;
uint64_t offset; // slice offset into the run-end encoded array
uint16_t ends_smem_byte_offset; // byte offset to decoded ends in smem
uint16_t values_smem_byte_offset; // byte offset to decoded values in smem
uint32_t num_runs;
uint32_t offset; // slice offset into the run-end encoded array
} runend;

/// Generate a linear sequence: `value[i] = base + i * multiplier`.
Expand All @@ -134,8 +134,8 @@ union SourceParams {
};

struct SourceOp {
enum SourceOpCode { BITUNPACK, LOAD, RUNEND, SEQUENCE } op_code;
union SourceParams params;
enum SourceOpCode : uint8_t { BITUNPACK, LOAD, RUNEND, SEQUENCE } op_code;
};

/// Scalar ops: element-wise transforms in registers.
Expand Down Expand Up @@ -166,17 +166,17 @@ union ScalarParams {
/// `output_ptype` (on the enclosing ScalarOp) to determine the values'
/// element type.
struct DictParams {
uint32_t values_smem_byte_offset; // byte offset to decoded dict values in smem
uint16_t values_smem_byte_offset; // byte offset to decoded dict values in smem
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current max shared memory with guard is 48KB.

} dict;
};

struct ScalarOp {
enum ScalarOpCode { FOR, ZIGZAG, ALP, DICT } op_code;
union ScalarParams params;
enum ScalarOpCode : uint8_t { FOR, ZIGZAG, ALP, DICT } op_code;
/// The PType this op produces. For type-preserving ops (FOR, ZIGZAG)
/// this equals the input PType. For type-changing ops (ALP, DICT) this
/// is the new output PType.
enum PTypeTag output_ptype;
union ScalarParams params;
};

/// Packed stage header, followed by `num_scalar_ops` inline ScalarOps.
Expand All @@ -188,11 +188,11 @@ struct ScalarOp {
/// `smem_byte_offset` is a byte offset into the dynamic shared memory
/// pool so that stages with different element widths can coexist.
struct PackedStage {
uint64_t input_ptr; // global memory pointer to this stage's encoded input
uint32_t smem_byte_offset; // byte offset within dynamic shared memory for output
uint32_t len; // number of elements this stage produces

uint64_t input_ptr; // global memory pointer to this stage's encoded input
uint64_t patches_ptr; // device ptr to packed source patches (0 = none)
struct SourceOp source;
uint32_t len; // number of elements this stage produces
Copy link
Copy Markdown
Contributor Author

@0ax1 0ax1 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to expand this again to u64, but for now let's assume and guard we stay within u32. If we go past u32, we should also try micro/marco-benchmarking in that range.

uint16_t smem_byte_offset; // byte offset within dynamic shared memory for output
uint8_t num_scalar_ops;
enum PTypeTag source_ptype; // PType produced by the source op
};
Expand Down Expand Up @@ -221,12 +221,13 @@ struct __attribute__((aligned(8))) PlanHeader {
/// `output_ptype` (or `source_ptype` if there are no scalar ops).
struct Stage {
uint64_t input_ptr; // encoded input in global memory
uint32_t smem_byte_offset; // byte offset within dynamic shared memory
const struct ScalarOp *scalar_ops; // pointer into packed plan buffer
uint64_t patches_ptr; // device ptr to packed source patches (0 = none)
uint32_t len; // elements produced
uint16_t smem_byte_offset; // byte offset within dynamic shared memory
enum PTypeTag source_ptype; // PType produced by the source op
struct SourceOp source; // source decode op
uint8_t num_scalar_ops; // number of scalar ops
const struct ScalarOp *scalar_ops; // scalar decode ops
struct SourceOp source; // source decode op
};

/// Parse a single stage from the packed plan byte buffer and advance the cursor.
Expand All @@ -243,12 +244,13 @@ __device__ inline Stage parse_stage(const uint8_t *&cursor) {

return Stage {
.input_ptr = packed_stage->input_ptr,
.smem_byte_offset = packed_stage->smem_byte_offset,
.scalar_ops = ops,
.patches_ptr = packed_stage->patches_ptr,
.len = packed_stage->len,
.smem_byte_offset = packed_stage->smem_byte_offset,
.source_ptype = packed_stage->source_ptype,
.source = packed_stage->source,
.num_scalar_ops = packed_stage->num_scalar_ops,
.scalar_ops = ops,
.source = packed_stage->source,
};
}

Expand Down
Loading
Loading