Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
GemmKernel.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2022 Intel Corporation
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
11#define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
12
13#if EIGEN_COMP_MSVC
14#include <intrin.h>
15#else
16#include <x86intrin.h>
17#endif
18#include <immintrin.h>
19#include <type_traits>
20
21// IWYU pragma: private
22#include "../../InternalHeaderCheck.h"
23
24#if !defined(EIGEN_USE_AVX512_GEMM_KERNELS)
25#define EIGEN_USE_AVX512_GEMM_KERNELS 1
26#endif
27
28#define SECOND_FETCH (32)
29#if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
30// Use less registers to load A elements to workaround compiler spills. Loose a
31// bit of performance (less than ~2%).
32#define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
33#endif
34
35namespace Eigen {
36namespace internal {
37
38template <typename Scalar, bool is_unit_inc>
39class gemm_class {
40 using vec = typename packet_traits<Scalar>::type;
41 using vec_ymm = typename unpacket_traits<vec>::half;
42 using vec_xmm = typename unpacket_traits<vec_ymm>::half;
43 using umask_t = typename unpacket_traits<vec>::mask_t;
44
45 static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float);
46 static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double);
47
48#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
49 static constexpr bool use_less_a_regs = !is_unit_inc;
50#else
51 static constexpr bool use_less_a_regs = true;
52#endif
53#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
54 static constexpr bool use_less_b_regs = !is_unit_inc;
55#else
56 static constexpr bool use_less_b_regs = true;
57#endif
58
59 static constexpr int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5};
60 static constexpr int b_regs[] = {6, use_less_b_regs ? 6 : 7};
61 static constexpr int c_regs[] = {
62 8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
63 };
64
65 static constexpr int alpha_load_reg = 0;
66 static constexpr int c_load_regs[] = {1, 2, 6};
67
68 static constexpr int a_shift = 128;
69 static constexpr int b_shift = 128;
70
71 static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8;
72 static constexpr int a_prefetch_size = nelems_in_cache_line * 2;
73 static constexpr int b_prefetch_size = nelems_in_cache_line * 8;
74
75 vec zmm[32];
76 umask_t mask;
77
78 // gemm arguments.
79 Index m;
80 const Index n, k, ldc;
81 const Index inc;
82 const Scalar *alpha;
83
84 const Scalar *a, *b;
85 Scalar *c;
86
87 const bool is_alpha1;
88 const bool is_beta0;
89
90 const Index a_stride, b_stride;
91 const Index a_off, b_off;
92
93 EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) {
94 _mm_prefetch((char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
95 }
96
97 EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) {
98 _mm_prefetch((char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
99 }
100
101 EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) { _mm_prefetch((char *)(x_addr - a_shift), _MM_HINT_T2); }
102
103 EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) {
104#if defined(__PRFCHW__) && __PRFCHW__ == 1
105 _m_prefetchw((void *)c_addr);
106#else
107 _mm_prefetch((char *)c_addr, _MM_HINT_T0);
108#endif
109 }
110
111 template <int nelems>
112 EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) {
113 switch (nelems * sizeof(*a_addr) * 8) {
114 default:
115 case 512 * 3:
116 a_reg = ploadu<vec>(a_addr);
117 break;
118 case 512 * 2:
119 a_reg = ploadu<vec>(a_addr);
120 break;
121 case 512 * 1:
122 a_reg = ploadu<vec>(a_addr);
123 break;
124 case 256 * 1:
125 a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(reinterpret_cast<const double *>(a_addr))));
126 break;
127 case 128 * 1:
128 a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(reinterpret_cast<const float *>(a_addr))));
129 break;
130 case 64 * 1:
131 a_reg = preinterpret<vec>(pload1<Packet8d>(reinterpret_cast<const double *>(a_addr)));
132 break;
133 case 32 * 1:
134 a_reg = pload1<vec>(a_addr);
135 break;
136 }
137 }
138
139 EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) { b_reg = pload1<vec>(b_addr); }
140
141 template <int nelems>
142 EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) {
143 if (is_unit_inc) {
144 switch (nelems * sizeof(*mem) * 8) {
145 default:
146 case 512 * 3:
147 pstoreu(mem, src);
148 break;
149 case 512 * 2:
150 pstoreu(mem, src);
151 break;
152 case 512 * 1:
153 pstoreu(mem, src);
154 break;
155 case 256 * 1:
156 pstoreu(mem, preinterpret<vec_ymm>(src));
157 break;
158 case 128 * 1:
159 pstoreu(mem, preinterpret<vec_xmm>(src));
160 break;
161 case 64 * 1:
162 pstorel(mem, preinterpret<vec_xmm>(src));
163 break;
164 case 32 * 1:
165 pstores(mem, preinterpret<vec_xmm>(src));
166 break;
167 }
168 } else {
169 switch (nelems * sizeof(*mem) * 8) {
170 default:
171 case 512 * 3:
172 pscatter(mem, src, inc);
173 break;
174 case 512 * 2:
175 pscatter(mem, src, inc);
176 break;
177 case 512 * 1:
178 pscatter(mem, src, inc);
179 break;
180 case 256 * 1:
181 pscatter(mem, src, inc, mask);
182 break;
183 case 128 * 1:
184 pscatter(mem, src, inc, mask);
185 break;
186 case 64 * 1:
187 pscatter(mem, src, inc, mask);
188 break;
189 case 32 * 1:
190 pscatter(mem, src, inc, mask);
191 break;
192 }
193 }
194 }
195
196 template <int nelems>
197 EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec &reg) {
198 if (is_unit_inc) {
199 switch (nelems * sizeof(*mem) * 8) {
200 default:
201 case 512 * 3:
202 dst = padd(src, ploadu<vec>(mem));
203 break;
204 case 512 * 2:
205 dst = padd(src, ploadu<vec>(mem));
206 break;
207 case 512 * 1:
208 dst = padd(src, ploadu<vec>(mem));
209 break;
210 case 256 * 1:
211 dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
212 break;
213 case 128 * 1:
214 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
215 break;
216 case 64 * 1:
217 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
218 break;
219 case 32 * 1:
220 dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
221 break;
222 }
223 } else {
224 // Zero out scratch register
225 reg = pzero(reg);
226
227 switch (nelems * sizeof(*mem) * 8) {
228 default:
229 case 512 * 3:
230 reg = pgather<Scalar, vec>(mem, inc);
231 dst = padd(src, reg);
232 break;
233 case 512 * 2:
234 reg = pgather<Scalar, vec>(mem, inc);
235 dst = padd(src, reg);
236 break;
237 case 512 * 1:
238 reg = pgather<Scalar, vec>(mem, inc);
239 dst = padd(src, reg);
240 break;
241 case 256 * 1:
242 reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
243 dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
244 break;
245 case 128 * 1:
246 reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
247 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
248 break;
249 case 64 * 1:
250 if (is_f32) {
251 reg = pgather(reg, mem, inc, mask);
252 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
253 } else {
254 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
255 }
256 break;
257 case 32 * 1:
258 dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
259 break;
260 }
261 }
262 }
263
264 EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) {
265 dst = pmadd(src1, src2, dst);
266
267#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
268 // Workaround register spills for gcc and clang
269 __asm__("#" : [dst] "+v"(dst) : [src1] "%v"(src1), [src2] "v"(src2));
270#endif
271 }
272
273 template <int nelems>
274 EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec &reg) {
275 if (is_unit_inc) {
276 switch (nelems * sizeof(*mem) * 8) {
277 default:
278 case 512 * 3:
279 dst = pmadd(scale, src, ploadu<vec>(mem));
280 break;
281 case 512 * 2:
282 dst = pmadd(scale, src, ploadu<vec>(mem));
283 break;
284 case 512 * 1:
285 dst = pmadd(scale, src, ploadu<vec>(mem));
286 break;
287 case 256 * 1:
288 dst =
289 preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
290 break;
291 case 128 * 1:
292 dst =
293 preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
294 break;
295 case 64 * 1:
296 dst =
297 preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
298 break;
299 case 32 * 1:
300 dst =
301 preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
302 break;
303 }
304 } else {
305 // Zero out scratch register
306 reg = pzero(reg);
307
308 switch (nelems * sizeof(*mem) * 8) {
309 default:
310 case 512 * 3:
311 reg = pgather<Scalar, vec>(mem, inc);
312 dst = pmadd(scale, src, reg);
313 break;
314 case 512 * 2:
315 reg = pgather<Scalar, vec>(mem, inc);
316 dst = pmadd(scale, src, reg);
317 break;
318 case 512 * 1:
319 reg = pgather<Scalar, vec>(mem, inc);
320 dst = pmadd(scale, src, reg);
321 break;
322 case 256 * 1:
323 reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
324 dst = preinterpret<vec>(
325 pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
326 break;
327 case 128 * 1:
328 reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
329 dst = preinterpret<vec>(
330 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
331 break;
332 case 64 * 1:
333 if (is_f32) {
334 reg = pgather(reg, mem, inc, mask);
335 dst = preinterpret<vec>(
336 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
337 } else {
338 dst = preinterpret<vec>(
339 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
340 }
341 break;
342 case 32 * 1:
343 dst =
344 preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
345 break;
346 }
347 }
348 }
349
350 template <int j, int endX, int i, int endY, int nelems>
351 EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> a_loads(const Scalar *ao) {
352 EIGEN_UNUSED_VARIABLE(ao);
353 }
354
355 template <int j, int endX, int i, int endY, int nelems>
356 EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> a_loads(const Scalar *ao) {
357 if (j < endX) {
358 if (i < endY) {
359 auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
360 const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
361 a_load<nelems>(a_reg, a_addr);
362
363 a_loads<j, endX, i + 1, endY, nelems>(ao);
364 } else {
365 a_loads<j + 1, endX, 0, endY, nelems>(ao);
366 }
367 }
368 }
369
370 template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
371 EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> prefetch_cs(const Scalar *co1,
372 const Scalar *co2) {
373 EIGEN_UNUSED_VARIABLE(co1);
374 EIGEN_UNUSED_VARIABLE(co2);
375 }
376
377 /* C prefetch loop structure.
378 * for (int un = 0; un < 8; un++) {
379 * if (b_unroll >= un + 1) {
380 * if (un == 4) co2 = co1 + 4 * ldc;
381 *
382 * for (int i = 0; i < um_vecs; i++) {
383 * Scalar *co = (un + 1 <= 4) ? co1 : co2;
384 * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
385 * prefetch_c(co + co_off);
386 * }
387 * }
388 * }
389 */
390
391 template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
392 EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) {
393 if (un < max_b_unroll) {
394 if (b_unroll >= un + 1) {
395 if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
396
397 if (i < um_vecs) {
398 Scalar *co = (un + 1 <= 4) ? co1 : co2;
399 auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
400 prefetch_c(co + co_off);
401
402 prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
403 } else {
404 prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
405 }
406
407 } else {
408 prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
409 }
410 }
411 }
412
413 // load_c
414 template <int i, int um_vecs, int idx, int nelems>
415 EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
416 EIGEN_UNUSED_VARIABLE(cox);
417 EIGEN_UNUSED_VARIABLE(alpha_reg);
418 }
419
420 template <int i, int um_vecs, int idx, int nelems>
421 EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
422 if (i < um_vecs) {
423 auto &c_reg = zmm[c_regs[i + idx * 3]];
424 auto &c_load_reg = zmm[c_load_regs[i % 3]];
425 auto c_mem = cox;
426 if (is_unit_inc)
427 c_mem += i * nelems_in_cache_line;
428 else
429 c_mem += i * nelems_in_cache_line * inc;
430
431 if (!is_beta0 && is_alpha1)
432 vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
433 else if (!is_beta0 && !is_alpha1)
434 vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
435 else if (is_beta0 && !is_alpha1)
436 c_reg = pmul(alpha_reg, c_reg);
437
438 scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
439 }
440 }
441
442 // store_c
443 template <int i, int um_vecs, int idx, int nelems>
444 EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox) {
445 EIGEN_UNUSED_VARIABLE(cox);
446 }
447
448 template <int i, int um_vecs, int idx, int nelems>
449 EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> write_c(Scalar *cox) {
450 if (i < um_vecs) {
451 auto &c_reg = zmm[c_regs[i + idx * 3]];
452 auto c_mem = cox;
453 if (is_unit_inc)
454 c_mem += i * nelems_in_cache_line;
455 else
456 c_mem += i * nelems_in_cache_line * inc;
457
458 c_store<nelems>(c_mem, c_reg);
459 c_reg = pzero(c_reg);
460
461 write_c<i + 1, um_vecs, idx, nelems>(cox);
462 }
463 }
464
465 /* C update loop structure.
466 * co2 = co1 + ldc;
467 *
468 * auto &alpha_reg = zmm[alpha_load_reg];
469 * if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
470 *
471 * int idx = 0;
472 * for (pow = 1; pow <= 8; pow <<= 1) {
473 *
474 * if (b_unroll >= pow) {
475 * for (count = 1; count < (pow + 1) / 2 + 1; count++) {
476 * if (pow >= 4) co2 += ldc;
477 *
478 * const Scalar *cox = (idx == 0) ? co1 : co2;
479 *
480 * const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
481 * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
482 * write_c<0, um_vecs, idx, a_unroll>(cox);
483 *
484 * idx++;
485 * }
486 * }
487 * }
488 *
489 * if (b_unroll == 1)
490 * co1 += ldc;
491 * else
492 * co1 = co2 + ldc;
493 */
494
495 template <int pow, int a_unroll, int idx>
496 EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) {
497 if (pow >= 4) cox += ldc;
498
499 const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
500 auto &alpha_reg = zmm[alpha_load_reg];
501
502 scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
503 write_c<0, um_vecs, idx, a_unroll>(cox);
504 }
505
506 template <int pow, int a_unroll>
507 EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) {
508 constexpr int idx = pow / 2;
509 Scalar *&cox = idx == 0 ? co1 : co2;
510
511 constexpr int max_count = (pow + 1) / 2;
512 static_assert(max_count <= 4, "Unsupported max_count.");
513
514 if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
515 if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
516 if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
517 if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
518 }
519
520 template <int max_b_unroll, int a_unroll, int b_unroll>
521 EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) {
522 auto &alpha_reg = zmm[alpha_load_reg];
523
524 co2 = co1 + ldc;
525 if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
526 if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask = static_cast<umask_t>((1ull << a_unroll) - 1);
527
528 static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll");
529
530 if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
531 if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
532 if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
533 if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
534
535 if (b_unroll == 1)
536 co1 += ldc;
537 else
538 co1 = co2 + ldc;
539 }
540
541 // compute
542 template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
543 EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
544 int &fetchB_idx, vec &b_reg) {
545 EIGEN_UNUSED_VARIABLE(ao);
546 EIGEN_UNUSED_VARIABLE(bo);
547 EIGEN_UNUSED_VARIABLE(fetchA_idx);
548 EIGEN_UNUSED_VARIABLE(fetchB_idx);
549 EIGEN_UNUSED_VARIABLE(b_reg);
550 }
551
552 template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
553 EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
554 int &fetchB_idx, vec &b_reg) {
555 if (um < um_vecs) {
556 auto &c_reg = zmm[c_regs[um + idx * 3]];
557 auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
558
559 vfmadd(c_reg, a_reg, b_reg);
560
561 if (!fetch_x && um == 0 &&
562 (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) ||
563 (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
564 prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
565 fetchA_idx++;
566 }
567
568 if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
569 prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
570 fetchB_idx++;
571 }
572
573 compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
574 }
575 }
576
577 // load_a
578 template <int um, int um_vecs, int uk, int nelems, bool ktail>
579 EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(const Scalar *ao) {
580 EIGEN_UNUSED_VARIABLE(ao);
581 }
582
583 template <int um, int um_vecs, int uk, int nelems, bool ktail>
584 EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> load_a(const Scalar *ao) {
585 if (um < um_vecs) {
586 auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
587 const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift;
588 a_load<nelems>(a_reg, a_addr);
589
590 load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
591 }
592 }
593 template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
594 EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
595 const Scalar *const &ao,
596 const Scalar *const &bo, Scalar *&co2,
597 int &fetchA_idx, int &fetchB_idx) {
598 EIGEN_UNUSED_VARIABLE(aa);
599 EIGEN_UNUSED_VARIABLE(ao);
600 EIGEN_UNUSED_VARIABLE(bo);
601 EIGEN_UNUSED_VARIABLE(co2);
602 EIGEN_UNUSED_VARIABLE(fetchA_idx);
603 EIGEN_UNUSED_VARIABLE(fetchB_idx);
604 }
605
606 template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
607 EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
608 const Scalar *const &ao,
609 const Scalar *const &bo, Scalar *&co2,
610 int &fetchA_idx, int &fetchB_idx) {
611 const int idx = (pow / 2) + count;
612
613 if (count < (pow + 1) / 2) {
614 auto &b_reg = zmm[b_regs[idx % 2]];
615
616 if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
617 if (fetch_x && uk == 3 && idx == 4) aa += 8;
618
619 if (b_unroll >= pow) {
620 compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
621
622 const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift;
623 b_load(b_reg, b_addr);
624 }
625
626 // Go to the next count.
627 innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
628 fetchB_idx);
629
630 } else {
631 // Maybe prefetch C data after count-loop.
632 if (pow == 2 && c_fetch) {
633 if (uk % 3 == 0 && uk > 0) {
634 co2 += ldc;
635 } else {
636 prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
637 }
638 }
639 }
640 }
641
642 template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch,
643 bool no_a_preload = false>
644 EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo,
645 Scalar *&co2, int &fetchA_idx, int &fetchB_idx) {
646 const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
647
648 if (max_b_unroll >= 1)
649 innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
650 if (max_b_unroll >= 2)
651 innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
652 if (max_b_unroll >= 4)
653 innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
654 if (max_b_unroll >= 8)
655 innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
656
657 // Load A after pow-loop. Skip this at the end to prevent running over the buffer
658 if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
659 }
660
661 /* Inner kernel loop structure.
662 * for (int uk = 0; uk < kfactor; uk++) {
663 * int idx = 0;
664 *
665 * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) {
666 * for (int count = 0; count < (pow + 1) / 2; count++) {
667 * auto &b_reg = zmm[b_regs[idx % 2]];
668 *
669 * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
670 * if (fetch_x && uk == 3 && idx == 4) aa += 8;
671 *
672 * if (b_unroll >= pow) {
673 * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
674 *
675 * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ;
676 * b_load(b_reg, b_addr);
677 * }
678 * idx++;
679 * }
680 *
681 * Maybe prefetch C data.
682 * if (pow == 2 && c_fetch) {
683 * if (uk % 3 == 0 && uk > 0) {
684 * co2 += ldc;
685 * } else {
686 * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
687 * }
688 * }
689 * }
690 *
691 * Load A.
692 * load_a<0, um_vecs, uk, ktail, a_unroll>(ao);
693 * }
694 *
695 * Advance A/B pointers after uk-loop.
696 * ao += a_unroll * kfactor;
697 * bo += b_unroll * kfactor;
698 */
699
700 template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch,
701 bool no_a_preload = false>
702 EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) {
703 int fetchA_idx = 0;
704 int fetchB_idx = 0;
705
706 const bool fetch_x = k_factor == max_k_factor;
707 const bool ktail = k_factor == 1;
708
709 static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
710 static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1),
711 "skipping a preload only allowed when k unroll is 1");
712
713 if (k_factor > 0)
714 innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
715 aa, ao, bo, co2, fetchA_idx, fetchB_idx);
716 if (k_factor > 1)
717 innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
718 aa, ao, bo, co2, fetchA_idx, fetchB_idx);
719 if (k_factor > 2)
720 innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
721 aa, ao, bo, co2, fetchA_idx, fetchB_idx);
722 if (k_factor > 3)
723 innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
724 aa, ao, bo, co2, fetchA_idx, fetchB_idx);
725
726 // Advance A/B pointers after uk-loop.
727 ao += a_unroll * k_factor;
728 bo += b_unroll * k_factor;
729 }
730
731 template <int a_unroll, int b_unroll, int max_b_unroll>
732 EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
733 const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
734 if (!use_less_a_regs && k > 1)
735 a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
736 else
737 a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
738
739 b_load(zmm[b_regs[0]], bo - b_shift + 0);
740 if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1);
741
742#ifndef SECOND_FETCH
743 prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
744#endif // SECOND_FETCH
745
746 // Unrolling k-loop by a factor of 4.
747 const int max_k_factor = 4;
748 Index kRem = k % max_k_factor;
749 Index k_ = k - kRem;
750 if (k_ >= max_k_factor) {
751 k_ -= max_k_factor;
752 kRem += max_k_factor;
753 }
754 Index loop_count = k_ / max_k_factor;
755
756 if (loop_count > 0) {
757#ifdef SECOND_FETCH
758 loop_count -= SECOND_FETCH;
759#endif
760 while (loop_count > 0) {
761 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
762 loop_count--;
763 }
764#ifdef SECOND_FETCH
765 co2 = co1 + nelems_in_cache_line - 1;
766
767 loop_count += b_unroll;
768 while (loop_count > 0) {
769 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
770 loop_count--;
771 }
772
773 loop_count += SECOND_FETCH - b_unroll;
774 while (loop_count > 0) {
775 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
776 loop_count--;
777 }
778#endif
779 }
780
781 // k-loop remainder handling.
782 loop_count = kRem;
783 while (loop_count > 1) {
784 innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
785 loop_count--;
786 }
787 if (loop_count > 0) {
788 innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
789 }
790
791 // Update C matrix.
792 c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
793 }
794
795 template <int a_unroll, int b_unroll, int max_b_unroll>
796 EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
797 // Set A matrix pointer.
798 ao = a + a_off * a_unroll;
799
800 // Set B matrix pointer if needed.
801 bo += b_unroll * b_off;
802
803 kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
804
805 // Advance B matrix pointer if needed.
806 bo += b_unroll * (b_stride - k - b_off);
807
808 // Advance prefetch A pointer.
809 aa += 16;
810 }
811
812 template <int a_unroll, int max_a_unroll, int max_b_unroll>
813 EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
814 // Set prefetch A pointers.
815 const Scalar *aa = a + a_unroll * a_stride;
816
817 // Set C matrix pointers.
818 co1 = c;
819 if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
820 if (is_unit_inc)
821 c += a_unroll;
822 else
823 c += a_unroll * inc;
824
825 // Set B matrix pointer.
826 bo = b;
827
828 // Main n-loop.
829 for (Index i = n / max_b_unroll; i > 0; i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
830
831 // n-remainders.
832 if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
833#if 0
834 if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
835 if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
836#else
837 // Copy kernels don't support tails of n = 2 for single/double precision.
838 // Loop over ones.
839 int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
840 while (n_rem > 0) {
841 nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
842 n_rem--;
843 }
844#endif
845
846 // Advance A matrix pointer.
847 a = ao + a_unroll * (a_stride - k - a_off);
848 }
849
850 public:
851 // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll.
852 template <int max_a_unroll, int max_b_unroll>
853 EIGEN_ALWAYS_INLINE void compute_kern() {
854 a -= -a_shift;
855 b -= -b_shift;
856
857 const Scalar *ao = nullptr;
858 const Scalar *bo = nullptr;
859 Scalar *co1 = nullptr;
860 Scalar *co2 = nullptr;
861
862 // Main m-loop.
863 for (; m >= max_a_unroll; m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
864
865 // m-remainders.
866 if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
867 if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
868 if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
869 if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
870 if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
871 if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
872
873 // Copy kernels don't support tails of m = 2 for single precision.
874 // Loop over ones.
875 if (is_f32) {
876 int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
877 while (m_rem > 0) {
878 mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
879 m_rem--;
880 }
881 }
882 }
883
884 gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_,
885 const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_,
886 Index a_off_, Index b_off_)
887 : m(m_),
888 n(n_),
889 k(k_),
890 ldc(ldc_),
891 inc(inc_),
892 alpha(alpha_),
893 a(a_),
894 b(b_),
895 c(c_),
896 is_alpha1(is_alpha1_),
897 is_beta0(is_beta0_),
898 a_stride(a_stride_),
899 b_stride(b_stride_),
900 a_off(a_off_),
901 b_off(b_off_) {
902 // Zero out all accumulation registers.
903 zmm[8] = pzero(zmm[8]);
904 zmm[9] = pzero(zmm[9]);
905 zmm[10] = pzero(zmm[10]);
906 zmm[11] = pzero(zmm[11]);
907 zmm[12] = pzero(zmm[12]);
908 zmm[13] = pzero(zmm[13]);
909 zmm[14] = pzero(zmm[14]);
910 zmm[15] = pzero(zmm[15]);
911 zmm[16] = pzero(zmm[16]);
912 zmm[17] = pzero(zmm[17]);
913 zmm[18] = pzero(zmm[18]);
914 zmm[19] = pzero(zmm[19]);
915 zmm[20] = pzero(zmm[20]);
916 zmm[21] = pzero(zmm[21]);
917 zmm[22] = pzero(zmm[22]);
918 zmm[23] = pzero(zmm[23]);
919 zmm[24] = pzero(zmm[24]);
920 zmm[25] = pzero(zmm[25]);
921 zmm[26] = pzero(zmm[26]);
922 zmm[27] = pzero(zmm[27]);
923 zmm[28] = pzero(zmm[28]);
924 zmm[29] = pzero(zmm[29]);
925 zmm[30] = pzero(zmm[30]);
926 zmm[31] = pzero(zmm[31]);
927 }
928};
929
930// Compute kernel with max unroll support of:
931// Single precision:
932// max_a_unroll: 48, 32, 16, 8, 4, 2, 1
933// max_b_unroll: 8, 4, 2, 1
934// Double precision:
935// max_a_unroll: 24, 16, 8, 4, 2, 1
936// max_b_unroll: 8, 4, 2, 1
937template <typename Scalar, int max_a_unroll, int max_b_unroll, bool is_alpha1, bool is_beta0, bool is_unit_inc>
938EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b,
939 Scalar *c, Index ldc, Index inc = 1, Index a_stride = -1, Index b_stride = -1,
940 Index a_off = 0, Index b_off = 0) {
941 if (a_stride == -1) a_stride = k;
942 if (b_stride == -1) b_stride = k;
943
944 gemm_class<Scalar, is_unit_inc> g(m, n, k, ldc, inc, alpha, a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off,
945 b_off);
946 g.template compute_kern<max_a_unroll, max_b_unroll>();
947}
948
949// Template specializations of GEBP kernels with nr = 8.
950#if EIGEN_USE_AVX512_GEMM_KERNELS
951template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
952class gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
953 : public gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
954 using Base = gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
955
956 public:
957 enum { nr = Base::Vectorizable ? 8 : 4 };
958};
959
960template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
961class gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
962 : public gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
963 using Base = gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
964
965 public:
966 enum { nr = Base::Vectorizable ? 8 : 4 };
967};
968
969template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
970struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode> {
971 typedef typename packet_traits<Scalar>::type Packet;
972 typedef typename DataMapper::LinearMapper LinearMapper;
973 enum { PacketSize = packet_traits<Scalar>::size };
974 EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
975 Index offset = 0);
976};
977
978template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
979EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>::operator()(
980 Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) {
981 constexpr int nr = 8;
982 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
983 EIGEN_UNUSED_VARIABLE(stride);
984 EIGEN_UNUSED_VARIABLE(offset);
985 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
986 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
987 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
988 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
989 Index count = 0;
990 const Index peeled_k = (depth / PacketSize) * PacketSize;
991 if (nr >= 8) {
992 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
993 // skip what we have before
994 if (PanelMode) count += 8 * offset;
995 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
996 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
997 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
998 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
999 const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
1000 const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
1001 const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
1002 const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
1003 Index k = 0;
1004 if ((PacketSize % 8) == 0) // TODO enable vectorized transposition for PacketSize==4
1005 {
1006 for (; k < peeled_k; k += PacketSize) {
1007 PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel;
1008
1009 kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1010 kernel.packet[1] = dm1.template loadPacket<Packet>(k);
1011 kernel.packet[2] = dm2.template loadPacket<Packet>(k);
1012 kernel.packet[3] = dm3.template loadPacket<Packet>(k);
1013 kernel.packet[4] = dm4.template loadPacket<Packet>(k);
1014 kernel.packet[5] = dm5.template loadPacket<Packet>(k);
1015 kernel.packet[6] = dm6.template loadPacket<Packet>(k);
1016 kernel.packet[7] = dm7.template loadPacket<Packet>(k);
1017
1018 ptranspose(kernel);
1019
1020 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1021 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1022 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1023 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1024 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
1025 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
1026 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
1027 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
1028 count += 8 * PacketSize;
1029 }
1030 }
1031 for (; k < depth; k++) {
1032 blockB[count + 0] = cj(dm0(k));
1033 blockB[count + 1] = cj(dm1(k));
1034 blockB[count + 2] = cj(dm2(k));
1035 blockB[count + 3] = cj(dm3(k));
1036 blockB[count + 4] = cj(dm4(k));
1037 blockB[count + 5] = cj(dm5(k));
1038 blockB[count + 6] = cj(dm6(k));
1039 blockB[count + 7] = cj(dm7(k));
1040 count += 8;
1041 }
1042 // skip what we have after
1043 if (PanelMode) count += 8 * (stride - offset - depth);
1044 }
1045 }
1046
1047 if (nr >= 4) {
1048 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1049 // skip what we have before
1050 if (PanelMode) count += 4 * offset;
1051 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1052 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1053 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1054 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1055
1056 Index k = 0;
1057 if ((PacketSize % 4) == 0) // TODO enable vectorized transposition for PacketSize==2 ??
1058 {
1059 for (; k < peeled_k; k += PacketSize) {
1060 PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
1061 kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1062 kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
1063 kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
1064 kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
1065 ptranspose(kernel);
1066 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1067 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1068 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1069 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1070 count += 4 * PacketSize;
1071 }
1072 }
1073 for (; k < depth; k++) {
1074 blockB[count + 0] = cj(dm0(k));
1075 blockB[count + 1] = cj(dm1(k));
1076 blockB[count + 2] = cj(dm2(k));
1077 blockB[count + 3] = cj(dm3(k));
1078 count += 4;
1079 }
1080 // skip what we have after
1081 if (PanelMode) count += 4 * (stride - offset - depth);
1082 }
1083 }
1084
1085 // copy the remaining columns one at a time (nr==1)
1086 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1087 if (PanelMode) count += offset;
1088 const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
1089 for (Index k = 0; k < depth; k++) {
1090 blockB[count] = cj(dm0(k));
1091 count += 1;
1092 }
1093 if (PanelMode) count += (stride - offset - depth);
1094 }
1095}
1096
1097template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
1098struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, RowMajor, Conjugate, PanelMode> {
1099 typedef typename packet_traits<Scalar>::type Packet;
1100 typedef typename unpacket_traits<Packet>::half HalfPacket;
1101 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
1102 typedef typename DataMapper::LinearMapper LinearMapper;
1103 enum {
1104 PacketSize = packet_traits<Scalar>::size,
1105 HalfPacketSize = unpacket_traits<HalfPacket>::size,
1106 QuarterPacketSize = unpacket_traits<QuarterPacket>::size
1107 };
1108 EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
1109 Index offset = 0) {
1110 constexpr int nr = 8;
1111 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
1112 EIGEN_UNUSED_VARIABLE(stride);
1113 EIGEN_UNUSED_VARIABLE(offset);
1114 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
1115 const bool HasHalf = (int)HalfPacketSize < (int)PacketSize;
1116 const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize;
1117 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
1118 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1119 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
1120 Index count = 0;
1121
1122 if (nr >= 8) {
1123 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1124 // skip what we have before
1125 if (PanelMode) count += 8 * offset;
1126 for (Index k = 0; k < depth; k++) {
1127 if (PacketSize == 8) {
1128 // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
1129 Packet A = rhs.template loadPacket<Packet>(k, j2);
1130 pstoreu(blockB + count, cj.pconj(A));
1131 } else if (HasHalf && HalfPacketSize == 8) {
1132 HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1133 pstoreu(blockB + count, cj.pconj(A));
1134 } else if (HasQuarter && QuarterPacketSize == 8) {
1135 QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1136 pstoreu(blockB + count, cj.pconj(A));
1137 } else if (PacketSize == 4) {
1138 // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
1139 // Packet B = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2 + PacketSize]);
1140 Packet A = rhs.template loadPacket<Packet>(k, j2);
1141 Packet B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
1142 pstoreu(blockB + count, cj.pconj(A));
1143 pstoreu(blockB + count + PacketSize, cj.pconj(B));
1144 } else {
1145 // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2];
1146 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1147 blockB[count + 0] = cj(dm0(0));
1148 blockB[count + 1] = cj(dm0(1));
1149 blockB[count + 2] = cj(dm0(2));
1150 blockB[count + 3] = cj(dm0(3));
1151 blockB[count + 4] = cj(dm0(4));
1152 blockB[count + 5] = cj(dm0(5));
1153 blockB[count + 6] = cj(dm0(6));
1154 blockB[count + 7] = cj(dm0(7));
1155 }
1156 count += 8;
1157 }
1158 // skip what we have after
1159 if (PanelMode) count += 8 * (stride - offset - depth);
1160 }
1161 }
1162
1163 if (nr >= 4) {
1164 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1165 // skip what we have before
1166 if (PanelMode) count += 4 * offset;
1167 for (Index k = 0; k < depth; k++) {
1168 if (PacketSize == 4) {
1169 Packet A = rhs.template loadPacket<Packet>(k, j2);
1170 pstoreu(blockB + count, cj.pconj(A));
1171 count += PacketSize;
1172 } else if (HasHalf && HalfPacketSize == 4) {
1173 HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1174 pstoreu(blockB + count, cj.pconj(A));
1175 count += HalfPacketSize;
1176 } else if (HasQuarter && QuarterPacketSize == 4) {
1177 QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1178 pstoreu(blockB + count, cj.pconj(A));
1179 count += QuarterPacketSize;
1180 } else {
1181 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1182 blockB[count + 0] = cj(dm0(0));
1183 blockB[count + 1] = cj(dm0(1));
1184 blockB[count + 2] = cj(dm0(2));
1185 blockB[count + 3] = cj(dm0(3));
1186 count += 4;
1187 }
1188 }
1189 // skip what we have after
1190 if (PanelMode) count += 4 * (stride - offset - depth);
1191 }
1192 }
1193 // copy the remaining columns one at a time (nr==1)
1194 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1195 if (PanelMode) count += offset;
1196 for (Index k = 0; k < depth; k++) {
1197 blockB[count] = cj(rhs(k, j2));
1198 count += 1;
1199 }
1200 if (PanelMode) count += stride - offset - depth;
1201 }
1202 }
1203};
1204
1205template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
1206struct gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs> {
1207 EIGEN_ALWAYS_INLINE void operator()(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows,
1208 Index depth, Index cols, Scalar alpha, Index strideA = -1, Index strideB = -1,
1209 Index offsetA = 0, Index offsetB = 0);
1210};
1211
1212template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
1213EIGEN_ALWAYS_INLINE void gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs>::operator()(
1214 const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols,
1215 Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
1216 if (res.incr() == 1) {
1217 if (alpha == 1) {
1218 gemm_kern_avx512<Scalar, mr, 8, true, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1219 (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1220 strideB, offsetA, offsetB);
1221 } else {
1222 gemm_kern_avx512<Scalar, mr, 8, false, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1223 (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1224 strideB, offsetA, offsetB);
1225 }
1226 } else {
1227 if (alpha == 1) {
1228 gemm_kern_avx512<Scalar, mr, 8, true, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1229 (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1230 strideB, offsetA, offsetB);
1231 } else {
1232 gemm_kern_avx512<Scalar, mr, 8, false, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1233 (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1234 strideB, offsetA, offsetB);
1235 }
1236 }
1237}
1238#endif // EIGEN_USE_AVX512_GEMM_KERNELS
1239
1240} // namespace internal
1241} // namespace Eigen
1242
1243#undef SECOND_FETCH
1244
1245#endif // EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
Namespace containing all symbols from the Eigen library.
Definition Core:137