10#ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
11#define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
22#include "../../InternalHeaderCheck.h"
24#if !defined(EIGEN_USE_AVX512_GEMM_KERNELS)
25#define EIGEN_USE_AVX512_GEMM_KERNELS 1
28#define SECOND_FETCH (32)
29#if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
32#define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
38template <
typename Scalar,
bool is_unit_inc>
40 using vec =
typename packet_traits<Scalar>::type;
41 using vec_ymm =
typename unpacket_traits<vec>::half;
42 using vec_xmm =
typename unpacket_traits<vec_ymm>::half;
43 using umask_t =
typename unpacket_traits<vec>::mask_t;
45 static constexpr bool is_f32 =
sizeof(Scalar) ==
sizeof(
float);
46 static constexpr bool is_f64 =
sizeof(Scalar) ==
sizeof(
double);
48#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
49 static constexpr bool use_less_a_regs = !is_unit_inc;
51 static constexpr bool use_less_a_regs =
true;
53#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
54 static constexpr bool use_less_b_regs = !is_unit_inc;
56 static constexpr bool use_less_b_regs =
true;
59 static constexpr int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5};
60 static constexpr int b_regs[] = {6, use_less_b_regs ? 6 : 7};
61 static constexpr int c_regs[] = {
62 8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
65 static constexpr int alpha_load_reg = 0;
66 static constexpr int c_load_regs[] = {1, 2, 6};
68 static constexpr int a_shift = 128;
69 static constexpr int b_shift = 128;
71 static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8;
72 static constexpr int a_prefetch_size = nelems_in_cache_line * 2;
73 static constexpr int b_prefetch_size = nelems_in_cache_line * 8;
80 const Index n, k, ldc;
90 const Index a_stride, b_stride;
91 const Index a_off, b_off;
93 EIGEN_ALWAYS_INLINE
void prefetch_a(
const Scalar *a_addr) {
94 _mm_prefetch((
char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
97 EIGEN_ALWAYS_INLINE
void prefetch_b(
const Scalar *b_addr) {
98 _mm_prefetch((
char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
101 EIGEN_ALWAYS_INLINE
void prefetch_x(
const Scalar *x_addr) { _mm_prefetch((
char *)(x_addr - a_shift), _MM_HINT_T2); }
103 EIGEN_ALWAYS_INLINE
void prefetch_c(
const Scalar *c_addr) {
104#if defined(__PRFCHW__) && __PRFCHW__ == 1
105 _m_prefetchw((
void *)c_addr);
107 _mm_prefetch((
char *)c_addr, _MM_HINT_T0);
111 template <
int nelems>
112 EIGEN_ALWAYS_INLINE
void a_load(vec &a_reg,
const Scalar *a_addr) {
113 switch (nelems *
sizeof(*a_addr) * 8) {
116 a_reg = ploadu<vec>(a_addr);
119 a_reg = ploadu<vec>(a_addr);
122 a_reg = ploadu<vec>(a_addr);
125 a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(
reinterpret_cast<const double *
>(a_addr))));
128 a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(
reinterpret_cast<const float *
>(a_addr))));
131 a_reg = preinterpret<vec>(pload1<Packet8d>(
reinterpret_cast<const double *
>(a_addr)));
134 a_reg = pload1<vec>(a_addr);
139 EIGEN_ALWAYS_INLINE
void b_load(vec &b_reg,
const Scalar *b_addr) { b_reg = pload1<vec>(b_addr); }
141 template <
int nelems>
142 EIGEN_ALWAYS_INLINE
void c_store(Scalar *mem, vec &src) {
144 switch (nelems *
sizeof(*mem) * 8) {
156 pstoreu(mem, preinterpret<vec_ymm>(src));
159 pstoreu(mem, preinterpret<vec_xmm>(src));
162 pstorel(mem, preinterpret<vec_xmm>(src));
165 pstores(mem, preinterpret<vec_xmm>(src));
169 switch (nelems *
sizeof(*mem) * 8) {
172 pscatter(mem, src, inc);
175 pscatter(mem, src, inc);
178 pscatter(mem, src, inc);
181 pscatter(mem, src, inc, mask);
184 pscatter(mem, src, inc, mask);
187 pscatter(mem, src, inc, mask);
190 pscatter(mem, src, inc, mask);
196 template <
int nelems>
197 EIGEN_ALWAYS_INLINE
void vaddm(vec &dst,
const Scalar *mem, vec &src, vec ®) {
199 switch (nelems *
sizeof(*mem) * 8) {
202 dst = padd(src, ploadu<vec>(mem));
205 dst = padd(src, ploadu<vec>(mem));
208 dst = padd(src, ploadu<vec>(mem));
211 dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
214 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
217 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
220 dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
227 switch (nelems *
sizeof(*mem) * 8) {
230 reg = pgather<Scalar, vec>(mem, inc);
231 dst = padd(src, reg);
234 reg = pgather<Scalar, vec>(mem, inc);
235 dst = padd(src, reg);
238 reg = pgather<Scalar, vec>(mem, inc);
239 dst = padd(src, reg);
242 reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
243 dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
246 reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
247 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
251 reg = pgather(reg, mem, inc, mask);
252 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
254 dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
258 dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
264 EIGEN_STRONG_INLINE
void vfmadd(vec &dst,
const vec &src1,
const vec &src2) {
265 dst = pmadd(src1, src2, dst);
267#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
269 __asm__(
"#" : [dst]
"+v"(dst) : [src1]
"%v"(src1), [src2]
"v"(src2));
273 template <
int nelems>
274 EIGEN_ALWAYS_INLINE
void vfmaddm(vec &dst,
const Scalar *mem, vec &src, vec &scale, vec ®) {
276 switch (nelems *
sizeof(*mem) * 8) {
279 dst = pmadd(scale, src, ploadu<vec>(mem));
282 dst = pmadd(scale, src, ploadu<vec>(mem));
285 dst = pmadd(scale, src, ploadu<vec>(mem));
289 preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
293 preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
297 preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
301 preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
308 switch (nelems *
sizeof(*mem) * 8) {
311 reg = pgather<Scalar, vec>(mem, inc);
312 dst = pmadd(scale, src, reg);
315 reg = pgather<Scalar, vec>(mem, inc);
316 dst = pmadd(scale, src, reg);
319 reg = pgather<Scalar, vec>(mem, inc);
320 dst = pmadd(scale, src, reg);
323 reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
324 dst = preinterpret<vec>(
325 pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
328 reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
329 dst = preinterpret<vec>(
330 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
334 reg = pgather(reg, mem, inc, mask);
335 dst = preinterpret<vec>(
336 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
338 dst = preinterpret<vec>(
339 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
344 preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
350 template <
int j,
int endX,
int i,
int endY,
int nelems>
351 EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> a_loads(
const Scalar *ao) {
352 EIGEN_UNUSED_VARIABLE(ao);
355 template <
int j,
int endX,
int i,
int endY,
int nelems>
356 EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> a_loads(
const Scalar *ao) {
359 auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
360 const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
361 a_load<nelems>(a_reg, a_addr);
363 a_loads<j, endX, i + 1, endY, nelems>(ao);
365 a_loads<j + 1, endX, 0, endY, nelems>(ao);
370 template <
int un,
int max_b_unroll,
int i,
int um_vecs,
int a_unroll,
int b_unroll>
371 EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> prefetch_cs(
const Scalar *co1,
373 EIGEN_UNUSED_VARIABLE(co1);
374 EIGEN_UNUSED_VARIABLE(co2);
391 template <
int un,
int max_b_unroll,
int i,
int um_vecs,
int a_unroll,
int b_unroll>
392 EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) {
393 if (un < max_b_unroll) {
394 if (b_unroll >= un + 1) {
395 if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
398 Scalar *co = (un + 1 <= 4) ? co1 : co2;
399 auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line *
sizeof *co;
400 prefetch_c(co + co_off);
402 prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
404 prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
408 prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
414 template <
int i,
int um_vecs,
int idx,
int nelems>
415 EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(
const Scalar *cox, vec &alpha_reg) {
416 EIGEN_UNUSED_VARIABLE(cox);
417 EIGEN_UNUSED_VARIABLE(alpha_reg);
420 template <
int i,
int um_vecs,
int idx,
int nelems>
421 EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(
const Scalar *cox, vec &alpha_reg) {
423 auto &c_reg = zmm[c_regs[i + idx * 3]];
424 auto &c_load_reg = zmm[c_load_regs[i % 3]];
427 c_mem += i * nelems_in_cache_line;
429 c_mem += i * nelems_in_cache_line * inc;
431 if (!is_beta0 && is_alpha1)
432 vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
433 else if (!is_beta0 && !is_alpha1)
434 vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
435 else if (is_beta0 && !is_alpha1)
436 c_reg = pmul(alpha_reg, c_reg);
438 scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
443 template <
int i,
int um_vecs,
int idx,
int nelems>
444 EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox) {
445 EIGEN_UNUSED_VARIABLE(cox);
448 template <
int i,
int um_vecs,
int idx,
int nelems>
449 EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> write_c(Scalar *cox) {
451 auto &c_reg = zmm[c_regs[i + idx * 3]];
454 c_mem += i * nelems_in_cache_line;
456 c_mem += i * nelems_in_cache_line * inc;
458 c_store<nelems>(c_mem, c_reg);
459 c_reg = pzero(c_reg);
461 write_c<i + 1, um_vecs, idx, nelems>(cox);
495 template <
int pow,
int a_unroll,
int idx>
496 EIGEN_ALWAYS_INLINE
void c_update_1count(Scalar *&cox) {
497 if (pow >= 4) cox += ldc;
499 const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
500 auto &alpha_reg = zmm[alpha_load_reg];
502 scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
503 write_c<0, um_vecs, idx, a_unroll>(cox);
506 template <
int pow,
int a_unroll>
507 EIGEN_ALWAYS_INLINE
void c_update_1pow(Scalar *&co1, Scalar *&co2) {
508 constexpr int idx = pow / 2;
509 Scalar *&cox = idx == 0 ? co1 : co2;
511 constexpr int max_count = (pow + 1) / 2;
512 static_assert(max_count <= 4,
"Unsupported max_count.");
514 if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
515 if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
516 if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
517 if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
520 template <
int max_b_unroll,
int a_unroll,
int b_unroll>
521 EIGEN_ALWAYS_INLINE
void c_update(Scalar *&co1, Scalar *&co2) {
522 auto &alpha_reg = zmm[alpha_load_reg];
525 if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
526 if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask =
static_cast<umask_t
>((1ull << a_unroll) - 1);
528 static_assert(max_b_unroll <= 8,
"Unsupported max_b_unroll");
530 if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
531 if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
532 if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
533 if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
542 template <
int um,
int um_vecs,
int idx,
int uk,
bool fetch_x,
bool ktail>
543 EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(
const Scalar *ao,
const Scalar *bo,
int &fetchA_idx,
544 int &fetchB_idx, vec &b_reg) {
545 EIGEN_UNUSED_VARIABLE(ao);
546 EIGEN_UNUSED_VARIABLE(bo);
547 EIGEN_UNUSED_VARIABLE(fetchA_idx);
548 EIGEN_UNUSED_VARIABLE(fetchB_idx);
549 EIGEN_UNUSED_VARIABLE(b_reg);
552 template <
int um,
int um_vecs,
int idx,
int uk,
bool fetch_x,
bool ktail>
553 EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> compute(
const Scalar *ao,
const Scalar *bo,
int &fetchA_idx,
554 int &fetchB_idx, vec &b_reg) {
556 auto &c_reg = zmm[c_regs[um + idx * 3]];
557 auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
559 vfmadd(c_reg, a_reg, b_reg);
561 if (!fetch_x && um == 0 &&
562 (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) ||
563 (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
564 prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
568 if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
569 prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
573 compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
578 template <
int um,
int um_vecs,
int uk,
int nelems,
bool ktail>
579 EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(
const Scalar *ao) {
580 EIGEN_UNUSED_VARIABLE(ao);
583 template <
int um,
int um_vecs,
int uk,
int nelems,
bool ktail>
584 EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> load_a(
const Scalar *ao) {
586 auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
587 const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift;
588 a_load<nelems>(a_reg, a_addr);
590 load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
593 template <
int uk,
int pow,
int count,
int um_vecs,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch>
594 EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> innerkernel_1pow(
const Scalar *&aa,
595 const Scalar *
const &ao,
596 const Scalar *
const &bo, Scalar *&co2,
597 int &fetchA_idx,
int &fetchB_idx) {
598 EIGEN_UNUSED_VARIABLE(aa);
599 EIGEN_UNUSED_VARIABLE(ao);
600 EIGEN_UNUSED_VARIABLE(bo);
601 EIGEN_UNUSED_VARIABLE(co2);
602 EIGEN_UNUSED_VARIABLE(fetchA_idx);
603 EIGEN_UNUSED_VARIABLE(fetchB_idx);
606 template <
int uk,
int pow,
int count,
int um_vecs,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch>
607 EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> innerkernel_1pow(
const Scalar *&aa,
608 const Scalar *
const &ao,
609 const Scalar *
const &bo, Scalar *&co2,
610 int &fetchA_idx,
int &fetchB_idx) {
611 const int idx = (pow / 2) + count;
613 if (count < (pow + 1) / 2) {
614 auto &b_reg = zmm[b_regs[idx % 2]];
616 if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
617 if (fetch_x && uk == 3 && idx == 4) aa += 8;
619 if (b_unroll >= pow) {
620 compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
622 const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift;
623 b_load(b_reg, b_addr);
627 innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
632 if (pow == 2 && c_fetch) {
633 if (uk % 3 == 0 && uk > 0) {
636 prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
642 template <
int uk,
int max_b_unroll,
int a_unroll,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch,
643 bool no_a_preload =
false>
644 EIGEN_ALWAYS_INLINE
void innerkernel_1uk(
const Scalar *&aa,
const Scalar *
const &ao,
const Scalar *
const &bo,
645 Scalar *&co2,
int &fetchA_idx,
int &fetchB_idx) {
646 const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
648 if (max_b_unroll >= 1)
649 innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
650 if (max_b_unroll >= 2)
651 innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
652 if (max_b_unroll >= 4)
653 innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
654 if (max_b_unroll >= 8)
655 innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
658 if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
700 template <
int a_unroll,
int b_unroll,
int k_factor,
int max_b_unroll,
int max_k_factor,
bool c_fetch,
701 bool no_a_preload =
false>
702 EIGEN_ALWAYS_INLINE
void innerkernel(
const Scalar *&aa,
const Scalar *&ao,
const Scalar *&bo, Scalar *&co2) {
706 const bool fetch_x = k_factor == max_k_factor;
707 const bool ktail = k_factor == 1;
709 static_assert(k_factor <= 4 && k_factor > 0,
"innerkernel maximum k_factor supported is 4");
710 static_assert(no_a_preload ==
false || (no_a_preload ==
true && k_factor == 1),
711 "skipping a preload only allowed when k unroll is 1");
714 innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
715 aa, ao, bo, co2, fetchA_idx, fetchB_idx);
717 innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
718 aa, ao, bo, co2, fetchA_idx, fetchB_idx);
720 innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
721 aa, ao, bo, co2, fetchA_idx, fetchB_idx);
723 innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
724 aa, ao, bo, co2, fetchA_idx, fetchB_idx);
727 ao += a_unroll * k_factor;
728 bo += b_unroll * k_factor;
731 template <
int a_unroll,
int b_unroll,
int max_b_unroll>
732 EIGEN_ALWAYS_INLINE
void kloop(
const Scalar *&aa,
const Scalar *&ao,
const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
733 const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
734 if (!use_less_a_regs && k > 1)
735 a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
737 a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
739 b_load(zmm[b_regs[0]], bo - b_shift + 0);
740 if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1);
743 prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
747 const int max_k_factor = 4;
748 Index kRem = k % max_k_factor;
750 if (k_ >= max_k_factor) {
752 kRem += max_k_factor;
754 Index loop_count = k_ / max_k_factor;
756 if (loop_count > 0) {
758 loop_count -= SECOND_FETCH;
760 while (loop_count > 0) {
761 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
765 co2 = co1 + nelems_in_cache_line - 1;
767 loop_count += b_unroll;
768 while (loop_count > 0) {
769 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
773 loop_count += SECOND_FETCH - b_unroll;
774 while (loop_count > 0) {
775 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
783 while (loop_count > 1) {
784 innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
787 if (loop_count > 0) {
788 innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
792 c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
795 template <
int a_unroll,
int b_unroll,
int max_b_unroll>
796 EIGEN_ALWAYS_INLINE
void nloop(
const Scalar *&aa,
const Scalar *&ao,
const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
798 ao = a + a_off * a_unroll;
801 bo += b_unroll * b_off;
803 kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
806 bo += b_unroll * (b_stride - k - b_off);
812 template <
int a_unroll,
int max_a_unroll,
int max_b_unroll>
813 EIGEN_ALWAYS_INLINE
void mloop(
const Scalar *&ao,
const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
815 const Scalar *aa = a + a_unroll * a_stride;
819 if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
829 for (Index i = n / max_b_unroll; i > 0; i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
832 if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
834 if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
835 if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
839 int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
841 nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
847 a = ao + a_unroll * (a_stride - k - a_off);
852 template <
int max_a_unroll,
int max_b_unroll>
853 EIGEN_ALWAYS_INLINE
void compute_kern() {
857 const Scalar *ao =
nullptr;
858 const Scalar *bo =
nullptr;
859 Scalar *co1 =
nullptr;
860 Scalar *co2 =
nullptr;
863 for (; m >= max_a_unroll; m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
866 if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
867 if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
868 if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
869 if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
870 if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
871 if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
876 int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
878 mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
884 gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_,
const Scalar *alpha_,
const Scalar *a_,
885 const Scalar *b_, Scalar *c_,
bool is_alpha1_,
bool is_beta0_, Index a_stride_, Index b_stride_,
886 Index a_off_, Index b_off_)
896 is_alpha1(is_alpha1_),
903 zmm[8] = pzero(zmm[8]);
904 zmm[9] = pzero(zmm[9]);
905 zmm[10] = pzero(zmm[10]);
906 zmm[11] = pzero(zmm[11]);
907 zmm[12] = pzero(zmm[12]);
908 zmm[13] = pzero(zmm[13]);
909 zmm[14] = pzero(zmm[14]);
910 zmm[15] = pzero(zmm[15]);
911 zmm[16] = pzero(zmm[16]);
912 zmm[17] = pzero(zmm[17]);
913 zmm[18] = pzero(zmm[18]);
914 zmm[19] = pzero(zmm[19]);
915 zmm[20] = pzero(zmm[20]);
916 zmm[21] = pzero(zmm[21]);
917 zmm[22] = pzero(zmm[22]);
918 zmm[23] = pzero(zmm[23]);
919 zmm[24] = pzero(zmm[24]);
920 zmm[25] = pzero(zmm[25]);
921 zmm[26] = pzero(zmm[26]);
922 zmm[27] = pzero(zmm[27]);
923 zmm[28] = pzero(zmm[28]);
924 zmm[29] = pzero(zmm[29]);
925 zmm[30] = pzero(zmm[30]);
926 zmm[31] = pzero(zmm[31]);
937template <
typename Scalar,
int max_a_unroll,
int max_b_unroll,
bool is_alpha1,
bool is_beta0,
bool is_unit_inc>
938EIGEN_DONT_INLINE
void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha,
const Scalar *a,
const Scalar *b,
939 Scalar *c, Index ldc, Index inc = 1, Index a_stride = -1, Index b_stride = -1,
940 Index a_off = 0, Index b_off = 0) {
941 if (a_stride == -1) a_stride = k;
942 if (b_stride == -1) b_stride = k;
944 gemm_class<Scalar, is_unit_inc> g(m, n, k, ldc, inc, alpha, a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off,
946 g.template compute_kern<max_a_unroll, max_b_unroll>();
950#if EIGEN_USE_AVX512_GEMM_KERNELS
951template <
bool ConjLhs_,
bool ConjRhs_,
int PacketSize_>
952class gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
953 :
public gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
954 using Base = gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
957 enum { nr = Base::Vectorizable ? 8 : 4 };
960template <
bool ConjLhs_,
bool ConjRhs_,
int PacketSize_>
961class gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
962 :
public gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
963 using Base = gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
966 enum { nr = Base::Vectorizable ? 8 : 4 };
969template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
970struct gemm_pack_rhs<Scalar, Index, DataMapper, 8,
ColMajor, Conjugate, PanelMode> {
971 typedef typename packet_traits<Scalar>::type Packet;
972 typedef typename DataMapper::LinearMapper LinearMapper;
973 enum { PacketSize = packet_traits<Scalar>::size };
974 EIGEN_DONT_INLINE
void operator()(Scalar *blockB,
const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
978template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
979EIGEN_DONT_INLINE
void gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>::operator()(
980 Scalar *blockB,
const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) {
981 constexpr int nr = 8;
982 EIGEN_ASM_COMMENT(
"EIGEN PRODUCT PACK RHS COLMAJOR");
983 EIGEN_UNUSED_VARIABLE(stride);
984 EIGEN_UNUSED_VARIABLE(offset);
985 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
986 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
987 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
988 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
990 const Index peeled_k = (depth / PacketSize) * PacketSize;
992 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
994 if (PanelMode) count += 8 * offset;
995 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
996 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
997 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
998 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
999 const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
1000 const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
1001 const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
1002 const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
1004 if ((PacketSize % 8) == 0)
1006 for (; k < peeled_k; k += PacketSize) {
1007 PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel;
1009 kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1010 kernel.packet[1] = dm1.template loadPacket<Packet>(k);
1011 kernel.packet[2] = dm2.template loadPacket<Packet>(k);
1012 kernel.packet[3] = dm3.template loadPacket<Packet>(k);
1013 kernel.packet[4] = dm4.template loadPacket<Packet>(k);
1014 kernel.packet[5] = dm5.template loadPacket<Packet>(k);
1015 kernel.packet[6] = dm6.template loadPacket<Packet>(k);
1016 kernel.packet[7] = dm7.template loadPacket<Packet>(k);
1020 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1021 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1022 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1023 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1024 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
1025 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
1026 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
1027 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
1028 count += 8 * PacketSize;
1031 for (; k < depth; k++) {
1032 blockB[count + 0] = cj(dm0(k));
1033 blockB[count + 1] = cj(dm1(k));
1034 blockB[count + 2] = cj(dm2(k));
1035 blockB[count + 3] = cj(dm3(k));
1036 blockB[count + 4] = cj(dm4(k));
1037 blockB[count + 5] = cj(dm5(k));
1038 blockB[count + 6] = cj(dm6(k));
1039 blockB[count + 7] = cj(dm7(k));
1043 if (PanelMode) count += 8 * (stride - offset - depth);
1048 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1050 if (PanelMode) count += 4 * offset;
1051 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1052 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1053 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1054 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1057 if ((PacketSize % 4) == 0)
1059 for (; k < peeled_k; k += PacketSize) {
1060 PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
1061 kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1062 kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
1063 kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
1064 kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
1066 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1067 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1068 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1069 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1070 count += 4 * PacketSize;
1073 for (; k < depth; k++) {
1074 blockB[count + 0] = cj(dm0(k));
1075 blockB[count + 1] = cj(dm1(k));
1076 blockB[count + 2] = cj(dm2(k));
1077 blockB[count + 3] = cj(dm3(k));
1081 if (PanelMode) count += 4 * (stride - offset - depth);
1086 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1087 if (PanelMode) count += offset;
1088 const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
1089 for (Index k = 0; k < depth; k++) {
1090 blockB[count] = cj(dm0(k));
1093 if (PanelMode) count += (stride - offset - depth);
1097template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
1098struct gemm_pack_rhs<Scalar, Index, DataMapper, 8,
RowMajor, Conjugate, PanelMode> {
1099 typedef typename packet_traits<Scalar>::type Packet;
1100 typedef typename unpacket_traits<Packet>::half HalfPacket;
1101 typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
1102 typedef typename DataMapper::LinearMapper LinearMapper;
1104 PacketSize = packet_traits<Scalar>::size,
1105 HalfPacketSize = unpacket_traits<HalfPacket>::size,
1106 QuarterPacketSize = unpacket_traits<QuarterPacket>::size
1108 EIGEN_DONT_INLINE
void operator()(Scalar *blockB,
const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
1110 constexpr int nr = 8;
1111 EIGEN_ASM_COMMENT(
"EIGEN PRODUCT PACK RHS ROWMAJOR");
1112 EIGEN_UNUSED_VARIABLE(stride);
1113 EIGEN_UNUSED_VARIABLE(offset);
1114 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
1115 const bool HasHalf = (int)HalfPacketSize < (
int)PacketSize;
1116 const bool HasQuarter = (int)QuarterPacketSize < (
int)HalfPacketSize;
1117 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
1118 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1119 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
1123 for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1125 if (PanelMode) count += 8 * offset;
1126 for (Index k = 0; k < depth; k++) {
1127 if (PacketSize == 8) {
1129 Packet A = rhs.template loadPacket<Packet>(k, j2);
1130 pstoreu(blockB + count, cj.pconj(A));
1131 }
else if (HasHalf && HalfPacketSize == 8) {
1132 HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1133 pstoreu(blockB + count, cj.pconj(A));
1134 }
else if (HasQuarter && QuarterPacketSize == 8) {
1135 QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1136 pstoreu(blockB + count, cj.pconj(A));
1137 }
else if (PacketSize == 4) {
1140 Packet A = rhs.template loadPacket<Packet>(k, j2);
1141 Packet B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
1142 pstoreu(blockB + count, cj.pconj(A));
1143 pstoreu(blockB + count + PacketSize, cj.pconj(B));
1146 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1147 blockB[count + 0] = cj(dm0(0));
1148 blockB[count + 1] = cj(dm0(1));
1149 blockB[count + 2] = cj(dm0(2));
1150 blockB[count + 3] = cj(dm0(3));
1151 blockB[count + 4] = cj(dm0(4));
1152 blockB[count + 5] = cj(dm0(5));
1153 blockB[count + 6] = cj(dm0(6));
1154 blockB[count + 7] = cj(dm0(7));
1159 if (PanelMode) count += 8 * (stride - offset - depth);
1164 for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1166 if (PanelMode) count += 4 * offset;
1167 for (Index k = 0; k < depth; k++) {
1168 if (PacketSize == 4) {
1169 Packet A = rhs.template loadPacket<Packet>(k, j2);
1170 pstoreu(blockB + count, cj.pconj(A));
1171 count += PacketSize;
1172 }
else if (HasHalf && HalfPacketSize == 4) {
1173 HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1174 pstoreu(blockB + count, cj.pconj(A));
1175 count += HalfPacketSize;
1176 }
else if (HasQuarter && QuarterPacketSize == 4) {
1177 QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1178 pstoreu(blockB + count, cj.pconj(A));
1179 count += QuarterPacketSize;
1181 const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1182 blockB[count + 0] = cj(dm0(0));
1183 blockB[count + 1] = cj(dm0(1));
1184 blockB[count + 2] = cj(dm0(2));
1185 blockB[count + 3] = cj(dm0(3));
1190 if (PanelMode) count += 4 * (stride - offset - depth);
1194 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1195 if (PanelMode) count += offset;
1196 for (Index k = 0; k < depth; k++) {
1197 blockB[count] = cj(rhs(k, j2));
1200 if (PanelMode) count += stride - offset - depth;
1205template <
typename Scalar,
typename Index,
typename DataMapper,
int mr,
bool ConjugateLhs,
bool ConjugateRhs>
1206struct gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs> {
1207 EIGEN_ALWAYS_INLINE
void operator()(
const DataMapper &res,
const Scalar *blockA,
const Scalar *blockB, Index rows,
1208 Index depth, Index cols, Scalar alpha, Index strideA = -1, Index strideB = -1,
1209 Index offsetA = 0, Index offsetB = 0);
1212template <
typename Scalar,
typename Index,
typename DataMapper,
int mr,
bool ConjugateLhs,
bool ConjugateRhs>
1213EIGEN_ALWAYS_INLINE
void gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs>::operator()(
1214 const DataMapper &res,
const Scalar *blockA,
const Scalar *blockB, Index rows, Index depth, Index cols,
1215 Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
1216 if (res.incr() == 1) {
1218 gemm_kern_avx512<Scalar, mr, 8, true, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1219 (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1220 strideB, offsetA, offsetB);
1222 gemm_kern_avx512<Scalar, mr, 8, false, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1223 (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1224 strideB, offsetA, offsetB);
1228 gemm_kern_avx512<Scalar, mr, 8, true, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1229 (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1230 strideB, offsetA, offsetB);
1232 gemm_kern_avx512<Scalar, mr, 8, false, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1233 (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1234 strideB, offsetA, offsetB);
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
Namespace containing all symbols from the Eigen library.
Definition Core:137