Skip to content

[HLSL][DXIL] Implement refract intrinsic #147342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 16, 2025
Merged

Conversation

raoanag
Copy link
Contributor

@raoanag raoanag commented Jul 7, 2025

  • Implement refract using HLSL source in hlsl_intrinsics.h
  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td
  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp
  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl
  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c
  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
  • Check for what OpenCL support is needed.

Resolves #99153

Copy link

github-actions bot commented Jul 7, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@raoanag raoanag marked this pull request as ready for review July 7, 2025 16:27
Copy link

github-actions bot commented Jul 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-backend-spir-v

Author: None (raoanag)

Changes
  • Implement refract using HLSL source in hlsl_intrinsics.h
  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td
  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp
  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl
  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c
  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
  • Check for what OpenCL support is needed.

Resolves #99153


Patch is 57.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147342.diff

19 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsSPIRVVK.td (+1)
  • (modified) clang/include/clang/Sema/Sema.h (+24)
  • (modified) clang/lib/CodeGen/TargetBuiltins/SPIR.cpp (+15)
  • (modified) clang/lib/Headers/hlsl/hlsl_detail.h (+8)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+36)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+59)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+105)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+64-9)
  • (modified) clang/lib/Sema/SemaSPIRV.cpp (+36-56)
  • (modified) clang/test/CodeGenHLSL/builtins/reflect.hlsl (+1-1)
  • (added) clang/test/CodeGenHLSL/builtins/refract.hlsl (+271)
  • (added) clang/test/CodeGenSPIRV/Builtins/refract.c (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl (+74)
  • (added) clang/test/SemaSPIRV/BuiltIns/refract-errors.c (+23)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/IR/IRBuilder.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll (+36)
  • (added) llvm/test/CodeGen/SPIRV/opencl/refract-error.ll (+12)
diff --git a/clang/include/clang/Basic/BuiltinsSPIRVVK.td b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
index 61cc0343c415e..5dc3c7588cd2a 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRVVK.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
@@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"
 
 def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
 def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
+def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 3fe26f950ad51..105ab804fffd0 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2791,6 +2791,30 @@ class Sema final : public SemaBase {
 
   void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
 
+  /// CheckVectorArgs - Check that the arguments of a vector function call
+  bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
+
+  bool CheckVectorArgs(CallExpr *TheCall);
+
+  bool CheckAllArgTypesAreCorrect(
+      Sema *S, CallExpr *TheCall,
+      llvm::ArrayRef<
+          llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+          Checks);
+  bool CheckAllArgTypesAreCorrect(
+      Sema *S, CallExpr *TheCall,
+      llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
+
+  static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                            int ArgOrdinal,
+                                            clang::QualType PassedType);
+  static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType);
+
+  static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                int ArgOrdinal,
+                                                clang::QualType PassedType);
   /// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
   /// TheCall is a constant expression.
   bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 0687485cd3f80..1c63e04f757c7 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
         ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    Value *I = EmitScalarExpr(E->getArg(0));
+    Value *N = EmitScalarExpr(E->getArg(1));
+    Value *eta = EmitScalarExpr(E->getArg(2));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->isFloatingType() &&
+           "refract operands must have a float representation");
+    assert(E->getArg(0)->getType()->isVectorType() &&
+           E->getArg(1)->getType()->isVectorType() &&
+           "refract I and N operands must be a vector");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+        ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     Value *Min = EmitScalarExpr(E->getArg(0));
     Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 80c4900121dfb..96e101a1e3aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -45,6 +45,14 @@ template <typename T> struct is_arithmetic {
   static const bool Value = __is_arithmetic(T);
 };
 
+template <typename T> struct is_vector {
+  static const bool value = false;
+};
+
+template <typename T, int N> struct is_vector<vector<T, N>> {
+  static const bool value = true;
+};
+
 template <typename T, int N>
 using HLSL_FIXED_VECTOR =
     vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 4eb7b8f45c85a..f6acb1cea2594 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,42 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
+  T Mul = N * I;
+  T K = 1 - Eta * Eta * (1 - (Mul * Mul));
+  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<T>(K < 0, static_cast<T>(0), Result);
+}
+
+template <typename T, typename U>
+constexpr T refract_vec_impl(T I, T N, U Eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  if (is_vector<T>::value) {
+    return __builtin_spirv_refract(I, N, Eta);
+  }
+#else
+  T Mul = dot(N, I);
+  T K = 1 - Eta * Eta * (1 - Mul * Mul);
+  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<T>(K < 0, static_cast<T>(0), Result);
+#endif
+}
+
+/*
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
+#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>))
+  return __builtin_spirv_refract(I, N, Eta);
+#else
+  T Mul = dot(N, I);
+  vector<T, L> K = 1 - Eta * Eta * (1 - Mul * Mul);
+  vector<T, L> Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
+#endif
+}
+
+*/
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ea880105fac3b..8c262ffce25f1 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -475,6 +475,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
   return __detail::reflect_vec_impl(I, N);
 }
 
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+    __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+    __detail::HLSL_FIXED_VECTOR<half, L> I,
+    __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+        __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
 //===----------------------------------------------------------------------===//
 // smoothstep builtin
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index dd5b710d7e1d4..98bca59f14ecd 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -16151,3 +16151,108 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
     }
   }
 }
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
+  for (unsigned i = 0; i < NumArgsToCheck; ++i) {
+    ExprResult Arg = TheCall->getArg(i);
+    QualType ArgTy = Arg.get()->getType();
+    auto *VTy = ArgTy->getAs<VectorType>();
+    if (VTy == nullptr) {
+      SemaRef.Diag(Arg.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTy
+          << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+  }
+  return false;
+}
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall) {
+  return CheckVectorArgs(TheCall, TheCall->getNumArgs());
+}
+
+
+bool Sema::CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  if (Checks.size() == 1) {
+    // Apply the single check to all arguments
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else if (Checks.size() == NumArgs) {
+    // Apply each check to the corresponding argument
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else {
+    // Mismatch: error or fallback
+    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+        << NumArgs << Checks.size();
+    return true;
+  }
+}
+
+bool Sema::CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
+}
+
+bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                                  int ArgOrdinal,
+                                                  clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+bool Sema::CheckFloatOrHalfScalarRepresentation(
+    Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bad357b50929b..991d330edfb6f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2401,17 +2401,40 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckAllArgTypesAreCorrect(
+bool CheckAllArgTypesAreCorrect(
     Sema *S, CallExpr *TheCall,
-    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
-                            clang::QualType PassedType)>
-        Check) {
-  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
-    Expr *Arg = TheCall->getArg(I);
-    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
-      return true;
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  if (Checks.size() == 1) {
+    // Apply the single check to all arguments
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else if (Checks.size() == NumArgs) {
+    // Apply each check to the corresponding argument
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else {
+    // Mismatch: error or fallback
+    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+        << NumArgs << Checks.size();
+    return true;
   }
-  return false;
+}
+
+bool CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
 }
 
 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
@@ -2428,6 +2451,38 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
   return false;
 }
 
+static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType = 
+      PassedType->isVectorType()
+        ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
                                   unsigned ArgIndex) {
   auto *Arg = TheCall->getArg(ArgIndex);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index c27d3fed2b990..1b4093065a63a 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -157,81 +157,61 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    // Use the helper function to check both arguments
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
-    }
 
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-
-    QualType RetTy = VTyA->getElementType();
+    QualType RetTy =
+        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
   case SPIRV::BI__builtin_spirv_length: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTy = ArgTyA->getAs<VectorType>();
-    if (VTy == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+
+    // Use the helper function to check the argument
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
-    }
-    QualType RetTy = VTy->getElementType();
+
+    QualType RetTy =
+        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
-  case SPIRV::BI__builtin_spirv_reflect: {
-    if (SemaRef.checkArgCount(TheCall, 2))
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
+        ChecksArr[] = {Sema::CheckFloatOrHalfVectorsRepresentation,
+                       Sema::CheckFloatOrHalfVectorsRepresentation,
+                       Sema::CheckFloatOrHalfScalarRepresentation};
+    if (SemaRef.CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                           llvm::ArrayRef(ChecksArr)))
       return true;
-    }
 
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->isFloatingType...
[truncated]

@raoanag raoanag force-pushed the user/raoanag/refract branch from 515ecda to 729fbf3 Compare July 8, 2025 23:03
return true;

llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
ChecksArr[] = {CheckFloatOrHalfRepresentation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused on if this is meant to handle scalar values as well as vectors? Looking at the code gen, we are asserting that the first two arguments are vectors, but here we allow them to be scalars. @farzonl Does this handle only the case where the first two arguments are vectors?

If that is the case 'CheckFloatOrHalfRepresentation' should be updated to only check for vectors of half or float and should probably be renamed to 'CheckFloatOrHalfVecRepresentation'.

Copy link
Contributor Author

@raoanag raoanag Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to support vector of size 1, which is implicitly converted to scalar.
Also HLSL_FIXED_VECTOR only supports Vector of N > 1.

Hence, even though first 2 args are described as vector for N = 1 they are seen as scalar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your explanation is slightly incorrect, but it does seem the __builtin_spirv_refract is reachable with a scalar value. In this case the codegen assertions are wrong and I will leave a comment there about updating them.

Copy link
Member

@farzonl farzonl Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll review this tomorrow, but supporting scalars here seems wrong. I’m almost 100% sure that spirv via dxc only supports vectors and that our semantics should match that.

Copy link
Contributor

@spall spall Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Refract spirv op says both scalar and vector are supported. https://registry.khronos.org/SPIR-V/specs/unified1/GLSL.std.450.pdf (search for Refract).
But it is up to us what we want to allow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

referring the khornos doc, SPIR asserts for vectorness for all the intrinsics - distance, length, reflect, smoothstep, faceforward would need to be updated since they all mention operands must all be a scalar or vector whose component type is floating-point.. SPIR.cpp

Looking into how E->getArg(0)->getType()->isVectorType()

isVectorType() checks for vector-ness, not size.
• Vectors of size 1 are technically allowed

Just sharing observations here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had some issues with spirv-val trying to do scalars with those other spirv opcodes. Also i’m pretty sure dxc doesn’t call the glsl instruction in many of the scalar cases. We have been using dxcs behavior as our default spec since we don’t have one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to open bug tickets for distance, length, & reflect. smoothstep and faceforward are generating the spirv glsl ext op for scalar cases: https://godbolt.org/z/fYf6zbGc5.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spirv doesn't support vectors of size 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support spirv glsl ext op with scalar args for distance, length, reflect, smoothstep & faceforward

I am not able to add relevant flags on this issue, but created one for tracking

typedef float float3 __attribute__((ext_vector_type(3)));
typedef float float4 __attribute__((ext_vector_type(4)));

// CHECK-LABEL: define spir_func <2 x float> @test_refract_float2(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@spall should we add f16 tests to this file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thats probably a good idea to ensure the _Float16 is handled properly in codegen.

Copy link
Contributor

@spall spall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@spall spall merged commit 056f0a1 into llvm:main Jul 16, 2025
10 checks passed
Copy link

@raoanag Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 16, 2025

LLVM Buildbot has detected a new failure on builder arc-builder running on arc-worker while building clang,llvm at step 6 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/3/builds/19138

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'LLVM :: CodeGen/X86/sse2-intrinsics-fast-isel.ll' FAILED ********************
Exit Code: 1

Command Output (stderr):
--
/buildbot/worker/arc-folder/build/bin/llc < /buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll -show-mc-encoding -fast-isel -mtriple=i386-unknown-unknown -mattr=+sse2 | /buildbot/worker/arc-folder/build/bin/FileCheck /buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll --check-prefixes=CHECK,X86,SSE,X86-SSE # RUN: at line 2
+ /buildbot/worker/arc-folder/build/bin/FileCheck /buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll --check-prefixes=CHECK,X86,SSE,X86-SSE
+ /buildbot/worker/arc-folder/build/bin/llc -show-mc-encoding -fast-isel -mtriple=i386-unknown-unknown -mattr=+sse2
LLVM ERROR: Cannot select: intrinsic %llvm.x86.sse2.clflush
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: /buildbot/worker/arc-folder/build/bin/llc -show-mc-encoding -fast-isel -mtriple=i386-unknown-unknown -mattr=+sse2
1.	Running pass 'Function Pass Manager' on module '<stdin>'.
2.	Running pass 'X86 DAG->DAG Instruction Selection' on function '@test_mm_clflush'
 #0 0x00000000023228c8 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/buildbot/worker/arc-folder/build/bin/llc+0x23228c8)
 #1 0x000000000231f7d5 SignalHandler(int, siginfo_t*, void*) Signals.cpp:0:0
 #2 0x00007f8cf3f3b630 __restore_rt sigaction.c:0:0
 #3 0x00007f8cf2c8b3d7 raise (/usr/lib64/libc.so.6+0x363d7)
 #4 0x00007f8cf2c8cac8 abort (/usr/lib64/libc.so.6+0x37ac8)
 #5 0x000000000071ae9b llvm::json::operator==(llvm::json::Value const&, llvm::json::Value const&) (.cold) JSON.cpp:0:0
 #6 0x00000000020b2739 llvm::SelectionDAGISel::CannotYetSelect(llvm::SDNode*) (/buildbot/worker/arc-folder/build/bin/llc+0x20b2739)
 #7 0x00000000020b730a llvm::SelectionDAGISel::SelectCodeCommon(llvm::SDNode*, unsigned char const*, unsigned int) (/buildbot/worker/arc-folder/build/bin/llc+0x20b730a)
 #8 0x00000000009590a7 (anonymous namespace)::X86DAGToDAGISel::Select(llvm::SDNode*) X86ISelDAGToDAG.cpp:0:0
 #9 0x00000000020adfff llvm::SelectionDAGISel::DoInstructionSelection() (/buildbot/worker/arc-folder/build/bin/llc+0x20adfff)
#10 0x00000000020bddf8 llvm::SelectionDAGISel::CodeGenAndEmitDAG() (/buildbot/worker/arc-folder/build/bin/llc+0x20bddf8)
#11 0x00000000020c1a6a llvm::SelectionDAGISel::SelectAllBasicBlocks(llvm::Function const&) (/buildbot/worker/arc-folder/build/bin/llc+0x20c1a6a)
#12 0x00000000020c26c5 llvm::SelectionDAGISel::runOnMachineFunction(llvm::MachineFunction&) (/buildbot/worker/arc-folder/build/bin/llc+0x20c26c5)
#13 0x00000000020ad81f llvm::SelectionDAGISelLegacy::runOnMachineFunction(llvm::MachineFunction&) (/buildbot/worker/arc-folder/build/bin/llc+0x20ad81f)
#14 0x00000000012016b7 llvm::MachineFunctionPass::runOnFunction(llvm::Function&) (.part.0) MachineFunctionPass.cpp:0:0
#15 0x000000000185e852 llvm::FPPassManager::runOnFunction(llvm::Function&) (/buildbot/worker/arc-folder/build/bin/llc+0x185e852)
#16 0x000000000185ebf1 llvm::FPPassManager::runOnModule(llvm::Module&) (/buildbot/worker/arc-folder/build/bin/llc+0x185ebf1)
#17 0x000000000185f807 llvm::legacy::PassManagerImpl::run(llvm::Module&) (/buildbot/worker/arc-folder/build/bin/llc+0x185f807)
#18 0x00000000007f7d82 compileModule(char**, llvm::LLVMContext&) llc.cpp:0:0
#19 0x0000000000723396 main (/buildbot/worker/arc-folder/build/bin/llc+0x723396)
#20 0x00007f8cf2c77555 __libc_start_main (/usr/lib64/libc.so.6+0x22555)
#21 0x00000000007edfc6 _start (/buildbot/worker/arc-folder/build/bin/llc+0x7edfc6)
/buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll:399:14: error: SSE-LABEL: expected string not found in input
; SSE-LABEL: test_mm_bsrli_si128:
             ^
<stdin>:170:21: note: scanning from here
test_mm_bslli_si128: # @test_mm_bslli_si128
                    ^
<stdin>:178:9: note: possible intended match here
 .globl test_mm_bsrli_si128 # 
        ^

Input file: <stdin>
Check file: /buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll

-dump-input=help explains the following input dump.
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement the refract HLSL Function
6 participants