Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
AVX512/TypeCasting.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2019 Rasmus Munk Larsen <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_TYPE_CASTING_AVX512_H
11#define EIGEN_TYPE_CASTING_AVX512_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <>
21struct type_casting_traits<float, bool> : vectorized_type_casting_traits<float, bool> {};
22template <>
23struct type_casting_traits<bool, float> : vectorized_type_casting_traits<bool, float> {};
24
25template <>
26struct type_casting_traits<float, int> : vectorized_type_casting_traits<float, int> {};
27template <>
28struct type_casting_traits<int, float> : vectorized_type_casting_traits<int, float> {};
29
30template <>
31struct type_casting_traits<float, double> : vectorized_type_casting_traits<float, double> {};
32template <>
33struct type_casting_traits<double, float> : vectorized_type_casting_traits<double, float> {};
34
35template <>
36struct type_casting_traits<double, int> : vectorized_type_casting_traits<double, int> {};
37template <>
38struct type_casting_traits<int, double> : vectorized_type_casting_traits<int, double> {};
39
40template <>
41struct type_casting_traits<double, int64_t> : vectorized_type_casting_traits<double, int64_t> {};
42template <>
43struct type_casting_traits<int64_t, double> : vectorized_type_casting_traits<int64_t, double> {};
44
45#ifndef EIGEN_VECTORIZE_AVX512FP16
46template <>
47struct type_casting_traits<half, float> : vectorized_type_casting_traits<half, float> {};
48template <>
49struct type_casting_traits<float, half> : vectorized_type_casting_traits<float, half> {};
50#endif
51
52template <>
53struct type_casting_traits<bfloat16, float> : vectorized_type_casting_traits<bfloat16, float> {};
54template <>
55struct type_casting_traits<float, bfloat16> : vectorized_type_casting_traits<float, bfloat16> {};
56
57template <>
58EIGEN_STRONG_INLINE Packet16b pcast<Packet16f, Packet16b>(const Packet16f& a) {
59 __mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a));
60 return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1));
61}
62
63template <>
64EIGEN_STRONG_INLINE Packet16f pcast<Packet16b, Packet16f>(const Packet16b& a) {
65 return _mm512_cvtepi32_ps(_mm512_and_si512(_mm512_cvtepi8_epi32(a), _mm512_set1_epi32(1)));
66}
67
68template <>
69EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
70 return _mm512_cvttps_epi32(a);
71}
72
73template <>
74EIGEN_STRONG_INLINE Packet8d pcast<Packet16f, Packet8d>(const Packet16f& a) {
75 return _mm512_cvtps_pd(_mm512_castps512_ps256(a));
76}
77
78template <>
79EIGEN_STRONG_INLINE Packet8d pcast<Packet8f, Packet8d>(const Packet8f& a) {
80 return _mm512_cvtps_pd(a);
81}
82
83template <>
84EIGEN_STRONG_INLINE Packet8l pcast<Packet8d, Packet8l>(const Packet8d& a) {
85#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL)
86 return _mm512_cvttpd_epi64(a);
87#else
88 EIGEN_ALIGN16 double aux[8];
89 pstore(aux, a);
90 return _mm512_set_epi64(static_cast<int64_t>(aux[7]), static_cast<int64_t>(aux[6]), static_cast<int64_t>(aux[5]),
91 static_cast<int64_t>(aux[4]), static_cast<int64_t>(aux[3]), static_cast<int64_t>(aux[2]),
92 static_cast<int64_t>(aux[1]), static_cast<int64_t>(aux[0]));
93#endif
94}
95
96template <>
97EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
98 return _mm512_cvtepi32_ps(a);
99}
100
101template <>
102EIGEN_STRONG_INLINE Packet8d pcast<Packet16i, Packet8d>(const Packet16i& a) {
103 return _mm512_cvtepi32_pd(_mm512_castsi512_si256(a));
104}
105
106template <>
107EIGEN_STRONG_INLINE Packet8d pcast<Packet8i, Packet8d>(const Packet8i& a) {
108 return _mm512_cvtepi32_pd(a);
109}
110
111template <>
112EIGEN_STRONG_INLINE Packet8d pcast<Packet8l, Packet8d>(const Packet8l& a) {
113#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL)
114 return _mm512_cvtepi64_pd(a);
115#else
116 EIGEN_ALIGN16 int64_t aux[8];
117 pstore(aux, a);
118 return _mm512_set_pd(static_cast<double>(aux[7]), static_cast<double>(aux[6]), static_cast<double>(aux[5]),
119 static_cast<double>(aux[4]), static_cast<double>(aux[3]), static_cast<double>(aux[2]),
120 static_cast<double>(aux[1]), static_cast<double>(aux[0]));
121#endif
122}
123
124template <>
125EIGEN_STRONG_INLINE Packet16f pcast<Packet8d, Packet16f>(const Packet8d& a, const Packet8d& b) {
126 return cat256(_mm512_cvtpd_ps(a), _mm512_cvtpd_ps(b));
127}
128
129template <>
130EIGEN_STRONG_INLINE Packet16i pcast<Packet8d, Packet16i>(const Packet8d& a, const Packet8d& b) {
131 return cat256i(_mm512_cvttpd_epi32(a), _mm512_cvttpd_epi32(b));
132}
133
134template <>
135EIGEN_STRONG_INLINE Packet8i pcast<Packet8d, Packet8i>(const Packet8d& a) {
136 return _mm512_cvtpd_epi32(a);
137}
138template <>
139EIGEN_STRONG_INLINE Packet8f pcast<Packet8d, Packet8f>(const Packet8d& a) {
140 return _mm512_cvtpd_ps(a);
141}
142
143template <>
144EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
145 return _mm512_castps_si512(a);
146}
147
148template <>
149EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(const Packet16i& a) {
150 return _mm512_castsi512_ps(a);
151}
152
153template <>
154EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
155 return _mm512_castps_pd(a);
156}
157
158template <>
159EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8l>(const Packet8l& a) {
160 return _mm512_castsi512_pd(a);
161}
162
163template <>
164EIGEN_STRONG_INLINE Packet8l preinterpret<Packet8l, Packet8d>(const Packet8d& a) {
165 return _mm512_castpd_si512(a);
166}
167
168template <>
169EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
170 return _mm512_castpd_ps(a);
171}
172
173template <>
174EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
175 return _mm512_castps512_ps256(a);
176}
177
178template <>
179EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet16f>(const Packet16f& a) {
180 return _mm512_castps512_ps128(a);
181}
182
183template <>
184EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet8d>(const Packet8d& a) {
185 return _mm512_castpd512_pd256(a);
186}
187
188template <>
189EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet8d>(const Packet8d& a) {
190 return _mm512_castpd512_pd128(a);
191}
192
193template <>
194EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8f>(const Packet8f& a) {
195 return _mm512_castps256_ps512(a);
196}
197
198template <>
199EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet4f>(const Packet4f& a) {
200 return _mm512_castps128_ps512(a);
201}
202
203template <>
204EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet4d>(const Packet4d& a) {
205 return _mm512_castpd256_pd512(a);
206}
207
208template <>
209EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const Packet2d& a) {
210 return _mm512_castpd128_pd512(a);
211}
212
213template <>
214EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i, Packet16i>(const Packet16i& a) {
215 return _mm512_castsi512_si256(a);
216}
217template <>
218EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet16i>(const Packet16i& a) {
219 return _mm512_castsi512_si128(a);
220}
221
222template <>
223EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet16h>(const Packet16h& a) {
224 return _mm256_castsi256_si128(a);
225}
226
227template <>
228EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
229 return _mm256_castsi256_si128(a);
230}
231
232#ifndef EIGEN_VECTORIZE_AVX512FP16
233
234template <>
235EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
236 return half2float(a);
237}
238
239template <>
240EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
241 return float2half(a);
242}
243
244#endif
245
246template <>
247EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
248 return Bf16ToF32(a);
249}
250
251template <>
252EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
253 return F32ToBf16(a);
254}
255
256#ifdef EIGEN_VECTORIZE_AVX512FP16
257
258template <>
259EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet32h>(const Packet32h& a) {
260 return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
261}
262template <>
263EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet32h>(const Packet32h& a) {
264 return _mm256_castsi256_si128(preinterpret<Packet16h>(a));
265}
266
267template <>
268EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
269 // Discard second-half of input.
270 Packet16h low = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
271 return _mm512_cvtxph_ps(_mm256_castsi256_ph(low));
272}
273
274template <>
275EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
276 __m512d result = _mm512_undefined_pd();
277 result = _mm512_insertf64x4(
278 result, _mm256_castsi256_pd(_mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
279 result = _mm512_insertf64x4(
280 result, _mm256_castsi256_pd(_mm512_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
281 return _mm512_castpd_ph(result);
282}
283
284template <>
285EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
286 // Discard second-half of input.
287 Packet8h low = _mm_castps_si128(_mm256_extractf32x4_ps(_mm256_castsi256_ps(a), 0));
288 return _mm256_cvtxph_ps(_mm_castsi128_ph(low));
289}
290
291template <>
292EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
293 __m256d result = _mm256_undefined_pd();
294 result = _mm256_insertf64x2(result,
295 _mm_castsi128_pd(_mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
296 result = _mm256_insertf64x2(result,
297 _mm_castsi128_pd(_mm256_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
298 return _mm256_castpd_si256(result);
299}
300
301template <>
302EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
303 Packet8f full = _mm256_cvtxph_ps(_mm_castsi128_ph(a));
304 // Discard second-half of input.
305 return _mm256_extractf32x4_ps(full, 0);
306}
307
308template <>
309EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
310 __m256 result = _mm256_undefined_ps();
311 result = _mm256_insertf128_ps(result, a, 0);
312 result = _mm256_insertf128_ps(result, b, 1);
313 return _mm256_cvtps_ph(result, _MM_FROUND_TO_NEAREST_INT);
314}
315
316#endif
317
318} // end namespace internal
319
320} // end namespace Eigen
321
322#endif // EIGEN_TYPE_CASTING_AVX512_H
Namespace containing all symbols from the Eigen library.
Definition Core:137