Skip to content

Commit 357813d

Browse files
committed
[spirv] Allows spec constants as attribute arguments (for selected attributes).
1 parent 8a8b29f commit 357813d

File tree

18 files changed

+459
-211
lines changed

18 files changed

+459
-211
lines changed

tools/clang/include/clang/AST/Expr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,9 @@ class Expr : public Stmt {
531531
bool isConstantInitializer(ASTContext &Ctx, bool ForRef,
532532
const Expr **Culprit = nullptr) const;
533533

534+
bool isVulkanSpecConstantExpr(const ASTContext &Ctx,
535+
APValue *Result = nullptr) const;
536+
534537
/// EvalStatus is a struct with detailed info about an evaluation in progress.
535538
struct EvalStatus {
536539
/// HasSideEffects - Whether the evaluated expression has side effects.

tools/clang/include/clang/Basic/Attr.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def HLSLMaxTessFactor: InheritableAttr {
668668
}
669669
def HLSLNumThreads: InheritableAttr {
670670
let Spellings = [CXX11<"", "numthreads", 2015>];
671-
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
671+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
672672
let Documentation = [Undocumented];
673673
}
674674
def HLSLRootSignature: InheritableAttr {
@@ -1004,7 +1004,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr {
10041004

10051005
def HLSLNodeId : InheritableAttr {
10061006
let Spellings = [CXX11<"", "nodeid", 2017>];
1007-
let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>];
1007+
let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>];
10081008
let Documentation = [Undocumented];
10091009
}
10101010

@@ -1016,25 +1016,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr {
10161016

10171017
def HLSLNodeShareInputOf : InheritableAttr {
10181018
let Spellings = [CXX11<"", "nodeshareinputof", 2017>];
1019-
let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>];
1019+
let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>];
10201020
let Documentation = [Undocumented];
10211021
}
10221022

10231023
def HLSLNodeDispatchGrid: InheritableAttr {
10241024
let Spellings = [CXX11<"", "nodedispatchgrid", 2015>];
1025-
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
1025+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
10261026
let Documentation = [Undocumented];
10271027
}
10281028

10291029
def HLSLNodeMaxDispatchGrid: InheritableAttr {
10301030
let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>];
1031-
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
1031+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
10321032
let Documentation = [Undocumented];
10331033
}
10341034

10351035
def HLSLNodeMaxRecursionDepth : InheritableAttr {
10361036
let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>];
1037-
let Args = [UnsignedArgument<"Count">];
1037+
let Args = [ExprArgument<"Count">];
10381038
let Documentation = [Undocumented];
10391039
}
10401040

@@ -1182,7 +1182,7 @@ def HLSLHitObject : InheritableAttr {
11821182

11831183
def HLSLMaxRecords : InheritableAttr {
11841184
let Spellings = [CXX11<"", "MaxRecords", 2015>];
1185-
let Args = [IntArgument<"maxCount">];
1185+
let Args = [ExprArgument<"maxCount">];
11861186
let Documentation = [Undocumented];
11871187
}
11881188

tools/clang/include/clang/SPIRV/SpirvContext.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,15 @@ class SpirvContext {
456456
instructionsWithLoweredType.end();
457457
}
458458

459+
SpirvInstruction *getSpecConstant(const VarDecl *decl) {
460+
return specConstants[decl];
461+
}
462+
463+
void registerSpecConstant(const VarDecl *decl,
464+
SpirvInstruction *specConstant) {
465+
specConstants[decl] = specConstant;
466+
}
467+
459468
void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) {
460469
auto iter = dispatchGridIndices.find(decl);
461470
if (iter == dispatchGridIndices.end()) {
@@ -536,6 +545,7 @@ class SpirvContext {
536545
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
537546
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypesById;
538547
llvm::SmallVector<const SpirvIntrinsicType *, 8> spirvIntrinsicTypes;
548+
llvm::MapVector<const VarDecl *, SpirvInstruction *> specConstants;
539549
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
540550
const RayQueryTypeKHR *rayQueryTypeKHR;
541551

tools/clang/include/clang/Sema/SemaHLSL.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,6 @@ unsigned CaculateInitListArraySizeForHLSL(clang::Sema *sema,
160160
const clang::InitListExpr *InitList,
161161
const clang::QualType EltTy);
162162

163-
bool ContainsLongVector(clang::QualType);
164-
165163
bool IsConversionToLessOrEqualElements(clang::Sema *self,
166164
const clang::ExprResult &sourceExpr,
167165
const clang::QualType &targetType,

tools/clang/lib/AST/ExprConstant.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9448,6 +9448,22 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx,
94489448
return true;
94499449
}
94509450

9451+
bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx,
9452+
APValue *Result) const {
9453+
if (auto *D = dyn_cast<DeclRefExpr>(this)) {
9454+
if (auto *V = dyn_cast<VarDecl>(D->getDecl())) {
9455+
if (V->hasAttr<VKConstantIdAttr>()) {
9456+
if (const Expr *I = V->getAnyInitializer()) {
9457+
if (!I->isCXX11ConstantExpr(Ctx, Result))
9458+
return false;
9459+
}
9460+
return true;
9461+
}
9462+
}
9463+
}
9464+
return false;
9465+
}
9466+
94519467
bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const {
94529468
return CheckICE(this, Ctx).Kind == IK_ICE;
94539469
}

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
328328
};
329329
} // namespace
330330

331+
static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr,
332+
uint32_t defaultVal = 0) {
333+
if (expr) {
334+
llvm::APSInt apsInt;
335+
APValue apValue;
336+
if (expr->isIntegerConstantExpr(apsInt, astContext))
337+
return (uint32_t)apsInt.getSExtValue();
338+
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt())
339+
return (uint32_t)apValue.getInt().getSExtValue();
340+
}
341+
return defaultVal;
342+
}
343+
331344
//------------------------------------------------------------------------------
332345
//
333346
// CGMSHLSLRuntime methods.
@@ -1422,6 +1435,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14221435
}
14231436

14241437
DiagnosticsEngine &Diags = CGM.getDiags();
1438+
ASTContext &astContext = CGM.getTypes().getContext();
14251439

14261440
std::unique_ptr<DxilFunctionProps> funcProps =
14271441
llvm::make_unique<DxilFunctionProps>();
@@ -1632,10 +1646,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16321646

16331647
// Populate numThreads
16341648
if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
1635-
1636-
funcProps->numThreads[0] = Attr->getX();
1637-
funcProps->numThreads[1] = Attr->getY();
1638-
funcProps->numThreads[2] = Attr->getZ();
1649+
funcProps->numThreads[0] = GetIntConstAttrArg(astContext, Attr->getX(), 1);
1650+
funcProps->numThreads[1] = GetIntConstAttrArg(astContext, Attr->getY(), 1);
1651+
funcProps->numThreads[2] = GetIntConstAttrArg(astContext, Attr->getZ(), 1);
16391652

16401653
if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
16411654
unsigned DiagID = Diags.getCustomDiagID(
@@ -1808,7 +1821,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18081821

18091822
if (const auto *pAttr = FD->getAttr<HLSLNodeIdAttr>()) {
18101823
funcProps->NodeShaderID.Name = pAttr->getName().str();
1811-
funcProps->NodeShaderID.Index = pAttr->getArrayIndex();
1824+
funcProps->NodeShaderID.Index =
1825+
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
18121826
} else {
18131827
funcProps->NodeShaderID.Name = FD->getName().str();
18141828
funcProps->NodeShaderID.Index = 0;
@@ -1819,20 +1833,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18191833
}
18201834
if (const auto *pAttr = FD->getAttr<HLSLNodeShareInputOfAttr>()) {
18211835
funcProps->NodeShaderSharedInput.Name = pAttr->getName().str();
1822-
funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex();
1836+
funcProps->NodeShaderSharedInput.Index =
1837+
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
18231838
}
18241839
if (const auto *pAttr = FD->getAttr<HLSLNodeDispatchGridAttr>()) {
1825-
funcProps->Node.DispatchGrid[0] = pAttr->getX();
1826-
funcProps->Node.DispatchGrid[1] = pAttr->getY();
1827-
funcProps->Node.DispatchGrid[2] = pAttr->getZ();
1840+
funcProps->Node.DispatchGrid[0] =
1841+
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
1842+
funcProps->Node.DispatchGrid[1] =
1843+
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
1844+
funcProps->Node.DispatchGrid[2] =
1845+
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
18281846
}
18291847
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxDispatchGridAttr>()) {
1830-
funcProps->Node.MaxDispatchGrid[0] = pAttr->getX();
1831-
funcProps->Node.MaxDispatchGrid[1] = pAttr->getY();
1832-
funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ();
1848+
funcProps->Node.MaxDispatchGrid[0] =
1849+
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
1850+
funcProps->Node.MaxDispatchGrid[1] =
1851+
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
1852+
funcProps->Node.MaxDispatchGrid[2] =
1853+
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
18331854
}
18341855
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxRecursionDepthAttr>()) {
1835-
funcProps->Node.MaxRecursionDepth = pAttr->getCount();
1856+
funcProps->Node.MaxRecursionDepth =
1857+
GetIntConstAttrArg(astContext, pAttr->getCount(), 0);
18361858
}
18371859
if (!FD->getAttr<HLSLNumThreadsAttr>()) {
18381860
// NumThreads wasn't specified.
@@ -2346,8 +2368,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23462368
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23472369

23482370
if (parmDecl->hasAttr<HLSLMaxRecordsAttr>()) {
2349-
node.MaxRecords =
2350-
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount();
2371+
node.MaxRecords = GetIntConstAttrArg(
2372+
astContext,
2373+
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount(), 1);
23512374
}
23522375
if (parmDecl->hasAttr<HLSLGloballyCoherentAttr>())
23532376
node.Flags.SetGloballyCoherent();
@@ -2378,7 +2401,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23782401
// OutputID from attribute
23792402
if (const auto *Attr = parmDecl->getAttr<HLSLNodeIdAttr>()) {
23802403
node.OutputID.Name = Attr->getName().str();
2381-
node.OutputID.Index = Attr->getArrayIndex();
2404+
node.OutputID.Index =
2405+
GetIntConstAttrArg(astContext, Attr->getArrayIndex(), 0);
23822406
} else {
23832407
node.OutputID.Name = parmDecl->getName().str();
23842408
node.OutputID.Index = 0;
@@ -2437,7 +2461,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24372461
node.MaxRecordsSharedWith = ix;
24382462
}
24392463
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
2440-
node.MaxRecords = Attr->getMaxCount();
2464+
node.MaxRecords = GetIntConstAttrArg(astContext, Attr->getMaxCount(), 0);
24412465
}
24422466

24432467
if (inputPatchCount > 1) {

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,6 +1815,7 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
18151815
void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
18161816
SpirvInstruction *specConstant) {
18171817
specConstant->setRValue();
1818+
spvContext.registerSpecConstant(decl, specConstant);
18181819
registerVariableForDecl(decl, createDeclSpirvInfo(specConstant));
18191820
}
18201821

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,6 +2523,24 @@ isFieldMergeWithPrevious(const StructType::FieldInfo &previous,
25232523
return previous.fieldIndex == field.fieldIndex;
25242524
}
25252525

2526+
uint32_t EmitTypeHandler::getAttrArgInstr(ASTContext &astContext,
2527+
const Expr *expr,
2528+
uint32_t defaultVal) {
2529+
if (expr) {
2530+
llvm::APSInt apsInt;
2531+
APValue apValue;
2532+
if (expr->isIntegerConstantExpr(apsInt, astContext))
2533+
return getOrCreateConstantInt(apsInt, context.getUIntType(32), false);
2534+
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) &&
2535+
apValue.isInt()) {
2536+
auto *declRefExpr = dyn_cast<DeclRefExpr>(expr);
2537+
auto *decl = dyn_cast<const VarDecl>(declRefExpr->getDecl());
2538+
return getOrAssignResultId(context.getSpecConstant(decl));
2539+
}
2540+
}
2541+
return defaultVal;
2542+
}
2543+
25262544
uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
25272545
// First get the decorations that would apply to this type.
25282546
bool alreadyExists = false;
@@ -2655,27 +2673,24 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
26552673
if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) {
26562674
StringRef name = nodeDecl->getName();
26572675
unsigned index = 0;
2658-
if (auto nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
2676+
if (auto *nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
26592677
name = nodeID->getName();
2660-
index = nodeID->getArrayIndex();
2678+
index = getAttrArgInstr(astContext, nodeID->getArrayIndex());
26612679
}
26622680

26632681
auto *str = new (context) SpirvConstantString(name);
26642682
uint32_t nodeName = getOrCreateConstantString(str);
26652683
emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
26662684
llvm::None, true);
26672685
if (index) {
2668-
uint32_t baseIndex = getOrCreateConstantInt(
2669-
llvm::APInt(32, index), context.getUIntType(32), false);
2670-
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX,
2671-
{baseIndex}, llvm::None, true);
2686+
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX, {index},
2687+
llvm::None, true);
26722688
}
26732689
}
26742690

26752691
uint32_t maxRecords;
26762692
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsAttr>()) {
2677-
maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()),
2678-
context.getUIntType(32), false);
2693+
maxRecords = getAttrArgInstr(astContext, attr->getMaxCount(), 1);
26792694
} else {
26802695
maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1),
26812696
context.getUIntType(32), false);

tools/clang/lib/SPIRV/EmitVisitor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class EmitTypeHandler {
6767
EmitTypeHandler(const EmitTypeHandler &) = delete;
6868
EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;
6969

70+
uint32_t getAttrArgInstr(ASTContext &astContext, const Expr *expr,
71+
uint32_t defaultVal = 0);
72+
7073
// Emits the instruction for the given type into the typeConstantBinary and
7174
// returns the result-id for the type. If the type has already been emitted,
7275
// it only returns its result-id.

0 commit comments

Comments
 (0)