Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
AVX512/Complex.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2018 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_COMPLEX_AVX512_H
11#define EIGEN_COMPLEX_AVX512_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20//---------- float ----------
21struct Packet8cf {
22 EIGEN_STRONG_INLINE Packet8cf() {}
23 EIGEN_STRONG_INLINE explicit Packet8cf(const __m512& a) : v(a) {}
24 __m512 v;
25};
26
27template <>
28struct packet_traits<std::complex<float> > : default_packet_traits {
29 typedef Packet8cf type;
30 typedef Packet4cf half;
31 enum {
32 Vectorizable = 1,
33 AlignedOnScalar = 1,
34 size = 8,
35
36 HasAdd = 1,
37 HasSub = 1,
38 HasMul = 1,
39 HasDiv = 1,
40 HasNegate = 1,
41 HasSqrt = 1,
42 HasLog = 1,
43 HasExp = 1,
44 HasAbs = 0,
45 HasAbs2 = 0,
46 HasMin = 0,
47 HasMax = 0,
48 HasSetLinear = 0
49 };
50};
51
52template <>
53struct unpacket_traits<Packet8cf> {
54 typedef std::complex<float> type;
55 typedef Packet4cf half;
56 typedef Packet16f as_real;
57 enum {
58 size = 8,
59 alignment = unpacket_traits<Packet16f>::alignment,
60 vectorizable = true,
61 masked_load_available = false,
62 masked_store_available = false
63 };
64};
65
66template <>
67EIGEN_STRONG_INLINE Packet8cf ptrue<Packet8cf>(const Packet8cf& a) {
68 return Packet8cf(ptrue(Packet16f(a.v)));
69}
70template <>
71EIGEN_STRONG_INLINE Packet8cf padd<Packet8cf>(const Packet8cf& a, const Packet8cf& b) {
72 return Packet8cf(_mm512_add_ps(a.v, b.v));
73}
74template <>
75EIGEN_STRONG_INLINE Packet8cf psub<Packet8cf>(const Packet8cf& a, const Packet8cf& b) {
76 return Packet8cf(_mm512_sub_ps(a.v, b.v));
77}
78template <>
79EIGEN_STRONG_INLINE Packet8cf pnegate(const Packet8cf& a) {
80 return Packet8cf(pnegate(a.v));
81}
82template <>
83EIGEN_STRONG_INLINE Packet8cf pconj(const Packet8cf& a) {
84 const __m512 mask = _mm512_castsi512_ps(_mm512_setr_epi32(
85 0x00000000, 0x80000000, 0x00000000, 0x80000000, 0x00000000, 0x80000000, 0x00000000, 0x80000000, 0x00000000,
86 0x80000000, 0x00000000, 0x80000000, 0x00000000, 0x80000000, 0x00000000, 0x80000000));
87 return Packet8cf(pxor(a.v, mask));
88}
89
90template <>
91EIGEN_STRONG_INLINE Packet8cf pmul<Packet8cf>(const Packet8cf& a, const Packet8cf& b) {
92 __m512 tmp2 = _mm512_mul_ps(_mm512_movehdup_ps(a.v), _mm512_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)));
93 return Packet8cf(_mm512_fmaddsub_ps(_mm512_moveldup_ps(a.v), b.v, tmp2));
94}
95
96template <>
97EIGEN_STRONG_INLINE Packet8cf pand<Packet8cf>(const Packet8cf& a, const Packet8cf& b) {
98 return Packet8cf(pand(a.v, b.v));
99}
100template <>
101EIGEN_STRONG_INLINE Packet8cf por<Packet8cf>(const Packet8cf& a, const Packet8cf& b) {
102 return Packet8cf(por(a.v, b.v));
103}
104template <>
105EIGEN_STRONG_INLINE Packet8cf pxor<Packet8cf>(const Packet8cf& a, const Packet8cf& b) {
106 return Packet8cf(pxor(a.v, b.v));
107}
108template <>
109EIGEN_STRONG_INLINE Packet8cf pandnot<Packet8cf>(const Packet8cf& a, const Packet8cf& b) {
110 return Packet8cf(pandnot(a.v, b.v));
111}
112
113template <>
114EIGEN_STRONG_INLINE Packet8cf pcmp_eq(const Packet8cf& a, const Packet8cf& b) {
115 __m512 eq = pcmp_eq<Packet16f>(a.v, b.v);
116 return Packet8cf(pand(eq, _mm512_permute_ps(eq, 0xB1)));
117}
118
119template <>
120EIGEN_STRONG_INLINE Packet8cf pload<Packet8cf>(const std::complex<float>* from) {
121 EIGEN_DEBUG_ALIGNED_LOAD return Packet8cf(pload<Packet16f>(&numext::real_ref(*from)));
122}
123template <>
124EIGEN_STRONG_INLINE Packet8cf ploadu<Packet8cf>(const std::complex<float>* from) {
125 EIGEN_DEBUG_UNALIGNED_LOAD return Packet8cf(ploadu<Packet16f>(&numext::real_ref(*from)));
126}
127
128template <>
129EIGEN_STRONG_INLINE Packet8cf pset1<Packet8cf>(const std::complex<float>& from) {
130 const float re = std::real(from);
131 const float im = std::imag(from);
132 return Packet8cf(_mm512_set_ps(im, re, im, re, im, re, im, re, im, re, im, re, im, re, im, re));
133}
134
135template <>
136EIGEN_STRONG_INLINE Packet8cf ploaddup<Packet8cf>(const std::complex<float>* from) {
137 return Packet8cf(_mm512_castpd_ps(ploaddup<Packet8d>((const double*)(const void*)from)));
138}
139template <>
140EIGEN_STRONG_INLINE Packet8cf ploadquad<Packet8cf>(const std::complex<float>* from) {
141 return Packet8cf(_mm512_castpd_ps(ploadquad<Packet8d>((const double*)(const void*)from)));
142}
143
144template <>
145EIGEN_STRONG_INLINE void pstore<std::complex<float> >(std::complex<float>* to, const Packet8cf& from) {
146 EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v);
147}
148template <>
149EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float>* to, const Packet8cf& from) {
150 EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v);
151}
152
153template <>
154EIGEN_DEVICE_FUNC inline Packet8cf pgather<std::complex<float>, Packet8cf>(const std::complex<float>* from,
155 Index stride) {
156 return Packet8cf(_mm512_castpd_ps(pgather<double, Packet8d>((const double*)(const void*)from, stride)));
157}
158
159template <>
160EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet8cf>(std::complex<float>* to, const Packet8cf& from,
161 Index stride) {
162 pscatter((double*)(void*)to, _mm512_castps_pd(from.v), stride);
163}
164
165template <>
166EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet8cf>(const Packet8cf& a) {
167 return pfirst(Packet2cf(_mm512_castps512_ps128(a.v)));
168}
169
170template <>
171EIGEN_STRONG_INLINE Packet8cf preverse(const Packet8cf& a) {
172 return Packet8cf(_mm512_castsi512_ps(_mm512_permutexvar_epi64(
173 _mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7), _mm512_castps_si512(a.v))));
174}
175
176template <>
177EIGEN_STRONG_INLINE std::complex<float> predux<Packet8cf>(const Packet8cf& a) {
178 return predux(padd(Packet4cf(extract256<0>(a.v)), Packet4cf(extract256<1>(a.v))));
179}
180
181template <>
182EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet8cf>(const Packet8cf& a) {
183 return predux_mul(pmul(Packet4cf(extract256<0>(a.v)), Packet4cf(extract256<1>(a.v))));
184}
185
186template <>
187EIGEN_STRONG_INLINE Packet4cf predux_half_dowto4<Packet8cf>(const Packet8cf& a) {
188 __m256 lane0 = extract256<0>(a.v);
189 __m256 lane1 = extract256<1>(a.v);
190 __m256 res = _mm256_add_ps(lane0, lane1);
191 return Packet4cf(res);
192}
193
194EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet8cf, Packet16f)
195
196template <>
197EIGEN_STRONG_INLINE Packet8cf pdiv<Packet8cf>(const Packet8cf& a, const Packet8cf& b) {
198 return pdiv_complex(a, b);
199}
200
201template <>
202EIGEN_STRONG_INLINE Packet8cf pcplxflip<Packet8cf>(const Packet8cf& x) {
203 return Packet8cf(_mm512_shuffle_ps(x.v, x.v, _MM_SHUFFLE(2, 3, 0, 1)));
204}
205
206//---------- double ----------
207struct Packet4cd {
208 EIGEN_STRONG_INLINE Packet4cd() {}
209 EIGEN_STRONG_INLINE explicit Packet4cd(const __m512d& a) : v(a) {}
210 __m512d v;
211};
212
213template <>
214struct packet_traits<std::complex<double> > : default_packet_traits {
215 typedef Packet4cd type;
216 typedef Packet2cd half;
217 enum {
218 Vectorizable = 1,
219 AlignedOnScalar = 0,
220 size = 4,
221
222 HasAdd = 1,
223 HasSub = 1,
224 HasMul = 1,
225 HasDiv = 1,
226 HasNegate = 1,
227 HasSqrt = 1,
228 HasLog = 1,
229 HasAbs = 0,
230 HasAbs2 = 0,
231 HasMin = 0,
232 HasMax = 0,
233 HasSetLinear = 0
234 };
235};
236
237template <>
238struct unpacket_traits<Packet4cd> {
239 typedef std::complex<double> type;
240 typedef Packet2cd half;
241 typedef Packet8d as_real;
242 enum {
243 size = 4,
244 alignment = unpacket_traits<Packet8d>::alignment,
245 vectorizable = true,
246 masked_load_available = false,
247 masked_store_available = false
248 };
249};
250
251template <>
252EIGEN_STRONG_INLINE Packet4cd padd<Packet4cd>(const Packet4cd& a, const Packet4cd& b) {
253 return Packet4cd(_mm512_add_pd(a.v, b.v));
254}
255template <>
256EIGEN_STRONG_INLINE Packet4cd psub<Packet4cd>(const Packet4cd& a, const Packet4cd& b) {
257 return Packet4cd(_mm512_sub_pd(a.v, b.v));
258}
259template <>
260EIGEN_STRONG_INLINE Packet4cd pnegate(const Packet4cd& a) {
261 return Packet4cd(pnegate(a.v));
262}
263template <>
264EIGEN_STRONG_INLINE Packet4cd pconj(const Packet4cd& a) {
265 const __m512d mask = _mm512_castsi512_pd(_mm512_set_epi32(0x80000000, 0x0, 0x0, 0x0, 0x80000000, 0x0, 0x0, 0x0,
266 0x80000000, 0x0, 0x0, 0x0, 0x80000000, 0x0, 0x0, 0x0));
267 return Packet4cd(pxor(a.v, mask));
268}
269
270template <>
271EIGEN_STRONG_INLINE Packet4cd pmul<Packet4cd>(const Packet4cd& a, const Packet4cd& b) {
272 __m512d tmp1 = _mm512_shuffle_pd(a.v, a.v, 0x0);
273 __m512d tmp2 = _mm512_shuffle_pd(a.v, a.v, 0xFF);
274 __m512d tmp3 = _mm512_shuffle_pd(b.v, b.v, 0x55);
275 __m512d odd = _mm512_mul_pd(tmp2, tmp3);
276 return Packet4cd(_mm512_fmaddsub_pd(tmp1, b.v, odd));
277}
278
279template <>
280EIGEN_STRONG_INLINE Packet4cd ptrue<Packet4cd>(const Packet4cd& a) {
281 return Packet4cd(ptrue(Packet8d(a.v)));
282}
283template <>
284EIGEN_STRONG_INLINE Packet4cd pand<Packet4cd>(const Packet4cd& a, const Packet4cd& b) {
285 return Packet4cd(pand(a.v, b.v));
286}
287template <>
288EIGEN_STRONG_INLINE Packet4cd por<Packet4cd>(const Packet4cd& a, const Packet4cd& b) {
289 return Packet4cd(por(a.v, b.v));
290}
291template <>
292EIGEN_STRONG_INLINE Packet4cd pxor<Packet4cd>(const Packet4cd& a, const Packet4cd& b) {
293 return Packet4cd(pxor(a.v, b.v));
294}
295template <>
296EIGEN_STRONG_INLINE Packet4cd pandnot<Packet4cd>(const Packet4cd& a, const Packet4cd& b) {
297 return Packet4cd(pandnot(a.v, b.v));
298}
299
300template <>
301EIGEN_STRONG_INLINE Packet4cd pcmp_eq(const Packet4cd& a, const Packet4cd& b) {
302 __m512d eq = pcmp_eq<Packet8d>(a.v, b.v);
303 return Packet4cd(pand(eq, _mm512_permute_pd(eq, 0x55)));
304}
305
306template <>
307EIGEN_STRONG_INLINE Packet4cd pload<Packet4cd>(const std::complex<double>* from) {
308 EIGEN_DEBUG_ALIGNED_LOAD return Packet4cd(pload<Packet8d>((const double*)from));
309}
310template <>
311EIGEN_STRONG_INLINE Packet4cd ploadu<Packet4cd>(const std::complex<double>* from) {
312 EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cd(ploadu<Packet8d>((const double*)from));
313}
314
315template <>
316EIGEN_STRONG_INLINE Packet4cd pset1<Packet4cd>(const std::complex<double>& from) {
317 return Packet4cd(_mm512_castps_pd(_mm512_broadcast_f32x4(_mm_castpd_ps(pset1<Packet1cd>(from).v))));
318}
319
320template <>
321EIGEN_STRONG_INLINE Packet4cd ploaddup<Packet4cd>(const std::complex<double>* from) {
322 return Packet4cd(
323 _mm512_insertf64x4(_mm512_castpd256_pd512(ploaddup<Packet2cd>(from).v), ploaddup<Packet2cd>(from + 1).v, 1));
324}
325
326template <>
327EIGEN_STRONG_INLINE void pstore<std::complex<double> >(std::complex<double>* to, const Packet4cd& from) {
328 EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v);
329}
330template <>
331EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double>* to, const Packet4cd& from) {
332 EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v);
333}
334
335template <>
336EIGEN_DEVICE_FUNC inline Packet4cd pgather<std::complex<double>, Packet4cd>(const std::complex<double>* from,
337 Index stride) {
338 return Packet4cd(_mm512_insertf64x4(
339 _mm512_castpd256_pd512(_mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu<Packet1cd>(from + 0 * stride).v),
340 ploadu<Packet1cd>(from + 1 * stride).v, 1)),
341 _mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu<Packet1cd>(from + 2 * stride).v),
342 ploadu<Packet1cd>(from + 3 * stride).v, 1),
343 1));
344}
345
346template <>
347EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet4cd>(std::complex<double>* to, const Packet4cd& from,
348 Index stride) {
349 __m512i fromi = _mm512_castpd_si512(from.v);
350 double* tod = (double*)(void*)to;
351 _mm_storeu_pd(tod + 0 * stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi, 0)));
352 _mm_storeu_pd(tod + 2 * stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi, 1)));
353 _mm_storeu_pd(tod + 4 * stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi, 2)));
354 _mm_storeu_pd(tod + 6 * stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi, 3)));
355}
356
357template <>
358EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet4cd>(const Packet4cd& a) {
359 __m128d low = extract128<0>(a.v);
360 EIGEN_ALIGN16 double res[2];
361 _mm_store_pd(res, low);
362 return std::complex<double>(res[0], res[1]);
363}
364
365template <>
366EIGEN_STRONG_INLINE Packet4cd preverse(const Packet4cd& a) {
367 return Packet4cd(_mm512_shuffle_f64x2(a.v, a.v, (shuffle_mask<3, 2, 1, 0>::mask)));
368}
369
370template <>
371EIGEN_STRONG_INLINE std::complex<double> predux<Packet4cd>(const Packet4cd& a) {
372 return predux(padd(Packet2cd(_mm512_extractf64x4_pd(a.v, 0)), Packet2cd(_mm512_extractf64x4_pd(a.v, 1))));
373}
374
375template <>
376EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet4cd>(const Packet4cd& a) {
377 return predux_mul(pmul(Packet2cd(_mm512_extractf64x4_pd(a.v, 0)), Packet2cd(_mm512_extractf64x4_pd(a.v, 1))));
378}
379
380EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cd, Packet8d)
381
382template <>
383EIGEN_STRONG_INLINE Packet4cd pdiv<Packet4cd>(const Packet4cd& a, const Packet4cd& b) {
384 return pdiv_complex(a, b);
385}
386
387template <>
388EIGEN_STRONG_INLINE Packet4cd pcplxflip<Packet4cd>(const Packet4cd& x) {
389 return Packet4cd(_mm512_permute_pd(x.v, 0x55));
390}
391
392EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8cf, 4>& kernel) {
393 PacketBlock<Packet8d, 4> pb;
394
395 pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
396 pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
397 pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
398 pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
399 ptranspose(pb);
400 kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
401 kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
402 kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
403 kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
404}
405
406EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8cf, 8>& kernel) {
407 PacketBlock<Packet8d, 8> pb;
408
409 pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
410 pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
411 pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
412 pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
413 pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v);
414 pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v);
415 pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v);
416 pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v);
417 ptranspose(pb);
418 kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
419 kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
420 kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
421 kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
422 kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]);
423 kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]);
424 kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]);
425 kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]);
426}
427
428EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4cd, 4>& kernel) {
429 __m512d T0 =
430 _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<0, 1, 0, 1>::mask)); // [a0 a1 b0 b1]
431 __m512d T1 =
432 _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<2, 3, 2, 3>::mask)); // [a2 a3 b2 b3]
433 __m512d T2 =
434 _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<0, 1, 0, 1>::mask)); // [c0 c1 d0 d1]
435 __m512d T3 =
436 _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<2, 3, 2, 3>::mask)); // [c2 c3 d2 d3]
437
438 kernel.packet[3] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<1, 3, 1, 3>::mask))); // [a3 b3 c3 d3]
439 kernel.packet[2] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<0, 2, 0, 2>::mask))); // [a2 b2 c2 d2]
440 kernel.packet[1] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<1, 3, 1, 3>::mask))); // [a1 b1 c1 d1]
441 kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<0, 2, 0, 2>::mask))); // [a0 b0 c0 d0]
442}
443
444template <>
445EIGEN_STRONG_INLINE Packet4cd psqrt<Packet4cd>(const Packet4cd& a) {
446 return psqrt_complex<Packet4cd>(a);
447}
448
449template <>
450EIGEN_STRONG_INLINE Packet8cf psqrt<Packet8cf>(const Packet8cf& a) {
451 return psqrt_complex<Packet8cf>(a);
452}
453
454template <>
455EIGEN_STRONG_INLINE Packet4cd plog<Packet4cd>(const Packet4cd& a) {
456 return plog_complex<Packet4cd>(a);
457}
458
459template <>
460EIGEN_STRONG_INLINE Packet8cf plog<Packet8cf>(const Packet8cf& a) {
461 return plog_complex<Packet8cf>(a);
462}
463
464template <>
465EIGEN_STRONG_INLINE Packet8cf pexp<Packet8cf>(const Packet8cf& a) {
466 return pexp_complex<Packet8cf>(a);
467}
468
469} // end namespace internal
470} // end namespace Eigen
471
472#endif // EIGEN_COMPLEX_AVX512_H
Namespace containing all symbols from the Eigen library.
Definition Core:137