Skip to content

Fix intrinsic lookup with namespaces #7599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tools/clang/include/clang/Sema/ExternalSemaSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,9 @@ class ExternalSemaSource : public ExternalASTSource {
// add call candidates to the given expression. It returns 'true'
// if standard overload search should be suppressed; false otherwise.
virtual bool AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE,
ArrayRef<Expr *> Args,
OverloadCandidateSet &CandidateSet,
bool PartialOverloading)
{
ArrayRef<Expr *> Args,
OverloadCandidateSet &CandidateSet,
Scope *S, bool PartialOverloading) {
return false;
}

Expand Down
5 changes: 5 additions & 0 deletions tools/clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -2495,9 +2495,14 @@ class Sema {
DeclAccessPair FoundDecl,
FunctionDecl *Fn);

// HLSL Change Begin
void CollectNamespaceContexts(Scope *,
SmallVectorImpl<const DeclContext *> &);
// HLSL Change End
void AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE,
ArrayRef<Expr *> Args,
OverloadCandidateSet &CandidateSet,
Scope *S, // HLSL Change
bool PartialOverloading = false);

// An enum used to represent the different possible results of building a
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/Sema/SemaCodeComplete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4020,7 +4020,7 @@ void Sema::CodeCompleteCall(Scope *S, Expr *Fn, ArrayRef<Expr *> Args) {

Expr *NakedFn = Fn->IgnoreParenCasts();
if (auto ULE = dyn_cast<UnresolvedLookupExpr>(NakedFn))
AddOverloadedCallCandidates(ULE, Args, CandidateSet,
AddOverloadedCallCandidates(ULE, Args, CandidateSet, S, // HLSL Change
/*PartialOverloading=*/true);
else if (auto UME = dyn_cast<UnresolvedMemberExpr>(NakedFn)) {
TemplateArgumentListInfo TemplateArgsBuffer, *TemplateArgs = nullptr;
Expand Down
185 changes: 114 additions & 71 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4152,6 +4152,7 @@ class HLSLExternalSource : public ExternalSemaSource {
SourceLocation(), &context.Idents.get("dx"),
/*PrevDecl*/ nullptr);
m_dxNSDecl->setImplicit();
m_dxNSDecl->setHasExternalLexicalStorage(true);
context.getTranslationUnitDecl()->addDecl(m_dxNSDecl);

#ifdef ENABLE_SPIRV_CODEGEN
Expand Down Expand Up @@ -5169,7 +5170,7 @@ class HLSLExternalSource : public ExternalSemaSource {

bool AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE,
ArrayRef<Expr *> Args,
OverloadCandidateSet &CandidateSet,
OverloadCandidateSet &CandidateSet, Scope *S,
bool PartialOverloading) override {
DXASSERT_NOMSG(ULE != nullptr);

Expand All @@ -5194,6 +5195,8 @@ class HLSLExternalSource : public ExternalSemaSource {
// Exceptions:
// - Vulkan-specific intrinsics live in the 'vk::' namespace.
// - DirectX-specific intrinsics live in the 'dx::' namespace.
// - Global namespaces could just mean we have a `using` declaration... so
// it can be anywhere!
if (isQualified && !isGlobalNamespace && !isVkNamespace && !isDxNamespace)
return false;

Expand All @@ -5204,81 +5207,121 @@ class HLSLExternalSource : public ExternalSemaSource {
}

StringRef nameIdentifier = idInfo->getName();
const HLSL_INTRINSIC *table = g_Intrinsics;
auto tableCount = _countof(g_Intrinsics);
if (isDxNamespace) {
table = g_DxIntrinsics;
tableCount = _countof(g_DxIntrinsics);
}
using IntrinsicArray = llvm::ArrayRef<const HLSL_INTRINSIC>;
struct IntrinsicTableEntry {
IntrinsicArray Table;
NamespaceDecl *NS;
};

llvm::SmallVector<IntrinsicTableEntry, 3> SearchTables;

if (isDxNamespace)
SearchTables.push_back(
IntrinsicTableEntry{IntrinsicArray(g_DxIntrinsics), m_dxNSDecl});
#ifdef ENABLE_SPIRV_CODEGEN
else if (isVkNamespace)
SearchTables.push_back(
IntrinsicTableEntry{IntrinsicArray(g_VkIntrinsics), m_vkNSDecl});
#endif
else if (isGlobalNamespace)
SearchTables.push_back(
IntrinsicTableEntry{IntrinsicArray(g_Intrinsics), m_hlslNSDecl});
else if (!isQualified) {
// If the name isn't qualified, we need to search all scopes that are
// accessible without qualification. This starts with the global scope and
// extends into any scopes that are referred to by using declarations.
SearchTables.push_back(
IntrinsicTableEntry{IntrinsicArray(g_Intrinsics), m_hlslNSDecl});

// If we have a scope chain, walk it to get using declarations.
if (S) {
SmallVector<const DeclContext *, 4> NSContexts;
m_sema->CollectNamespaceContexts(S, NSContexts);
bool DXFound = false;
bool VKFound = false;
for (const auto &UD : NSContexts) {
if (static_cast<DeclContext *>(m_dxNSDecl) == UD)
DXFound = true;
else if (static_cast<DeclContext *>(m_vkNSDecl) == UD)
VKFound = true;
}
if (DXFound)
SearchTables.push_back(
IntrinsicTableEntry{IntrinsicArray(g_DxIntrinsics), m_dxNSDecl});
#ifdef ENABLE_SPIRV_CODEGEN
if (isVkNamespace) {
table = g_VkIntrinsics;
tableCount = _countof(g_VkIntrinsics);
if (VKFound)
SearchTables.push_back(
IntrinsicTableEntry{IntrinsicArray(g_VkIntrinsics), m_vkNSDecl});
#endif
}
}
#endif // ENABLE_SPIRV_CODEGEN

IntrinsicDefIter cursor = FindIntrinsicByNameAndArgCount(
table, tableCount, StringRef(), nameIdentifier, Args.size());
IntrinsicDefIter end = IntrinsicDefIter::CreateEnd(
table, tableCount, IntrinsicTableDefIter::CreateEnd(m_intrinsicTables));

for (; cursor != end; ++cursor) {
// If this is the intrinsic we're interested in, build up a representation
// of the types we need.
const HLSL_INTRINSIC *pIntrinsic = *cursor;
LPCSTR tableName = cursor.GetTableName();
LPCSTR lowering = cursor.GetLoweringStrategy();
DXASSERT(pIntrinsic->uNumArgs <= g_MaxIntrinsicParamCount + 1,
"otherwise g_MaxIntrinsicParamCount needs to be updated for "
"wider signatures");

std::vector<QualType> functionArgTypes;
size_t badArgIdx;
bool argsMatch =
MatchArguments(cursor, QualType(), QualType(), QualType(), Args,
&functionArgTypes, badArgIdx);
if (!functionArgTypes.size())
return false;
assert(!SearchTables.empty() && "Must have at least one search table!");

for (const auto &T : SearchTables) {

IntrinsicDefIter cursor = FindIntrinsicByNameAndArgCount(
T.Table.data(), T.Table.size(), StringRef(), nameIdentifier,
Args.size());
IntrinsicDefIter end = IntrinsicDefIter::CreateEnd(
T.Table.data(), T.Table.size(),
IntrinsicTableDefIter::CreateEnd(m_intrinsicTables));

for (; cursor != end; ++cursor) {
// If this is the intrinsic we're interested in, build up a
// representation of the types we need.
const HLSL_INTRINSIC *pIntrinsic = *cursor;
LPCSTR tableName = cursor.GetTableName();
LPCSTR lowering = cursor.GetLoweringStrategy();
DXASSERT(pIntrinsic->uNumArgs <= g_MaxIntrinsicParamCount + 1,
"otherwise g_MaxIntrinsicParamCount needs to be updated for "
"wider signatures");

std::vector<QualType> functionArgTypes;
size_t badArgIdx;
bool argsMatch =
MatchArguments(cursor, QualType(), QualType(), QualType(), Args,
&functionArgTypes, badArgIdx);
if (!functionArgTypes.size())
return false;

// Get or create the overload we're interested in.
FunctionDecl *intrinsicFuncDecl = nullptr;
std::pair<UsedIntrinsicStore::iterator, bool> insertResult =
m_usedIntrinsics.insert(UsedIntrinsic(pIntrinsic, functionArgTypes));
bool insertedNewValue = insertResult.second;
if (insertedNewValue) {
NamespaceDecl *nsDecl = m_hlslNSDecl;
if (isVkNamespace)
nsDecl = m_vkNSDecl;
else if (isDxNamespace)
nsDecl = m_dxNSDecl;
DXASSERT(tableName,
"otherwise IDxcIntrinsicTable::GetTableName() failed");
intrinsicFuncDecl =
AddHLSLIntrinsicFunction(*m_context, nsDecl, tableName, lowering,
pIntrinsic, &functionArgTypes);
insertResult.first->setFunctionDecl(intrinsicFuncDecl);
} else {
intrinsicFuncDecl = (*insertResult.first).getFunctionDecl();
}
// Get or create the overload we're interested in.
FunctionDecl *intrinsicFuncDecl = nullptr;
std::pair<UsedIntrinsicStore::iterator, bool> insertResult =
m_usedIntrinsics.insert(
UsedIntrinsic(pIntrinsic, functionArgTypes));
bool insertedNewValue = insertResult.second;
if (insertedNewValue) {
DXASSERT(tableName,
"otherwise IDxcIntrinsicTable::GetTableName() failed");
intrinsicFuncDecl =
AddHLSLIntrinsicFunction(*m_context, T.NS, tableName, lowering,
pIntrinsic, &functionArgTypes);
insertResult.first->setFunctionDecl(intrinsicFuncDecl);
} else {
intrinsicFuncDecl = (*insertResult.first).getFunctionDecl();
}

OverloadCandidate &candidate = CandidateSet.addCandidate(Args.size());
candidate.Function = intrinsicFuncDecl;
candidate.FoundDecl.setDecl(intrinsicFuncDecl);
candidate.Viable = argsMatch;
CandidateSet.isNewCandidate(intrinsicFuncDecl); // used to insert into set
if (argsMatch)
return true;
if (badArgIdx) {
candidate.FailureKind = ovl_fail_bad_conversion;
QualType ParamType =
intrinsicFuncDecl->getParamDecl(badArgIdx - 1)->getType();
candidate.Conversions[badArgIdx - 1].setBad(
BadConversionSequence::no_conversion, Args[badArgIdx - 1],
ParamType);
} else {
// A less informative error. Needed when the failure relates to the
// return type
candidate.FailureKind = ovl_fail_bad_final_conversion;
OverloadCandidate &candidate = CandidateSet.addCandidate(Args.size());
candidate.Function = intrinsicFuncDecl;
candidate.FoundDecl.setDecl(intrinsicFuncDecl);
candidate.Viable = argsMatch;
CandidateSet.isNewCandidate(
intrinsicFuncDecl); // used to insert into set
if (argsMatch)
return true;
if (badArgIdx) {
candidate.FailureKind = ovl_fail_bad_conversion;
QualType ParamType =
intrinsicFuncDecl->getParamDecl(badArgIdx - 1)->getType();
candidate.Conversions[badArgIdx - 1].setBad(
BadConversionSequence::no_conversion, Args[badArgIdx - 1],
ParamType);
} else {
// A less informative error. Needed when the failure relates to the
// return type
candidate.FailureKind = ovl_fail_bad_final_conversion;
}
}
}

Expand Down
36 changes: 35 additions & 1 deletion tools/clang/lib/Sema/SemaLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
using namespace clang;
using namespace sema;

// HLSL Note: This set of utilities copied to SemaHLSL.cpp.
namespace {
class UnqualUsingEntry {
const DeclContext *Nominated;
Expand Down Expand Up @@ -4809,9 +4810,12 @@ void Sema::diagnoseTypo(const TypoCorrection &Correction,

NamedDecl *ChosenDecl =
Correction.isKeyword() ? nullptr : Correction.getCorrectionDecl();
if (PrevNote.getDiagID() && ChosenDecl)
// HLSL Change begin: don't put notes on invalid source locations.
if (PrevNote.getDiagID() && ChosenDecl &&
!ChosenDecl->getLocation().isInvalid())
Diag(ChosenDecl->getLocation(), PrevNote)
<< CorrectedQuotedStr << (ErrorRecovery ? FixItHint() : FixTypo);
// HLSL Change end
}

TypoExpr *Sema::createDelayedTypo(std::unique_ptr<TypoCorrectionConsumer> TCC,
Expand All @@ -4836,3 +4840,33 @@ const Sema::TypoExprState &Sema::getTypoExprState(TypoExpr *TE) const {
void Sema::clearDelayedTypo(TypoExpr *TE) {
DelayedTypos.erase(TE);
}

// HLSL Change Begin
void Sema::CollectNamespaceContexts(Scope *S,
SmallVectorImpl<const DeclContext *> &NSs) {
UnqualUsingDirectiveSet UDirs;

// Add using directives from this context up to the top level. This
// handles cases where the current declaration is in a context that has
// a using directive but might be in a scope chain that doesn't reach
// the using directive (i.e. a using inside a namespace or class
// declaration but the function definition is outside).
DeclContext *Ctx = S->getEntity();
for (DeclContext *UCtx = Ctx; UCtx; UCtx = UCtx->getParent()) {
if (UCtx->isTransparentContext())
continue;

UDirs.visit(UCtx, UCtx);
}
// Find the first namespace or translation-unit scope.
Scope *Innermost = S;
while (Innermost && !isNamespaceOrTranslationUnitScope(Innermost))
Innermost = Innermost->getParent();

UDirs.visitScopeChain(S, Innermost);
UDirs.done();

for (auto &UD : UDirs)
NSs.push_back(UD.getNominatedNamespace());
}
// HLSL Change End
7 changes: 4 additions & 3 deletions tools/clang/lib/Sema/SemaOverload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10627,6 +10627,7 @@ static void AddOverloadedCallCandidate(Sema &S,
void Sema::AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE,
ArrayRef<Expr *> Args,
OverloadCandidateSet &CandidateSet,
Scope *S, // HLSL Change
bool PartialOverloading) {

#ifndef NDEBUG
Expand Down Expand Up @@ -10659,8 +10660,8 @@ void Sema::AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE,
#endif

// HLSL Change - allow ExternalSource the ability to add the overloads for a call.
if (ExternalSource &&
ExternalSource->AddOverloadedCallCandidates(ULE, Args, CandidateSet, PartialOverloading)) {
if (ExternalSource && ExternalSource->AddOverloadedCallCandidates(
ULE, Args, CandidateSet, S, PartialOverloading)) {
return;
}

Expand Down Expand Up @@ -10970,7 +10971,7 @@ bool Sema::buildOverloadedCallSet(Scope *S, Expr *Fn,

// Add the functions denoted by the callee to the set of candidate
// functions, including those from argument-dependent lookup.
AddOverloadedCallCandidates(ULE, Args, *CandidateSet);
AddOverloadedCallCandidates(ULE, Args, *CandidateSet, S); // HLSL Change

if (getLangOpts().MSVCCompat &&
CurContext->isDependentContext() && !isSFINAEContext() &&
Expand Down
2 changes: 0 additions & 2 deletions tools/clang/test/SemaHLSL/effects-syntax.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,10 @@ static const PixelShader ps1 { state=foo; }; /* expected-warning
/*verify-ast
No matching AST found for line!
*/
// expected-note@? {{'PixelShader' declared here}}
PixelShadeR ps < int foo=1;> = ps1; // Case insensitive! /* expected-error {{unknown type name 'PixelShadeR'; did you mean 'PixelShader'?}} expected-warning {{effect object ignored - effect syntax is deprecated}} expected-warning {{possible effect annotation ignored - effect syntax is deprecated}} fxc-pass {{}} */
/*verify-ast
No matching AST found for line!
*/
// expected-note@? {{'VertexShader' declared here}}
VertexShadeR vs; // Case insensitive! /* expected-error {{unknown type name 'VertexShadeR'; did you mean 'VertexShader'?}} expected-warning {{effect object ignored - effect syntax is deprecated}} fxc-pass {{}} */

// Case sensitive
Expand Down
4 changes: 2 additions & 2 deletions tools/clang/test/SemaHLSL/raytracings.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ void run() {
RAY_FLAG_CULL_OPAQUE +
RAY_FLAG_CULL_NON_OPAQUE;

rayFlags += RAY_FLAG_INVALID; /* expected-note@? {{'RAY_FLAG_NONE' declared here}} expected-error {{use of undeclared identifier 'RAY_FLAG_INVALID'; did you mean 'RAY_FLAG_NONE'?}} */
rayFlags += RAY_FLAG_INVALID; /* expected-error {{use of undeclared identifier 'RAY_FLAG_INVALID'; did you mean 'RAY_FLAG_NONE'?}} */

int intFlag = RAY_FLAG_CULL_OPAQUE;

int hitKindFlag =
HIT_KIND_TRIANGLE_FRONT_FACE + HIT_KIND_TRIANGLE_BACK_FACE;

hitKindFlag += HIT_KIND_INVALID; /* expected-note@? {{'HIT_KIND_NONE' declared here}} expected-error {{use of undeclared identifier 'HIT_KIND_INVALID'; did you mean 'HIT_KIND_NONE'?}} */
hitKindFlag += HIT_KIND_INVALID; /* expected-error {{use of undeclared identifier 'HIT_KIND_INVALID'; did you mean 'HIT_KIND_NONE'?}} */


BuiltInTriangleIntersectionAttributes attr;
Expand Down
Loading