-
Notifications
You must be signed in to change notification settings - Fork 372
Refactor WarpExchangeShfl
#8183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
bernhardmgruber
wants to merge
2
commits into
NVIDIA:main
Choose a base branch
from
bernhardmgruber:ref_WarpExchangeShfl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -34,212 +34,169 @@ class WarpExchangeShfl | |||||||
|
|
||||||||
| static constexpr bool IS_ARCH_WARP = LOGICAL_WARP_THREADS == warp_threads; | ||||||||
|
|
||||||||
| // concrete recursion class | ||||||||
| template <typename OutputT, int IDX, int SIZE> | ||||||||
| class CompileTimeArray : protected CompileTimeArray<OutputT, IDX + 1, SIZE> | ||||||||
| template <int NUM_ENTRIES, int IDX> | ||||||||
| _CCCL_DEVICE _CCCL_FORCEINLINE void | ||||||||
| transpose_foreach(InputT (&vals)[ITEMS_PER_THREAD], const bool xor_bit_set, const unsigned mask) | ||||||||
| { | ||||||||
| protected: | ||||||||
| InputT val; | ||||||||
|
|
||||||||
| template <int NUM_ENTRIES> | ||||||||
| _CCCL_DEVICE void Foreach(const bool xor_bit_set, const unsigned mask) | ||||||||
| { | ||||||||
| // The implementation here is a recursive divide-and-conquer approach | ||||||||
| // that takes inspiration from: | ||||||||
| // https://forums.developer.nvidia.com/t/transposing-register-held-matrices-with-warp-shuffles-need-help/38652/2 | ||||||||
| // | ||||||||
| // At its core, the problem can be boiled down to transposing the matrix | ||||||||
| // | ||||||||
| // A B | ||||||||
| // C D | ||||||||
| // | ||||||||
| // by swapping the off-diagonal elements/sub-matrices B and C recursively. | ||||||||
| // | ||||||||
| // This implementation requires power-of-two matrices. In order to avoid | ||||||||
| // the use of local or shared memory, all index computation has to occur | ||||||||
| // at compile-time, since registers cannot be indexed dynamically. | ||||||||
| // Furthermore, using recursive templates reduces the mental load on the | ||||||||
| // optimizer, since lowering for-loops into registers oftentimes requires | ||||||||
| // finagling them with #pragma unroll, which leads to brittle code. | ||||||||
| // | ||||||||
| // To illustrate this algorithm, let's pretend we have warpSize = 8, | ||||||||
| // where t0, ..., t7 denote the 8 threads, and thread i has an array of | ||||||||
| // size 8 with data = [Ai, Bi, ..., Hi] (the columns in the schematics). | ||||||||
| // | ||||||||
| // In the first round, we exchange the largest 4x4 off-diagonal | ||||||||
| // submatrix. Boxes illustrate the submatrices to be exchanged. | ||||||||
| // | ||||||||
| // ROUND 1 | ||||||||
| // ======= | ||||||||
| // t0 t1 t2 t3 t4 t5 t6 t7 | ||||||||
| // ┌──────────────┐ | ||||||||
| // A0 A1 A2 A3 │A4 A5 A6 A7│ NUM_ENTRIES == 4 tells us how many | ||||||||
| // │ │ entries we have in a submatrix, | ||||||||
| // │ │ in this case 4 and the size of | ||||||||
| // B0 B1 B2 B3 │B4 B5 B6 B7│ the jumps between submatrices. | ||||||||
| // │ │ | ||||||||
| // │ │ 1. t[0,1,2,3] data[4] swap with t[4,5,6,7]'s data[0] | ||||||||
| // C0 C1 C2 C3 │C4 C5 C6 C7│ 2. t[0,1,2,3] data[5] swap with t[4,5,6,7]'s data[1] | ||||||||
| // │ │ 3. t[0,1,2,3] data[6] swap with t[4,5,6,7]'s data[2] | ||||||||
| // │ │ 4. t[0,1,2,3] data[7] swap with t[4,5,6,7]'s data[3] | ||||||||
| // D0 D1 D2 D3 │D4 D5 D6 D7│ | ||||||||
| // └──────────────┘ | ||||||||
| // ┌──────────────┐ | ||||||||
| // │E0 E1 E2 E3│ E4 E5 E6 E7 | ||||||||
| // │ │ | ||||||||
| // │ │ | ||||||||
| // │F0 F1 F2 F3│ F4 F5 F6 F7 | ||||||||
| // │ │ | ||||||||
| // │ │ | ||||||||
| // │G0 G1 G2 G3│ G4 G5 G6 G7 | ||||||||
| // │ │ | ||||||||
| // │ │ | ||||||||
| // │H0 H1 H2 H3│ H4 H5 H6 H7 | ||||||||
| // └──────────────┘ | ||||||||
| // | ||||||||
| // ROUND 2 | ||||||||
| // ======= | ||||||||
| // t0 t1 t2 t3 t4 t5 t6 t7 | ||||||||
| // ┌──────┐ ┌──────┐ | ||||||||
| // A0 A1 │A2 A3│ E0 E1 │E2 E3│ NUM_ENTRIES == 2 so we have 2 | ||||||||
| // │ │ │ │ submatrices per thread and there | ||||||||
| // │ │ │ │ are 2 elements between these | ||||||||
| // B0 B1 │B2 B3│ F0 F1 │F2 F3│ submatrices. | ||||||||
| // └──────┘ └──────┘ | ||||||||
| // ┌──────┐ ┌──────┐ 1. t[0,1,4,5] data[2] swap with t[2,3,6,7]'s data[0] | ||||||||
| // │C0 C1│ C2 C3 │G0 G1│ G2 G3 2. t[0,1,4,5] data[3] swap with t[2,3,6,7]'s data[1] | ||||||||
| // │ │ │ │ 3. t[0,1,4,5] data[6] swap with t[2,3,6,7]'s data[4] | ||||||||
| // │ │ │ │ 4. t[0,1,4,5] data[7] swap with t[2,3,6,7]'s data[5] | ||||||||
| // │D0 D1│ D2 D3 │H0 H1│ H2 H3 | ||||||||
| // └──────┘ └──────┘ | ||||||||
| // ┌──────┐ ┌──────┐ | ||||||||
| // A4 A5 │A6 A7│ E4 E5 │E6 E7│ | ||||||||
| // │ │ │ │ | ||||||||
| // │ │ │ │ | ||||||||
| // B4 B5 │B6 B7│ F4 F5 │F6 F7│ | ||||||||
| // └──────┘ └──────┘ | ||||||||
| // ┌──────┐ ┌──────┐ | ||||||||
| // │C4 C5│ C6 C7 │G4 G5│ G6 G7 | ||||||||
| // │ │ │ │ | ||||||||
| // │ │ │ │ | ||||||||
| // │D4 D5│ D6 D7 │H4 H5│ H6 H7 | ||||||||
| // └──────┘ └──────┘ | ||||||||
| // | ||||||||
| // ROUND 3 | ||||||||
| // ======= | ||||||||
| // t0 t1 t2 t3 t4 t5 t6 t7 | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // A0 │A1│ C0 │C1│ E0 │E1│ G0 │G1│ NUM_ENTRIES == 1 so we have 4 | ||||||||
| // └──┘ └──┘ └──┘ └──┘ submatrices per thread and there | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ is 1 element between these | ||||||||
| // │B0│ B1 │D0│ D1 │F0│ F1 │H0│ H1 submatrices. | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ 1. t[0,2,4,6] data[1] swap with t[1,3,5,7]'s data[0] | ||||||||
| // A2 │A3│ C2 │C3│ E2 │E3│ G2 │G3│ 2. t[0,2,4,6] data[3] swap with t[1,3,5,7]'s data[2] | ||||||||
| // └──┘ └──┘ └──┘ └──┘ 3. t[0,2,4,6] data[5] swap with t[1,3,5,7]'s data[4] | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ 4. t[0,2,4,6] data[7] swap with t[1,3,5,7]'s data[6] | ||||||||
| // │B2│ B3 │D2│ D3 │F2│ F3 │H2│ H3 | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // A4 │A5│ C4 │C5│ E4 │E5│ G4 │G5│ | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // │B4│ B5 │D4│ D5 │F4│ F5 │H4│ H5 | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // A6 │A7│ C6 │C7│ E6 │E7│ G6 │G7│ | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // │B6│ B7 │D6│ D7 │F6│ F7 │H6│ H7 | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // | ||||||||
| // RESULT | ||||||||
| // ====== | ||||||||
| // t0 t1 t2 t3 t4 t5 t6 t7 | ||||||||
| // | ||||||||
| // A0 B0 C0 D0 E0 F0 G0 H0 | ||||||||
| // | ||||||||
| // | ||||||||
| // A1 B1 C1 D1 E1 F1 G1 H1 | ||||||||
| // | ||||||||
| // | ||||||||
| // A2 B2 C2 D2 E2 F2 G2 H2 | ||||||||
| // | ||||||||
| // | ||||||||
| // A3 B3 C3 D3 E3 F3 G3 H3 | ||||||||
| // | ||||||||
| // | ||||||||
| // A4 B4 C4 D4 E4 F4 G4 H4 | ||||||||
| // | ||||||||
| // | ||||||||
| // A5 B5 C5 D5 E5 F5 G5 H5 | ||||||||
| // | ||||||||
| // | ||||||||
| // A6 B6 C6 D6 E6 F6 G6 H6 | ||||||||
| // | ||||||||
| // | ||||||||
| // A7 B7 C7 D7 E7 F7 G7 H7 | ||||||||
| // | ||||||||
| // The implementation here is a recursive divide-and-conquer approach | ||||||||
| // that takes inspiration from: | ||||||||
| // https://forums.developer.nvidia.com/t/transposing-register-held-matrices-with-warp-shuffles-need-help/38652/2 | ||||||||
| // | ||||||||
| // At its core, the problem can be boiled down to transposing the matrix | ||||||||
| // | ||||||||
| // A B | ||||||||
| // C D | ||||||||
| // | ||||||||
| // by swapping the off-diagonal elements/sub-matrices B and C recursively. | ||||||||
| // | ||||||||
| // This implementation requires power-of-two matrices. In order to avoid | ||||||||
| // the use of local or shared memory, all index computation has to occur | ||||||||
| // at compile-time, since registers cannot be indexed dynamically. | ||||||||
| // Furthermore, using recursive templates reduces the mental load on the | ||||||||
| // optimizer, since lowering for-loops into registers oftentimes requires | ||||||||
| // finagling them with #pragma unroll, which leads to brittle code. | ||||||||
| // | ||||||||
| // To illustrate this algorithm, let's pretend we have warpSize = 8, | ||||||||
| // where t0, ..., t7 denote the 8 threads, and thread i has an array of | ||||||||
| // size 8 with data = [Ai, Bi, ..., Hi] (the columns in the schematics). | ||||||||
| // | ||||||||
| // In the first round, we exchange the largest 4x4 off-diagonal | ||||||||
| // submatrix. Boxes illustrate the submatrices to be exchanged. | ||||||||
| // | ||||||||
| // ROUND 1 | ||||||||
| // ======= | ||||||||
| // t0 t1 t2 t3 t4 t5 t6 t7 | ||||||||
| // ┌──────────────┐ | ||||||||
| // A0 A1 A2 A3 │A4 A5 A6 A7│ NUM_ENTRIES == 4 tells us how many | ||||||||
| // │ │ entries we have in a submatrix, | ||||||||
| // │ │ in this case 4 and the size of | ||||||||
| // B0 B1 B2 B3 │B4 B5 B6 B7│ the jumps between submatrices. | ||||||||
| // │ │ | ||||||||
| // │ │ 1. t[0,1,2,3] data[4] swap with t[4,5,6,7]'s data[0] | ||||||||
| // C0 C1 C2 C3 │C4 C5 C6 C7│ 2. t[0,1,2,3] data[5] swap with t[4,5,6,7]'s data[1] | ||||||||
| // │ │ 3. t[0,1,2,3] data[6] swap with t[4,5,6,7]'s data[2] | ||||||||
| // │ │ 4. t[0,1,2,3] data[7] swap with t[4,5,6,7]'s data[3] | ||||||||
| // D0 D1 D2 D3 │D4 D5 D6 D7│ | ||||||||
| // └──────────────┘ | ||||||||
| // ┌──────────────┐ | ||||||||
| // │E0 E1 E2 E3│ E4 E5 E6 E7 | ||||||||
| // │ │ | ||||||||
| // │ │ | ||||||||
| // │F0 F1 F2 F3│ F4 F5 F6 F7 | ||||||||
| // │ │ | ||||||||
| // │ │ | ||||||||
| // │G0 G1 G2 G3│ G4 G5 G6 G7 | ||||||||
| // │ │ | ||||||||
| // │ │ | ||||||||
| // │H0 H1 H2 H3│ H4 H5 H6 H7 | ||||||||
| // └──────────────┘ | ||||||||
| // | ||||||||
| // ROUND 2 | ||||||||
| // ======= | ||||||||
| // t0 t1 t2 t3 t4 t5 t6 t7 | ||||||||
| // ┌──────┐ ┌──────┐ | ||||||||
| // A0 A1 │A2 A3│ E0 E1 │E2 E3│ NUM_ENTRIES == 2 so we have 2 | ||||||||
| // │ │ │ │ submatrices per thread and there | ||||||||
| // │ │ │ │ are 2 elements between these | ||||||||
| // B0 B1 │B2 B3│ F0 F1 │F2 F3│ submatrices. | ||||||||
| // └──────┘ └──────┘ | ||||||||
| // ┌──────┐ ┌──────┐ 1. t[0,1,4,5] data[2] swap with t[2,3,6,7]'s data[0] | ||||||||
| // │C0 C1│ C2 C3 │G0 G1│ G2 G3 2. t[0,1,4,5] data[3] swap with t[2,3,6,7]'s data[1] | ||||||||
| // │ │ │ │ 3. t[0,1,4,5] data[6] swap with t[2,3,6,7]'s data[4] | ||||||||
| // │ │ │ │ 4. t[0,1,4,5] data[7] swap with t[2,3,6,7]'s data[5] | ||||||||
| // │D0 D1│ D2 D3 │H0 H1│ H2 H3 | ||||||||
| // └──────┘ └──────┘ | ||||||||
| // ┌──────┐ ┌──────┐ | ||||||||
| // A4 A5 │A6 A7│ E4 E5 │E6 E7│ | ||||||||
| // │ │ │ │ | ||||||||
| // │ │ │ │ | ||||||||
| // B4 B5 │B6 B7│ F4 F5 │F6 F7│ | ||||||||
| // └──────┘ └──────┘ | ||||||||
| // ┌──────┐ ┌──────┐ | ||||||||
| // │C4 C5│ C6 C7 │G4 G5│ G6 G7 | ||||||||
| // │ │ │ │ | ||||||||
| // │ │ │ │ | ||||||||
| // │D4 D5│ D6 D7 │H4 H5│ H6 H7 | ||||||||
| // └──────┘ └──────┘ | ||||||||
| // | ||||||||
| // ROUND 3 | ||||||||
| // ======= | ||||||||
| // t0 t1 t2 t3 t4 t5 t6 t7 | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // A0 │A1│ C0 │C1│ E0 │E1│ G0 │G1│ NUM_ENTRIES == 1 so we have 4 | ||||||||
| // └──┘ └──┘ └──┘ └──┘ submatrices per thread and there | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ is 1 element between these | ||||||||
| // │B0│ B1 │D0│ D1 │F0│ F1 │H0│ H1 submatrices. | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ 1. t[0,2,4,6] data[1] swap with t[1,3,5,7]'s data[0] | ||||||||
| // A2 │A3│ C2 │C3│ E2 │E3│ G2 │G3│ 2. t[0,2,4,6] data[3] swap with t[1,3,5,7]'s data[2] | ||||||||
| // └──┘ └──┘ └──┘ └──┘ 3. t[0,2,4,6] data[5] swap with t[1,3,5,7]'s data[4] | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ 4. t[0,2,4,6] data[7] swap with t[1,3,5,7]'s data[6] | ||||||||
| // │B2│ B3 │D2│ D3 │F2│ F3 │H2│ H3 | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // A4 │A5│ C4 │C5│ E4 │E5│ G4 │G5│ | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // │B4│ B5 │D4│ D5 │F4│ F5 │H4│ H5 | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // A6 │A7│ C6 │C7│ E6 │E7│ G6 │G7│ | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // ┌──┐ ┌──┐ ┌──┐ ┌──┐ | ||||||||
| // │B6│ B7 │D6│ D7 │F6│ F7 │H6│ H7 | ||||||||
| // └──┘ └──┘ └──┘ └──┘ | ||||||||
| // | ||||||||
| // RESULT | ||||||||
| // ====== | ||||||||
| // t0 t1 t2 t3 t4 t5 t6 t7 | ||||||||
| // | ||||||||
| // A0 B0 C0 D0 E0 F0 G0 H0 | ||||||||
| // | ||||||||
| // | ||||||||
| // A1 B1 C1 D1 E1 F1 G1 H1 | ||||||||
| // | ||||||||
| // | ||||||||
| // A2 B2 C2 D2 E2 F2 G2 H2 | ||||||||
| // | ||||||||
| // | ||||||||
| // A3 B3 C3 D3 E3 F3 G3 H3 | ||||||||
| // | ||||||||
| // | ||||||||
| // A4 B4 C4 D4 E4 F4 G4 H4 | ||||||||
| // | ||||||||
| // | ||||||||
| // A5 B5 C5 D5 E5 F5 G5 H5 | ||||||||
| // | ||||||||
| // | ||||||||
| // A6 B6 C6 D6 E6 F6 G6 H6 | ||||||||
| // | ||||||||
| // | ||||||||
| // A7 B7 C7 D7 E7 F7 G7 H7 | ||||||||
| // | ||||||||
|
|
||||||||
| // NOTE: Do *NOT* try to refactor this code to use a reference, since nvcc | ||||||||
| // tends to choke on it and then drop everything into local memory. | ||||||||
| const InputT send_val = (xor_bit_set ? CompileTimeArray<OutputT, IDX, SIZE>::val | ||||||||
| : CompileTimeArray<OutputT, IDX + NUM_ENTRIES, SIZE>::val); | ||||||||
| const InputT recv_val = __shfl_xor_sync(mask, send_val, NUM_ENTRIES, LOGICAL_WARP_THREADS); | ||||||||
| (xor_bit_set ? CompileTimeArray<OutputT, IDX, SIZE>::val | ||||||||
| : CompileTimeArray<OutputT, IDX + NUM_ENTRIES, SIZE>::val) = recv_val; | ||||||||
| InputT& v = xor_bit_set ? vals[IDX] : vals[IDX + NUM_ENTRIES]; | ||||||||
| v = __shfl_xor_sync(mask, v, NUM_ENTRIES, LOGICAL_WARP_THREADS); | ||||||||
|
|
||||||||
| constexpr int next_idx = IDX + 1 + ((IDX + 1) % NUM_ENTRIES == 0) * NUM_ENTRIES; | ||||||||
| CompileTimeArray<OutputT, next_idx, SIZE>::template Foreach<NUM_ENTRIES>(xor_bit_set, mask); | ||||||||
| } | ||||||||
|
|
||||||||
| template <int NUM_ENTRIES> | ||||||||
| _CCCL_DEVICE void TransposeImpl(const unsigned int lane_id, const unsigned int mask, constant_t<NUM_ENTRIES>) | ||||||||
| constexpr int next_idx = IDX + 1 + ((IDX + 1) % NUM_ENTRIES == 0) * NUM_ENTRIES; | ||||||||
| if constexpr (next_idx < NUM_ENTRIES) | ||||||||
| { | ||||||||
| if constexpr (NUM_ENTRIES != 0) | ||||||||
| { | ||||||||
| const bool xor_bit_set = lane_id & NUM_ENTRIES; | ||||||||
| Foreach<NUM_ENTRIES>(xor_bit_set, mask); | ||||||||
|
|
||||||||
| TransposeImpl(lane_id, mask, constant_v<NUM_ENTRIES / 2>); | ||||||||
| } | ||||||||
| transpose_foreach<NUM_ENTRIES, next_idx>(vals, xor_bit_set, mask); | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| public: | ||||||||
| _CCCL_DEVICE | ||||||||
| CompileTimeArray(const InputT (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD]) | ||||||||
| : CompileTimeArray<OutputT, IDX + 1, SIZE>{input_items, output_items} | ||||||||
| , val{input_items[IDX]} | ||||||||
| {} | ||||||||
|
|
||||||||
| _CCCL_DEVICE ~CompileTimeArray() | ||||||||
| template <int NUM_ENTRIES> | ||||||||
| _CCCL_DEVICE _CCCL_FORCEINLINE void | ||||||||
| transpose(InputT (&vals)[ITEMS_PER_THREAD], const unsigned int lane_id, const unsigned int mask) | ||||||||
| { | ||||||||
| if constexpr (NUM_ENTRIES != 0) | ||||||||
| { | ||||||||
| this->output_items[IDX] = val; | ||||||||
| } | ||||||||
| const bool xor_bit_set = lane_id & NUM_ENTRIES; | ||||||||
| transpose_foreach<NUM_ENTRIES, 0>(vals, xor_bit_set, mask); | ||||||||
|
|
||||||||
| _CCCL_DEVICE void Transpose(const unsigned int lane_id, const unsigned int mask) | ||||||||
| { | ||||||||
| TransposeImpl(lane_id, mask, constant_v<ITEMS_PER_THREAD / 2>); | ||||||||
| transpose<NUM_ENTRIES / 2>(vals, lane_id, mask); | ||||||||
| } | ||||||||
| }; | ||||||||
|
|
||||||||
| // terminating partial specialization | ||||||||
| template <typename OutputT, int SIZE> | ||||||||
| class CompileTimeArray<OutputT, SIZE, SIZE> | ||||||||
| { | ||||||||
| protected: | ||||||||
| // used for dumping back the individual values after transposing | ||||||||
| InputT (&output_items)[ITEMS_PER_THREAD]; | ||||||||
|
|
||||||||
| template <int> | ||||||||
| _CCCL_DEVICE void Foreach(bool, unsigned) | ||||||||
| {} | ||||||||
|
|
||||||||
| public: | ||||||||
| _CCCL_DEVICE CompileTimeArray(const InputT (&)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD]) | ||||||||
| : output_items{output_items} | ||||||||
| {} | ||||||||
| }; | ||||||||
| } | ||||||||
|
|
||||||||
| const unsigned int lane_id; | ||||||||
| const unsigned int warp_id; | ||||||||
|
|
@@ -260,8 +217,21 @@ public: | |||||||
| _CCCL_DEVICE _CCCL_FORCEINLINE void | ||||||||
| BlockedToStriped(const InputT (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD]) | ||||||||
| { | ||||||||
| CompileTimeArray<OutputT, 0, ITEMS_PER_THREAD> arr{input_items, output_items}; | ||||||||
| arr.Transpose(lane_id, member_mask); | ||||||||
| InputT vals[ITEMS_PER_THREAD]; | ||||||||
|
|
||||||||
| _CCCL_PRAGMA_UNROLL_FULL() | ||||||||
| for (int i = 0; i < ITEMS_PER_THREAD; i++) | ||||||||
| { | ||||||||
| vals[i] = input_items[i]; | ||||||||
| } | ||||||||
|
|
||||||||
| transpose<ITEMS_PER_THREAD / 2>(lane_id, member_mask); | ||||||||
|
|
||||||||
| _CCCL_PRAGMA_UNROLL_FULL() | ||||||||
| for (int i = 0; i < ITEMS_PER_THREAD; i++) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| { | ||||||||
| output_items[i] = vals[i]; | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| template <typename OutputT> | ||||||||
|
|
||||||||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are commonly adding