forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 21
/
bfloat16.h
690 lines (576 loc) · 19 KB
/
bfloat16.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Defines a proxy class for storing non-standard 16-bit floating point values with
8 bits of exponent and 7 bit of mantissa.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include "cutlass/floating_point_nvrtc.h"
#else
#include <cmath>
#include <limits>
#include <cstdint>
#include <cstring>
#endif
#if !defined(CUTLASS_ENABLE_SYCL)
#include <cuda_bf16.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Floating-point type with 8 bits of exponent and 7 bits of mantissa.
struct alignas(2) bfloat16_t {
//
// Data members
//
/// Storage type
uint16_t storage;
//
// Methods
//
/// Constructs from an unsigned short
CUTLASS_HOST_DEVICE
static bfloat16_t bitcast(uint16_t x) {
bfloat16_t h;
h.storage = x;
return h;
}
private:
struct from_32_bit_integer_t {};
static constexpr from_32_bit_integer_t from_32_bit_integer{};
template<class T>
CUTLASS_HOST_DEVICE
explicit bfloat16_t(from_32_bit_integer_t, T x) {
static_assert(cutlass::platform::is_integral<T>::value && sizeof(T) == 4, "Requires 32-bit integer");
float flt = static_cast<float>(x);
uint32_t bits;
#if defined(__CUDA_ARCH__)
bits = reinterpret_cast<uint32_t &>(flt);
#else
std::memcpy(&bits, &flt, sizeof(bits));
#endif
storage = uint16_t(bits >> 16);
}
public:
/// Default constructor
bfloat16_t() = default;
#if !defined(CUTLASS_ENABLE_SYCL)
/// Reinterpret cast from CUDA's __nv_bfloat16 type
CUTLASS_HOST_DEVICE
explicit bfloat16_t(__nv_bfloat16 const & x) {
#if defined(__CUDA_ARCH__)
storage = reinterpret_cast<uint16_t const &>(x);
#else
__nv_bfloat16_raw raw(x);
std::memcpy(&storage, &raw.x, sizeof(storage));
#endif
}
#endif
/// Floating-point conversion - round toward nearest
CUTLASS_HOST_DEVICE
explicit bfloat16_t(float x) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x));
#elif defined(CUTLASS_ENABLE_SYCL)
storage = sycl::ext::oneapi::detail::bfloat16ToBits(sycl::ext::oneapi::bfloat16(x));
#else
uint32_t bits;
#if defined(__CUDA_ARCH__)
bits = reinterpret_cast<uint32_t &>(x);
#else
std::memcpy(&bits, &x, sizeof(bits));
#endif
if ((bits & 0x7f800000) != 0x7f800000) {
bool mantissa_bit = ((bits & (1 << 16)) != 0);
bool round_bit = ((bits & (1 << 15)) != 0);
bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0);
if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) {
bits += uint32_t(1 << 16);
}
}
else if (bits & ~0xff800000) {
bits = 0x7fffffff;
}
storage = uint16_t((bits >> 16) & 0xffff);
#endif
}
/// Floating-point conversion - round toward nearest
CUTLASS_HOST_DEVICE
explicit bfloat16_t(double x): bfloat16_t(float(x)) {
}
/// Integer conversion - round toward nearest
CUTLASS_HOST_DEVICE
explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {}
CUTLASS_HOST_DEVICE
explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {}
/// Converts to float
CUTLASS_HOST_DEVICE
operator float() const {
unsigned bits = (unsigned(storage) << 16);
#if defined(__CUDA_ARCH__)
return reinterpret_cast<float const &>(bits);
#else
float flt;
std::memcpy(&flt, &bits, sizeof(flt));
return flt;
#endif
}
/// Converts to float
CUTLASS_HOST_DEVICE
explicit operator double() const {
return double(float(*this));
}
/// Converts to int
CUTLASS_HOST_DEVICE
explicit operator int() const {
return int(float(*this));
}
/// Casts to bool
CUTLASS_HOST_DEVICE
explicit operator bool() const {
return (float(*this) != 0.0f);
}
#if !defined(CUTLASS_ENABLE_SYCL)
/// Bitcasts to CUDA's bf16 type
CUTLASS_DEVICE
__nv_bfloat16 to_nv_bfloat16() const {
return reinterpret_cast<__nv_bfloat16 const &>(storage);
}
#endif
/// Obtains raw bits
CUTLASS_HOST_DEVICE
uint16_t raw() const {
return storage;
}
/// Returns the sign bit
CUTLASS_HOST_DEVICE
bool signbit() const {
return ((raw() & 0x8000) != 0);
}
/// Returns the biased exponent
CUTLASS_HOST_DEVICE
int exponent_biased() const {
return int((raw() >> 7) & 0x0ff);
}
/// Returns the unbiased exponent
CUTLASS_HOST_DEVICE
int exponent() const {
return exponent_biased() - 127;
}
/// Returns the mantissa
CUTLASS_HOST_DEVICE
int mantissa() const {
return int(raw() & 0x7f);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
bool signbit(cutlass::bfloat16_t const& h) {
return h.signbit();
}
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) {
return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fff);
}
CUTLASS_HOST_DEVICE
bool isnan(cutlass::bfloat16_t const& h) {
return (h.exponent_biased() == 0x0ff) && h.mantissa();
}
CUTLASS_HOST_DEVICE
bool isfinite(cutlass::bfloat16_t const& h) {
return (h.exponent_biased() != 0x0ff);
}
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t nan_bf16(const char*) {
// NVIDIA canonical NaN
return cutlass::bfloat16_t::bitcast(0x7fff);
}
CUTLASS_HOST_DEVICE
bool isinf(cutlass::bfloat16_t const& h) {
return (h.exponent_biased() == 0x0ff) && !h.mantissa();
}
CUTLASS_HOST_DEVICE
bool isnormal(cutlass::bfloat16_t const& h) {
return h.exponent_biased() && h.exponent_biased() != 0x0ff;
}
CUTLASS_HOST_DEVICE
int fpclassify(cutlass::bfloat16_t const& h) {
int exp = h.exponent_biased();
int mantissa = h.mantissa();
if (exp == 0x0ff) {
if (mantissa) {
return FP_NAN;
}
else {
return FP_INFINITE;
}
}
else if (!exp) {
if (mantissa) {
return FP_SUBNORMAL;
}
else {
return FP_ZERO;
}
}
return FP_NORMAL;
}
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) {
#if defined(__CUDACC_RTC__)
return cutlass::bfloat16_t(sqrtf(float(h)));
#else
return cutlass::bfloat16_t(std::sqrt(float(h)));
#endif
}
CUTLASS_HOST_DEVICE
bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) {
uint16_t a_bits;
uint16_t b_bits;
#if defined(__CUDA_ARCH__)
a_bits = reinterpret_cast<uint16_t const &>(a);
b_bits = reinterpret_cast<uint16_t const &>(b);
#else
std::memcpy(&a_bits, &a, sizeof(a_bits));
std::memcpy(&b_bits, &b, sizeof(b_bits));
#endif
uint16_t a_mag = (a_bits & 0x7fff);
uint16_t b_sign = (b_bits & 0x8000);
uint16_t result = (a_mag | b_sign);
return bfloat16_t::bitcast(result);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Standard Library operations and definitions
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(__CUDACC_RTC__)
namespace std {
/// Numeric limits
template <>
struct numeric_limits<cutlass::bfloat16_t> {
static bool const is_specialized = true;
static bool const is_signed = true;
static bool const is_integer = false;
static bool const is_exact = false;
static bool const has_infinity = true;
static bool const has_quiet_NaN = true;
static bool const has_signaling_NaN = false;
static std::float_denorm_style const has_denorm = std::denorm_present;
static bool const has_denorm_loss = true;
static std::float_round_style const round_style = std::round_to_nearest;
static bool const is_iec559 = false;
static bool const is_bounded = true;
static bool const is_modulo = false;
static int const digits = 7;
/// Least positive value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); }
/// Minimum finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); }
/// Maximum finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); }
};
} // namespace std
#endif
namespace cutlass {
namespace platform {
/// Forward Declaration
template <class T>
struct numeric_limits;
/// Numeric limits
template <>
struct numeric_limits<cutlass::bfloat16_t> {
static bool const is_specialized = true;
static bool const is_signed = true;
static bool const is_integer = false;
static bool const is_exact = false;
static bool const has_infinity = true;
static bool const has_quiet_NaN = true;
static bool const has_signaling_NaN = false;
#if !defined(__CUDACC_RTC__)
static std::float_denorm_style const has_denorm = std::denorm_present;
#endif
static bool const has_denorm_loss = true;
#if !defined(__CUDACC_RTC__)
static std::float_round_style const round_style = std::round_to_nearest;
#endif
static bool const is_iec559 = false;
static bool const is_bounded = true;
static bool const is_modulo = false;
static int const digits = 7;
/// Least positive value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); }
/// Minimum finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); }
/// Maximum finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); }
};
} // namespace platform
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Arithmetic operators
//
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return __heq(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
#else
return float(lhs) == float(rhs);
#endif
}
CUTLASS_HOST_DEVICE
bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return __hne(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
#else
return float(lhs) != float(rhs);
#endif
}
CUTLASS_HOST_DEVICE
bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return __hlt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
#else
return float(lhs) < float(rhs);
#endif
}
CUTLASS_HOST_DEVICE
bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return __hle(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
#else
return float(lhs) <= float(rhs);
#endif
}
CUTLASS_HOST_DEVICE
bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return __hgt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
#else
return float(lhs) > float(rhs);
#endif
}
CUTLASS_HOST_DEVICE
bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return __hge(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16());
#else
return float(lhs) >= float(rhs);
#endif
}
CUTLASS_HOST_DEVICE
bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
#else
return bfloat16_t(float(lhs) + float(rhs));
#endif
}
CUTLASS_HOST_DEVICE
bfloat16_t operator-(bfloat16_t const& lhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_t(__hneg(lhs.to_nv_bfloat16()));
#else
return bfloat16_t(-float(lhs));
#endif
}
CUTLASS_HOST_DEVICE
bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
#else
return bfloat16_t(float(lhs) - float(rhs));
#endif
}
CUTLASS_HOST_DEVICE
bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
#else
return bfloat16_t(float(lhs) * float(rhs));
#endif
}
CUTLASS_HOST_DEVICE
bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
#else
return bfloat16_t(float(lhs) / float(rhs));
#endif
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
#else
lhs = bfloat16_t(float(lhs) + float(rhs));
#endif
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
#else
lhs = bfloat16_t(float(lhs) - float(rhs));
#endif
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
lhs = bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
#else
lhs = bfloat16_t(float(lhs) * float(rhs));
#endif
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
lhs = bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()));
#else
lhs = bfloat16_t(float(lhs) / float(rhs));
#endif
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator++(bfloat16_t & lhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
#else
float tmp(lhs);
++tmp;
lhs = bfloat16_t(tmp);
#endif
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator--(bfloat16_t & lhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
#else
float tmp(lhs);
--tmp;
lhs = bfloat16_t(tmp);
#endif
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t operator++(bfloat16_t & lhs, int) {
bfloat16_t ret(lhs);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
#else
float tmp(lhs);
tmp++;
lhs = bfloat16_t(tmp);
#endif
return ret;
}
CUTLASS_HOST_DEVICE
bfloat16_t operator--(bfloat16_t & lhs, int) {
bfloat16_t ret(lhs);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16()));
#else
float tmp(lhs);
tmp--;
lhs = bfloat16_t(tmp);
#endif
return ret;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// User-defined literals
//
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t operator ""_bf16(long double x) {
return cutlass::bfloat16_t(float(x));
}
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t operator ""_bf16(unsigned long long int x) {
return cutlass::bfloat16_t(int(x));
}
/////////////////////////////////////////////////////////////////////////////////////////////////