Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
PacketMathFP16.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4//
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_PACKET_MATH_FP16_AVX512_H
11#define EIGEN_PACKET_MATH_FP16_AVX512_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20typedef __m512h Packet32h;
21typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
22typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
23
24template <>
25struct is_arithmetic<Packet8h> {
26 enum { value = true };
27};
28
29template <>
30struct packet_traits<half> : default_packet_traits {
31 typedef Packet32h type;
32 typedef Packet16h half;
33 enum {
34 Vectorizable = 1,
35 AlignedOnScalar = 1,
36 size = 32,
37
38 HasCmp = 1,
39 HasAdd = 1,
40 HasSub = 1,
41 HasMul = 1,
42 HasDiv = 1,
43 HasNegate = 1,
44 HasAbs = 1,
45 HasAbs2 = 0,
46 HasMin = 1,
47 HasMax = 1,
48 HasConj = 1,
49 HasSetLinear = 0,
50 HasLog = 1,
51 HasLog1p = 1,
52 HasExp = 1,
53 HasExpm1 = 1,
54 HasSqrt = 1,
55 HasRsqrt = 1,
56 // These ones should be implemented in future
57 HasBessel = 0,
58 HasNdtri = 0,
59 HasSin = EIGEN_FAST_MATH,
60 HasCos = EIGEN_FAST_MATH,
61 HasTanh = EIGEN_FAST_MATH,
62 HasErf = 0, // EIGEN_FAST_MATH,
63 HasBlend = 0
64 };
65};
66
67template <>
68struct unpacket_traits<Packet32h> {
69 typedef Eigen::half type;
70 typedef Packet16h half;
71 enum {
72 size = 32,
73 alignment = Aligned64,
74 vectorizable = true,
75 masked_load_available = false,
76 masked_store_available = false
77 };
78};
79
80template <>
81struct unpacket_traits<Packet16h> {
82 typedef Eigen::half type;
83 typedef Packet8h half;
84 enum {
85 size = 16,
86 alignment = Aligned32,
87 vectorizable = true,
88 masked_load_available = false,
89 masked_store_available = false
90 };
91};
92
93template <>
94struct unpacket_traits<Packet8h> {
95 typedef Eigen::half type;
96 typedef Packet8h half;
97 enum {
98 size = 8,
99 alignment = Aligned16,
100 vectorizable = true,
101 masked_load_available = false,
102 masked_store_available = false
103 };
104};
105
106// Memory functions
107
108// pset1
109
110template <>
111EIGEN_STRONG_INLINE Packet32h pset1<Packet32h>(const Eigen::half& from) {
112 return _mm512_set1_ph(static_cast<_Float16>(from));
113}
114
115// pset1frombits
116template <>
117EIGEN_STRONG_INLINE Packet32h pset1frombits<Packet32h>(unsigned short from) {
118 return _mm512_castsi512_ph(_mm512_set1_epi16(from));
119}
120
121// pfirst
122
123template <>
124EIGEN_STRONG_INLINE Eigen::half pfirst<Packet32h>(const Packet32h& from) {
125#ifdef EIGEN_VECTORIZE_AVX512DQ
126 return half_impl::raw_uint16_to_half(
127 static_cast<unsigned short>(_mm256_extract_epi16(_mm512_extracti32x8_epi32(_mm512_castph_si512(from), 0), 0)));
128#else
129 Eigen::half dest[32];
130 _mm512_storeu_ph(dest, from);
131 return dest[0];
132#endif
133}
134
135// pload
136
137template <>
138EIGEN_STRONG_INLINE Packet32h pload<Packet32h>(const Eigen::half* from) {
139 EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ph(from);
140}
141
142// ploadu
143
144template <>
145EIGEN_STRONG_INLINE Packet32h ploadu<Packet32h>(const Eigen::half* from) {
146 EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ph(from);
147}
148
149// pstore
150
151template <>
152EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet32h& from) {
153 EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ph(to, from);
154}
155
156// pstoreu
157
158template <>
159EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet32h& from) {
160 EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ph(to, from);
161}
162
163// ploaddup
164template <>
165EIGEN_STRONG_INLINE Packet32h ploaddup<Packet32h>(const Eigen::half* from) {
166 __m512h a = _mm512_castph256_ph512(_mm256_loadu_ph(from));
167 return _mm512_permutexvar_ph(_mm512_set_epi16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6,
168 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0),
169 a);
170}
171
172// ploadquad
173template <>
174EIGEN_STRONG_INLINE Packet32h ploadquad<Packet32h>(const Eigen::half* from) {
175 __m512h a = _mm512_castph128_ph512(_mm_loadu_ph(from));
176 return _mm512_permutexvar_ph(
177 _mm512_set_epi16(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0),
178 a);
179}
180
181// pabs
182
183template <>
184EIGEN_STRONG_INLINE Packet32h pabs<Packet32h>(const Packet32h& a) {
185 return _mm512_abs_ph(a);
186}
187
188// psignbit
189
190template <>
191EIGEN_STRONG_INLINE Packet32h psignbit<Packet32h>(const Packet32h& a) {
192 return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15));
193}
194
195// pmin
196
197template <>
198EIGEN_STRONG_INLINE Packet32h pmin<Packet32h>(const Packet32h& a, const Packet32h& b) {
199 return _mm512_min_ph(a, b);
200}
201
202// pmax
203
204template <>
205EIGEN_STRONG_INLINE Packet32h pmax<Packet32h>(const Packet32h& a, const Packet32h& b) {
206 return _mm512_max_ph(a, b);
207}
208
209// plset
210template <>
211EIGEN_STRONG_INLINE Packet32h plset<Packet32h>(const half& a) {
212 return _mm512_add_ph(_mm512_set1_ph(a),
213 _mm512_set_ph(31.0f, 30.0f, 29.0f, 28.0f, 27.0f, 26.0f, 25.0f, 24.0f, 23.0f, 22.0f, 21.0f, 20.0f,
214 19.0f, 18.0f, 17.0f, 16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 8.0f,
215 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f));
216}
217
218// por
219
220template <>
221EIGEN_STRONG_INLINE Packet32h por(const Packet32h& a, const Packet32h& b) {
222 return _mm512_castsi512_ph(_mm512_or_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
223}
224
225// pxor
226
227template <>
228EIGEN_STRONG_INLINE Packet32h pxor(const Packet32h& a, const Packet32h& b) {
229 return _mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
230}
231
232// pand
233
234template <>
235EIGEN_STRONG_INLINE Packet32h pand(const Packet32h& a, const Packet32h& b) {
236 return _mm512_castsi512_ph(_mm512_and_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
237}
238
239// pandnot
240
241template <>
242EIGEN_STRONG_INLINE Packet32h pandnot(const Packet32h& a, const Packet32h& b) {
243 return _mm512_castsi512_ph(_mm512_andnot_si512(_mm512_castph_si512(b), _mm512_castph_si512(a)));
244}
245
246// pselect
247
248template <>
249EIGEN_DEVICE_FUNC inline Packet32h pselect(const Packet32h& mask, const Packet32h& a, const Packet32h& b) {
250 __mmask32 mask32 = _mm512_cmp_epi16_mask(_mm512_castph_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
251 return _mm512_mask_blend_ph(mask32, a, b);
252}
253
254// pcmp_eq
255
256template <>
257EIGEN_STRONG_INLINE Packet32h pcmp_eq(const Packet32h& a, const Packet32h& b) {
258 __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_EQ_OQ);
259 return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
260}
261
262// pcmp_le
263
264template <>
265EIGEN_STRONG_INLINE Packet32h pcmp_le(const Packet32h& a, const Packet32h& b) {
266 __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LE_OQ);
267 return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
268}
269
270// pcmp_lt
271
272template <>
273EIGEN_STRONG_INLINE Packet32h pcmp_lt(const Packet32h& a, const Packet32h& b) {
274 __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LT_OQ);
275 return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
276}
277
278// pcmp_lt_or_nan
279
280template <>
281EIGEN_STRONG_INLINE Packet32h pcmp_lt_or_nan(const Packet32h& a, const Packet32h& b) {
282 __mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_NGE_UQ);
283 return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi16(0), mask, 0xffffu));
284}
285
286// padd
287
288template <>
289EIGEN_STRONG_INLINE Packet32h padd<Packet32h>(const Packet32h& a, const Packet32h& b) {
290 return _mm512_add_ph(a, b);
291}
292
293template <>
294EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
295 return _mm256_castph_si256(_mm256_add_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
296}
297
298template <>
299EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
300 return _mm_castph_si128(_mm_add_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
301}
302
303// psub
304
305template <>
306EIGEN_STRONG_INLINE Packet32h psub<Packet32h>(const Packet32h& a, const Packet32h& b) {
307 return _mm512_sub_ph(a, b);
308}
309
310template <>
311EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
312 return _mm256_castph_si256(_mm256_sub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
313}
314
315template <>
316EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) {
317 return _mm_castph_si128(_mm_sub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
318}
319
320// pmul
321
322template <>
323EIGEN_STRONG_INLINE Packet32h pmul<Packet32h>(const Packet32h& a, const Packet32h& b) {
324 return _mm512_mul_ph(a, b);
325}
326
327template <>
328EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
329 return _mm256_castph_si256(_mm256_mul_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
330}
331
332template <>
333EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
334 return _mm_castph_si128(_mm_mul_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
335}
336
337// pdiv
338
339template <>
340EIGEN_STRONG_INLINE Packet32h pdiv<Packet32h>(const Packet32h& a, const Packet32h& b) {
341 return _mm512_div_ph(a, b);
342}
343
344template <>
345EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
346 return _mm256_castph_si256(_mm256_div_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
347}
348
349template <>
350EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
351 return _mm_castph_si128(_mm_div_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
352}
353
354// pround
355
356template <>
357EIGEN_STRONG_INLINE Packet32h pround<Packet32h>(const Packet32h& a) {
358 // Work-around for default std::round rounding mode.
359
360 // Mask for the sign bit
361 const Packet32h signMask = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x8000u));
362 // The largest half-preicision float less than 0.5
363 const Packet32h prev0dot5 = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x37FFu));
364
365 return _mm512_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
366}
367
368// print
369
370template <>
371EIGEN_STRONG_INLINE Packet32h print<Packet32h>(const Packet32h& a) {
372 return _mm512_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
373}
374
375// pceil
376
377template <>
378EIGEN_STRONG_INLINE Packet32h pceil<Packet32h>(const Packet32h& a) {
379 return _mm512_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
380}
381
382// pfloor
383
384template <>
385EIGEN_STRONG_INLINE Packet32h pfloor<Packet32h>(const Packet32h& a) {
386 return _mm512_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
387}
388
389// ptrunc
390
391template <>
392EIGEN_STRONG_INLINE Packet32h ptrunc<Packet32h>(const Packet32h& a) {
393 return _mm512_roundscale_ph(a, _MM_FROUND_TO_ZERO);
394}
395
396// predux
397template <>
398EIGEN_STRONG_INLINE half predux<Packet32h>(const Packet32h& a) {
399 return (half)_mm512_reduce_add_ph(a);
400}
401
402template <>
403EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& a) {
404 return (half)_mm256_reduce_add_ph(_mm256_castsi256_ph(a));
405}
406
407template <>
408EIGEN_STRONG_INLINE half predux<Packet8h>(const Packet8h& a) {
409 return (half)_mm_reduce_add_ph(_mm_castsi128_ph(a));
410}
411
412// predux_half_dowto4
413template <>
414EIGEN_STRONG_INLINE Packet16h predux_half_dowto4<Packet32h>(const Packet32h& a) {
415#ifdef EIGEN_VECTORIZE_AVX512DQ
416 __m256i lowHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 0));
417 __m256i highHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 1));
418
419 return Packet16h(padd<Packet16h>(lowHalf, highHalf));
420#else
421 Eigen::half data[32];
422 _mm512_storeu_ph(data, a);
423
424 __m256i lowHalf = _mm256_castph_si256(_mm256_loadu_ph(data));
425 __m256i highHalf = _mm256_castph_si256(_mm256_loadu_ph(data + 16));
426
427 return Packet16h(padd<Packet16h>(lowHalf, highHalf));
428#endif
429}
430
431// predux_max
432
433// predux_min
434
435// predux_mul
436
437#ifdef EIGEN_VECTORIZE_FMA
438
439// pmadd
440
441template <>
442EIGEN_STRONG_INLINE Packet32h pmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
443 return _mm512_fmadd_ph(a, b, c);
444}
445
446template <>
447EIGEN_STRONG_INLINE Packet16h pmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
448 return _mm256_castph_si256(_mm256_fmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
449}
450
451template <>
452EIGEN_STRONG_INLINE Packet8h pmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
453 return _mm_castph_si128(_mm_fmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
454}
455
456// pmsub
457
458template <>
459EIGEN_STRONG_INLINE Packet32h pmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
460 return _mm512_fmsub_ph(a, b, c);
461}
462
463template <>
464EIGEN_STRONG_INLINE Packet16h pmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
465 return _mm256_castph_si256(_mm256_fmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
466}
467
468template <>
469EIGEN_STRONG_INLINE Packet8h pmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
470 return _mm_castph_si128(_mm_fmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
471}
472
473// pnmadd
474
475template <>
476EIGEN_STRONG_INLINE Packet32h pnmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
477 return _mm512_fnmadd_ph(a, b, c);
478}
479
480template <>
481EIGEN_STRONG_INLINE Packet16h pnmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
482 return _mm256_castph_si256(_mm256_fnmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
483}
484
485template <>
486EIGEN_STRONG_INLINE Packet8h pnmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
487 return _mm_castph_si128(_mm_fnmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
488}
489
490// pnmsub
491
492template <>
493EIGEN_STRONG_INLINE Packet32h pnmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
494 return _mm512_fnmsub_ph(a, b, c);
495}
496
497template <>
498EIGEN_STRONG_INLINE Packet16h pnmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
499 return _mm256_castph_si256(_mm256_fnmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
500}
501
502template <>
503EIGEN_STRONG_INLINE Packet8h pnmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
504 return _mm_castph_si128(_mm_fnmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
505}
506
507#endif
508
509// pnegate
510
511template <>
512EIGEN_STRONG_INLINE Packet32h pnegate<Packet32h>(const Packet32h& a) {
513 return _mm512_sub_ph(_mm512_set1_ph(0.0), a);
514}
515
516// pconj
517
518template <>
519EIGEN_STRONG_INLINE Packet32h pconj<Packet32h>(const Packet32h& a) {
520 return a;
521}
522
523// psqrt
524
525template <>
526EIGEN_STRONG_INLINE Packet32h psqrt<Packet32h>(const Packet32h& a) {
527 return _mm512_sqrt_ph(a);
528}
529
530// prsqrt
531
532template <>
533EIGEN_STRONG_INLINE Packet32h prsqrt<Packet32h>(const Packet32h& a) {
534 return _mm512_rsqrt_ph(a);
535}
536
537// preciprocal
538
539template <>
540EIGEN_STRONG_INLINE Packet32h preciprocal<Packet32h>(const Packet32h& a) {
541 return _mm512_rcp_ph(a);
542}
543
544// ptranspose
545
546EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 32>& a) {
547 __m512i t[32];
548
549 EIGEN_UNROLL_LOOP
550 for (int i = 0; i < 16; i++) {
551 t[2 * i] = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
552 t[2 * i + 1] =
553 _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
554 }
555
556 __m512i p[32];
557
558 EIGEN_UNROLL_LOOP
559 for (int i = 0; i < 8; i++) {
560 p[4 * i] = _mm512_unpacklo_epi32(t[4 * i], t[4 * i + 2]);
561 p[4 * i + 1] = _mm512_unpackhi_epi32(t[4 * i], t[4 * i + 2]);
562 p[4 * i + 2] = _mm512_unpacklo_epi32(t[4 * i + 1], t[4 * i + 3]);
563 p[4 * i + 3] = _mm512_unpackhi_epi32(t[4 * i + 1], t[4 * i + 3]);
564 }
565
566 __m512i q[32];
567
568 EIGEN_UNROLL_LOOP
569 for (int i = 0; i < 4; i++) {
570 q[8 * i] = _mm512_unpacklo_epi64(p[8 * i], p[8 * i + 4]);
571 q[8 * i + 1] = _mm512_unpackhi_epi64(p[8 * i], p[8 * i + 4]);
572 q[8 * i + 2] = _mm512_unpacklo_epi64(p[8 * i + 1], p[8 * i + 5]);
573 q[8 * i + 3] = _mm512_unpackhi_epi64(p[8 * i + 1], p[8 * i + 5]);
574 q[8 * i + 4] = _mm512_unpacklo_epi64(p[8 * i + 2], p[8 * i + 6]);
575 q[8 * i + 5] = _mm512_unpackhi_epi64(p[8 * i + 2], p[8 * i + 6]);
576 q[8 * i + 6] = _mm512_unpacklo_epi64(p[8 * i + 3], p[8 * i + 7]);
577 q[8 * i + 7] = _mm512_unpackhi_epi64(p[8 * i + 3], p[8 * i + 7]);
578 }
579
580 __m512i f[32];
581
582#define PACKET32H_TRANSPOSE_HELPER(X, Y) \
583 do { \
584 f[Y * 8] = _mm512_inserti32x4(f[Y * 8], _mm512_extracti32x4_epi32(q[X * 8], Y), X); \
585 f[Y * 8 + 1] = _mm512_inserti32x4(f[Y * 8 + 1], _mm512_extracti32x4_epi32(q[X * 8 + 1], Y), X); \
586 f[Y * 8 + 2] = _mm512_inserti32x4(f[Y * 8 + 2], _mm512_extracti32x4_epi32(q[X * 8 + 2], Y), X); \
587 f[Y * 8 + 3] = _mm512_inserti32x4(f[Y * 8 + 3], _mm512_extracti32x4_epi32(q[X * 8 + 3], Y), X); \
588 f[Y * 8 + 4] = _mm512_inserti32x4(f[Y * 8 + 4], _mm512_extracti32x4_epi32(q[X * 8 + 4], Y), X); \
589 f[Y * 8 + 5] = _mm512_inserti32x4(f[Y * 8 + 5], _mm512_extracti32x4_epi32(q[X * 8 + 5], Y), X); \
590 f[Y * 8 + 6] = _mm512_inserti32x4(f[Y * 8 + 6], _mm512_extracti32x4_epi32(q[X * 8 + 6], Y), X); \
591 f[Y * 8 + 7] = _mm512_inserti32x4(f[Y * 8 + 7], _mm512_extracti32x4_epi32(q[X * 8 + 7], Y), X); \
592 } while (false);
593
594 PACKET32H_TRANSPOSE_HELPER(0, 0);
595 PACKET32H_TRANSPOSE_HELPER(1, 1);
596 PACKET32H_TRANSPOSE_HELPER(2, 2);
597 PACKET32H_TRANSPOSE_HELPER(3, 3);
598
599 PACKET32H_TRANSPOSE_HELPER(1, 0);
600 PACKET32H_TRANSPOSE_HELPER(2, 0);
601 PACKET32H_TRANSPOSE_HELPER(3, 0);
602 PACKET32H_TRANSPOSE_HELPER(2, 1);
603 PACKET32H_TRANSPOSE_HELPER(3, 1);
604 PACKET32H_TRANSPOSE_HELPER(3, 2);
605
606 PACKET32H_TRANSPOSE_HELPER(0, 1);
607 PACKET32H_TRANSPOSE_HELPER(0, 2);
608 PACKET32H_TRANSPOSE_HELPER(0, 3);
609 PACKET32H_TRANSPOSE_HELPER(1, 2);
610 PACKET32H_TRANSPOSE_HELPER(1, 3);
611 PACKET32H_TRANSPOSE_HELPER(2, 3);
612
613#undef PACKET32H_TRANSPOSE_HELPER
614
615 EIGEN_UNROLL_LOOP
616 for (int i = 0; i < 32; i++) {
617 a.packet[i] = _mm512_castsi512_ph(f[i]);
618 }
619}
620
621EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 4>& a) {
622 __m512i p0, p1, p2, p3, t0, t1, t2, t3, a0, a1, a2, a3;
623 t0 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
624 t1 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
625 t2 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
626 t3 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
627
628 p0 = _mm512_unpacklo_epi32(t0, t2);
629 p1 = _mm512_unpackhi_epi32(t0, t2);
630 p2 = _mm512_unpacklo_epi32(t1, t3);
631 p3 = _mm512_unpackhi_epi32(t1, t3);
632
633 a0 = p0;
634 a1 = p1;
635 a2 = p2;
636 a3 = p3;
637
638 a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p1, 0), 1);
639 a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p0, 1), 0);
640
641 a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p2, 0), 2);
642 a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p0, 2), 0);
643
644 a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p3, 0), 3);
645 a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p0, 3), 0);
646
647 a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p2, 1), 2);
648 a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p1, 2), 1);
649
650 a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p3, 2), 3);
651 a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p2, 3), 2);
652
653 a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p3, 1), 3);
654 a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p1, 3), 1);
655
656 a.packet[0] = _mm512_castsi512_ph(a0);
657 a.packet[1] = _mm512_castsi512_ph(a1);
658 a.packet[2] = _mm512_castsi512_ph(a2);
659 a.packet[3] = _mm512_castsi512_ph(a3);
660}
661
662// preverse
663
664template <>
665EIGEN_STRONG_INLINE Packet32h preverse(const Packet32h& a) {
666 return _mm512_permutexvar_ph(_mm512_set_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
667 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31),
668 a);
669}
670
671// pscatter
672
673template <>
674EIGEN_STRONG_INLINE void pscatter<half, Packet32h>(half* to, const Packet32h& from, Index stride) {
675 EIGEN_ALIGN64 half aux[32];
676 pstore(aux, from);
677
678 EIGEN_UNROLL_LOOP
679 for (int i = 0; i < 32; i++) {
680 to[stride * i] = aux[i];
681 }
682}
683
684// pgather
685
686template <>
687EIGEN_STRONG_INLINE Packet32h pgather<Eigen::half, Packet32h>(const Eigen::half* from, Index stride) {
688 return _mm512_castsi512_ph(_mm512_set_epi16(
689 from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x, from[27 * stride].x,
690 from[26 * stride].x, from[25 * stride].x, from[24 * stride].x, from[23 * stride].x, from[22 * stride].x,
691 from[21 * stride].x, from[20 * stride].x, from[19 * stride].x, from[18 * stride].x, from[17 * stride].x,
692 from[16 * stride].x, from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
693 from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x, from[7 * stride].x,
694 from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x, from[2 * stride].x,
695 from[1 * stride].x, from[0 * stride].x));
696}
697
698template <>
699EIGEN_STRONG_INLINE Packet16h pcos<Packet16h>(const Packet16h&);
700template <>
701EIGEN_STRONG_INLINE Packet16h psin<Packet16h>(const Packet16h&);
702template <>
703EIGEN_STRONG_INLINE Packet16h plog<Packet16h>(const Packet16h&);
704template <>
705EIGEN_STRONG_INLINE Packet16h plog2<Packet16h>(const Packet16h&);
706template <>
707EIGEN_STRONG_INLINE Packet16h plog1p<Packet16h>(const Packet16h&);
708template <>
709EIGEN_STRONG_INLINE Packet16h pexp<Packet16h>(const Packet16h&);
710template <>
711EIGEN_STRONG_INLINE Packet16h pexpm1<Packet16h>(const Packet16h&);
712template <>
713EIGEN_STRONG_INLINE Packet16h ptanh<Packet16h>(const Packet16h&);
714template <>
715EIGEN_STRONG_INLINE Packet16h pfrexp<Packet16h>(const Packet16h&, Packet16h&);
716template <>
717EIGEN_STRONG_INLINE Packet16h pldexp<Packet16h>(const Packet16h&, const Packet16h&);
718
719EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) {
720 __m512d result = _mm512_undefined_pd();
721 result = _mm512_insertf64x4(result, _mm256_castsi256_pd(a), 0);
722 result = _mm512_insertf64x4(result, _mm256_castsi256_pd(b), 1);
723 return _mm512_castpd_ph(result);
724}
725
726EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) {
727 a = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 0));
728 b = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 1));
729}
730
731// psin
732template <>
733EIGEN_STRONG_INLINE Packet32h psin<Packet32h>(const Packet32h& a) {
734 Packet16h low;
735 Packet16h high;
736 extract2Packet16h(a, low, high);
737
738 Packet16h lowOut = psin(low);
739 Packet16h highOut = psin(high);
740
741 return combine2Packet16h(lowOut, highOut);
742}
743
744// pcos
745template <>
746EIGEN_STRONG_INLINE Packet32h pcos<Packet32h>(const Packet32h& a) {
747 Packet16h low;
748 Packet16h high;
749 extract2Packet16h(a, low, high);
750
751 Packet16h lowOut = pcos(low);
752 Packet16h highOut = pcos(high);
753
754 return combine2Packet16h(lowOut, highOut);
755}
756
757// plog
758template <>
759EIGEN_STRONG_INLINE Packet32h plog<Packet32h>(const Packet32h& a) {
760 Packet16h low;
761 Packet16h high;
762 extract2Packet16h(a, low, high);
763
764 Packet16h lowOut = plog(low);
765 Packet16h highOut = plog(high);
766
767 return combine2Packet16h(lowOut, highOut);
768}
769
770// plog2
771template <>
772EIGEN_STRONG_INLINE Packet32h plog2<Packet32h>(const Packet32h& a) {
773 Packet16h low;
774 Packet16h high;
775 extract2Packet16h(a, low, high);
776
777 Packet16h lowOut = plog2(low);
778 Packet16h highOut = plog2(high);
779
780 return combine2Packet16h(lowOut, highOut);
781}
782
783// plog1p
784template <>
785EIGEN_STRONG_INLINE Packet32h plog1p<Packet32h>(const Packet32h& a) {
786 Packet16h low;
787 Packet16h high;
788 extract2Packet16h(a, low, high);
789
790 Packet16h lowOut = plog1p(low);
791 Packet16h highOut = plog1p(high);
792
793 return combine2Packet16h(lowOut, highOut);
794}
795
796// pexp
797template <>
798EIGEN_STRONG_INLINE Packet32h pexp<Packet32h>(const Packet32h& a) {
799 Packet16h low;
800 Packet16h high;
801 extract2Packet16h(a, low, high);
802
803 Packet16h lowOut = pexp(low);
804 Packet16h highOut = pexp(high);
805
806 return combine2Packet16h(lowOut, highOut);
807}
808
809// pexpm1
810template <>
811EIGEN_STRONG_INLINE Packet32h pexpm1<Packet32h>(const Packet32h& a) {
812 Packet16h low;
813 Packet16h high;
814 extract2Packet16h(a, low, high);
815
816 Packet16h lowOut = pexpm1(low);
817 Packet16h highOut = pexpm1(high);
818
819 return combine2Packet16h(lowOut, highOut);
820}
821
822// ptanh
823template <>
824EIGEN_STRONG_INLINE Packet32h ptanh<Packet32h>(const Packet32h& a) {
825 Packet16h low;
826 Packet16h high;
827 extract2Packet16h(a, low, high);
828
829 Packet16h lowOut = ptanh(low);
830 Packet16h highOut = ptanh(high);
831
832 return combine2Packet16h(lowOut, highOut);
833}
834
835// pfrexp
836template <>
837EIGEN_STRONG_INLINE Packet32h pfrexp<Packet32h>(const Packet32h& a, Packet32h& exponent) {
838 Packet16h low;
839 Packet16h high;
840 extract2Packet16h(a, low, high);
841
842 Packet16h exp1 = _mm256_undefined_si256();
843 Packet16h exp2 = _mm256_undefined_si256();
844
845 Packet16h lowOut = pfrexp(low, exp1);
846 Packet16h highOut = pfrexp(high, exp2);
847
848 exponent = combine2Packet16h(exp1, exp2);
849
850 return combine2Packet16h(lowOut, highOut);
851}
852
853// pldexp
854template <>
855EIGEN_STRONG_INLINE Packet32h pldexp<Packet32h>(const Packet32h& a, const Packet32h& exponent) {
856 Packet16h low;
857 Packet16h high;
858 extract2Packet16h(a, low, high);
859
860 Packet16h exp1;
861 Packet16h exp2;
862 extract2Packet16h(exponent, exp1, exp2);
863
864 Packet16h lowOut = pldexp(low, exp1);
865 Packet16h highOut = pldexp(high, exp2);
866
867 return combine2Packet16h(lowOut, highOut);
868}
869
870} // end namespace internal
871} // end namespace Eigen
872
873#endif // EIGEN_PACKET_MATH_FP16_AVX512_H
@ Aligned64
Definition Constants.h:239
@ Aligned32
Definition Constants.h:238
@ Aligned16
Definition Constants.h:237
Namespace containing all symbols from the Eigen library.
Definition Core:137