Skip to content

Commit

Permalink
gpu2: shaders: implement initial values for cs
Browse files Browse the repository at this point in the history
  • Loading branch information
DHrpcs3 committed Oct 1, 2024
1 parent 59946fe commit d099439
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 14 deletions.
9 changes: 9 additions & 0 deletions rpcsx-gpu2/lib/gcn-shader/include/shader/dialect/amdgpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ enum Op {
PS_INPUT_VGPR,
PS_COMP_SWAP,
VS_GET_INDEX,
CS_INPUT_SGPR,
CS_SET_INITIAL_EXEC,
CS_SET_THREAD_ID,
RESOURCE_PHI,

OpCount,
Expand Down Expand Up @@ -49,6 +52,12 @@ inline const char *getInstructionName(unsigned op) {
return "ps_comp_swap";
case VS_GET_INDEX:
return "vs_get_index";
case CS_INPUT_SGPR:
return "cs_input_sgpr";
case CS_SET_INITIAL_EXEC:
return "cs_set_initial_exec";
case CS_SET_THREAD_ID:
return "cs_set_thread_id";
case RESOURCE_PHI:
return "resource_phi";
}
Expand Down
54 changes: 54 additions & 0 deletions rpcsx-gpu2/lib/gcn-shader/shaders/rdna.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,60 @@ float32_t ps_input_vgpr(int32_t index, f32vec4 fragCoord, bool frontFace) {
return 0;
}

uint32_t cs_input_sgpr(int32_t index, u32vec3 localInvocationId) {
if (index == 0) {
return localInvocationId.x;
}

if (index == 1) {
return localInvocationId.y;
}

if (index == 2) {
return localInvocationId.z;
}

return 0;
}

void cs_set_initial_exec(u32vec3 localInvocationId, u32vec3 workgroupSize) {
uint32_t totalWorkgroupSize = workgroupSize.x * workgroupSize.y * workgroupSize.z;

if (totalWorkgroupSize == 64) {
exec = ~uint64_t(0);
return;
}

if (totalWorkgroupSize < 64) {
exec = (uint64_t(1) << totalWorkgroupSize) - 1;
return;
}

uint32_t waveCount = totalWorkgroupSize / 64;

uint32_t totalInvocationIndex = localInvocationId.x +
localInvocationId.y * workgroupSize.x +
localInvocationId.z * workgroupSize.x * workgroupSize.y;

uint32_t waveIndex = (totalInvocationIndex + 63) / 64;

if (waveIndex + 1 < waveCount) {
exec = ~uint64_t(0);
return;
}

uint32_t lastWaveLen = totalWorkgroupSize % 64;
exec = lastWaveLen == 0 ? ~uint64_t(0) : ((uint64_t(1) << lastWaveLen) - 1);
}

void cs_set_thread_id(u32vec3 localInvocationId, u32vec3 workgroupSize) {
uint32_t totalInvocationIndex = localInvocationId.x +
localInvocationId.y * workgroupSize.x +
localInvocationId.z * workgroupSize.x * workgroupSize.y;

thread_id = totalInvocationIndex % 64;
}

const uint32_t kPrimTypeQuadList = 0x13;
const uint32_t kPrimTypeQuadStrip = 0x14;

Expand Down
105 changes: 91 additions & 14 deletions rpcsx-gpu2/lib/gcn-shader/src/GcnConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ static void instructionsToSpv(GcnConverter &converter, gcn::Import &importer,
auto memorySSA = buildMemorySSA(cfg, &moduleInfo);
spv::Import resourceImporter;

memorySSA.print(std::cerr, body, context.ns);
// memorySSA.print(std::cerr, body, context.ns);

ResourcesBuilder resourcesBuilder;
std::map<ir::Value, std::int32_t> resourceConfigSlots;
Expand Down Expand Up @@ -1324,8 +1324,8 @@ static void instructionsToSpv(GcnConverter &converter, gcn::Import &importer,
}
}

static void createEntryPoint(gcn::Context &context, gcn::Stage stage,
ir::Region &&body) {
static void createEntryPoint(gcn::Context &context, const gcn::Environment &env,
gcn::Stage stage, ir::Region &&body) {
auto executionModel = ir::spv::ExecutionModel::GLCompute;

switch (stage) {
Expand Down Expand Up @@ -1408,6 +1408,17 @@ static void createEntryPoint(gcn::Context &context, gcn::Stage stage,
mainFn.getLocation(), mainFn,
ir::spv::ExecutionMode::OriginUpperLeft());
}

if (executionModel == ir::spv::ExecutionModel::GLCompute) {
auto executionModes = gcn::Builder::createAppend(
context, context.layout.getOrCreateExecutionModes(context));

executionModes.createSpvExecutionMode(
mainFn.getLocation(), mainFn,
ir::spv::ExecutionMode::LocalSize(env.numThreadX, env.numThreadY,
env.numThreadZ));
}

entryPoints.createSpvEntryPoint(mainFn.getLocation(), executionModel, mainFn,
"main", interfaceList);
}
Expand Down Expand Up @@ -1457,13 +1468,9 @@ static void createInitialValues(GcnConverter &converter,
} else if (stage == gcn::Stage::Ps) {
auto boolT = context.getTypeBool();
auto f32T = context.getTypeFloat32();
auto s32T = context.getTypeSInt32();
auto f32x3 = context.getTypeVector(f32T, 3);
auto f32x4 = context.getTypeVector(f32T, 4);

auto boolPT = context.getTypePointer(ir::spv::StorageClass::Input, boolT);
auto s32PT = context.getTypePointer(ir::spv::StorageClass::Input, s32T);
auto f32x3PT = context.getTypePointer(ir::spv::StorageClass::Input, f32x3);
auto f32x4PT = context.getTypePointer(ir::spv::StorageClass::Input, f32x4);

auto globals = gcn::Builder::createAppend(
Expand Down Expand Up @@ -1517,16 +1524,86 @@ static void createInitialValues(GcnConverter &converter,
builder.createSpvBitcast(loc, context.getTypeSInt32(), runtimeIndex));

auto vgprValue = builder.createValue(loc, ir::amdgpu::PS_INPUT_VGPR,
std::span<const ir::Operand>{{
context.getTypeFloat32(),
indexLocal,
fragCoord,
frontFace,
}});
context.getTypeFloat32(), indexLocal,
fragCoord, frontFace);
context.writeReg(loc, builder, gcn::RegId::Vgpr, i, vgprValue);
}
}

if (stage == gcn::Stage::Cs) {
auto uintT = context.getTypeUInt32();
auto uvec3T = context.getTypeVector(uintT, 3);
auto pInputUVec3T =
context.getTypePointer(ir::spv::StorageClass::Input, uvec3T);

auto globals = gcn::Builder::createAppend(
context, context.layout.getOrCreateGlobals(context));
auto annotations = gcn::Builder::createAppend(
context, context.layout.getOrCreateAnnotations(context));

auto workGroupIdVar = globals.createSpvVariable(
loc, pInputUVec3T, ir::spv::StorageClass::Input);
annotations.createSpvDecorate(
loc, workGroupIdVar,
ir::spv::Decoration::BuiltIn(ir::spv::BuiltIn::WorkgroupId));

auto localInvocationIdVar = globals.createSpvVariable(
loc, pInputUVec3T, ir::spv::StorageClass::Input);
annotations.createSpvDecorate(
loc, localInvocationIdVar,
ir::spv::Decoration::BuiltIn(ir::spv::BuiltIn::LocalInvocationId));

auto workGroupId = builder.createSpvLoad(loc, uvec3T, workGroupIdVar);
auto workGroupIdLocalVar =
converter.createLocalVariable(builder, loc, workGroupId);
auto localInvocationId =
builder.createSpvLoad(loc, uvec3T, localInvocationIdVar);
auto localInvocationIdLocVar =
converter.createLocalVariable(builder, loc, localInvocationId);

{
auto indexLocal =
converter.createLocalVariable(builder, loc, context.simm32(0));
int end = env.sgprCount;
end = std::min<int>(end, env.userSgprs.size() +
static_cast<int>(gcn::CsSGprInput::Count));

for (int i = env.userSgprs.size(); i < end; ++i) {
std::uint32_t slot =
info.create(gcn::ConfigType::CsInputSGpr, i - env.userSgprs.size());
auto runtimeIndex = converter.createReadConfig(stage, builder, slot);
builder.createSpvStore(loc, indexLocal,
builder.createSpvBitcast(
loc, context.getTypeSInt32(), runtimeIndex));

auto sgprValue = builder.createValue(loc, ir::amdgpu::CS_INPUT_SGPR,
context.getTypeUInt32(),
indexLocal, workGroupIdLocalVar);
context.writeReg(loc, builder, gcn::RegId::Sgpr, i, sgprValue);
}
}

for (std::int32_t i = 0; i < 3; ++i) {
auto value = builder.createSpvCompositeExtract(loc, uintT,
localInvocationId, {{i}});
context.writeReg(loc, builder, gcn::RegId::Vgpr, i, value);
}

auto workgroupSize = builder.createSpvCompositeConstruct(
loc, uvec3T,
{{context.imm32(env.numThreadX), context.imm32(env.numThreadY),
context.imm32(env.numThreadZ)}});
auto workgroupSizeLocVar =
converter.createLocalVariable(builder, loc, workgroupSize);

builder.createValue(loc, ir::amdgpu::CS_SET_INITIAL_EXEC,
context.getTypeVoid(), localInvocationIdLocVar,
workgroupSizeLocVar);
builder.createValue(loc, ir::amdgpu::CS_SET_THREAD_ID,
context.getTypeVoid(), localInvocationIdLocVar,
workgroupSizeLocVar);
}

context.writeReg(loc, builder, gcn::RegId::Vcc, 0, context.imm64(0));

for (int word = 0; word < 2; ++word) {
Expand Down Expand Up @@ -1561,7 +1638,7 @@ gcn::convertToSpv(Context &context, ir::Region body,
context.imm32(0));
}

createEntryPoint(context, stage, std::move(body));
createEntryPoint(context, env, stage, std::move(body));

for (int userSgpr = std::countr_zero(context.requiredUserSgprs);
userSgpr < 32;
Expand Down

0 comments on commit d099439

Please sign in to comment.