Skip to content

Commit

Permalink
Enabled 3D_RTRTRT for cases that lengths are not aligned to 64
Browse files Browse the repository at this point in the history
  • Loading branch information
feizheng10 authored Aug 12, 2020
1 parent 8503e7f commit 0f7e9ba
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 0 deletions.
9 changes: 9 additions & 0 deletions library/src/include/tree_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ enum ComputeScheme
CS_KERNEL_2D_SINGLE,

CS_3D_STRAIGHT,
CS_3D_RTRTRT,
CS_3D_RTRT,
CS_3D_RC,
CS_KERNEL_3D_STOCKHAM_BLOCK_CC,
Expand Down Expand Up @@ -259,7 +260,10 @@ class TreeNode
void build_CS_2D_RC();

// 3D node builder:
// 3D 4 node builder, R: 2D FFTs, T: transpose XY_Z, R: row FFTs, T: transpose Z_XY
void build_CS_3D_RTRT();
// 3D 6 node builder, R: row FFTs, T: transpose XY_Z, R: row FFTs, T: transpose XY_Z, R: row FFTs, T: transpose XY_Z
void build_CS_3D_RTRTRT();

// State maintained while traversing the tree.
//
Expand Down Expand Up @@ -329,6 +333,10 @@ class TreeNode
OperatingBuffer& flipIn,
OperatingBuffer& flipOut,
OperatingBuffer& obOutBuf);
void assign_buffers_CS_3D_RTRTRT(TraverseState& state,
OperatingBuffer& flipIn,
OperatingBuffer& flipOut,
OperatingBuffer& obOutBuf);

// Set placement variable and in/out array types
void TraverseTreeAssignPlacementsLogicA(rocfft_array_type rootIn, rocfft_array_type rootOut);
Expand All @@ -347,6 +355,7 @@ class TreeNode
void assign_params_CS_2D_RTRT();
void assign_params_CS_2D_RC_STRAIGHT();
void assign_params_CS_3D_RTRT();
void assign_params_CS_3D_RTRTRT();
void assign_params_CS_3D_RC_STRAIGHT();

// Determine work memory requirements:
Expand Down
166 changes: 166 additions & 0 deletions library/src/plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ std::string PrintScheme(ComputeScheme cs)
{ENUMSTR(CS_KERNEL_2D_SINGLE)},

{ENUMSTR(CS_3D_STRAIGHT)},
{ENUMSTR(CS_3D_RTRTRT)},
{ENUMSTR(CS_3D_RTRT)},
{ENUMSTR(CS_3D_RC)},
{ENUMSTR(CS_KERNEL_3D_STOCKHAM_BLOCK_CC)},
Expand Down Expand Up @@ -825,6 +826,24 @@ void TreeNode::RecursiveBuildTree()
else
{
scheme = CS_3D_RTRT;

// NB:
// Try to build the 1st child but not really add it in. Switch to
// CS_3D_RTRTRT if the 1st child is CS_2D_RTRT.(Any better idea?)
// And enable this only for cases that lengths are not aligned to
// 64 because perf issue.
// See more comments in assign_params_CS_3D_RTRTRT().
TreeNode* child0 = TreeNode::CreateNode(this);
child0->length = length;
child0->dimension = 2;
child0->RecursiveBuildTree();
if((child0->scheme == CS_2D_RTRT) && (length[0] % 64) && (length[1] % 64)
&& (length[2] % 64))
{
scheme = CS_3D_RTRTRT;
}

DeleteNode(child0);
}

switch(scheme)
Expand All @@ -834,6 +853,11 @@ void TreeNode::RecursiveBuildTree()
build_CS_3D_RTRT();
}
break;
case CS_3D_RTRTRT:
{
build_CS_3D_RTRTRT();
}
break;
case CS_3D_RC:
{
// 2d fft
Expand Down Expand Up @@ -1927,6 +1951,32 @@ void TreeNode::build_CS_3D_RTRT()
childNodes.push_back(trans2Plan);
}

void TreeNode::build_CS_3D_RTRTRT()
{
scheme = CS_3D_RTRTRT;
std::vector<size_t> cur_length = length;

for(int i = 0; i < 6; i += 2)
{
// row ffts
auto row_plan = TreeNode::CreateNode(this);
row_plan->length = cur_length;
row_plan->dimension = 1;
row_plan->RecursiveBuildTree();
childNodes.push_back(row_plan);

// transpose XY_Z
auto trans_plan = TreeNode::CreateNode(this);
trans_plan->length = cur_length;
trans_plan->scheme = CS_KERNEL_TRANSPOSE_XY_Z;
trans_plan->dimension = 2;
childNodes.push_back(trans_plan);

std::swap(cur_length[2], cur_length[1]);
std::swap(cur_length[1], cur_length[0]);
}
}

struct TreeNode::TraverseState
{
TraverseState(const ExecPlan& execPlan)
Expand Down Expand Up @@ -2086,6 +2136,9 @@ void TreeNode::TraverseTreeAssignBuffersLogicA(TraverseState& state,
case CS_3D_RC:
assign_buffers_CS_RC(state, flipIn, flipOut, obOutBuf);
break;
case CS_3D_RTRTRT:
assign_buffers_CS_3D_RTRTRT(state, flipIn, flipOut, obOutBuf);
break;
default:
if(parent == nullptr)
{
Expand Down Expand Up @@ -2767,6 +2820,67 @@ void TreeNode::assign_buffers_CS_RC(TraverseState& state,
}
}

void TreeNode::assign_buffers_CS_3D_RTRTRT(TraverseState& state,
OperatingBuffer& flipIn,
OperatingBuffer& flipOut,
OperatingBuffer& obOutBuf)
{
assert(scheme == CS_3D_RTRTRT);
assert(childNodes.size() == 6);

obOut = obOutBuf;

// TODO: adjust buffer assignment for padding

flipIn = obIn;
flipOut = OB_TEMP;

// R
childNodes[0]->SetInputBuffer(state);
childNodes[0]->obOut = obOutBuf;
childNodes[0]->inArrayType = inArrayType;
childNodes[0]->outArrayType = outArrayType;
childNodes[0]->TraverseTreeAssignBuffersLogicA(state, flipIn, flipOut, obOutBuf);

flipIn = OB_TEMP;
flipOut = obOut;
obOutBuf = obOut;

// T
childNodes[1]->SetInputBuffer(state);
childNodes[1]->obOut = OB_TEMP;
childNodes[1]->inArrayType = childNodes[0]->outArrayType;
childNodes[1]->outArrayType = rocfft_array_type_complex_interleaved;

// R
childNodes[2]->inArrayType = rocfft_array_type_complex_interleaved;
childNodes[2]->outArrayType = rocfft_array_type_complex_interleaved;
childNodes[2]->SetInputBuffer(state);
childNodes[2]->obOut = OB_TEMP;
flipIn = OB_TEMP;
flipOut = obOutBuf;
childNodes[2]->TraverseTreeAssignBuffersLogicA(state, flipIn, flipOut, obOutBuf);

// T
childNodes[3]->SetInputBuffer(state);
childNodes[3]->obOut = obOutBuf;
childNodes[3]->inArrayType = rocfft_array_type_complex_interleaved;
childNodes[3]->outArrayType = outArrayType;

// R
childNodes[4]->SetInputBuffer(state);
childNodes[4]->obOut = flipIn;
childNodes[4]->TraverseTreeAssignBuffersLogicA(state, flipIn, flipOut, obOutBuf);
childNodes[4]->inArrayType = childNodes[3]->outArrayType;
childNodes[4]->outArrayType = rocfft_array_type_complex_interleaved;

// T
childNodes[5]->SetInputBuffer(state);
childNodes[5]->inArrayType = rocfft_array_type_complex_interleaved;
childNodes[5]->outArrayType = outArrayType;
childNodes[5]->obOut = obOutBuf;
}

///////////////////////////////////////////////////////////////////////////////
/// Set placement variable and in/out array types, if not already set.
void TreeNode::TraverseTreeAssignPlacementsLogicA(const rocfft_array_type rootIn,
Expand Down Expand Up @@ -2933,6 +3047,9 @@ void TreeNode::TraverseTreeAssignParamsLogicA()
case CS_3D_RTRT:
assign_params_CS_3D_RTRT();
break;
case CS_3D_RTRTRT:
assign_params_CS_3D_RTRTRT();
break;
case CS_3D_RC:
case CS_3D_STRAIGHT:
assign_params_CS_3D_RC_STRAIGHT();
Expand Down Expand Up @@ -3959,6 +4076,55 @@ void TreeNode::assign_params_CS_3D_RTRT()
trans2Plan->oDist = oDist;
}

void TreeNode::assign_params_CS_3D_RTRTRT()
{
assert(scheme == CS_3D_RTRTRT);
assert(childNodes.size() == 6);
// TODO:
// Need regular transpose padding to improve performance for cases that
// lengths are aligned to 64, i.e. 512x512x512. However, there are few
// potential issues need to be fixed first:
// (1) The performance of current transpose_kernel2_scheme for case
// 512x512x512 need to be improved.
// (2) For in-place transform, the user buffer is not big enough for
// output with padding.
// (3) We should be able to pad the output of the 1st and 2nd transpose,
// and naturally the input of the 3rd transpose. However, the current
// transpose_kernel2_scheme and transpose_tile_device don't work for
// the 2nd padding transpose.
// Or the perf of new diagonal transpose is good enough that we don't
// need padding any more.

for(int i = 0; i < 6; i += 2)
{
auto row_plan = childNodes[i];
if(i == 0)
{
row_plan->inStride = inStride;
row_plan->iDist = iDist;
row_plan->outStride = outStride;
row_plan->oDist = oDist;
}
else
{
row_plan->inStride = childNodes[i - 1]->outStride;
row_plan->iDist = childNodes[i - 1]->oDist;
row_plan->outStride = row_plan->inStride;
row_plan->oDist = row_plan->iDist;
}
row_plan->TraverseTreeAssignParamsLogicA();

auto trans_plan = childNodes[i + 1];
trans_plan->inStride = row_plan->outStride;
trans_plan->iDist = row_plan->oDist;

trans_plan->outStride.push_back(1);
trans_plan->outStride.push_back(trans_plan->outStride[0] * trans_plan->length[2]);
trans_plan->outStride.push_back(trans_plan->outStride[1] * trans_plan->length[0]);
trans_plan->oDist = trans_plan->iDist;
}
}

void TreeNode::assign_params_CS_3D_RC_STRAIGHT()
{
auto xyPlan = childNodes[0];
Expand Down

0 comments on commit 0f7e9ba

Please sign in to comment.