Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
BFloat16.h
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef EIGEN_BFLOAT16_H
17#define EIGEN_BFLOAT16_H
18
19// IWYU pragma: private
20#include "../../InternalHeaderCheck.h"
21
22#if defined(EIGEN_HAS_HIP_BF16)
23// When compiling with GPU support, the "hip_bfloat16" base class as well as
24// some other routines are defined in the GPU compiler header files
25// (hip_bfloat16.h), and they are not tagged constexpr
26// As a consequence, we get compile failures when compiling Eigen with
27// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
28// Eigen with GPU support
29#pragma push_macro("EIGEN_CONSTEXPR")
30#undef EIGEN_CONSTEXPR
31#define EIGEN_CONSTEXPR
32#endif
33
34#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
35 template <> \
36 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED PACKET_BF16 METHOD<PACKET_BF16>( \
37 const PACKET_BF16& _x) { \
38 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
39 }
40
41// Only use HIP GPU bf16 in kernels
42#if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
43#define EIGEN_USE_HIP_BF16
44#endif
45
46namespace Eigen {
47
48struct bfloat16;
49
50namespace numext {
51template <>
52EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src);
53
54template <>
55EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src);
56} // namespace numext
57namespace bfloat16_impl {
58
59#if defined(EIGEN_USE_HIP_BF16)
60
61struct __bfloat16_raw : public hip_bfloat16 {
62 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
63 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(hip_bfloat16 hb) : hip_bfloat16(hb) {}
64 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : hip_bfloat16(raw) {}
65};
66
67#else
68
69// Make our own __bfloat16_raw definition.
70struct __bfloat16_raw {
71#if defined(EIGEN_HAS_HIP_BF16) && !defined(EIGEN_GPU_COMPILE_PHASE)
72 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
73#else
74 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
75#endif
76 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
77 unsigned short value;
78};
79
80#endif // defined(EIGEN_USE_HIP_BF16)
81
82EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
83template <bool AssumeArgumentIsNormalOrInfinityOrZero>
84EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
85// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
86// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
87template <>
88EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
89template <>
90EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
91EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
92
93struct bfloat16_base : public __bfloat16_raw {
94 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
95 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
96};
97
98} // namespace bfloat16_impl
99
100// Class definition.
101struct bfloat16 : public bfloat16_impl::bfloat16_base {
102 typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
103
104 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
105
106 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
107
108 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
109 : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
110
111 template <class T>
112 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
113 : bfloat16_impl::bfloat16_base(
114 bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
115
116 explicit EIGEN_DEVICE_FUNC bfloat16(float f)
117 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
118
119 // Following the convention of numpy, converting between complex and
120 // float will lead to loss of imag value.
121 template <typename RealScalar>
122 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
123 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
124
125 EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
126 return bfloat16_impl::bfloat16_to_float(*this);
127 }
128};
129
130// TODO(majnemer): Get rid of this once we can rely on C++17 inline variables do
131// solve the ODR issue.
132namespace bfloat16_impl {
133template <typename = void>
134struct numeric_limits_bfloat16_impl {
135 static EIGEN_CONSTEXPR const bool is_specialized = true;
136 static EIGEN_CONSTEXPR const bool is_signed = true;
137 static EIGEN_CONSTEXPR const bool is_integer = false;
138 static EIGEN_CONSTEXPR const bool is_exact = false;
139 static EIGEN_CONSTEXPR const bool has_infinity = true;
140 static EIGEN_CONSTEXPR const bool has_quiet_NaN = true;
141 static EIGEN_CONSTEXPR const bool has_signaling_NaN = true;
142#if __cplusplus >= 202302L
143 EIGEN_DIAGNOSTICS(push)
144 EIGEN_DISABLE_DEPRECATED_WARNING
145#endif
146 static EIGEN_CONSTEXPR const std::float_denorm_style has_denorm = std::denorm_present;
147 static EIGEN_CONSTEXPR const bool has_denorm_loss = false;
148#if __cplusplus >= 202302L
149 EIGEN_DIAGNOSTICS(pop)
150#endif
151 static EIGEN_CONSTEXPR const std::float_round_style round_style = std::numeric_limits<float>::round_style;
152 static EIGEN_CONSTEXPR const bool is_iec559 = true;
153 // The C++ standard defines this as "true if the set of values representable
154 // by the type is finite." BFloat16 has finite precision.
155 static EIGEN_CONSTEXPR const bool is_bounded = true;
156 static EIGEN_CONSTEXPR const bool is_modulo = false;
157 static EIGEN_CONSTEXPR const int digits = 8;
158 static EIGEN_CONSTEXPR const int digits10 = 2;
159 static EIGEN_CONSTEXPR const int max_digits10 = 4;
160 static EIGEN_CONSTEXPR const int radix = std::numeric_limits<float>::radix;
161 static EIGEN_CONSTEXPR const int min_exponent = std::numeric_limits<float>::min_exponent;
162 static EIGEN_CONSTEXPR const int min_exponent10 = std::numeric_limits<float>::min_exponent10;
163 static EIGEN_CONSTEXPR const int max_exponent = std::numeric_limits<float>::max_exponent;
164 static EIGEN_CONSTEXPR const int max_exponent10 = std::numeric_limits<float>::max_exponent10;
165 static EIGEN_CONSTEXPR const bool traps = std::numeric_limits<float>::traps;
166 // IEEE754: "The implementer shall choose how tininess is detected, but shall
167 // detect tininess in the same way for all operations in radix two"
168 static EIGEN_CONSTEXPR const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
169
170 static EIGEN_CONSTEXPR Eigen::bfloat16(min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
171 static EIGEN_CONSTEXPR Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
172 static EIGEN_CONSTEXPR Eigen::bfloat16(max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
173 static EIGEN_CONSTEXPR Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
174 static EIGEN_CONSTEXPR Eigen::bfloat16 round_error() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3f00); }
175 static EIGEN_CONSTEXPR Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
176 static EIGEN_CONSTEXPR Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
177 static EIGEN_CONSTEXPR Eigen::bfloat16 signaling_NaN() {
178 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fa0);
179 }
180 static EIGEN_CONSTEXPR Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
181};
182
183template <typename T>
184EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_specialized;
185template <typename T>
186EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_signed;
187template <typename T>
188EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_integer;
189template <typename T>
190EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_exact;
191template <typename T>
192EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_infinity;
193template <typename T>
194EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_quiet_NaN;
195template <typename T>
196EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_signaling_NaN;
197#if __cplusplus >= 202302L
198EIGEN_DIAGNOSTICS(push)
199EIGEN_DISABLE_DEPRECATED_WARNING
200#endif
201template <typename T>
202EIGEN_CONSTEXPR const std::float_denorm_style numeric_limits_bfloat16_impl<T>::has_denorm;
203template <typename T>
204EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_denorm_loss;
205#if __cplusplus >= 202302L
206EIGEN_DIAGNOSTICS(pop)
207#endif
208template <typename T>
209EIGEN_CONSTEXPR const std::float_round_style numeric_limits_bfloat16_impl<T>::round_style;
210template <typename T>
211EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_iec559;
212template <typename T>
213EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_bounded;
214template <typename T>
215EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_modulo;
216template <typename T>
217EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::digits;
218template <typename T>
219EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::digits10;
220template <typename T>
221EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_digits10;
222template <typename T>
223EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::radix;
224template <typename T>
225EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::min_exponent;
226template <typename T>
227EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::min_exponent10;
228template <typename T>
229EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_exponent;
230template <typename T>
231EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_exponent10;
232template <typename T>
233EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::traps;
234template <typename T>
235EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::tinyness_before;
236} // end namespace bfloat16_impl
237} // end namespace Eigen
238
239namespace std {
240// If std::numeric_limits<T> is specialized, should also specialize
241// std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
242// std::numeric_limits<const volatile T>
243// https://stackoverflow.com/a/16519653/
244template <>
245class numeric_limits<Eigen::bfloat16> : public Eigen::bfloat16_impl::numeric_limits_bfloat16_impl<> {};
246template <>
247class numeric_limits<const Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
248template <>
249class numeric_limits<volatile Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
250template <>
251class numeric_limits<const volatile Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> {};
252} // end namespace std
253
254namespace Eigen {
255
256namespace bfloat16_impl {
257
258// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
259// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
260// of the functions, while the latter can only deal with one of them.
261#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
262
263#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
264// We need to provide emulated *host-side* BF16 operators for clang.
265#pragma push_macro("EIGEN_DEVICE_FUNC")
266#undef EIGEN_DEVICE_FUNC
267#if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16))
268#define EIGEN_DEVICE_FUNC __host__
269#else // both host and device need emulated ops.
270#define EIGEN_DEVICE_FUNC __host__ __device__
271#endif
272#endif
273
274// Definitions for CPUs, mostly working through conversion
275// to/from fp32.
276
277EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16& a, const bfloat16& b) {
278 return bfloat16(float(a) + float(b));
279}
280EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16& a, const int& b) {
281 return bfloat16(float(a) + static_cast<float>(b));
282}
283EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const int& a, const bfloat16& b) {
284 return bfloat16(static_cast<float>(a) + float(b));
285}
286EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator*(const bfloat16& a, const bfloat16& b) {
287 return bfloat16(float(a) * float(b));
288}
289EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(const bfloat16& a, const bfloat16& b) {
290 return bfloat16(float(a) - float(b));
291}
292EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(const bfloat16& a, const bfloat16& b) {
293 return bfloat16(float(a) / float(b));
294}
295EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(const bfloat16& a) {
296 numext::uint16_t x = numext::bit_cast<uint16_t>(a) ^ 0x8000;
297 return numext::bit_cast<bfloat16>(x);
298}
299EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator+=(bfloat16& a, const bfloat16& b) {
300 a = bfloat16(float(a) + float(b));
301 return a;
302}
303EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator*=(bfloat16& a, const bfloat16& b) {
304 a = bfloat16(float(a) * float(b));
305 return a;
306}
307EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator-=(bfloat16& a, const bfloat16& b) {
308 a = bfloat16(float(a) - float(b));
309 return a;
310}
311EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator/=(bfloat16& a, const bfloat16& b) {
312 a = bfloat16(float(a) / float(b));
313 return a;
314}
315EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
316 a += bfloat16(1);
317 return a;
318}
319EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
320 a -= bfloat16(1);
321 return a;
322}
323EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
324 bfloat16 original_value = a;
325 ++a;
326 return original_value;
327}
328EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
329 bfloat16 original_value = a;
330 --a;
331 return original_value;
332}
333EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const bfloat16& a, const bfloat16& b) {
334 return numext::equal_strict(float(a), float(b));
335}
336EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const bfloat16& a, const bfloat16& b) {
337 return numext::not_equal_strict(float(a), float(b));
338}
339EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const bfloat16& a, const bfloat16& b) {
340 return float(a) < float(b);
341}
342EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const bfloat16& a, const bfloat16& b) {
343 return float(a) <= float(b);
344}
345EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const bfloat16& a, const bfloat16& b) {
346 return float(a) > float(b);
347}
348EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const bfloat16& a, const bfloat16& b) {
349 return float(a) >= float(b);
350}
351
352#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
353#pragma pop_macro("EIGEN_DEVICE_FUNC")
354#endif
355#endif // Emulate support for bfloat16 floats
356
357// Division by an index. Do it in full float precision to avoid accuracy
358// issues in converting the denominator to bfloat16.
359EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(const bfloat16& a, Index b) {
360 return bfloat16(static_cast<float>(a) / static_cast<float>(b));
361}
362
363EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
364#if defined(EIGEN_USE_HIP_BF16)
365 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(v, __bfloat16_raw::truncate));
366#else
367 __bfloat16_raw output;
368 if (numext::isnan EIGEN_NOT_A_MACRO(v)) {
369 output.value = std::signbit(v) ? 0xFFC0 : 0x7FC0;
370 return output;
371 }
372 output.value = static_cast<numext::uint16_t>(numext::bit_cast<numext::uint32_t>(v) >> 16);
373 return output;
374#endif
375}
376
377EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
378#if defined(EIGEN_USE_HIP_BF16)
379 __bfloat16_raw bf;
380 bf.data = value;
381 return bf;
382#else
383 return __bfloat16_raw(value);
384#endif
385}
386
387EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
388 const __bfloat16_raw& bf) {
389#if defined(EIGEN_USE_HIP_BF16)
390 return bf.data;
391#else
392 return bf.value;
393#endif
394}
395
396// float_to_bfloat16_rtne template specialization that does not make any
397// assumption about the value of its function argument (ff).
398template <>
399EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
400#if defined(EIGEN_USE_HIP_BF16)
401 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
402#else
403 __bfloat16_raw output;
404
405 if (numext::isnan EIGEN_NOT_A_MACRO(ff)) {
406 // If the value is a NaN, squash it to a qNaN with msb of fraction set,
407 // this makes sure after truncation we don't end up with an inf.
408 //
409 // qNaN magic: All exponent bits set + most significant bit of fraction
410 // set.
411 output.value = std::signbit(ff) ? 0xFFC0 : 0x7FC0;
412 } else {
413 // Fast rounding algorithm that rounds a half value to nearest even. This
414 // reduces expected error when we convert a large number of floats. Here
415 // is how it works:
416 //
417 // Definitions:
418 // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
419 // with the following tags:
420 //
421 // Sign | Exp (8 bits) | Frac (23 bits)
422 // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
423 //
424 // S: Sign bit.
425 // E: Exponent bits.
426 // F: First 6 bits of fraction.
427 // L: Least significant bit of resulting bfloat16 if we truncate away the
428 // rest of the float32. This is also the 7th bit of fraction
429 // R: Rounding bit, 8th bit of fraction.
430 // T: Sticky bits, rest of fraction, 15 bits.
431 //
432 // To round half to nearest even, there are 3 cases where we want to round
433 // down (simply truncate the result of the bits away, which consists of
434 // rounding bit and sticky bits) and two cases where we want to round up
435 // (truncate then add one to the result).
436 //
437 // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
438 // 1s) as the rounding bias, adds the rounding bias to the input, then
439 // truncates the last 16 bits away.
440 //
441 // To understand how it works, we can analyze this algorithm case by case:
442 //
443 // 1. L = 0, R = 0:
444 // Expect: round down, this is less than half value.
445 //
446 // Algorithm:
447 // - Rounding bias: 0x7fff + 0 = 0x7fff
448 // - Adding rounding bias to input may create any carry, depending on
449 // whether there is any value set to 1 in T bits.
450 // - R may be set to 1 if there is a carry.
451 // - L remains 0.
452 // - Note that this case also handles Inf and -Inf, where all fraction
453 // bits, including L, R and Ts are all 0. The output remains Inf after
454 // this algorithm.
455 //
456 // 2. L = 1, R = 0:
457 // Expect: round down, this is less than half value.
458 //
459 // Algorithm:
460 // - Rounding bias: 0x7fff + 1 = 0x8000
461 // - Adding rounding bias to input doesn't change sticky bits but
462 // adds 1 to rounding bit.
463 // - L remains 1.
464 //
465 // 3. L = 0, R = 1, all of T are 0:
466 // Expect: round down, this is exactly at half, the result is already
467 // even (L=0).
468 //
469 // Algorithm:
470 // - Rounding bias: 0x7fff + 0 = 0x7fff
471 // - Adding rounding bias to input sets all sticky bits to 1, but
472 // doesn't create a carry.
473 // - R remains 1.
474 // - L remains 0.
475 //
476 // 4. L = 1, R = 1:
477 // Expect: round up, this is exactly at half, the result needs to be
478 // round to the next even number.
479 //
480 // Algorithm:
481 // - Rounding bias: 0x7fff + 1 = 0x8000
482 // - Adding rounding bias to input doesn't change sticky bits, but
483 // creates a carry from rounding bit.
484 // - The carry sets L to 0, creates another carry bit and propagate
485 // forward to F bits.
486 // - If all the F bits are 1, a carry then propagates to the exponent
487 // bits, which then creates the minimum value with the next exponent
488 // value. Note that we won't have the case where exponents are all 1,
489 // since that's either a NaN (handled in the other if condition) or inf
490 // (handled in case 1).
491 //
492 // 5. L = 0, R = 1, any of T is 1:
493 // Expect: round up, this is greater than half.
494 //
495 // Algorithm:
496 // - Rounding bias: 0x7fff + 0 = 0x7fff
497 // - Adding rounding bias to input creates a carry from sticky bits,
498 // sets rounding bit to 0, then create another carry.
499 // - The second carry sets L to 1.
500 //
501 // Examples:
502 //
503 // Exact half value that is already even:
504 // Input:
505 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
506 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
507 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
508 //
509 // This falls into case 3. We truncate the rest of 16 bits and no
510 // carry is created into F and L:
511 //
512 // Output:
513 // Sign | Exp (8 bit) | Frac (first 7 bit)
514 // S E E E E E E E E F F F F F F L
515 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
516 //
517 // Exact half value, round to next even number:
518 // Input:
519 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
520 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
521 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
522 //
523 // This falls into case 4. We create a carry from R and T,
524 // which then propagates into L and F:
525 //
526 // Output:
527 // Sign | Exp (8 bit) | Frac (first 7 bit)
528 // S E E E E E E E E F F F F F F L
529 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
530 //
531 //
532 // Max denormal value round to min normal value:
533 // Input:
534 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
535 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
536 // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
537 //
538 // This falls into case 4. We create a carry from R and T,
539 // propagate into L and F, which then propagates into exponent
540 // bits:
541 //
542 // Output:
543 // Sign | Exp (8 bit) | Frac (first 7 bit)
544 // S E E E E E E E E F F F F F F L
545 // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
546 //
547 // Max normal value round to Inf:
548 // Input:
549 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
550 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
551 // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
552 //
553 // This falls into case 4. We create a carry from R and T,
554 // propagate into L and F, which then propagates into exponent
555 // bits:
556 //
557 // Sign | Exp (8 bit) | Frac (first 7 bit)
558 // S E E E E E E E E F F F F F F L
559 // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
560
561 // At this point, ff must be either a normal float, or +/-infinity.
562 output = float_to_bfloat16_rtne<true>(ff);
563 }
564 return output;
565#endif
566}
567
568// float_to_bfloat16_rtne template specialization that assumes that its function
569// argument (ff) is either a normal floating point number, or +/-infinity, or
570// zero. Used to improve the runtime performance of conversion from an integer
571// type to bfloat16.
572template <>
573EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
574#if defined(EIGEN_USE_HIP_BF16)
575 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
576#else
577 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
578 __bfloat16_raw output;
579
580 // Least significant bit of resulting bfloat.
581 numext::uint32_t lsb = (input >> 16) & 1;
582 numext::uint32_t rounding_bias = 0x7fff + lsb;
583 input += rounding_bias;
584 output.value = static_cast<numext::uint16_t>(input >> 16);
585 return output;
586#endif
587}
588
589EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
590#if defined(EIGEN_USE_HIP_BF16)
591 return static_cast<float>(h);
592#else
593 return numext::bit_cast<float>(static_cast<numext::uint32_t>(h.value) << 16);
594#endif
595}
596
597// --- standard functions ---
598
599EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const bfloat16& a) {
600 EIGEN_USING_STD(isinf);
601#if defined(EIGEN_USE_HIP_BF16)
602 return (isinf)(a); // Uses HIP hip_bfloat16 isinf operator
603#else
604 return (isinf)(float(a));
605#endif
606}
607EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const bfloat16& a) {
608 EIGEN_USING_STD(isnan);
609#if defined(EIGEN_USE_HIP_BF16)
610 return (isnan)(a); // Uses HIP hip_bfloat16 isnan operator
611#else
612 return (isnan)(float(a));
613#endif
614}
615EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const bfloat16& a) {
616 return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a));
617}
618
619EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
620 numext::uint16_t x = numext::bit_cast<numext::uint16_t>(a) & 0x7FFF;
621 return numext::bit_cast<bfloat16>(x);
622}
623EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) { return bfloat16(::expf(float(a))); }
624EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) { return bfloat16(numext::expm1(float(a))); }
625EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) { return bfloat16(::logf(float(a))); }
626EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) { return bfloat16(numext::log1p(float(a))); }
627EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) { return bfloat16(::log10f(float(a))); }
628EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
629 return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
630}
631EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { return bfloat16(::sqrtf(float(a))); }
632EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
633 return bfloat16(::powf(float(a), float(b)));
634}
635EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(const bfloat16& a, const bfloat16& b) {
636 return bfloat16(::atan2f(float(a), float(b)));
637}
638EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) { return bfloat16(::sinf(float(a))); }
639EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) { return bfloat16(::cosf(float(a))); }
640EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) { return bfloat16(::tanf(float(a))); }
641EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) { return bfloat16(::asinf(float(a))); }
642EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) { return bfloat16(::acosf(float(a))); }
643EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) { return bfloat16(::atanf(float(a))); }
644EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) { return bfloat16(::sinhf(float(a))); }
645EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) { return bfloat16(::coshf(float(a))); }
646EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) { return bfloat16(::tanhf(float(a))); }
647EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) { return bfloat16(::asinhf(float(a))); }
648EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) { return bfloat16(::acoshf(float(a))); }
649EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) { return bfloat16(::atanhf(float(a))); }
650EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) { return bfloat16(::floorf(float(a))); }
651EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) { return bfloat16(::ceilf(float(a))); }
652EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) { return bfloat16(::rintf(float(a))); }
653EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) { return bfloat16(::roundf(float(a))); }
654EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 trunc(const bfloat16& a) { return bfloat16(::truncf(float(a))); }
655EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
656 return bfloat16(::fmodf(float(a), float(b)));
657}
658
659EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(min)(const bfloat16& a, const bfloat16& b) {
660 const float f1 = static_cast<float>(a);
661 const float f2 = static_cast<float>(b);
662 return f2 < f1 ? b : a;
663}
664
665EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(max)(const bfloat16& a, const bfloat16& b) {
666 const float f1 = static_cast<float>(a);
667 const float f2 = static_cast<float>(b);
668 return f1 < f2 ? b : a;
669}
670
671EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
672 const float f1 = static_cast<float>(a);
673 const float f2 = static_cast<float>(b);
674 return bfloat16(::fminf(f1, f2));
675}
676
677EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
678 const float f1 = static_cast<float>(a);
679 const float f2 = static_cast<float>(b);
680 return bfloat16(::fmaxf(f1, f2));
681}
682
683#ifndef EIGEN_NO_IO
684EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const bfloat16& v) {
685 os << static_cast<float>(v);
686 return os;
687}
688#endif
689
690} // namespace bfloat16_impl
691
692namespace internal {
693
694template <>
695struct is_arithmetic<bfloat16> {
696 enum { value = true };
697};
698
699template <>
700struct random_impl<bfloat16> {
701 enum : int { MantissaBits = 7 };
702 using Impl = random_impl<float>;
703 static EIGEN_DEVICE_FUNC inline bfloat16 run(const bfloat16& x, const bfloat16& y) {
704 float result = Impl::run(x, y, MantissaBits);
705 return bfloat16(result);
706 }
707 static EIGEN_DEVICE_FUNC inline bfloat16 run() {
708 float result = Impl::run(MantissaBits);
709 return bfloat16(result);
710 }
711};
712
713} // namespace internal
714
715template <>
716struct NumTraits<Eigen::bfloat16> : GenericNumTraits<Eigen::bfloat16> {
717 enum { IsSigned = true, IsInteger = false, IsComplex = false, RequireInitialization = false };
718
719 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
720 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
721 }
722 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
723 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f);
724 }
725 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
726 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
727 }
728 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
729 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
730 }
731 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
732 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
733 }
734 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
735 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
736 }
737};
738
739} // namespace Eigen
740
741#if defined(EIGEN_HAS_HIP_BF16)
742#pragma pop_macro("EIGEN_CONSTEXPR")
743#endif
744
745namespace Eigen {
746namespace numext {
747
748template <>
749EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(const Eigen::bfloat16& h) {
750 return (bfloat16_impl::isnan)(h);
751}
752
753template <>
754EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(const Eigen::bfloat16& h) {
755 return (bfloat16_impl::isinf)(h);
756}
757
758template <>
759EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const Eigen::bfloat16& h) {
760 return (bfloat16_impl::isfinite)(h);
761}
762
763template <>
764EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
765 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src);
766}
767
768template <>
769EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
770 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
771}
772
773} // namespace numext
774} // namespace Eigen
775
776#if EIGEN_HAS_STD_HASH
777namespace std {
778template <>
779struct hash<Eigen::bfloat16> {
780 EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
781 return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
782 }
783};
784} // namespace std
785#endif
786
787// Add the missing shfl* intrinsics.
788// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
789// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
790//
791// HIP and CUDA prior to SDK 9.0 define
792// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
793// CUDA since 9.0 deprecates those and instead defines
794// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
795// with native support for __half and __nv_bfloat16
796//
797// Note that the following are __device__ - only functions.
798#if defined(EIGEN_HIPCC)
799
800#if defined(EIGEN_HAS_HIP_BF16)
801
802__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl(Eigen::bfloat16 var, int srcLane, int width = warpSize) {
803 const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
804 return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
805}
806
807__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_up(Eigen::bfloat16 var, unsigned int delta,
808 int width = warpSize) {
809 const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
810 return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
811}
812
813__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_down(Eigen::bfloat16 var, unsigned int delta,
814 int width = warpSize) {
815 const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
816 return Eigen::numext::bit_cast<Eigen::bfloat16>(
817 static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
818}
819
820__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_xor(Eigen::bfloat16 var, int laneMask, int width = warpSize) {
821 const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
822 return Eigen::numext::bit_cast<Eigen::bfloat16>(
823 static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
824}
825
826#endif // HIP
827
828#endif // __shfl*
829
830#if defined(EIGEN_HIPCC)
831EIGEN_STRONG_INLINE __device__ Eigen::bfloat16 __ldg(const Eigen::bfloat16* ptr) {
832 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(
833 __ldg(Eigen::numext::bit_cast<const Eigen::numext::uint16_t*>(ptr)));
834}
835#endif // __ldg
836
837#endif // EIGEN_BFLOAT16_H
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)