16#ifndef EIGEN_BFLOAT16_H
17#define EIGEN_BFLOAT16_H
20#include "../../InternalHeaderCheck.h"
22#if defined(EIGEN_HAS_HIP_BF16)
29#pragma push_macro("EIGEN_CONSTEXPR")
31#define EIGEN_CONSTEXPR
34#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
36 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED PACKET_BF16 METHOD<PACKET_BF16>( \
37 const PACKET_BF16& _x) { \
38 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
42#if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
43#define EIGEN_USE_HIP_BF16
52EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(
const uint16_t& src);
55EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(
const Eigen::bfloat16& src);
57namespace bfloat16_impl {
59#if defined(EIGEN_USE_HIP_BF16)
61struct __bfloat16_raw :
public hip_bfloat16 {
62 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
63 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(hip_bfloat16 hb) : hip_bfloat16(hb) {}
64 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(
unsigned short raw) : hip_bfloat16(raw) {}
70struct __bfloat16_raw {
71#if defined(EIGEN_HAS_HIP_BF16) && !defined(EIGEN_GPU_COMPILE_PHASE)
72 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
74 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
76 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(
unsigned short raw) : value(raw) {}
82EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(
unsigned short value);
83template <
bool AssumeArgumentIsNormalOrInfinityOrZero>
84EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(
float ff);
88EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(
float ff);
90EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(
float ff);
91EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(__bfloat16_raw h);
93struct bfloat16_base :
public __bfloat16_raw {
94 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
95 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(
const __bfloat16_raw& h) : __bfloat16_raw(h) {}
101struct bfloat16 :
public bfloat16_impl::bfloat16_base {
102 typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
104 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
106 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
108 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
bool b)
109 : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
112 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
113 : bfloat16_impl::bfloat16_base(
114 bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
116 explicit EIGEN_DEVICE_FUNC bfloat16(
float f)
117 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
121 template <
typename RealScalar>
122 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(
const std::complex<RealScalar>& val)
123 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.
real()))) {}
125 EIGEN_DEVICE_FUNC
operator float()
const {
126 return bfloat16_impl::bfloat16_to_float(*
this);
132namespace bfloat16_impl {
133template <
typename =
void>
134struct numeric_limits_bfloat16_impl {
135 static EIGEN_CONSTEXPR
const bool is_specialized =
true;
136 static EIGEN_CONSTEXPR
const bool is_signed =
true;
137 static EIGEN_CONSTEXPR
const bool is_integer =
false;
138 static EIGEN_CONSTEXPR
const bool is_exact =
false;
139 static EIGEN_CONSTEXPR
const bool has_infinity =
true;
140 static EIGEN_CONSTEXPR
const bool has_quiet_NaN =
true;
141 static EIGEN_CONSTEXPR
const bool has_signaling_NaN =
true;
142#if __cplusplus >= 202302L
143 EIGEN_DIAGNOSTICS(push)
144 EIGEN_DISABLE_DEPRECATED_WARNING
146 static EIGEN_CONSTEXPR
const std::float_denorm_style has_denorm = std::denorm_present;
147 static EIGEN_CONSTEXPR
const bool has_denorm_loss =
false;
148#if __cplusplus >= 202302L
149 EIGEN_DIAGNOSTICS(pop)
151 static EIGEN_CONSTEXPR
const std::float_round_style round_style = std::numeric_limits<float>::round_style;
152 static EIGEN_CONSTEXPR
const bool is_iec559 =
true;
155 static EIGEN_CONSTEXPR
const bool is_bounded =
true;
156 static EIGEN_CONSTEXPR
const bool is_modulo =
false;
157 static EIGEN_CONSTEXPR
const int digits = 8;
158 static EIGEN_CONSTEXPR
const int digits10 = 2;
159 static EIGEN_CONSTEXPR
const int max_digits10 = 4;
160 static EIGEN_CONSTEXPR
const int radix = std::numeric_limits<float>::radix;
161 static EIGEN_CONSTEXPR
const int min_exponent = std::numeric_limits<float>::min_exponent;
162 static EIGEN_CONSTEXPR
const int min_exponent10 = std::numeric_limits<float>::min_exponent10;
163 static EIGEN_CONSTEXPR
const int max_exponent = std::numeric_limits<float>::max_exponent;
164 static EIGEN_CONSTEXPR
const int max_exponent10 = std::numeric_limits<float>::max_exponent10;
165 static EIGEN_CONSTEXPR
const bool traps = std::numeric_limits<float>::traps;
168 static EIGEN_CONSTEXPR
const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
170 static EIGEN_CONSTEXPR Eigen::bfloat16(min)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
171 static EIGEN_CONSTEXPR Eigen::bfloat16 lowest() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
172 static EIGEN_CONSTEXPR Eigen::bfloat16(max)() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
173 static EIGEN_CONSTEXPR Eigen::bfloat16 epsilon() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
174 static EIGEN_CONSTEXPR Eigen::bfloat16 round_error() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3f00); }
175 static EIGEN_CONSTEXPR Eigen::bfloat16 infinity() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
176 static EIGEN_CONSTEXPR Eigen::bfloat16 quiet_NaN() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
177 static EIGEN_CONSTEXPR Eigen::bfloat16 signaling_NaN() {
178 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fa0);
180 static EIGEN_CONSTEXPR Eigen::bfloat16 denorm_min() {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
184EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_specialized;
186EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_signed;
188EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_integer;
190EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_exact;
192EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_infinity;
194EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_quiet_NaN;
196EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_signaling_NaN;
197#if __cplusplus >= 202302L
198EIGEN_DIAGNOSTICS(push)
199EIGEN_DISABLE_DEPRECATED_WARNING
202EIGEN_CONSTEXPR
const std::float_denorm_style numeric_limits_bfloat16_impl<T>::has_denorm;
204EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::has_denorm_loss;
205#if __cplusplus >= 202302L
206EIGEN_DIAGNOSTICS(pop)
209EIGEN_CONSTEXPR
const std::float_round_style numeric_limits_bfloat16_impl<T>::round_style;
211EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_iec559;
213EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_bounded;
215EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::is_modulo;
217EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::digits;
219EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::digits10;
221EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_digits10;
223EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::radix;
225EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::min_exponent;
227EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::min_exponent10;
229EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_exponent;
231EIGEN_CONSTEXPR
const int numeric_limits_bfloat16_impl<T>::max_exponent10;
233EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::traps;
235EIGEN_CONSTEXPR
const bool numeric_limits_bfloat16_impl<T>::tinyness_before;
245class numeric_limits<
Eigen::bfloat16> :
public Eigen::bfloat16_impl::numeric_limits_bfloat16_impl<> {};
247class numeric_limits<const
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
249class numeric_limits<volatile
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
251class numeric_limits<const volatile
Eigen::bfloat16> :
public numeric_limits<Eigen::bfloat16> {};
256namespace bfloat16_impl {
261#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
263#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
265#pragma push_macro("EIGEN_DEVICE_FUNC")
266#undef EIGEN_DEVICE_FUNC
267#if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16))
268#define EIGEN_DEVICE_FUNC __host__
270#define EIGEN_DEVICE_FUNC __host__ __device__
277EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const bfloat16& a,
const bfloat16& b) {
278 return bfloat16(
float(a) +
float(b));
280EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const bfloat16& a,
const int& b) {
281 return bfloat16(
float(a) +
static_cast<float>(b));
283EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(
const int& a,
const bfloat16& b) {
284 return bfloat16(
static_cast<float>(a) +
float(b));
286EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator*(
const bfloat16& a,
const bfloat16& b) {
287 return bfloat16(
float(a) *
float(b));
289EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(
const bfloat16& a,
const bfloat16& b) {
290 return bfloat16(
float(a) -
float(b));
292EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(
const bfloat16& a,
const bfloat16& b) {
293 return bfloat16(
float(a) /
float(b));
295EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(
const bfloat16& a) {
296 numext::uint16_t x = numext::bit_cast<uint16_t>(a) ^ 0x8000;
297 return numext::bit_cast<bfloat16>(x);
299EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator+=(bfloat16& a,
const bfloat16& b) {
300 a = bfloat16(
float(a) +
float(b));
303EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator*=(bfloat16& a,
const bfloat16& b) {
304 a = bfloat16(
float(a) *
float(b));
307EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator-=(bfloat16& a,
const bfloat16& b) {
308 a = bfloat16(
float(a) -
float(b));
311EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator/=(bfloat16& a,
const bfloat16& b) {
312 a = bfloat16(
float(a) /
float(b));
315EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
319EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
323EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a,
int) {
324 bfloat16 original_value = a;
326 return original_value;
328EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a,
int) {
329 bfloat16 original_value = a;
331 return original_value;
333EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator==(
const bfloat16& a,
const bfloat16& b) {
334 return numext::equal_strict(
float(a),
float(b));
336EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator!=(
const bfloat16& a,
const bfloat16& b) {
337 return numext::not_equal_strict(
float(a),
float(b));
339EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator<(
const bfloat16& a,
const bfloat16& b) {
340 return float(a) < float(b);
342EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator<=(
const bfloat16& a,
const bfloat16& b) {
343 return float(a) <= float(b);
345EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator>(
const bfloat16& a,
const bfloat16& b) {
346 return float(a) > float(b);
348EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
bool operator>=(
const bfloat16& a,
const bfloat16& b) {
349 return float(a) >= float(b);
352#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
353#pragma pop_macro("EIGEN_DEVICE_FUNC")
359EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(
const bfloat16& a, Index b) {
360 return bfloat16(
static_cast<float>(a) /
static_cast<float>(b));
363EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(
const float v) {
364#if defined(EIGEN_USE_HIP_BF16)
365 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(v, __bfloat16_raw::truncate));
367 __bfloat16_raw output;
368 if (numext::isnan EIGEN_NOT_A_MACRO(v)) {
369 output.value = std::signbit(v) ? 0xFFC0 : 0x7FC0;
372 output.value =
static_cast<numext::uint16_t
>(numext::bit_cast<numext::uint32_t>(v) >> 16);
377EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
378#if defined(EIGEN_USE_HIP_BF16)
383 return __bfloat16_raw(value);
387EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
388 const __bfloat16_raw& bf) {
389#if defined(EIGEN_USE_HIP_BF16)
399EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(
float ff) {
400#if defined(EIGEN_USE_HIP_BF16)
401 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
403 __bfloat16_raw output;
405 if (numext::isnan EIGEN_NOT_A_MACRO(ff)) {
411 output.value = std::signbit(ff) ? 0xFFC0 : 0x7FC0;
562 output = float_to_bfloat16_rtne<true>(ff);
573EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(
float ff) {
574#if defined(EIGEN_USE_HIP_BF16)
575 return __bfloat16_raw(__bfloat16_raw::round_to_bfloat16(ff));
577 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
578 __bfloat16_raw output;
581 numext::uint32_t lsb = (input >> 16) & 1;
582 numext::uint32_t rounding_bias = 0x7fff + lsb;
583 input += rounding_bias;
584 output.value =
static_cast<numext::uint16_t
>(input >> 16);
589EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
float bfloat16_to_float(__bfloat16_raw h) {
590#if defined(EIGEN_USE_HIP_BF16)
591 return static_cast<float>(h);
593 return numext::bit_cast<float>(
static_cast<numext::uint32_t
>(h.value) << 16);
599EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(
const bfloat16& a) {
600 EIGEN_USING_STD(isinf);
601#if defined(EIGEN_USE_HIP_BF16)
604 return (isinf)(float(a));
607EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(
const bfloat16& a) {
608 EIGEN_USING_STD(isnan);
609#if defined(EIGEN_USE_HIP_BF16)
612 return (isnan)(float(a));
615EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(
const bfloat16& a) {
616 return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a));
619EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(
const bfloat16& a) {
620 numext::uint16_t x = numext::bit_cast<numext::uint16_t>(a) & 0x7FFF;
621 return numext::bit_cast<bfloat16>(x);
623EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(
const bfloat16& a) {
return bfloat16(::expf(
float(a))); }
624EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(
const bfloat16& a) {
return bfloat16(numext::expm1(
float(a))); }
625EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(
const bfloat16& a) {
return bfloat16(::logf(
float(a))); }
626EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(
const bfloat16& a) {
return bfloat16(numext::log1p(
float(a))); }
627EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(
const bfloat16& a) {
return bfloat16(::log10f(
float(a))); }
628EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(
const bfloat16& a) {
629 return bfloat16(
static_cast<float>(EIGEN_LOG2E) * ::logf(
float(a)));
631EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(
const bfloat16& a) {
return bfloat16(::sqrtf(
float(a))); }
632EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(
const bfloat16& a,
const bfloat16& b) {
633 return bfloat16(::powf(
float(a),
float(b)));
635EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(
const bfloat16& a,
const bfloat16& b) {
636 return bfloat16(::atan2f(
float(a),
float(b)));
638EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(
const bfloat16& a) {
return bfloat16(::sinf(
float(a))); }
639EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(
const bfloat16& a) {
return bfloat16(::cosf(
float(a))); }
640EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(
const bfloat16& a) {
return bfloat16(::tanf(
float(a))); }
641EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(
const bfloat16& a) {
return bfloat16(::asinf(
float(a))); }
642EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(
const bfloat16& a) {
return bfloat16(::acosf(
float(a))); }
643EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(
const bfloat16& a) {
return bfloat16(::atanf(
float(a))); }
644EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(
const bfloat16& a) {
return bfloat16(::sinhf(
float(a))); }
645EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(
const bfloat16& a) {
return bfloat16(::coshf(
float(a))); }
646EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(
const bfloat16& a) {
return bfloat16(::tanhf(
float(a))); }
647EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(
const bfloat16& a) {
return bfloat16(::asinhf(
float(a))); }
648EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(
const bfloat16& a) {
return bfloat16(::acoshf(
float(a))); }
649EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(
const bfloat16& a) {
return bfloat16(::atanhf(
float(a))); }
650EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(
const bfloat16& a) {
return bfloat16(::floorf(
float(a))); }
651EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(
const bfloat16& a) {
return bfloat16(::ceilf(
float(a))); }
652EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(
const bfloat16& a) {
return bfloat16(::rintf(
float(a))); }
653EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(
const bfloat16& a) {
return bfloat16(::roundf(
float(a))); }
654EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 trunc(
const bfloat16& a) {
return bfloat16(::truncf(
float(a))); }
655EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(
const bfloat16& a,
const bfloat16& b) {
656 return bfloat16(::fmodf(
float(a),
float(b)));
659EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(min)(
const bfloat16& a,
const bfloat16& b) {
660 const float f1 =
static_cast<float>(a);
661 const float f2 =
static_cast<float>(b);
662 return f2 < f1 ? b : a;
665EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(max)(
const bfloat16& a,
const bfloat16& b) {
666 const float f1 =
static_cast<float>(a);
667 const float f2 =
static_cast<float>(b);
668 return f1 < f2 ? b : a;
671EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(
const bfloat16& a,
const bfloat16& b) {
672 const float f1 =
static_cast<float>(a);
673 const float f2 =
static_cast<float>(b);
674 return bfloat16(::fminf(f1, f2));
677EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(
const bfloat16& a,
const bfloat16& b) {
678 const float f1 =
static_cast<float>(a);
679 const float f2 =
static_cast<float>(b);
680 return bfloat16(::fmaxf(f1, f2));
684EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os,
const bfloat16& v) {
685 os << static_cast<float>(v);
695struct is_arithmetic<bfloat16> {
696 enum { value =
true };
700struct random_impl<bfloat16> {
701 enum :
int { MantissaBits = 7 };
702 using Impl = random_impl<float>;
703 static EIGEN_DEVICE_FUNC
inline bfloat16 run(
const bfloat16& x,
const bfloat16& y) {
704 float result = Impl::run(x, y, MantissaBits);
705 return bfloat16(result);
707 static EIGEN_DEVICE_FUNC
inline bfloat16 run() {
708 float result = Impl::run(MantissaBits);
709 return bfloat16(result);
716struct NumTraits<
Eigen::bfloat16> : GenericNumTraits<Eigen::bfloat16> {
717 enum { IsSigned =
true, IsInteger =
false, IsComplex =
false, RequireInitialization =
false };
719 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
720 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
722 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
723 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D);
725 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
726 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
728 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
729 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
731 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
732 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
734 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
735 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
741#if defined(EIGEN_HAS_HIP_BF16)
742#pragma pop_macro("EIGEN_CONSTEXPR")
749EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(
const Eigen::bfloat16& h) {
750 return (bfloat16_impl::isnan)(h);
754EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(
const Eigen::bfloat16& h) {
755 return (bfloat16_impl::isinf)(h);
759EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(
const Eigen::bfloat16& h) {
760 return (bfloat16_impl::isfinite)(h);
764EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(
const uint16_t& src) {
765 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src);
769EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(
const Eigen::bfloat16& src) {
770 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
776#if EIGEN_HAS_STD_HASH
779struct hash<
Eigen::bfloat16> {
780 EIGEN_STRONG_INLINE std::size_t operator()(
const Eigen::bfloat16& a)
const {
781 return static_cast<std::size_t
>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
798#if defined(EIGEN_HIPCC)
800#if defined(EIGEN_HAS_HIP_BF16)
802__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl(Eigen::bfloat16 var,
int srcLane,
int width = warpSize) {
803 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
804 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t
>(__shfl(ivar, srcLane, width)));
807__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_up(Eigen::bfloat16 var,
unsigned int delta,
808 int width = warpSize) {
809 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
810 return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t
>(__shfl_up(ivar, delta, width)));
813__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_down(Eigen::bfloat16 var,
unsigned int delta,
814 int width = warpSize) {
815 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
816 return Eigen::numext::bit_cast<Eigen::bfloat16>(
817 static_cast<Eigen::numext::uint16_t
>(__shfl_down(ivar, delta, width)));
820__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_xor(Eigen::bfloat16 var,
int laneMask,
int width = warpSize) {
821 const int ivar =
static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
822 return Eigen::numext::bit_cast<Eigen::bfloat16>(
823 static_cast<Eigen::numext::uint16_t
>(__shfl_xor(ivar, laneMask, width)));
830#if defined(EIGEN_HIPCC)
831EIGEN_STRONG_INLINE __device__ Eigen::bfloat16 __ldg(
const Eigen::bfloat16* ptr) {
832 return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(
833 __ldg(Eigen::numext::bit_cast<const Eigen::numext::uint16_t*>(ptr)));
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)