10#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
11#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
13template <
bool isARowMajor = true>
14EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
15 EIGEN_IF_CONSTEXPR(isARowMajor)
return i * LDA + j;
16 else return i + j * LDA;
60EIGEN_ALWAYS_INLINE
auto remMask(int64_t m) {
61 EIGEN_IF_CONSTEXPR(N == 16) {
return 0xFFFF >> (16 - m); }
62 else EIGEN_IF_CONSTEXPR(N == 8) {
63 return 0xFF >> (8 - m);
65 else EIGEN_IF_CONSTEXPR(N == 4) {
66 return 0x0F >> (4 - m);
71template <
typename Packet>
72EIGEN_ALWAYS_INLINE
void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
75EIGEN_ALWAYS_INLINE
void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
76 __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
77 __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
78 __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
79 __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
80 __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
81 __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
82 __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
83 __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
85 kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
86 kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
87 kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
88 kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
89 kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
90 kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
91 kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
92 kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
94 T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
95 T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
96 T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
97 T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
98 T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
99 T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
100 T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
101 T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
102 T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
103 T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
104 T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
105 T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
106 T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
107 T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
108 T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
109 T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
111 kernel.packet[0] = T0;
112 kernel.packet[1] = T1;
113 kernel.packet[2] = T2;
114 kernel.packet[3] = T3;
115 kernel.packet[4] = T4;
116 kernel.packet[5] = T5;
117 kernel.packet[6] = T6;
118 kernel.packet[7] = T7;
122EIGEN_ALWAYS_INLINE
void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
129template <
typename Scalar>
132 using vec =
typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
133 using vecHalf =
typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
134 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
151 template <
int64_t endN,
int64_t counter,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
152 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
153 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
154 constexpr int64_t counterReverse = endN - counter;
155 constexpr int64_t startN = counterReverse;
157 EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
158 EIGEN_IF_CONSTEXPR(remM) {
160 C_arr + LDC * startN,
161 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
162 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
163 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
164 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
167 pstoreu<Scalar>(C_arr + LDC * startN,
168 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN),
169 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
174 vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
175 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
177 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
178 preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
180 EIGEN_IF_CONSTEXPR(remM) {
182 C_arr + LDC * startN,
183 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
184 preinterpret<vecHalf>(
185 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
186 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
190 C_arr + LDC * startN,
191 padd(ploadu<vecHalf>((
const Scalar *)C_arr + LDC * startN),
192 preinterpret<vecHalf>(
193 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
196 aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
199 template <
int64_t endN,
int64_t counter,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
200 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
201 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
202 EIGEN_UNUSED_VARIABLE(C_arr);
203 EIGEN_UNUSED_VARIABLE(LDC);
204 EIGEN_UNUSED_VARIABLE(zmm);
205 EIGEN_UNUSED_VARIABLE(remM_);
208 template <
int64_t endN,
int64_t unrollN,
int64_t packetIndexOffset,
bool remM>
209 static EIGEN_ALWAYS_INLINE
void storeC(Scalar *C_arr, int64_t LDC,
210 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
212 aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
240 template <
int64_t unrollN,
int64_t packetIndexOffset>
241 static EIGEN_ALWAYS_INLINE
void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
244 constexpr int64_t zmmStride = unrollN / PacketSize;
245 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
246 r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
247 r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
248 r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
249 r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
250 r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
251 r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
252 r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
253 r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
255 zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
256 zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
257 zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
258 zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
259 zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
260 zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
261 zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
262 zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
280template <
typename Scalar>
283 using vec =
typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
284 using vecHalf =
typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
285 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
302 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
303 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
304 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
306 constexpr int64_t counterReverse = endN - counter;
307 constexpr int64_t startN = counterReverse;
309 EIGEN_IF_CONSTEXPR(remM) {
310 ymm.packet[packetIndexOffset + startN] =
311 ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
314 EIGEN_IF_CONSTEXPR(remN_ == 0) {
315 ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB]);
317 else ymm.packet[packetIndexOffset + startN] =
318 ploadu<vecHalf>((
const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
321 aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
324 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
325 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
326 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
328 EIGEN_UNUSED_VARIABLE(B_arr);
329 EIGEN_UNUSED_VARIABLE(LDB);
330 EIGEN_UNUSED_VARIABLE(ymm);
331 EIGEN_UNUSED_VARIABLE(remM_);
340 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remK,
bool remM>
341 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
342 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
343 constexpr int64_t counterReverse = endN - counter;
344 constexpr int64_t startN = counterReverse;
346 EIGEN_IF_CONSTEXPR(remK || remM) {
347 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
348 remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
351 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
354 aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
357 template <
int64_t endN,
int64_t counter,
int64_t packetIndexOffset,
bool remK,
bool remM>
358 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
359 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
360 EIGEN_UNUSED_VARIABLE(B_arr);
361 EIGEN_UNUSED_VARIABLE(LDB);
362 EIGEN_UNUSED_VARIABLE(ymm);
363 EIGEN_UNUSED_VARIABLE(rem_);
372 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remN_>
373 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
374 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
375 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
376 constexpr int64_t counterReverse = endN - counter;
377 constexpr int64_t startN = counterReverse;
378 transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
379 aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
382 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remN_>
383 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
384 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
385 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
386 EIGEN_UNUSED_VARIABLE(B_arr);
387 EIGEN_UNUSED_VARIABLE(LDB);
388 EIGEN_UNUSED_VARIABLE(B_temp);
389 EIGEN_UNUSED_VARIABLE(LDB_);
390 EIGEN_UNUSED_VARIABLE(ymm);
391 EIGEN_UNUSED_VARIABLE(remM_);
400 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remK_>
401 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
402 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
403 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
404 constexpr int64_t counterReverse = endN - counter;
405 constexpr int64_t startN = counterReverse;
407 EIGEN_IF_CONSTEXPR(toTemp) {
408 transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
411 transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
414 aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
417 template <
int64_t endN,
int64_t counter,
bool toTemp,
bool remM,
int64_t remK_>
418 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
419 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
420 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
421 EIGEN_UNUSED_VARIABLE(B_arr);
422 EIGEN_UNUSED_VARIABLE(LDB);
423 EIGEN_UNUSED_VARIABLE(B_temp);
424 EIGEN_UNUSED_VARIABLE(LDB_);
425 EIGEN_UNUSED_VARIABLE(ymm);
426 EIGEN_UNUSED_VARIABLE(remM_);
433 template <
int64_t endN,
int64_t packetIndexOffset,
bool remM,
int64_t remN_>
434 static EIGEN_ALWAYS_INLINE
void loadB(Scalar *B_arr, int64_t LDB,
435 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
437 aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
440 template <
int64_t endN,
int64_t packetIndexOffset,
bool remK,
bool remM>
441 static EIGEN_ALWAYS_INLINE
void storeB(Scalar *B_arr, int64_t LDB,
442 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
444 aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
447 template <
int64_t unrollN,
bool toTemp,
bool remM,
int64_t remN_ = 0>
448 static EIGEN_ALWAYS_INLINE
void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
449 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
451 EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
453 aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
457 template <
int64_t unrollN,
bool toTemp,
bool remM,
int64_t remK_>
458 static EIGEN_ALWAYS_INLINE
void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
459 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
461 aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
464 template <
int64_t packetIndexOffset>
465 static EIGEN_ALWAYS_INLINE
void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
468 PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
469 r.packet[0] = ymm.packet[packetIndexOffset + 0];
470 r.packet[1] = ymm.packet[packetIndexOffset + 1];
471 r.packet[2] = ymm.packet[packetIndexOffset + 2];
472 r.packet[3] = ymm.packet[packetIndexOffset + 3];
473 r.packet[4] = ymm.packet[packetIndexOffset + 4];
474 r.packet[5] = ymm.packet[packetIndexOffset + 5];
475 r.packet[6] = ymm.packet[packetIndexOffset + 6];
476 r.packet[7] = ymm.packet[packetIndexOffset + 7];
478 ymm.packet[packetIndexOffset + 0] = r.packet[0];
479 ymm.packet[packetIndexOffset + 1] = r.packet[1];
480 ymm.packet[packetIndexOffset + 2] = r.packet[2];
481 ymm.packet[packetIndexOffset + 3] = r.packet[3];
482 ymm.packet[packetIndexOffset + 4] = r.packet[4];
483 ymm.packet[packetIndexOffset + 5] = r.packet[5];
484 ymm.packet[packetIndexOffset + 6] = r.packet[6];
485 ymm.packet[packetIndexOffset + 7] = r.packet[7];
488 template <
int64_t unrollN,
bool toTemp,
bool remM>
489 static EIGEN_ALWAYS_INLINE
void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
490 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
492 constexpr int64_t U3 = PacketSize * 3;
493 constexpr int64_t U2 = PacketSize * 2;
494 constexpr int64_t U1 = PacketSize * 1;
502 EIGEN_IF_CONSTEXPR(unrollN == U3) {
504 constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
505 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
506 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
507 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
508 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
509 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
511 EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
512 transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
514 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
515 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
516 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
517 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
521 else EIGEN_IF_CONSTEXPR(unrollN == U2) {
523 constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
524 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
525 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
526 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
527 EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
528 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
530 EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
531 transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
532 &B_temp[maxUBlock], LDB_, ymm, remM_);
533 transB::template transposeLxL<0>(ymm);
534 transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
535 &B_temp[maxUBlock], LDB_, ymm, remM_);
538 else EIGEN_IF_CONSTEXPR(unrollN == U1) {
540 transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
541 transB::template transposeLxL<0>(ymm);
542 EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
543 transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
545 else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
547 transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
548 transB::template transposeLxL<0>(ymm);
549 transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
551 else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
553 transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
554 transB::template transposeLxL<0>(ymm);
555 transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
557 else EIGEN_IF_CONSTEXPR(unrollN == 2) {
559 transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
560 transB::template transposeLxL<0>(ymm);
561 transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
563 else EIGEN_IF_CONSTEXPR(unrollN == 1) {
565 transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
566 transB::template transposeLxL<0>(ymm);
567 transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
584template <
typename Scalar>
587 using vec =
typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
588 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
605 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
606 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
607 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
608 constexpr int64_t counterReverse = endM * endK - counter;
609 constexpr int64_t startM = counterReverse / (endK);
610 constexpr int64_t startK = counterReverse % endK;
612 constexpr int64_t packetIndex = startM * endK + startK;
613 constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
614 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
615 EIGEN_IF_CONSTEXPR(krem) {
616 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
619 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
621 aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
624 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
625 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
626 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
627 EIGEN_UNUSED_VARIABLE(B_arr);
628 EIGEN_UNUSED_VARIABLE(LDB);
629 EIGEN_UNUSED_VARIABLE(RHSInPacket);
630 EIGEN_UNUSED_VARIABLE(rem);
640 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
641 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
642 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
643 constexpr int64_t counterReverse = endM * endK - counter;
644 constexpr int64_t startM = counterReverse / (endK);
645 constexpr int64_t startK = counterReverse % endK;
647 constexpr int64_t packetIndex = startM * endK + startK;
648 constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
649 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
650 EIGEN_IF_CONSTEXPR(krem) {
651 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
654 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
656 aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
659 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
int64_t counter,
bool krem>
660 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
661 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
662 EIGEN_UNUSED_VARIABLE(B_arr);
663 EIGEN_UNUSED_VARIABLE(LDB);
664 EIGEN_UNUSED_VARIABLE(RHSInPacket);
665 EIGEN_UNUSED_VARIABLE(rem);
676 template <
int64_t currM,
int64_t endK,
int64_t counter>
677 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
678 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
679 constexpr int64_t counterReverse = endK - counter;
680 constexpr int64_t startK = counterReverse;
682 constexpr int64_t packetIndex = currM * endK + startK;
683 RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
684 aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
687 template <
int64_t currM,
int64_t endK,
int64_t counter>
688 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
689 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
690 EIGEN_UNUSED_VARIABLE(RHSInPacket);
691 EIGEN_UNUSED_VARIABLE(AInPacket);
701 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
702 int64_t counter, int64_t currentM>
703 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
704 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
705 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
706 constexpr int64_t counterReverse = (endM - initM) * endK - counter;
707 constexpr int64_t startM = initM + counterReverse / (endK);
708 constexpr int64_t startK = counterReverse % endK;
711 constexpr int64_t packetIndex = startM * endK + startK;
712 EIGEN_IF_CONSTEXPR(currentM > 0) {
713 RHSInPacket.packet[packetIndex] =
714 pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
715 RHSInPacket.packet[packetIndex]);
718 EIGEN_IF_CONSTEXPR(startK == endK - 1) {
720 EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
723 EIGEN_IF_CONSTEXPR(isFWDSolve)
724 AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
725 else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
729 EIGEN_IF_CONSTEXPR(isFWDSolve)
730 AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
731 else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
735 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
736 A_arr, LDA, RHSInPacket, AInPacket);
739 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
740 int64_t counter, int64_t currentM>
741 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
742 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
743 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
744 EIGEN_UNUSED_VARIABLE(A_arr);
745 EIGEN_UNUSED_VARIABLE(LDA);
746 EIGEN_UNUSED_VARIABLE(RHSInPacket);
747 EIGEN_UNUSED_VARIABLE(AInPacket);
756 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t endM,
int64_t counter,
int64_t numK>
757 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
758 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
759 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
760 constexpr int64_t counterReverse = endM - counter;
761 constexpr int64_t startM = counterReverse;
763 constexpr int64_t currentM = startM;
770 EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
771 trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
776 trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
780 EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
781 trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
783 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
787 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
788 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
789 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
790 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
791 EIGEN_UNUSED_VARIABLE(A_arr);
792 EIGEN_UNUSED_VARIABLE(LDA);
793 EIGEN_UNUSED_VARIABLE(RHSInPacket);
794 EIGEN_UNUSED_VARIABLE(AInPacket);
805 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
bool krem = false>
806 static EIGEN_ALWAYS_INLINE
void loadRHS(Scalar *B_arr, int64_t LDB,
807 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
808 aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
815 template <
bool isFWDSolve,
int64_t endM,
int64_t endK,
bool krem = false>
816 static EIGEN_ALWAYS_INLINE
void storeRHS(Scalar *B_arr, int64_t LDB,
817 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
818 aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
824 template <
int64_t currM,
int64_t endK>
825 static EIGEN_ALWAYS_INLINE
void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
826 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
827 aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
834 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
836 static EIGEN_ALWAYS_INLINE
void updateRHS(Scalar *A_arr, int64_t LDA,
837 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
838 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
839 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
840 A_arr, LDA, RHSInPacket, AInPacket);
849 template <
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag,
int64_t endM,
int64_t numK>
850 static EIGEN_ALWAYS_INLINE
void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
851 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
852 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
853 static_assert(numK >= 1 && numK <= 3,
"numK out of range");
854 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
863template <
typename Scalar,
bool isAdd>
866 using vec =
typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
867 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
885 template <
int64_t endM,
int64_t endN,
int64_t counter>
886 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
887 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
888 constexpr int64_t counterReverse = endM * endN - counter;
889 constexpr int64_t startM = counterReverse / (endN);
890 constexpr int64_t startN = counterReverse % endN;
892 zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
893 aux_setzero<endM, endN, counter - 1>(zmm);
896 template <
int64_t endM,
int64_t endN,
int64_t counter>
897 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
898 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
899 EIGEN_UNUSED_VARIABLE(zmm);
909 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
910 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
911 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
912 EIGEN_UNUSED_VARIABLE(rem_);
913 constexpr int64_t counterReverse = endM * endN - counter;
914 constexpr int64_t startM = counterReverse / (endN);
915 constexpr int64_t startN = counterReverse % endN;
917 EIGEN_IF_CONSTEXPR(rem)
918 zmm.packet[startN * endM + startM] =
919 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
920 zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
921 else zmm.packet[startN * endM + startM] =
922 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
923 aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
926 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
927 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
928 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
929 EIGEN_UNUSED_VARIABLE(C_arr);
930 EIGEN_UNUSED_VARIABLE(LDC);
931 EIGEN_UNUSED_VARIABLE(zmm);
932 EIGEN_UNUSED_VARIABLE(rem_);
942 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
943 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
944 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
945 EIGEN_UNUSED_VARIABLE(rem_);
946 constexpr int64_t counterReverse = endM * endN - counter;
947 constexpr int64_t startM = counterReverse / (endN);
948 constexpr int64_t startN = counterReverse % endN;
950 EIGEN_IF_CONSTEXPR(rem)
951 pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
952 remMask<PacketSize>(rem_));
953 else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
954 aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
957 template <
int64_t endM,
int64_t endN,
int64_t counter,
bool rem>
958 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
959 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
960 EIGEN_UNUSED_VARIABLE(C_arr);
961 EIGEN_UNUSED_VARIABLE(LDC);
962 EIGEN_UNUSED_VARIABLE(zmm);
963 EIGEN_UNUSED_VARIABLE(rem_);
972 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
int64_t counter,
bool rem>
973 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
974 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
975 EIGEN_UNUSED_VARIABLE(rem_);
976 constexpr int64_t counterReverse = endL - counter;
977 constexpr int64_t startL = counterReverse;
979 EIGEN_IF_CONSTEXPR(rem)
980 zmm.packet[unrollM * unrollN + startL] =
981 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
982 else zmm.packet[unrollM * unrollN + startL] =
983 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
985 aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
988 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
int64_t counter,
bool rem>
989 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
990 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
991 EIGEN_UNUSED_VARIABLE(B_t);
992 EIGEN_UNUSED_VARIABLE(LDB);
993 EIGEN_UNUSED_VARIABLE(zmm);
994 EIGEN_UNUSED_VARIABLE(rem_);
1003 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t counter,
int64_t numLoad>
1004 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
1005 Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1006 constexpr int64_t counterReverse = endB - counter;
1007 constexpr int64_t startB = counterReverse;
1009 zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
1011 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
1014 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t counter,
int64_t numLoad>
1015 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
1016 Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1017 EIGEN_UNUSED_VARIABLE(A_t);
1018 EIGEN_UNUSED_VARIABLE(LDA);
1019 EIGEN_UNUSED_VARIABLE(zmm);
1029 template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1030 int64_t numBCast,
bool rem>
1031 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
1032 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1033 EIGEN_UNUSED_VARIABLE(rem_);
1034 if ((numLoad / endM + currK < unrollK)) {
1035 constexpr int64_t counterReverse = endM - counter;
1036 constexpr int64_t startM = counterReverse;
1038 EIGEN_IF_CONSTEXPR(rem) {
1039 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1040 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
1043 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1044 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
1047 aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1051 template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1052 int64_t numBCast,
bool rem>
1053 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
1054 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1055 EIGEN_UNUSED_VARIABLE(B_t);
1056 EIGEN_UNUSED_VARIABLE(LDB);
1057 EIGEN_UNUSED_VARIABLE(zmm);
1058 EIGEN_UNUSED_VARIABLE(rem_);
1069 template <
bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1070 int64_t numBCast,
bool rem>
1071 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
1072 Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1074 EIGEN_UNUSED_VARIABLE(rem_);
1075 constexpr int64_t counterReverse = endM * endN * endK - counter;
1076 constexpr int startK = counterReverse / (endM * endN);
1077 constexpr int startN = (counterReverse / (endM)) % endN;
1078 constexpr int startM = counterReverse % endM;
1080 EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
1081 gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
1082 gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
1087 EIGEN_IF_CONSTEXPR(isAdd) {
1088 zmm.packet[startN * endM + startM] =
1089 pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1090 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1093 zmm.packet[startN * endM + startM] =
1094 pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1095 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1098 EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
1099 zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
1100 (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
1105 EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
1106 gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1108 aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
1111 template <
bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1112 int64_t numBCast,
bool rem>
1113 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
1114 Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1116 EIGEN_UNUSED_VARIABLE(B_t);
1117 EIGEN_UNUSED_VARIABLE(A_t);
1118 EIGEN_UNUSED_VARIABLE(LDB);
1119 EIGEN_UNUSED_VARIABLE(LDA);
1120 EIGEN_UNUSED_VARIABLE(zmm);
1121 EIGEN_UNUSED_VARIABLE(rem_);
1128 template <
int64_t endM,
int64_t endN>
1129 static EIGEN_ALWAYS_INLINE
void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1130 aux_setzero<endM, endN, endM * endN>(zmm);
1136 template <
int64_t endM,
int64_t endN,
bool rem = false>
1137 static EIGEN_ALWAYS_INLINE
void updateC(Scalar *C_arr, int64_t LDC,
1138 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1140 EIGEN_UNUSED_VARIABLE(rem_);
1141 aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1144 template <
int64_t endM,
int64_t endN,
bool rem = false>
1145 static EIGEN_ALWAYS_INLINE
void storeC(Scalar *C_arr, int64_t LDC,
1146 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1148 EIGEN_UNUSED_VARIABLE(rem_);
1149 aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1155 template <
int64_t unrollM,
int64_t unrollN,
int64_t endL,
bool rem>
1156 static EIGEN_ALWAYS_INLINE
void startLoadB(Scalar *B_t, int64_t LDB,
1157 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1159 EIGEN_UNUSED_VARIABLE(rem_);
1160 aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
1166 template <
bool isARowMajor,
int64_t unrollM,
int64_t unrollN,
int64_t endB,
int64_t numLoad>
1167 static EIGEN_ALWAYS_INLINE
void startBCastA(Scalar *A_t, int64_t LDA,
1168 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1169 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
1175 template <
int64_t endM,
int64_t unrollN,
int64_t currK,
int64_t unrollK,
int64_t numLoad,
int64_t numBCast,
bool rem>
1176 static EIGEN_ALWAYS_INLINE
void loadB(Scalar *B_t, int64_t LDB,
1177 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1179 EIGEN_UNUSED_VARIABLE(rem_);
1180 aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1206 template <
bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
1208 static EIGEN_ALWAYS_INLINE
void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
1209 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1211 EIGEN_UNUSED_VARIABLE(rem_);
1212 aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,