Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
GPU/PacketMath.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_PACKET_MATH_GPU_H
11#define EIGEN_PACKET_MATH_GPU_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20// Read-only data cached load available.
21#if defined(EIGEN_HIP_DEVICE_COMPILE) || (defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 350)
22#define EIGEN_GPU_HAS_LDG 1
23#endif
24
25// FP16 math available.
26#if (defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530)
27#define EIGEN_CUDA_HAS_FP16_ARITHMETIC 1
28#endif
29
30#if defined(EIGEN_HIP_DEVICE_COMPILE) || defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)
31#define EIGEN_GPU_HAS_FP16_ARITHMETIC 1
32#endif
33
34// Make sure this is only available when targeting a GPU: we don't want to
35// introduce conflicts between these packet_traits definitions and the ones
36// we'll use on the host side (SSE, AVX, ...)
37#if defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU)
38
39template <>
40struct is_arithmetic<float4> {
41 enum { value = true };
42};
43template <>
44struct is_arithmetic<double2> {
45 enum { value = true };
46};
47
48template <>
49struct packet_traits<float> : default_packet_traits {
50 typedef float4 type;
51 typedef float4 half;
52 enum {
53 Vectorizable = 1,
54 AlignedOnScalar = 1,
55 size = 4,
56
57 HasDiv = 1,
58 HasSin = 0,
59 HasCos = 0,
60 HasLog = 1,
61 HasExp = 1,
62 HasSqrt = 1,
63 HasRsqrt = 1,
64 HasLGamma = 1,
65 HasDiGamma = 1,
66 HasZeta = 1,
67 HasPolygamma = 1,
68 HasErf = 1,
69 HasErfc = 1,
70 HasNdtri = 1,
71 HasBessel = 1,
72 HasIGamma = 1,
73 HasIGammaDerA = 1,
74 HasGammaSampleDerAlpha = 1,
75 HasIGammac = 1,
76 HasBetaInc = 1,
77 HasBlend = 0
78 };
79};
80
81template <>
82struct packet_traits<double> : default_packet_traits {
83 typedef double2 type;
84 typedef double2 half;
85 enum {
86 Vectorizable = 1,
87 AlignedOnScalar = 1,
88 size = 2,
89
90 HasDiv = 1,
91 HasLog = 1,
92 HasExp = 1,
93 HasSqrt = 1,
94 HasRsqrt = 1,
95 HasLGamma = 1,
96 HasDiGamma = 1,
97 HasZeta = 1,
98 HasPolygamma = 1,
99 HasErf = 1,
100 HasErfc = 1,
101 HasNdtri = 1,
102 HasBessel = 1,
103 HasIGamma = 1,
104 HasIGammaDerA = 1,
105 HasGammaSampleDerAlpha = 1,
106 HasIGammac = 1,
107 HasBetaInc = 1,
108 HasBlend = 0,
109 };
110};
111
112template <>
113struct unpacket_traits<float4> {
114 typedef float type;
115 enum {
116 size = 4,
117 alignment = Aligned16,
118 vectorizable = true,
119 masked_load_available = false,
120 masked_store_available = false
121 };
122 typedef float4 half;
123};
124template <>
125struct unpacket_traits<double2> {
126 typedef double type;
127 enum {
128 size = 2,
129 alignment = Aligned16,
130 vectorizable = true,
131 masked_load_available = false,
132 masked_store_available = false
133 };
134 typedef double2 half;
135};
136
137template <>
138EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pset1<float4>(const float& from) {
139 return make_float4(from, from, from, from);
140}
141template <>
142EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pset1<double2>(const double& from) {
143 return make_double2(from, from);
144}
145
146// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
147// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
148// of the functions, while the latter can only deal with one of them.
149#if defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
150
151EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_and(const float& a, const float& b) {
152 return __int_as_float(__float_as_int(a) & __float_as_int(b));
153}
154EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_and(const double& a, const double& b) {
155 return __longlong_as_double(__double_as_longlong(a) & __double_as_longlong(b));
156}
157
158EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_or(const float& a, const float& b) {
159 return __int_as_float(__float_as_int(a) | __float_as_int(b));
160}
161EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_or(const double& a, const double& b) {
162 return __longlong_as_double(__double_as_longlong(a) | __double_as_longlong(b));
163}
164
165EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_xor(const float& a, const float& b) {
166 return __int_as_float(__float_as_int(a) ^ __float_as_int(b));
167}
168EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_xor(const double& a, const double& b) {
169 return __longlong_as_double(__double_as_longlong(a) ^ __double_as_longlong(b));
170}
171
172EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_andnot(const float& a, const float& b) {
173 return __int_as_float(__float_as_int(a) & ~__float_as_int(b));
174}
175EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_andnot(const double& a, const double& b) {
176 return __longlong_as_double(__double_as_longlong(a) & ~__double_as_longlong(b));
177}
178EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float eq_mask(const float& a, const float& b) {
179 return __int_as_float(a == b ? 0xffffffffu : 0u);
180}
181EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double eq_mask(const double& a, const double& b) {
182 return __longlong_as_double(a == b ? 0xffffffffffffffffull : 0ull);
183}
184
185EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float lt_mask(const float& a, const float& b) {
186 return __int_as_float(a < b ? 0xffffffffu : 0u);
187}
188
189EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double lt_mask(const double& a, const double& b) {
190 return __longlong_as_double(a < b ? 0xffffffffffffffffull : 0ull);
191}
192
193EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float le_mask(const float& a, const float& b) {
194 return __int_as_float(a <= b ? 0xffffffffu : 0u);
195}
196
197EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double le_mask(const double& a, const double& b) {
198 return __longlong_as_double(a <= b ? 0xffffffffffffffffull : 0ull);
199}
200
201template <>
202EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pand<float4>(const float4& a, const float4& b) {
203 return make_float4(bitwise_and(a.x, b.x), bitwise_and(a.y, b.y), bitwise_and(a.z, b.z), bitwise_and(a.w, b.w));
204}
205template <>
206EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pand<double2>(const double2& a, const double2& b) {
207 return make_double2(bitwise_and(a.x, b.x), bitwise_and(a.y, b.y));
208}
209
210template <>
211EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 por<float4>(const float4& a, const float4& b) {
212 return make_float4(bitwise_or(a.x, b.x), bitwise_or(a.y, b.y), bitwise_or(a.z, b.z), bitwise_or(a.w, b.w));
213}
214template <>
215EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 por<double2>(const double2& a, const double2& b) {
216 return make_double2(bitwise_or(a.x, b.x), bitwise_or(a.y, b.y));
217}
218
219template <>
220EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pxor<float4>(const float4& a, const float4& b) {
221 return make_float4(bitwise_xor(a.x, b.x), bitwise_xor(a.y, b.y), bitwise_xor(a.z, b.z), bitwise_xor(a.w, b.w));
222}
223template <>
224EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pxor<double2>(const double2& a, const double2& b) {
225 return make_double2(bitwise_xor(a.x, b.x), bitwise_xor(a.y, b.y));
226}
227
228template <>
229EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pandnot<float4>(const float4& a, const float4& b) {
230 return make_float4(bitwise_andnot(a.x, b.x), bitwise_andnot(a.y, b.y), bitwise_andnot(a.z, b.z),
231 bitwise_andnot(a.w, b.w));
232}
233template <>
234EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pandnot<double2>(const double2& a, const double2& b) {
235 return make_double2(bitwise_andnot(a.x, b.x), bitwise_andnot(a.y, b.y));
236}
237
238template <>
239EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_eq<float4>(const float4& a, const float4& b) {
240 return make_float4(eq_mask(a.x, b.x), eq_mask(a.y, b.y), eq_mask(a.z, b.z), eq_mask(a.w, b.w));
241}
242template <>
243EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_lt<float4>(const float4& a, const float4& b) {
244 return make_float4(lt_mask(a.x, b.x), lt_mask(a.y, b.y), lt_mask(a.z, b.z), lt_mask(a.w, b.w));
245}
246template <>
247EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_le<float4>(const float4& a, const float4& b) {
248 return make_float4(le_mask(a.x, b.x), le_mask(a.y, b.y), le_mask(a.z, b.z), le_mask(a.w, b.w));
249}
250template <>
251EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pcmp_eq<double2>(const double2& a, const double2& b) {
252 return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y));
253}
254template <>
255EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pcmp_lt<double2>(const double2& a, const double2& b) {
256 return make_double2(lt_mask(a.x, b.x), lt_mask(a.y, b.y));
257}
258template <>
259EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pcmp_le<double2>(const double2& a, const double2& b) {
260 return make_double2(le_mask(a.x, b.x), le_mask(a.y, b.y));
261}
262#endif // defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG &&
263 // !EIGEN_COMP_NVCC)
264
265template <>
266EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset<float4>(const float& a) {
267 return make_float4(a, a + 1, a + 2, a + 3);
268}
269template <>
270EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 plset<double2>(const double& a) {
271 return make_double2(a, a + 1);
272}
273
274template <>
275EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 padd<float4>(const float4& a, const float4& b) {
276 return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
277}
278template <>
279EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 padd<double2>(const double2& a, const double2& b) {
280 return make_double2(a.x + b.x, a.y + b.y);
281}
282
283template <>
284EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 psub<float4>(const float4& a, const float4& b) {
285 return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
286}
287template <>
288EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 psub<double2>(const double2& a, const double2& b) {
289 return make_double2(a.x - b.x, a.y - b.y);
290}
291
292template <>
293EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pnegate(const float4& a) {
294 return make_float4(-a.x, -a.y, -a.z, -a.w);
295}
296template <>
297EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pnegate(const double2& a) {
298 return make_double2(-a.x, -a.y);
299}
300
301template <>
302EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pconj(const float4& a) {
303 return a;
304}
305template <>
306EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pconj(const double2& a) {
307 return a;
308}
309
310template <>
311EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmul<float4>(const float4& a, const float4& b) {
312 return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
313}
314template <>
315EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmul<double2>(const double2& a, const double2& b) {
316 return make_double2(a.x * b.x, a.y * b.y);
317}
318
319template <>
320EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pdiv<float4>(const float4& a, const float4& b) {
321 return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
322}
323template <>
324EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pdiv<double2>(const double2& a, const double2& b) {
325 return make_double2(a.x / b.x, a.y / b.y);
326}
327
328template <>
329EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmin<float4>(const float4& a, const float4& b) {
330 return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w));
331}
332template <>
333EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmin<double2>(const double2& a, const double2& b) {
334 return make_double2(fmin(a.x, b.x), fmin(a.y, b.y));
335}
336
337template <>
338EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmax<float4>(const float4& a, const float4& b) {
339 return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w));
340}
341template <>
342EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmax<double2>(const double2& a, const double2& b) {
343 return make_double2(fmax(a.x, b.x), fmax(a.y, b.y));
344}
345
346template <>
347EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pload<float4>(const float* from) {
348 return *reinterpret_cast<const float4*>(from);
349}
350
351template <>
352EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pload<double2>(const double* from) {
353 return *reinterpret_cast<const double2*>(from);
354}
355
356template <>
357EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 ploadu<float4>(const float* from) {
358 return make_float4(from[0], from[1], from[2], from[3]);
359}
360template <>
361EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 ploadu<double2>(const double* from) {
362 return make_double2(from[0], from[1]);
363}
364
365template <>
366EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 ploaddup<float4>(const float* from) {
367 return make_float4(from[0], from[0], from[1], from[1]);
368}
369template <>
370EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 ploaddup<double2>(const double* from) {
371 return make_double2(from[0], from[0]);
372}
373
374template <>
375EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<float>(float* to, const float4& from) {
376 *reinterpret_cast<float4*>(to) = from;
377}
378
379template <>
380EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<double>(double* to, const double2& from) {
381 *reinterpret_cast<double2*>(to) = from;
382}
383
384template <>
385EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const float4& from) {
386 to[0] = from.x;
387 to[1] = from.y;
388 to[2] = from.z;
389 to[3] = from.w;
390}
391
392template <>
393EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const double2& from) {
394 to[0] = from.x;
395 to[1] = from.y;
396}
397
398template <>
399EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float4 ploadt_ro<float4, Aligned>(const float* from) {
400#if defined(EIGEN_GPU_HAS_LDG)
401 return __ldg(reinterpret_cast<const float4*>(from));
402#else
403 return make_float4(from[0], from[1], from[2], from[3]);
404#endif
405}
406template <>
407EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double2 ploadt_ro<double2, Aligned>(const double* from) {
408#if defined(EIGEN_GPU_HAS_LDG)
409 return __ldg(reinterpret_cast<const double2*>(from));
410#else
411 return make_double2(from[0], from[1]);
412#endif
413}
414
415template <>
416EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float4 ploadt_ro<float4, Unaligned>(const float* from) {
417#if defined(EIGEN_GPU_HAS_LDG)
418 return make_float4(__ldg(from + 0), __ldg(from + 1), __ldg(from + 2), __ldg(from + 3));
419#else
420 return make_float4(from[0], from[1], from[2], from[3]);
421#endif
422}
423template <>
424EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double2 ploadt_ro<double2, Unaligned>(const double* from) {
425#if defined(EIGEN_GPU_HAS_LDG)
426 return make_double2(__ldg(from + 0), __ldg(from + 1));
427#else
428 return make_double2(from[0], from[1]);
429#endif
430}
431
432template <>
433EIGEN_DEVICE_FUNC inline float4 pgather<float, float4>(const float* from, Index stride) {
434 return make_float4(from[0 * stride], from[1 * stride], from[2 * stride], from[3 * stride]);
435}
436
437template <>
438EIGEN_DEVICE_FUNC inline double2 pgather<double, double2>(const double* from, Index stride) {
439 return make_double2(from[0 * stride], from[1 * stride]);
440}
441
442template <>
443EIGEN_DEVICE_FUNC inline void pscatter<float, float4>(float* to, const float4& from, Index stride) {
444 to[stride * 0] = from.x;
445 to[stride * 1] = from.y;
446 to[stride * 2] = from.z;
447 to[stride * 3] = from.w;
448}
449template <>
450EIGEN_DEVICE_FUNC inline void pscatter<double, double2>(double* to, const double2& from, Index stride) {
451 to[stride * 0] = from.x;
452 to[stride * 1] = from.y;
453}
454
455template <>
456EIGEN_DEVICE_FUNC inline float pfirst<float4>(const float4& a) {
457 return a.x;
458}
459template <>
460EIGEN_DEVICE_FUNC inline double pfirst<double2>(const double2& a) {
461 return a.x;
462}
463
464template <>
465EIGEN_DEVICE_FUNC inline float predux<float4>(const float4& a) {
466 return a.x + a.y + a.z + a.w;
467}
468template <>
469EIGEN_DEVICE_FUNC inline double predux<double2>(const double2& a) {
470 return a.x + a.y;
471}
472
473template <>
474EIGEN_DEVICE_FUNC inline float predux_max<float4>(const float4& a) {
475 return fmaxf(fmaxf(a.x, a.y), fmaxf(a.z, a.w));
476}
477template <>
478EIGEN_DEVICE_FUNC inline double predux_max<double2>(const double2& a) {
479 return fmax(a.x, a.y);
480}
481
482template <>
483EIGEN_DEVICE_FUNC inline float predux_min<float4>(const float4& a) {
484 return fminf(fminf(a.x, a.y), fminf(a.z, a.w));
485}
486template <>
487EIGEN_DEVICE_FUNC inline double predux_min<double2>(const double2& a) {
488 return fmin(a.x, a.y);
489}
490
491template <>
492EIGEN_DEVICE_FUNC inline float predux_mul<float4>(const float4& a) {
493 return a.x * a.y * a.z * a.w;
494}
495template <>
496EIGEN_DEVICE_FUNC inline double predux_mul<double2>(const double2& a) {
497 return a.x * a.y;
498}
499
500template <>
501EIGEN_DEVICE_FUNC inline float4 pabs<float4>(const float4& a) {
502 return make_float4(fabsf(a.x), fabsf(a.y), fabsf(a.z), fabsf(a.w));
503}
504template <>
505EIGEN_DEVICE_FUNC inline double2 pabs<double2>(const double2& a) {
506 return make_double2(fabs(a.x), fabs(a.y));
507}
508
509template <>
510EIGEN_DEVICE_FUNC inline float4 pfloor<float4>(const float4& a) {
511 return make_float4(floorf(a.x), floorf(a.y), floorf(a.z), floorf(a.w));
512}
513template <>
514EIGEN_DEVICE_FUNC inline double2 pfloor<double2>(const double2& a) {
515 return make_double2(floor(a.x), floor(a.y));
516}
517
518template <>
519EIGEN_DEVICE_FUNC inline float4 pceil<float4>(const float4& a) {
520 return make_float4(ceilf(a.x), ceilf(a.y), ceilf(a.z), ceilf(a.w));
521}
522template <>
523EIGEN_DEVICE_FUNC inline double2 pceil<double2>(const double2& a) {
524 return make_double2(ceil(a.x), ceil(a.y));
525}
526
527template <>
528EIGEN_DEVICE_FUNC inline float4 print<float4>(const float4& a) {
529 return make_float4(rintf(a.x), rintf(a.y), rintf(a.z), rintf(a.w));
530}
531template <>
532EIGEN_DEVICE_FUNC inline double2 print<double2>(const double2& a) {
533 return make_double2(rint(a.x), rint(a.y));
534}
535
536template <>
537EIGEN_DEVICE_FUNC inline float4 ptrunc<float4>(const float4& a) {
538 return make_float4(truncf(a.x), truncf(a.y), truncf(a.z), truncf(a.w));
539}
540template <>
541EIGEN_DEVICE_FUNC inline double2 ptrunc<double2>(const double2& a) {
542 return make_double2(trunc(a.x), trunc(a.y));
543}
544
545EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<float4, 4>& kernel) {
546 float tmp = kernel.packet[0].y;
547 kernel.packet[0].y = kernel.packet[1].x;
548 kernel.packet[1].x = tmp;
549
550 tmp = kernel.packet[0].z;
551 kernel.packet[0].z = kernel.packet[2].x;
552 kernel.packet[2].x = tmp;
553
554 tmp = kernel.packet[0].w;
555 kernel.packet[0].w = kernel.packet[3].x;
556 kernel.packet[3].x = tmp;
557
558 tmp = kernel.packet[1].z;
559 kernel.packet[1].z = kernel.packet[2].y;
560 kernel.packet[2].y = tmp;
561
562 tmp = kernel.packet[1].w;
563 kernel.packet[1].w = kernel.packet[3].y;
564 kernel.packet[3].y = tmp;
565
566 tmp = kernel.packet[2].w;
567 kernel.packet[2].w = kernel.packet[3].z;
568 kernel.packet[3].z = tmp;
569}
570
571EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<double2, 2>& kernel) {
572 double tmp = kernel.packet[0].y;
573 kernel.packet[0].y = kernel.packet[1].x;
574 kernel.packet[1].x = tmp;
575}
576
577#endif // defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU)
578
579// Half-packet functions are not available on the host for CUDA 9.0-9.2, only
580// on device. There is no benefit to using them on the host anyways, since they are
581// emulated.
582#if (defined(EIGEN_HAS_CUDA_FP16) || defined(EIGEN_HAS_HIP_FP16)) && defined(EIGEN_GPU_COMPILE_PHASE)
583
584typedef ulonglong2 Packet4h2;
585template <>
586struct unpacket_traits<Packet4h2> {
587 typedef Eigen::half type;
588 enum {
589 size = 8,
590 alignment = Aligned16,
591 vectorizable = true,
592 masked_load_available = false,
593 masked_store_available = false
594 };
595 typedef Packet4h2 half;
596};
597template <>
598struct is_arithmetic<Packet4h2> {
599 enum { value = true };
600};
601
602template <>
603struct unpacket_traits<half2> {
604 typedef Eigen::half type;
605 enum {
606 size = 2,
607 alignment = Aligned16,
608 vectorizable = true,
609 masked_load_available = false,
610 masked_store_available = false
611 };
612 typedef half2 half;
613};
614template <>
615struct is_arithmetic<half2> {
616 enum { value = true };
617};
618
619template <>
620struct packet_traits<Eigen::half> : default_packet_traits {
621 typedef Packet4h2 type;
622 typedef Packet4h2 half;
623 enum {
624 Vectorizable = 1,
625 AlignedOnScalar = 1,
626 size = 8,
627 HasAdd = 1,
628 HasSub = 1,
629 HasMul = 1,
630 HasDiv = 1,
631 HasSqrt = 1,
632 HasRsqrt = 1,
633 HasExp = 1,
634 HasExpm1 = 1,
635 HasLog = 1,
636 HasLog1p = 1
637 };
638};
639
640template <>
641EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1<half2>(const Eigen::half& from) {
642 return __half2half2(from);
643}
644
645template <>
646EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pset1<Packet4h2>(const Eigen::half& from) {
647 Packet4h2 r;
648 half2* p_alias = reinterpret_cast<half2*>(&r);
649 p_alias[0] = pset1<half2>(from);
650 p_alias[1] = pset1<half2>(from);
651 p_alias[2] = pset1<half2>(from);
652 p_alias[3] = pset1<half2>(from);
653 return r;
654}
655
656namespace {
657
658EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pload(const Eigen::half* from) {
659 return *reinterpret_cast<const half2*>(from);
660}
661
662EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploadu(const Eigen::half* from) { return __halves2half2(from[0], from[1]); }
663
664EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploaddup(const Eigen::half* from) {
665 return __halves2half2(from[0], from[0]);
666}
667
668EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const half2& from) {
669 *reinterpret_cast<half2*>(to) = from;
670}
671
672EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const half2& from) {
673 to[0] = __low2half(from);
674 to[1] = __high2half(from);
675}
676
677EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro_aligned(const Eigen::half* from) {
678#if defined(EIGEN_GPU_HAS_LDG)
679 // Input is guaranteed to be properly aligned.
680 return __ldg(reinterpret_cast<const half2*>(from));
681#else
682 return __halves2half2(*(from + 0), *(from + 1));
683#endif
684}
685
686EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro_unaligned(const Eigen::half* from) {
687#if defined(EIGEN_GPU_HAS_LDG)
688 return __halves2half2(__ldg(from + 0), __ldg(from + 1));
689#else
690 return __halves2half2(*(from + 0), *(from + 1));
691#endif
692}
693
694EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pgather(const Eigen::half* from, Index stride) {
695 return __halves2half2(from[0 * stride], from[1 * stride]);
696}
697
698EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(Eigen::half* to, const half2& from, Index stride) {
699 to[stride * 0] = __low2half(from);
700 to[stride * 1] = __high2half(from);
701}
702
703EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst(const half2& a) { return __low2half(a); }
704
705EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pabs(const half2& a) {
706 half a1 = __low2half(a);
707 half a2 = __high2half(a);
708 half result1 = half_impl::raw_uint16_to_half(a1.x & 0x7FFF);
709 half result2 = half_impl::raw_uint16_to_half(a2.x & 0x7FFF);
710 return __halves2half2(result1, result2);
711}
712
713EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ptrue(const half2& /*a*/) {
714 half true_half = half_impl::raw_uint16_to_half(0xffffu);
715 return pset1<half2>(true_half);
716}
717
718EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pzero(const half2& /*a*/) {
719 half false_half = half_impl::raw_uint16_to_half(0x0000u);
720 return pset1<half2>(false_half);
721}
722
723EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<half2, 2>& kernel) {
724 __half a1 = __low2half(kernel.packet[0]);
725 __half a2 = __high2half(kernel.packet[0]);
726 __half b1 = __low2half(kernel.packet[1]);
727 __half b2 = __high2half(kernel.packet[1]);
728 kernel.packet[0] = __halves2half2(a1, b1);
729 kernel.packet[1] = __halves2half2(a2, b2);
730}
731
732EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset(const Eigen::half& a) {
733#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
734 return __halves2half2(a, __hadd(a, __float2half(1.0f)));
735#else
736 float f = __half2float(a) + 1.0f;
737 return __halves2half2(a, __float2half(f));
738#endif
739}
740
741EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect(const half2& mask, const half2& a, const half2& b) {
742 half mask_low = __low2half(mask);
743 half mask_high = __high2half(mask);
744 half result_low = mask_low == half(0) ? __low2half(b) : __low2half(a);
745 half result_high = mask_high == half(0) ? __high2half(b) : __high2half(a);
746 return __halves2half2(result_low, result_high);
747}
748
749EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq(const half2& a, const half2& b) {
750 half true_half = half_impl::raw_uint16_to_half(0xffffu);
751 half false_half = half_impl::raw_uint16_to_half(0x0000u);
752 half a1 = __low2half(a);
753 half a2 = __high2half(a);
754 half b1 = __low2half(b);
755 half b2 = __high2half(b);
756 half eq1 = __half2float(a1) == __half2float(b1) ? true_half : false_half;
757 half eq2 = __half2float(a2) == __half2float(b2) ? true_half : false_half;
758 return __halves2half2(eq1, eq2);
759}
760
761EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt(const half2& a, const half2& b) {
762 half true_half = half_impl::raw_uint16_to_half(0xffffu);
763 half false_half = half_impl::raw_uint16_to_half(0x0000u);
764 half a1 = __low2half(a);
765 half a2 = __high2half(a);
766 half b1 = __low2half(b);
767 half b2 = __high2half(b);
768 half eq1 = __half2float(a1) < __half2float(b1) ? true_half : false_half;
769 half eq2 = __half2float(a2) < __half2float(b2) ? true_half : false_half;
770 return __halves2half2(eq1, eq2);
771}
772
773EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_le(const half2& a, const half2& b) {
774 half true_half = half_impl::raw_uint16_to_half(0xffffu);
775 half false_half = half_impl::raw_uint16_to_half(0x0000u);
776 half a1 = __low2half(a);
777 half a2 = __high2half(a);
778 half b1 = __low2half(b);
779 half b2 = __high2half(b);
780 half eq1 = __half2float(a1) <= __half2float(b1) ? true_half : false_half;
781 half eq2 = __half2float(a2) <= __half2float(b2) ? true_half : false_half;
782 return __halves2half2(eq1, eq2);
783}
784
785EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand(const half2& a, const half2& b) {
786 half a1 = __low2half(a);
787 half a2 = __high2half(a);
788 half b1 = __low2half(b);
789 half b2 = __high2half(b);
790 half result1 = half_impl::raw_uint16_to_half(a1.x & b1.x);
791 half result2 = half_impl::raw_uint16_to_half(a2.x & b2.x);
792 return __halves2half2(result1, result2);
793}
794
795EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 por(const half2& a, const half2& b) {
796 half a1 = __low2half(a);
797 half a2 = __high2half(a);
798 half b1 = __low2half(b);
799 half b2 = __high2half(b);
800 half result1 = half_impl::raw_uint16_to_half(a1.x | b1.x);
801 half result2 = half_impl::raw_uint16_to_half(a2.x | b2.x);
802 return __halves2half2(result1, result2);
803}
804
805EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pxor(const half2& a, const half2& b) {
806 half a1 = __low2half(a);
807 half a2 = __high2half(a);
808 half b1 = __low2half(b);
809 half b2 = __high2half(b);
810 half result1 = half_impl::raw_uint16_to_half(a1.x ^ b1.x);
811 half result2 = half_impl::raw_uint16_to_half(a2.x ^ b2.x);
812 return __halves2half2(result1, result2);
813}
814
815EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pandnot(const half2& a, const half2& b) {
816 half a1 = __low2half(a);
817 half a2 = __high2half(a);
818 half b1 = __low2half(b);
819 half b2 = __high2half(b);
820 half result1 = half_impl::raw_uint16_to_half(a1.x & ~b1.x);
821 half result2 = half_impl::raw_uint16_to_half(a2.x & ~b2.x);
822 return __halves2half2(result1, result2);
823}
824
825EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd(const half2& a, const half2& b) {
826#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
827 return __hadd2(a, b);
828#else
829 float a1 = __low2float(a);
830 float a2 = __high2float(a);
831 float b1 = __low2float(b);
832 float b2 = __high2float(b);
833 float r1 = a1 + b1;
834 float r2 = a2 + b2;
835 return __floats2half2_rn(r1, r2);
836#endif
837}
838
839EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psub(const half2& a, const half2& b) {
840#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
841 return __hsub2(a, b);
842#else
843 float a1 = __low2float(a);
844 float a2 = __high2float(a);
845 float b1 = __low2float(b);
846 float b2 = __high2float(b);
847 float r1 = a1 - b1;
848 float r2 = a2 - b2;
849 return __floats2half2_rn(r1, r2);
850#endif
851}
852
853EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pnegate(const half2& a) {
854#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
855 return __hneg2(a);
856#else
857 float a1 = __low2float(a);
858 float a2 = __high2float(a);
859 return __floats2half2_rn(-a1, -a2);
860#endif
861}
862
863EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pconj(const half2& a) { return a; }
864
865EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul(const half2& a, const half2& b) {
866#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
867 return __hmul2(a, b);
868#else
869 float a1 = __low2float(a);
870 float a2 = __high2float(a);
871 float b1 = __low2float(b);
872 float b2 = __high2float(b);
873 float r1 = a1 * b1;
874 float r2 = a2 * b2;
875 return __floats2half2_rn(r1, r2);
876#endif
877}
878
879EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmadd(const half2& a, const half2& b, const half2& c) {
880#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
881 return __hfma2(a, b, c);
882#else
883 float a1 = __low2float(a);
884 float a2 = __high2float(a);
885 float b1 = __low2float(b);
886 float b2 = __high2float(b);
887 float c1 = __low2float(c);
888 float c2 = __high2float(c);
889 float r1 = a1 * b1 + c1;
890 float r2 = a2 * b2 + c2;
891 return __floats2half2_rn(r1, r2);
892#endif
893}
894
895EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv(const half2& a, const half2& b) {
896#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
897 return __h2div(a, b);
898#else
899 float a1 = __low2float(a);
900 float a2 = __high2float(a);
901 float b1 = __low2float(b);
902 float b2 = __high2float(b);
903 float r1 = a1 / b1;
904 float r2 = a2 / b2;
905 return __floats2half2_rn(r1, r2);
906#endif
907}
908
909EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin(const half2& a, const half2& b) {
910 float a1 = __low2float(a);
911 float a2 = __high2float(a);
912 float b1 = __low2float(b);
913 float b2 = __high2float(b);
914 __half r1 = a1 < b1 ? __low2half(a) : __low2half(b);
915 __half r2 = a2 < b2 ? __high2half(a) : __high2half(b);
916 return __halves2half2(r1, r2);
917}
918
919EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax(const half2& a, const half2& b) {
920 float a1 = __low2float(a);
921 float a2 = __high2float(a);
922 float b1 = __low2float(b);
923 float b2 = __high2float(b);
924 __half r1 = a1 > b1 ? __low2half(a) : __low2half(b);
925 __half r2 = a2 > b2 ? __high2half(a) : __high2half(b);
926 return __halves2half2(r1, r2);
927}
928
929EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux(const half2& a) {
930#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
931 return __hadd(__low2half(a), __high2half(a));
932#else
933 float a1 = __low2float(a);
934 float a2 = __high2float(a);
935 return Eigen::half(__float2half(a1 + a2));
936#endif
937}
938
939EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max(const half2& a) {
940#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
941 __half first = __low2half(a);
942 __half second = __high2half(a);
943 return __hgt(first, second) ? first : second;
944#else
945 float a1 = __low2float(a);
946 float a2 = __high2float(a);
947 return a1 > a2 ? __low2half(a) : __high2half(a);
948#endif
949}
950
951EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min(const half2& a) {
952#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
953 __half first = __low2half(a);
954 __half second = __high2half(a);
955 return __hlt(first, second) ? first : second;
956#else
957 float a1 = __low2float(a);
958 float a2 = __high2float(a);
959 return a1 < a2 ? __low2half(a) : __high2half(a);
960#endif
961}
962
963EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul(const half2& a) {
964#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
965 return __hmul(__low2half(a), __high2half(a));
966#else
967 float a1 = __low2float(a);
968 float a2 = __high2float(a);
969 return Eigen::half(__float2half(a1 * a2));
970#endif
971}
972
973EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog1p(const half2& a) {
974 float a1 = __low2float(a);
975 float a2 = __high2float(a);
976 float r1 = log1pf(a1);
977 float r2 = log1pf(a2);
978 return __floats2half2_rn(r1, r2);
979}
980
981EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexpm1(const half2& a) {
982 float a1 = __low2float(a);
983 float a2 = __high2float(a);
984 float r1 = expm1f(a1);
985 float r2 = expm1f(a2);
986 return __floats2half2_rn(r1, r2);
987}
988
989#if (EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)) || defined(EIGEN_HIP_DEVICE_COMPILE)
990
991EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog(const half2& a) { return h2log(a); }
992
993EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp(const half2& a) { return h2exp(a); }
994
995EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt(const half2& a) { return h2sqrt(a); }
996
997EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt(const half2& a) { return h2rsqrt(a); }
998
999#else
1000
1001EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog(const half2& a) {
1002 float a1 = __low2float(a);
1003 float a2 = __high2float(a);
1004 float r1 = logf(a1);
1005 float r2 = logf(a2);
1006 return __floats2half2_rn(r1, r2);
1007}
1008
1009EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp(const half2& a) {
1010 float a1 = __low2float(a);
1011 float a2 = __high2float(a);
1012 float r1 = expf(a1);
1013 float r2 = expf(a2);
1014 return __floats2half2_rn(r1, r2);
1015}
1016
1017EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt(const half2& a) {
1018 float a1 = __low2float(a);
1019 float a2 = __high2float(a);
1020 float r1 = sqrtf(a1);
1021 float r2 = sqrtf(a2);
1022 return __floats2half2_rn(r1, r2);
1023}
1024
1025EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt(const half2& a) {
1026 float a1 = __low2float(a);
1027 float a2 = __high2float(a);
1028 float r1 = rsqrtf(a1);
1029 float r2 = rsqrtf(a2);
1030 return __floats2half2_rn(r1, r2);
1031}
1032#endif
1033} // namespace
1034
1035template <>
1036EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pload<Packet4h2>(const Eigen::half* from) {
1037 return *reinterpret_cast<const Packet4h2*>(from);
1038}
1039
1040// unaligned load;
1041template <>
1042EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 ploadu<Packet4h2>(const Eigen::half* from) {
1043 Packet4h2 r;
1044 half2* p_alias = reinterpret_cast<half2*>(&r);
1045 p_alias[0] = ploadu(from + 0);
1046 p_alias[1] = ploadu(from + 2);
1047 p_alias[2] = ploadu(from + 4);
1048 p_alias[3] = ploadu(from + 6);
1049 return r;
1050}
1051
1052template <>
1053EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 ploaddup<Packet4h2>(const Eigen::half* from) {
1054 Packet4h2 r;
1055 half2* p_alias = reinterpret_cast<half2*>(&r);
1056 p_alias[0] = ploaddup(from + 0);
1057 p_alias[1] = ploaddup(from + 1);
1058 p_alias[2] = ploaddup(from + 2);
1059 p_alias[3] = ploaddup(from + 3);
1060 return r;
1061}
1062
1063template <>
1064EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4h2& from) {
1065 *reinterpret_cast<Packet4h2*>(to) = from;
1066}
1067
1068template <>
1069EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4h2& from) {
1070 const half2* from_alias = reinterpret_cast<const half2*>(&from);
1071 pstoreu(to + 0, from_alias[0]);
1072 pstoreu(to + 2, from_alias[1]);
1073 pstoreu(to + 4, from_alias[2]);
1074 pstoreu(to + 6, from_alias[3]);
1075}
1076
1077template <>
1078EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4h2 ploadt_ro<Packet4h2, Aligned>(const Eigen::half* from) {
1079#if defined(EIGEN_GPU_HAS_LDG)
1080 Packet4h2 r;
1081 r = __ldg(reinterpret_cast<const Packet4h2*>(from));
1082 return r;
1083#else
1084 Packet4h2 r;
1085 half2* r_alias = reinterpret_cast<half2*>(&r);
1086 r_alias[0] = ploadt_ro_aligned(from + 0);
1087 r_alias[1] = ploadt_ro_aligned(from + 2);
1088 r_alias[2] = ploadt_ro_aligned(from + 4);
1089 r_alias[3] = ploadt_ro_aligned(from + 6);
1090 return r;
1091#endif
1092}
1093
1094template <>
1095EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4h2 ploadt_ro<Packet4h2, Unaligned>(const Eigen::half* from) {
1096 Packet4h2 r;
1097 half2* r_alias = reinterpret_cast<half2*>(&r);
1098 r_alias[0] = ploadt_ro_unaligned(from + 0);
1099 r_alias[1] = ploadt_ro_unaligned(from + 2);
1100 r_alias[2] = ploadt_ro_unaligned(from + 4);
1101 r_alias[3] = ploadt_ro_unaligned(from + 6);
1102 return r;
1103}
1104
1105template <>
1106EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pgather<Eigen::half, Packet4h2>(const Eigen::half* from, Index stride) {
1107 Packet4h2 r;
1108 half2* p_alias = reinterpret_cast<half2*>(&r);
1109 p_alias[0] = __halves2half2(from[0 * stride], from[1 * stride]);
1110 p_alias[1] = __halves2half2(from[2 * stride], from[3 * stride]);
1111 p_alias[2] = __halves2half2(from[4 * stride], from[5 * stride]);
1112 p_alias[3] = __halves2half2(from[6 * stride], from[7 * stride]);
1113 return r;
1114}
1115
1116template <>
1117EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet4h2>(Eigen::half* to, const Packet4h2& from,
1118 Index stride) {
1119 const half2* from_alias = reinterpret_cast<const half2*>(&from);
1120 pscatter(to + stride * 0, from_alias[0], stride);
1121 pscatter(to + stride * 2, from_alias[1], stride);
1122 pscatter(to + stride * 4, from_alias[2], stride);
1123 pscatter(to + stride * 6, from_alias[3], stride);
1124}
1125
1126template <>
1127EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4h2>(const Packet4h2& a) {
1128 return pfirst(*(reinterpret_cast<const half2*>(&a)));
1129}
1130
1131template <>
1132EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pabs<Packet4h2>(const Packet4h2& a) {
1133 Packet4h2 r;
1134 half2* p_alias = reinterpret_cast<half2*>(&r);
1135 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1136 p_alias[0] = pabs(a_alias[0]);
1137 p_alias[1] = pabs(a_alias[1]);
1138 p_alias[2] = pabs(a_alias[2]);
1139 p_alias[3] = pabs(a_alias[3]);
1140 return r;
1141}
1142
1143template <>
1144EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 ptrue<Packet4h2>(const Packet4h2& /*a*/) {
1145 half true_half = half_impl::raw_uint16_to_half(0xffffu);
1146 return pset1<Packet4h2>(true_half);
1147}
1148
1149template <>
1150EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pzero<Packet4h2>(const Packet4h2& /*a*/) {
1151 half false_half = half_impl::raw_uint16_to_half(0x0000u);
1152 return pset1<Packet4h2>(false_half);
1153}
1154
1155EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_double(double* d_row0, double* d_row1, double* d_row2,
1156 double* d_row3, double* d_row4, double* d_row5,
1157 double* d_row6, double* d_row7) {
1158 double d_tmp;
1159 d_tmp = d_row0[1];
1160 d_row0[1] = d_row4[0];
1161 d_row4[0] = d_tmp;
1162
1163 d_tmp = d_row1[1];
1164 d_row1[1] = d_row5[0];
1165 d_row5[0] = d_tmp;
1166
1167 d_tmp = d_row2[1];
1168 d_row2[1] = d_row6[0];
1169 d_row6[0] = d_tmp;
1170
1171 d_tmp = d_row3[1];
1172 d_row3[1] = d_row7[0];
1173 d_row7[0] = d_tmp;
1174}
1175
1176EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_half2(half2* f_row0, half2* f_row1, half2* f_row2,
1177 half2* f_row3) {
1178 half2 f_tmp;
1179 f_tmp = f_row0[1];
1180 f_row0[1] = f_row2[0];
1181 f_row2[0] = f_tmp;
1182
1183 f_tmp = f_row1[1];
1184 f_row1[1] = f_row3[0];
1185 f_row3[0] = f_tmp;
1186}
1187
1188EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_half(half2& f0, half2& f1) {
1189 __half a1 = __low2half(f0);
1190 __half a2 = __high2half(f0);
1191 __half b1 = __low2half(f1);
1192 __half b2 = __high2half(f1);
1193 f0 = __halves2half2(a1, b1);
1194 f1 = __halves2half2(a2, b2);
1195}
1196
1197EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4h2, 8>& kernel) {
1198 double* d_row0 = reinterpret_cast<double*>(&kernel.packet[0]);
1199 double* d_row1 = reinterpret_cast<double*>(&kernel.packet[1]);
1200 double* d_row2 = reinterpret_cast<double*>(&kernel.packet[2]);
1201 double* d_row3 = reinterpret_cast<double*>(&kernel.packet[3]);
1202 double* d_row4 = reinterpret_cast<double*>(&kernel.packet[4]);
1203 double* d_row5 = reinterpret_cast<double*>(&kernel.packet[5]);
1204 double* d_row6 = reinterpret_cast<double*>(&kernel.packet[6]);
1205 double* d_row7 = reinterpret_cast<double*>(&kernel.packet[7]);
1206 ptranspose_double(d_row0, d_row1, d_row2, d_row3, d_row4, d_row5, d_row6, d_row7);
1207
1208 half2* f_row0 = reinterpret_cast<half2*>(d_row0);
1209 half2* f_row1 = reinterpret_cast<half2*>(d_row1);
1210 half2* f_row2 = reinterpret_cast<half2*>(d_row2);
1211 half2* f_row3 = reinterpret_cast<half2*>(d_row3);
1212 ptranspose_half2(f_row0, f_row1, f_row2, f_row3);
1213 ptranspose_half(f_row0[0], f_row1[0]);
1214 ptranspose_half(f_row0[1], f_row1[1]);
1215 ptranspose_half(f_row2[0], f_row3[0]);
1216 ptranspose_half(f_row2[1], f_row3[1]);
1217
1218 f_row0 = reinterpret_cast<half2*>(d_row0 + 1);
1219 f_row1 = reinterpret_cast<half2*>(d_row1 + 1);
1220 f_row2 = reinterpret_cast<half2*>(d_row2 + 1);
1221 f_row3 = reinterpret_cast<half2*>(d_row3 + 1);
1222 ptranspose_half2(f_row0, f_row1, f_row2, f_row3);
1223 ptranspose_half(f_row0[0], f_row1[0]);
1224 ptranspose_half(f_row0[1], f_row1[1]);
1225 ptranspose_half(f_row2[0], f_row3[0]);
1226 ptranspose_half(f_row2[1], f_row3[1]);
1227
1228 f_row0 = reinterpret_cast<half2*>(d_row4);
1229 f_row1 = reinterpret_cast<half2*>(d_row5);
1230 f_row2 = reinterpret_cast<half2*>(d_row6);
1231 f_row3 = reinterpret_cast<half2*>(d_row7);
1232 ptranspose_half2(f_row0, f_row1, f_row2, f_row3);
1233 ptranspose_half(f_row0[0], f_row1[0]);
1234 ptranspose_half(f_row0[1], f_row1[1]);
1235 ptranspose_half(f_row2[0], f_row3[0]);
1236 ptranspose_half(f_row2[1], f_row3[1]);
1237
1238 f_row0 = reinterpret_cast<half2*>(d_row4 + 1);
1239 f_row1 = reinterpret_cast<half2*>(d_row5 + 1);
1240 f_row2 = reinterpret_cast<half2*>(d_row6 + 1);
1241 f_row3 = reinterpret_cast<half2*>(d_row7 + 1);
1242 ptranspose_half2(f_row0, f_row1, f_row2, f_row3);
1243 ptranspose_half(f_row0[0], f_row1[0]);
1244 ptranspose_half(f_row0[1], f_row1[1]);
1245 ptranspose_half(f_row2[0], f_row3[0]);
1246 ptranspose_half(f_row2[1], f_row3[1]);
1247}
1248
1249template <>
1250EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 plset<Packet4h2>(const Eigen::half& a) {
1251#if defined(EIGEN_HIP_DEVICE_COMPILE)
1252
1253 Packet4h2 r;
1254 half2* p_alias = reinterpret_cast<half2*>(&r);
1255 p_alias[0] = __halves2half2(a, __hadd(a, __float2half(1.0f)));
1256 p_alias[1] = __halves2half2(__hadd(a, __float2half(2.0f)), __hadd(a, __float2half(3.0f)));
1257 p_alias[2] = __halves2half2(__hadd(a, __float2half(4.0f)), __hadd(a, __float2half(5.0f)));
1258 p_alias[3] = __halves2half2(__hadd(a, __float2half(6.0f)), __hadd(a, __float2half(7.0f)));
1259 return r;
1260#elif defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)
1261 Packet4h2 r;
1262 half2* r_alias = reinterpret_cast<half2*>(&r);
1263
1264 half2 b = pset1<half2>(a);
1265 half2 c;
1266 half2 half_offset0 = __halves2half2(__float2half(0.0f), __float2half(2.0f));
1267 half2 half_offset1 = __halves2half2(__float2half(4.0f), __float2half(6.0f));
1268
1269 c = __hadd2(b, half_offset0);
1270 r_alias[0] = plset(__low2half(c));
1271 r_alias[1] = plset(__high2half(c));
1272
1273 c = __hadd2(b, half_offset1);
1274 r_alias[2] = plset(__low2half(c));
1275 r_alias[3] = plset(__high2half(c));
1276
1277 return r;
1278
1279#else
1280 float f = __half2float(a);
1281 Packet4h2 r;
1282 half2* p_alias = reinterpret_cast<half2*>(&r);
1283 p_alias[0] = __halves2half2(a, __float2half(f + 1.0f));
1284 p_alias[1] = __halves2half2(__float2half(f + 2.0f), __float2half(f + 3.0f));
1285 p_alias[2] = __halves2half2(__float2half(f + 4.0f), __float2half(f + 5.0f));
1286 p_alias[3] = __halves2half2(__float2half(f + 6.0f), __float2half(f + 7.0f));
1287 return r;
1288#endif
1289}
1290
1291template <>
1292EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pselect<Packet4h2>(const Packet4h2& mask, const Packet4h2& a,
1293 const Packet4h2& b) {
1294 Packet4h2 r;
1295 half2* r_alias = reinterpret_cast<half2*>(&r);
1296 const half2* mask_alias = reinterpret_cast<const half2*>(&mask);
1297 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1298 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1299 r_alias[0] = pselect(mask_alias[0], a_alias[0], b_alias[0]);
1300 r_alias[1] = pselect(mask_alias[1], a_alias[1], b_alias[1]);
1301 r_alias[2] = pselect(mask_alias[2], a_alias[2], b_alias[2]);
1302 r_alias[3] = pselect(mask_alias[3], a_alias[3], b_alias[3]);
1303 return r;
1304}
1305
1306template <>
1307EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pcmp_eq<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1308 Packet4h2 r;
1309 half2* r_alias = reinterpret_cast<half2*>(&r);
1310 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1311 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1312 r_alias[0] = pcmp_eq(a_alias[0], b_alias[0]);
1313 r_alias[1] = pcmp_eq(a_alias[1], b_alias[1]);
1314 r_alias[2] = pcmp_eq(a_alias[2], b_alias[2]);
1315 r_alias[3] = pcmp_eq(a_alias[3], b_alias[3]);
1316 return r;
1317}
1318
1319template <>
1320EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pcmp_lt<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1321 Packet4h2 r;
1322 half2* r_alias = reinterpret_cast<half2*>(&r);
1323 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1324 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1325 r_alias[0] = pcmp_lt(a_alias[0], b_alias[0]);
1326 r_alias[1] = pcmp_lt(a_alias[1], b_alias[1]);
1327 r_alias[2] = pcmp_lt(a_alias[2], b_alias[2]);
1328 r_alias[3] = pcmp_lt(a_alias[3], b_alias[3]);
1329 return r;
1330}
1331
1332template <>
1333EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pcmp_le<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1334 Packet4h2 r;
1335 half2* r_alias = reinterpret_cast<half2*>(&r);
1336 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1337 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1338 r_alias[0] = pcmp_le(a_alias[0], b_alias[0]);
1339 r_alias[1] = pcmp_le(a_alias[1], b_alias[1]);
1340 r_alias[2] = pcmp_le(a_alias[2], b_alias[2]);
1341 r_alias[3] = pcmp_le(a_alias[3], b_alias[3]);
1342 return r;
1343}
1344
1345template <>
1346EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pand<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1347 Packet4h2 r;
1348 half2* r_alias = reinterpret_cast<half2*>(&r);
1349 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1350 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1351 r_alias[0] = pand(a_alias[0], b_alias[0]);
1352 r_alias[1] = pand(a_alias[1], b_alias[1]);
1353 r_alias[2] = pand(a_alias[2], b_alias[2]);
1354 r_alias[3] = pand(a_alias[3], b_alias[3]);
1355 return r;
1356}
1357
1358template <>
1359EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 por<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1360 Packet4h2 r;
1361 half2* r_alias = reinterpret_cast<half2*>(&r);
1362 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1363 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1364 r_alias[0] = por(a_alias[0], b_alias[0]);
1365 r_alias[1] = por(a_alias[1], b_alias[1]);
1366 r_alias[2] = por(a_alias[2], b_alias[2]);
1367 r_alias[3] = por(a_alias[3], b_alias[3]);
1368 return r;
1369}
1370
1371template <>
1372EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pxor<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1373 Packet4h2 r;
1374 half2* r_alias = reinterpret_cast<half2*>(&r);
1375 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1376 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1377 r_alias[0] = pxor(a_alias[0], b_alias[0]);
1378 r_alias[1] = pxor(a_alias[1], b_alias[1]);
1379 r_alias[2] = pxor(a_alias[2], b_alias[2]);
1380 r_alias[3] = pxor(a_alias[3], b_alias[3]);
1381 return r;
1382}
1383
1384template <>
1385EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pandnot<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1386 Packet4h2 r;
1387 half2* r_alias = reinterpret_cast<half2*>(&r);
1388 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1389 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1390 r_alias[0] = pandnot(a_alias[0], b_alias[0]);
1391 r_alias[1] = pandnot(a_alias[1], b_alias[1]);
1392 r_alias[2] = pandnot(a_alias[2], b_alias[2]);
1393 r_alias[3] = pandnot(a_alias[3], b_alias[3]);
1394 return r;
1395}
1396
1397template <>
1398EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 padd<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1399 Packet4h2 r;
1400 half2* r_alias = reinterpret_cast<half2*>(&r);
1401 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1402 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1403 r_alias[0] = padd(a_alias[0], b_alias[0]);
1404 r_alias[1] = padd(a_alias[1], b_alias[1]);
1405 r_alias[2] = padd(a_alias[2], b_alias[2]);
1406 r_alias[3] = padd(a_alias[3], b_alias[3]);
1407 return r;
1408}
1409
1410template <>
1411EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 psub<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1412 Packet4h2 r;
1413 half2* r_alias = reinterpret_cast<half2*>(&r);
1414 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1415 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1416 r_alias[0] = psub(a_alias[0], b_alias[0]);
1417 r_alias[1] = psub(a_alias[1], b_alias[1]);
1418 r_alias[2] = psub(a_alias[2], b_alias[2]);
1419 r_alias[3] = psub(a_alias[3], b_alias[3]);
1420 return r;
1421}
1422
1423template <>
1424EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pnegate(const Packet4h2& a) {
1425 Packet4h2 r;
1426 half2* r_alias = reinterpret_cast<half2*>(&r);
1427 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1428 r_alias[0] = pnegate(a_alias[0]);
1429 r_alias[1] = pnegate(a_alias[1]);
1430 r_alias[2] = pnegate(a_alias[2]);
1431 r_alias[3] = pnegate(a_alias[3]);
1432 return r;
1433}
1434
1435template <>
1436EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pconj(const Packet4h2& a) {
1437 return a;
1438}
1439
1440template <>
1441EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmul<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1442 Packet4h2 r;
1443 half2* r_alias = reinterpret_cast<half2*>(&r);
1444 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1445 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1446 r_alias[0] = pmul(a_alias[0], b_alias[0]);
1447 r_alias[1] = pmul(a_alias[1], b_alias[1]);
1448 r_alias[2] = pmul(a_alias[2], b_alias[2]);
1449 r_alias[3] = pmul(a_alias[3], b_alias[3]);
1450 return r;
1451}
1452
1453template <>
1454EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmadd<Packet4h2>(const Packet4h2& a, const Packet4h2& b,
1455 const Packet4h2& c) {
1456 Packet4h2 r;
1457 half2* r_alias = reinterpret_cast<half2*>(&r);
1458 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1459 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1460 const half2* c_alias = reinterpret_cast<const half2*>(&c);
1461 r_alias[0] = pmadd(a_alias[0], b_alias[0], c_alias[0]);
1462 r_alias[1] = pmadd(a_alias[1], b_alias[1], c_alias[1]);
1463 r_alias[2] = pmadd(a_alias[2], b_alias[2], c_alias[2]);
1464 r_alias[3] = pmadd(a_alias[3], b_alias[3], c_alias[3]);
1465 return r;
1466}
1467
1468template <>
1469EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pdiv<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1470 Packet4h2 r;
1471 half2* r_alias = reinterpret_cast<half2*>(&r);
1472 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1473 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1474 r_alias[0] = pdiv(a_alias[0], b_alias[0]);
1475 r_alias[1] = pdiv(a_alias[1], b_alias[1]);
1476 r_alias[2] = pdiv(a_alias[2], b_alias[2]);
1477 r_alias[3] = pdiv(a_alias[3], b_alias[3]);
1478 return r;
1479}
1480
1481template <>
1482EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmin<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1483 Packet4h2 r;
1484 half2* r_alias = reinterpret_cast<half2*>(&r);
1485 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1486 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1487 r_alias[0] = pmin(a_alias[0], b_alias[0]);
1488 r_alias[1] = pmin(a_alias[1], b_alias[1]);
1489 r_alias[2] = pmin(a_alias[2], b_alias[2]);
1490 r_alias[3] = pmin(a_alias[3], b_alias[3]);
1491 return r;
1492}
1493
1494template <>
1495EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmax<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
1496 Packet4h2 r;
1497 half2* r_alias = reinterpret_cast<half2*>(&r);
1498 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1499 const half2* b_alias = reinterpret_cast<const half2*>(&b);
1500 r_alias[0] = pmax(a_alias[0], b_alias[0]);
1501 r_alias[1] = pmax(a_alias[1], b_alias[1]);
1502 r_alias[2] = pmax(a_alias[2], b_alias[2]);
1503 r_alias[3] = pmax(a_alias[3], b_alias[3]);
1504 return r;
1505}
1506
1507template <>
1508EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux<Packet4h2>(const Packet4h2& a) {
1509 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1510
1511 return predux(a_alias[0]) + predux(a_alias[1]) + predux(a_alias[2]) + predux(a_alias[3]);
1512}
1513
1514template <>
1515EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max<Packet4h2>(const Packet4h2& a) {
1516 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1517 half2 m0 = __halves2half2(predux_max(a_alias[0]), predux_max(a_alias[1]));
1518 half2 m1 = __halves2half2(predux_max(a_alias[2]), predux_max(a_alias[3]));
1519 __half first = predux_max(m0);
1520 __half second = predux_max(m1);
1521#if defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)
1522 return (__hgt(first, second) ? first : second);
1523#else
1524 float ffirst = __half2float(first);
1525 float fsecond = __half2float(second);
1526 return (ffirst > fsecond) ? first : second;
1527#endif
1528}
1529
1530template <>
1531EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min<Packet4h2>(const Packet4h2& a) {
1532 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1533 half2 m0 = __halves2half2(predux_min(a_alias[0]), predux_min(a_alias[1]));
1534 half2 m1 = __halves2half2(predux_min(a_alias[2]), predux_min(a_alias[3]));
1535 __half first = predux_min(m0);
1536 __half second = predux_min(m1);
1537#if defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)
1538 return (__hlt(first, second) ? first : second);
1539#else
1540 float ffirst = __half2float(first);
1541 float fsecond = __half2float(second);
1542 return (ffirst < fsecond) ? first : second;
1543#endif
1544}
1545
1546// likely overflow/underflow
1547template <>
1548EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet4h2>(const Packet4h2& a) {
1549 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1550 return predux_mul(pmul(pmul(a_alias[0], a_alias[1]), pmul(a_alias[2], a_alias[3])));
1551}
1552
1553template <>
1554EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 plog1p<Packet4h2>(const Packet4h2& a) {
1555 Packet4h2 r;
1556 half2* r_alias = reinterpret_cast<half2*>(&r);
1557 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1558 r_alias[0] = plog1p(a_alias[0]);
1559 r_alias[1] = plog1p(a_alias[1]);
1560 r_alias[2] = plog1p(a_alias[2]);
1561 r_alias[3] = plog1p(a_alias[3]);
1562 return r;
1563}
1564
1565template <>
1566EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pexpm1<Packet4h2>(const Packet4h2& a) {
1567 Packet4h2 r;
1568 half2* r_alias = reinterpret_cast<half2*>(&r);
1569 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1570 r_alias[0] = pexpm1(a_alias[0]);
1571 r_alias[1] = pexpm1(a_alias[1]);
1572 r_alias[2] = pexpm1(a_alias[2]);
1573 r_alias[3] = pexpm1(a_alias[3]);
1574 return r;
1575}
1576
1577template <>
1578EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 plog<Packet4h2>(const Packet4h2& a) {
1579 Packet4h2 r;
1580 half2* r_alias = reinterpret_cast<half2*>(&r);
1581 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1582 r_alias[0] = plog(a_alias[0]);
1583 r_alias[1] = plog(a_alias[1]);
1584 r_alias[2] = plog(a_alias[2]);
1585 r_alias[3] = plog(a_alias[3]);
1586 return r;
1587}
1588
1589template <>
1590EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pexp<Packet4h2>(const Packet4h2& a) {
1591 Packet4h2 r;
1592 half2* r_alias = reinterpret_cast<half2*>(&r);
1593 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1594 r_alias[0] = pexp(a_alias[0]);
1595 r_alias[1] = pexp(a_alias[1]);
1596 r_alias[2] = pexp(a_alias[2]);
1597 r_alias[3] = pexp(a_alias[3]);
1598 return r;
1599}
1600
1601template <>
1602EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 psqrt<Packet4h2>(const Packet4h2& a) {
1603 Packet4h2 r;
1604 half2* r_alias = reinterpret_cast<half2*>(&r);
1605 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1606 r_alias[0] = psqrt(a_alias[0]);
1607 r_alias[1] = psqrt(a_alias[1]);
1608 r_alias[2] = psqrt(a_alias[2]);
1609 r_alias[3] = psqrt(a_alias[3]);
1610 return r;
1611}
1612
1613template <>
1614EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 prsqrt<Packet4h2>(const Packet4h2& a) {
1615 Packet4h2 r;
1616 half2* r_alias = reinterpret_cast<half2*>(&r);
1617 const half2* a_alias = reinterpret_cast<const half2*>(&a);
1618 r_alias[0] = prsqrt(a_alias[0]);
1619 r_alias[1] = prsqrt(a_alias[1]);
1620 r_alias[2] = prsqrt(a_alias[2]);
1621 r_alias[3] = prsqrt(a_alias[3]);
1622 return r;
1623}
1624
1625// The following specialized padd, pmul, pdiv, pmin, pmax, pset1 are needed for
1626// the implementation of GPU half reduction.
1627template <>
1628EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd<half2>(const half2& a, const half2& b) {
1629#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
1630 return __hadd2(a, b);
1631#else
1632 float a1 = __low2float(a);
1633 float a2 = __high2float(a);
1634 float b1 = __low2float(b);
1635 float b2 = __high2float(b);
1636 float r1 = a1 + b1;
1637 float r2 = a2 + b2;
1638 return __floats2half2_rn(r1, r2);
1639#endif
1640}
1641
1642template <>
1643EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul<half2>(const half2& a, const half2& b) {
1644#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
1645 return __hmul2(a, b);
1646#else
1647 float a1 = __low2float(a);
1648 float a2 = __high2float(a);
1649 float b1 = __low2float(b);
1650 float b2 = __high2float(b);
1651 float r1 = a1 * b1;
1652 float r2 = a2 * b2;
1653 return __floats2half2_rn(r1, r2);
1654#endif
1655}
1656
1657template <>
1658EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv<half2>(const half2& a, const half2& b) {
1659#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC)
1660 return __h2div(a, b);
1661#else
1662 float a1 = __low2float(a);
1663 float a2 = __high2float(a);
1664 float b1 = __low2float(b);
1665 float b2 = __high2float(b);
1666 float r1 = a1 / b1;
1667 float r2 = a2 / b2;
1668 return __floats2half2_rn(r1, r2);
1669#endif
1670}
1671
1672template <>
1673EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin<half2>(const half2& a, const half2& b) {
1674 float a1 = __low2float(a);
1675 float a2 = __high2float(a);
1676 float b1 = __low2float(b);
1677 float b2 = __high2float(b);
1678 __half r1 = a1 < b1 ? __low2half(a) : __low2half(b);
1679 __half r2 = a2 < b2 ? __high2half(a) : __high2half(b);
1680 return __halves2half2(r1, r2);
1681}
1682
1683template <>
1684EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax<half2>(const half2& a, const half2& b) {
1685 float a1 = __low2float(a);
1686 float a2 = __high2float(a);
1687 float b1 = __low2float(b);
1688 float b2 = __high2float(b);
1689 __half r1 = a1 > b1 ? __low2half(a) : __low2half(b);
1690 __half r2 = a2 > b2 ? __high2half(a) : __high2half(b);
1691 return __halves2half2(r1, r2);
1692}
1693
1694#endif // (defined(EIGEN_HAS_CUDA_FP16) || defined(EIGEN_HAS_HIP_FP16)) && defined(EIGEN_GPU_COMPILE_PHASE)
1695
1696#undef EIGEN_GPU_HAS_LDG
1697#undef EIGEN_CUDA_HAS_FP16_ARITHMETIC
1698#undef EIGEN_GPU_HAS_FP16_ARITHMETIC
1699
1700} // end namespace internal
1701
1702} // end namespace Eigen
1703
1704#endif // EIGEN_PACKET_MATH_GPU_H
@ Aligned16
Definition Constants.h:237
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_rint_op< typename Derived::Scalar >, const Derived > rint(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_trunc_op< typename Derived::Scalar >, const Derived > trunc(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_floor_op< typename Derived::Scalar >, const Derived > floor(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ceil_op< typename Derived::Scalar >, const Derived > ceil(const Eigen::ArrayBase< Derived > &x)