Skip to content

Commit 4efa3dc

Browse files
authored
Merge HLSLHalf_t and HLSLBool_t (#7630)
This PR merges some more long vector exec test code from staging-sm6.9 into main. Specifically, we bring over the helper classes that define data types for half and bool. Halfs are only available in newer c++ versions so a simple class was needed to implement the proper logic using existing DX helpers that were added for this same reason. The bool class is used as the size of a bool in c++ differs from that in HLSL. Also brings in some tests cases using these data types. Test cases were verified locally by running against WARP. Addresses #7546
1 parent a11702e commit 4efa3dc

File tree

5 files changed

+383
-13
lines changed

5 files changed

+383
-13
lines changed

tools/clang/unittests/HLSLExec/LongVectorOpTable.xml

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,23 @@
1212
<ParameterType Name="DataType">String</ParameterType>
1313
<ParameterType Name="OpTypeEnum">String</ParameterType>
1414
</ParameterTypes>
15+
<!-- LongVectorBinaryOpTypeTable DataType: bool -->
16+
<Row Name="ScalarAdd_bool">
17+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarAdd</Parameter>
18+
<Parameter Name="DataType">bool</Parameter>
19+
</Row>
20+
<Row Name="Add_bool">
21+
<Parameter Name="OpTypeEnum">BinaryOpType_Add</Parameter>
22+
<Parameter Name="DataType">bool</Parameter>
23+
</Row>
24+
<Row Name="ScalarSubtract_bool">
25+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarSubtract</Parameter>
26+
<Parameter Name="DataType">bool</Parameter>
27+
</Row>
28+
<Row Name="Subtract_bool">
29+
<Parameter Name="OpTypeEnum">BinaryOpType_Subtract</Parameter>
30+
<Parameter Name="DataType">bool</Parameter>
31+
</Row>
1532
<!-- LongVectorBinaryOpTypeTable DataType: int16 -->
1633
<Row Name="ScalarAdd_int16">
1734
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarAdd</Parameter>
@@ -354,6 +371,63 @@
354371
<Parameter Name="OpTypeEnum">BinaryOpType_Max</Parameter>
355372
<Parameter Name="DataType">uint64</Parameter>
356373
</Row>
374+
<!-- LongVectorBinaryOpTypeTable DataType: float16 -->
375+
<Row Name="ScalarAdd_float16">
376+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarAdd</Parameter>
377+
<Parameter Name="DataType">float16</Parameter>
378+
</Row>
379+
<Row Name="Add_float16">
380+
<Parameter Name="OpTypeEnum">BinaryOpType_Add</Parameter>
381+
<Parameter Name="DataType">float16</Parameter>
382+
</Row>
383+
<Row Name="ScalarSubtract_float16">
384+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarSubtract</Parameter>
385+
<Parameter Name="DataType">float16</Parameter>
386+
</Row>
387+
<Row Name="Subtract_float16">
388+
<Parameter Name="OpTypeEnum">BinaryOpType_Subtract</Parameter>
389+
<Parameter Name="DataType">float16</Parameter>
390+
</Row>
391+
<Row Name="ScalarMultiply_float16">
392+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarMultiply</Parameter>
393+
<Parameter Name="DataType">float16</Parameter>
394+
</Row>
395+
<Row Name="Multiply_float16">
396+
<Parameter Name="OpTypeEnum">BinaryOpType_Multiply</Parameter>
397+
<Parameter Name="DataType">float16</Parameter>
398+
</Row>
399+
<Row Name="ScalarDivide_float16">
400+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarDivide</Parameter>
401+
<Parameter Name="DataType">float16</Parameter>
402+
</Row>
403+
<Row Name="Divide_float16">
404+
<Parameter Name="OpTypeEnum">BinaryOpType_Divide</Parameter>
405+
<Parameter Name="DataType">float16</Parameter>
406+
</Row>
407+
<Row Name="ScalarModulus_float16">
408+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarModulus</Parameter>
409+
<Parameter Name="DataType">float16</Parameter>
410+
</Row>
411+
<Row Name="Modulus_float16">
412+
<Parameter Name="OpTypeEnum">BinaryOpType_Modulus</Parameter>
413+
<Parameter Name="DataType">float16</Parameter>
414+
</Row>
415+
<Row Name="ScalarMin_float16">
416+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarMin</Parameter>
417+
<Parameter Name="DataType">float16</Parameter>
418+
</Row>
419+
<Row Name="Min_float16">
420+
<Parameter Name="OpTypeEnum">BinaryOpType_Min</Parameter>
421+
<Parameter Name="DataType">float16</Parameter>
422+
</Row>
423+
<Row Name="ScalarMax_float16">
424+
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarMax</Parameter>
425+
<Parameter Name="DataType">float16</Parameter>
426+
</Row>
427+
<Row Name="Max_float16">
428+
<Parameter Name="OpTypeEnum">BinaryOpType_Max</Parameter>
429+
<Parameter Name="DataType">float16</Parameter>
430+
</Row>
357431
<!-- LongVectorBinaryOpTypeTable DataType: float32 -->
358432
<Row Name="ScalarAdd_float32">
359433
<Parameter Name="OpTypeEnum">BinaryOpType_ScalarAdd</Parameter>
@@ -471,6 +545,11 @@
471545
<ParameterType Name="DataType">String</ParameterType>
472546
<ParameterType Name="OpTypeEnum">String</ParameterType>
473547
</ParameterTypes>
548+
<!-- LongVectorUnaryOpTypeTable DataType: bool -->
549+
<Row Name="Initialize_bool">
550+
<Parameter Name="OpTypeEnum">UnaryOpType_Initialize</Parameter>
551+
<Parameter Name="DataType">bool</Parameter>
552+
</Row>
474553
<!-- LongVectorUnaryOpTypeTable DataType: int16 -->
475554
<Row Name="Initialize_int16">
476555
<Parameter Name="OpTypeEnum">UnaryOpType_Initialize</Parameter>
@@ -501,6 +580,11 @@
501580
<Parameter Name="OpTypeEnum">UnaryOpType_Initialize</Parameter>
502581
<Parameter Name="DataType">uint64</Parameter>
503582
</Row>
583+
<!-- LongVectorUnaryOpTypeTable DataType: float16 -->
584+
<Row Name="Initialize_float16">
585+
<Parameter Name="OpTypeEnum">UnaryOpType_Initialize</Parameter>
586+
<Parameter Name="DataType">float16</Parameter>
587+
</Row>
504588
<!-- LongVectorUnaryOpTypeTable DataType: float32 -->
505589
<Row Name="Initialize_float32">
506590
<Parameter Name="OpTypeEnum">UnaryOpType_Initialize</Parameter>

tools/clang/unittests/HLSLExec/LongVectorTestData.h

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,204 @@
77
#include <string>
88
#include <vector>
99

10+
// A helper struct because C++ bools are 1 byte and HLSL bools are 4 bytes.
11+
// Take int32_t as a constuctor argument and convert it to bool when needed.
12+
// Comparisons cast to a bool because we only care if the bool representation is
13+
// true or false.
14+
struct HLSLBool_t {
15+
HLSLBool_t() : Val(0) {}
16+
HLSLBool_t(int32_t Val) : Val(Val) {}
17+
HLSLBool_t(bool Val) : Val(Val) {}
18+
HLSLBool_t(const HLSLBool_t &Other) : Val(Other.Val) {}
19+
20+
bool operator==(const HLSLBool_t &Other) const {
21+
return static_cast<bool>(Val) == static_cast<bool>(Other.Val);
22+
}
23+
24+
bool operator!=(const HLSLBool_t &Other) const {
25+
return static_cast<bool>(Val) != static_cast<bool>(Other.Val);
26+
}
27+
28+
bool operator<(const HLSLBool_t &Other) const { return Val < Other.Val; }
29+
30+
bool operator>(const HLSLBool_t &Other) const { return Val > Other.Val; }
31+
32+
bool operator<=(const HLSLBool_t &Other) const { return Val <= Other.Val; }
33+
34+
bool operator>=(const HLSLBool_t &Other) const { return Val >= Other.Val; }
35+
36+
HLSLBool_t operator*(const HLSLBool_t &Other) const {
37+
return HLSLBool_t(Val * Other.Val);
38+
}
39+
40+
HLSLBool_t operator+(const HLSLBool_t &Other) const {
41+
return HLSLBool_t(Val + Other.Val);
42+
}
43+
44+
HLSLBool_t operator-(const HLSLBool_t &Other) const {
45+
return HLSLBool_t(Val - Other.Val);
46+
}
47+
48+
HLSLBool_t operator/(const HLSLBool_t &Other) const {
49+
return HLSLBool_t(Val / Other.Val);
50+
}
51+
52+
HLSLBool_t operator%(const HLSLBool_t &Other) const {
53+
return HLSLBool_t(Val % Other.Val);
54+
}
55+
56+
// So we can construct std::wstrings using std::wostream
57+
friend std::wostream &operator<<(std::wostream &Os, const HLSLBool_t &Obj) {
58+
Os << static_cast<bool>(Obj.Val);
59+
return Os;
60+
}
61+
62+
// So we can construct std::strings using std::ostream
63+
friend std::ostream &operator<<(std::ostream &Os, const HLSLBool_t &Obj) {
64+
Os << static_cast<bool>(Obj.Val);
65+
return Os;
66+
}
67+
68+
int32_t Val = 0;
69+
};
70+
71+
// No native float16 type in C++ until C++23 . So we use uint16_t to represent
72+
// it. Simple little wrapping struct to help handle the right behavior.
73+
struct HLSLHalf_t {
74+
HLSLHalf_t() : Val(0) {}
75+
HLSLHalf_t(DirectX::PackedVector::HALF Val) : Val(Val) {}
76+
HLSLHalf_t(const HLSLHalf_t &Other) : Val(Other.Val) {}
77+
HLSLHalf_t(const float F) {
78+
Val = DirectX::PackedVector::XMConvertFloatToHalf(F);
79+
}
80+
HLSLHalf_t(const double D) {
81+
float F = 0.0f;
82+
// We wrap '::max' in () to prevent it from being expanded as a
83+
// macro by the Windows SDK.
84+
if (D >= (std::numeric_limits<double>::max)())
85+
F = (std::numeric_limits<float>::max)();
86+
else if (D <= std::numeric_limits<double>::lowest())
87+
F = std::numeric_limits<float>::lowest();
88+
else
89+
F = static_cast<float>(D);
90+
91+
Val = DirectX::PackedVector::XMConvertFloatToHalf(F);
92+
}
93+
HLSLHalf_t(const int I) {
94+
VERIFY_IS_TRUE(I == 0, L"HLSLHalf_t constructor with int override only "
95+
L"meant for cases when initializing to 0.");
96+
const float F = static_cast<float>(I);
97+
Val = DirectX::PackedVector::XMConvertFloatToHalf(F);
98+
}
99+
100+
// Implicit conversion to float for use with things like std::acos, std::tan,
101+
// etc
102+
operator float() const {
103+
return DirectX::PackedVector::XMConvertHalfToFloat(Val);
104+
}
105+
106+
bool operator==(const HLSLHalf_t &Other) const {
107+
// Convert to floats to properly handle the '0 == -0' case which must
108+
// compare to true but have different uint16_t values.
109+
// That is, 0 == -0 is true. We store Val as a uint16_t.
110+
const float A = DirectX::PackedVector::XMConvertHalfToFloat(Val);
111+
const float B = DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
112+
return A == B;
113+
}
114+
115+
bool operator<(const HLSLHalf_t &Other) const {
116+
return DirectX::PackedVector::XMConvertHalfToFloat(Val) <
117+
DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
118+
}
119+
120+
bool operator>(const HLSLHalf_t &Other) const {
121+
return DirectX::PackedVector::XMConvertHalfToFloat(Val) >
122+
DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
123+
}
124+
125+
// Used by tolerance checks in the tests.
126+
bool operator>(float F) const {
127+
const float A = DirectX::PackedVector::XMConvertHalfToFloat(Val);
128+
return A > F;
129+
}
130+
131+
bool operator<(float F) const {
132+
const float A = DirectX::PackedVector::XMConvertHalfToFloat(Val);
133+
return A < F;
134+
}
135+
136+
bool operator<=(const HLSLHalf_t &Other) const {
137+
return DirectX::PackedVector::XMConvertHalfToFloat(Val) <=
138+
DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
139+
}
140+
141+
bool operator>=(const HLSLHalf_t &Other) const {
142+
return DirectX::PackedVector::XMConvertHalfToFloat(Val) >=
143+
DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
144+
}
145+
146+
bool operator!=(const HLSLHalf_t &Other) const { return Val != Other.Val; }
147+
148+
HLSLHalf_t operator*(const HLSLHalf_t &Other) const {
149+
const float A = DirectX::PackedVector::XMConvertHalfToFloat(Val);
150+
const float B = DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
151+
return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(A * B));
152+
}
153+
154+
HLSLHalf_t operator+(const HLSLHalf_t &Other) const {
155+
const float A = DirectX::PackedVector::XMConvertHalfToFloat(Val);
156+
const float B = DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
157+
return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(A + B));
158+
}
159+
160+
HLSLHalf_t operator-(const HLSLHalf_t &Other) const {
161+
const float A = DirectX::PackedVector::XMConvertHalfToFloat(Val);
162+
const float B = DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
163+
return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(A - B));
164+
}
165+
166+
HLSLHalf_t operator/(const HLSLHalf_t &Other) const {
167+
const float A = DirectX::PackedVector::XMConvertHalfToFloat(Val);
168+
const float B = DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
169+
return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(A / B));
170+
}
171+
172+
HLSLHalf_t operator%(const HLSLHalf_t &Other) const {
173+
const float A = DirectX::PackedVector::XMConvertHalfToFloat(Val);
174+
const float B = DirectX::PackedVector::XMConvertHalfToFloat(Other.Val);
175+
const float C = std::fmod(A, B);
176+
return HLSLHalf_t(DirectX::PackedVector::XMConvertFloatToHalf(C));
177+
}
178+
179+
// So we can construct std::wstrings using std::wostream
180+
friend std::wostream &operator<<(std::wostream &Os, const HLSLHalf_t &Obj) {
181+
Os << DirectX::PackedVector::XMConvertHalfToFloat(Obj.Val);
182+
return Os;
183+
}
184+
185+
// So we can construct std::wstrings using std::wostream
186+
friend std::ostream &operator<<(std::ostream &Os, const HLSLHalf_t &Obj) {
187+
Os << DirectX::PackedVector::XMConvertHalfToFloat(Obj.Val);
188+
return Os;
189+
}
190+
191+
// HALF is an alias to uint16_t
192+
DirectX::PackedVector::HALF Val = 0;
193+
};
194+
10195
template <typename T> struct LongVectorTestData {
11196
static const std::map<std::wstring, std::vector<T>> Data;
12197
};
13198

199+
template <> struct LongVectorTestData<HLSLBool_t> {
200+
inline static const std::map<std::wstring, std::vector<HLSLBool_t>> Data = {
201+
{L"DefaultInputValueSet1",
202+
{false, true, false, false, false, false, true, true, true, true}},
203+
{L"DefaultInputValueSet2",
204+
{true, false, false, false, false, true, true, true, false, false}},
205+
};
206+
};
207+
14208
template <> struct LongVectorTestData<int16_t> {
15209
inline static const std::map<std::wstring, std::vector<int16_t>> Data = {
16210
{L"DefaultInputValueSet1", {-6, 1, 7, 3, 8, 4, -3, 8, 8, -2}},
@@ -53,12 +247,36 @@ template <> struct LongVectorTestData<uint64_t> {
53247
};
54248
};
55249

250+
template <> struct LongVectorTestData<HLSLHalf_t> {
251+
inline static const std::map<std::wstring, std::vector<HLSLHalf_t>> Data = {
252+
{L"DefaultInputValueSet1",
253+
{-1.0, -1.0, 1.0, -0.01, 1.0, -0.01, 1.0, -0.01, 1.0, -0.01}},
254+
{L"DefaultInputValueSet2",
255+
{1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}},
256+
{L"DefaultClampArgs", {-1.0, 1.0}}, // Min, Max values for clamp
257+
// Range [ -pi/2, pi/2]
258+
{L"TrigonometricInputValueSet_RangeHalfPi",
259+
{-1.073, 0.044, -1.047, 0.313, 1.447, -0.865, 1.364, -0.715, -0.800,
260+
0.541}},
261+
{L"TrigonometricInputValueSet_RangeOne",
262+
{0.331, 0.727, -0.957, 0.677, -0.025, 0.495, 0.855, -0.673, -0.678,
263+
-0.905}},
264+
};
265+
};
266+
56267
template <> struct LongVectorTestData<float> {
57268
inline static const std::map<std::wstring, std::vector<float>> Data = {
58269
{L"DefaultInputValueSet1",
59270
{1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}},
60271
{L"DefaultInputValueSet2",
61272
{1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}},
273+
// Range [ -pi/2, pi/2]
274+
{L"TrigonometricInputValueSet_RangeHalfPi",
275+
{0.315f, -0.316f, 1.409f, -0.09f, -1.569f, 1.302f, -0.326f, 0.781f,
276+
-1.235f, 0.623f}},
277+
{L"TrigonometricInputValueSet_RangeOne",
278+
{0.727f, 0.331f, -0.957f, 0.677f, -0.025f, 0.495f, 0.855f, -0.673f,
279+
-0.678f, -0.905f}},
62280
};
63281
};
64282

@@ -68,7 +286,13 @@ template <> struct LongVectorTestData<double> {
68286
{1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}},
69287
{L"DefaultInputValueSet2",
70288
{1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0}},
71-
};
289+
// Range [ -pi/2, pi/2]
290+
{L"TrigonometricInputValueSet_RangeHalfPi",
291+
{0.807, 0.605, 1.317, 0.188, 1.566, -1.507, 0.67, -1.553, 0.194,
292+
-0.883}},
293+
{L"TrigonometricInputValueSet_RangeOne",
294+
{0.331, 0.277, -0.957, 0.677, -0.025, 0.495, 0.855, -0.673, -0.678,
295+
-0.905}}};
72296
};
73297

74298
#endif // LONGVECTORTESTDATA_H

tools/clang/unittests/HLSLExec/LongVectors.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ void LongVector::OpTest::dispatchTestByDataType(
110110
TableParameterHandler &Handler) {
111111
using namespace WEX::Common;
112112

113-
if (DataType == L"int16")
113+
if (DataType == L"bool")
114+
dispatchTestByVectorSize<HLSLBool_t>(OpType, Handler);
115+
else if (DataType == L"int16")
114116
dispatchTestByVectorSize<int16_t>(OpType, Handler);
115117
else if (DataType == L"int32")
116118
dispatchTestByVectorSize<int32_t>(OpType, Handler);
@@ -122,6 +124,8 @@ void LongVector::OpTest::dispatchTestByDataType(
122124
dispatchTestByVectorSize<uint32_t>(OpType, Handler);
123125
else if (DataType == L"uint64")
124126
dispatchTestByVectorSize<uint64_t>(OpType, Handler);
127+
else if (DataType == L"float16")
128+
dispatchTestByVectorSize<HLSLHalf_t>(OpType, Handler);
125129
else if (DataType == L"float32")
126130
dispatchTestByVectorSize<float>(OpType, Handler);
127131
else if (DataType == L"float64")

0 commit comments

Comments
 (0)