Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
TrsmUnrolls.inc
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2022 Intel Corporation
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
11#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
12
13template <bool isARowMajor = true>
14EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
15 EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
16 else return i + j * LDA;
17}
18
57namespace unrolls {
58
59template <int64_t N>
60EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
61 EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
62 else EIGEN_IF_CONSTEXPR(N == 8) {
63 return 0xFF >> (8 - m);
64 }
65 else EIGEN_IF_CONSTEXPR(N == 4) {
66 return 0x0F >> (4 - m);
67 }
68 return 0;
69}
70
71template <typename Packet>
72EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
73
74template <>
75EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
76 __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
77 __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
78 __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
79 __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
80 __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
81 __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
82 __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
83 __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
84
85 kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
86 kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
87 kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
88 kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
89 kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
90 kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
91 kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
92 kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
93
94 T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
95 T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
96 T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
97 T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
98 T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
99 T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
100 T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
101 T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
102 T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
103 T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
104 T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
105 T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
106 T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
107 T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
108 T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
109 T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
110
111 kernel.packet[0] = T0;
112 kernel.packet[1] = T1;
113 kernel.packet[2] = T2;
114 kernel.packet[3] = T3;
115 kernel.packet[4] = T4;
116 kernel.packet[5] = T5;
117 kernel.packet[6] = T6;
118 kernel.packet[7] = T7;
119}
120
121template <>
122EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
123 ptranspose(kernel);
124}
125
126/***
127 * Unrolls for tranposed C stores
128 */
129template <typename Scalar>
130class trans {
131 public:
132 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
133 using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
134 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
135
136 /***********************************
137 * Auxillary Functions for:
138 * - storeC
139 ***********************************
140 */
141
151 template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
152 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
153 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
154 constexpr int64_t counterReverse = endN - counter;
155 constexpr int64_t startN = counterReverse;
156
157 EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
158 EIGEN_IF_CONSTEXPR(remM) {
159 pstoreu<Scalar>(
160 C_arr + LDC * startN,
161 padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
162 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
163 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
164 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
165 }
166 else {
167 pstoreu<Scalar>(C_arr + LDC * startN,
168 padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
169 preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
170 }
171 }
172 else { // This block is only needed for fp32 case
173 // Reinterpret as __m512 for _mm512_shuffle_f32x4
174 vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
175 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
176 // Swap lower and upper half of avx register.
177 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
178 preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
179
180 EIGEN_IF_CONSTEXPR(remM) {
181 pstoreu<Scalar>(
182 C_arr + LDC * startN,
183 padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
184 preinterpret<vecHalf>(
185 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
186 remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
187 }
188 else {
189 pstoreu<Scalar>(
190 C_arr + LDC * startN,
191 padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
192 preinterpret<vecHalf>(
193 zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
194 }
195 }
196 aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
197 }
198
199 template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
200 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
201 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
202 EIGEN_UNUSED_VARIABLE(C_arr);
203 EIGEN_UNUSED_VARIABLE(LDC);
204 EIGEN_UNUSED_VARIABLE(zmm);
205 EIGEN_UNUSED_VARIABLE(remM_);
206 }
207
208 template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
209 static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
210 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
211 int64_t remM_ = 0) {
212 aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
213 }
214
240 template <int64_t unrollN, int64_t packetIndexOffset>
241 static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
242 // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
243 // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
244 constexpr int64_t zmmStride = unrollN / PacketSize;
245 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
246 r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
247 r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
248 r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
249 r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
250 r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
251 r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
252 r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
253 r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
254 trans8x8blocks(r);
255 zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
256 zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
257 zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
258 zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
259 zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
260 zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
261 zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
262 zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
263 }
264};
265
280template <typename Scalar>
281class transB {
282 public:
283 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
284 using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
285 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
286
287 /***********************************
288 * Auxillary Functions for:
289 * - loadB
290 * - storeB
291 * - loadBBlock
292 * - storeBBlock
293 ***********************************
294 */
295
302 template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
303 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
304 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
305 int64_t remM_ = 0) {
306 constexpr int64_t counterReverse = endN - counter;
307 constexpr int64_t startN = counterReverse;
308
309 EIGEN_IF_CONSTEXPR(remM) {
310 ymm.packet[packetIndexOffset + startN] =
311 ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
312 }
313 else {
314 EIGEN_IF_CONSTEXPR(remN_ == 0) {
315 ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
316 }
317 else ymm.packet[packetIndexOffset + startN] =
318 ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
319 }
320
321 aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
322 }
323
324 template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
325 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
326 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
327 int64_t remM_ = 0) {
328 EIGEN_UNUSED_VARIABLE(B_arr);
329 EIGEN_UNUSED_VARIABLE(LDB);
330 EIGEN_UNUSED_VARIABLE(ymm);
331 EIGEN_UNUSED_VARIABLE(remM_);
332 }
333
340 template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
341 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
342 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
343 constexpr int64_t counterReverse = endN - counter;
344 constexpr int64_t startN = counterReverse;
345
346 EIGEN_IF_CONSTEXPR(remK || remM) {
347 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
348 remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
349 }
350 else {
351 pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
352 }
353
354 aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
355 }
356
357 template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
358 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
359 Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
360 EIGEN_UNUSED_VARIABLE(B_arr);
361 EIGEN_UNUSED_VARIABLE(LDB);
362 EIGEN_UNUSED_VARIABLE(ymm);
363 EIGEN_UNUSED_VARIABLE(rem_);
364 }
365
372 template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
373 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
374 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
375 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
376 constexpr int64_t counterReverse = endN - counter;
377 constexpr int64_t startN = counterReverse;
378 transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
379 aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
380 }
381
382 template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
383 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
384 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
385 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
386 EIGEN_UNUSED_VARIABLE(B_arr);
387 EIGEN_UNUSED_VARIABLE(LDB);
388 EIGEN_UNUSED_VARIABLE(B_temp);
389 EIGEN_UNUSED_VARIABLE(LDB_);
390 EIGEN_UNUSED_VARIABLE(ymm);
391 EIGEN_UNUSED_VARIABLE(remM_);
392 }
393
400 template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
401 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
402 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
403 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
404 constexpr int64_t counterReverse = endN - counter;
405 constexpr int64_t startN = counterReverse;
406
407 EIGEN_IF_CONSTEXPR(toTemp) {
408 transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
409 }
410 else {
411 transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
412 ymm, remM_);
413 }
414 aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
415 }
416
417 template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
418 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
419 Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
420 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
421 EIGEN_UNUSED_VARIABLE(B_arr);
422 EIGEN_UNUSED_VARIABLE(LDB);
423 EIGEN_UNUSED_VARIABLE(B_temp);
424 EIGEN_UNUSED_VARIABLE(LDB_);
425 EIGEN_UNUSED_VARIABLE(ymm);
426 EIGEN_UNUSED_VARIABLE(remM_);
427 }
428
429 /********************************************************
430 * Wrappers for aux_XXXX to hide counter parameter
431 ********************************************************/
432
433 template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
434 static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
435 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
436 int64_t remM_ = 0) {
437 aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
438 }
439
440 template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
441 static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
442 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
443 int64_t rem_ = 0) {
444 aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
445 }
446
447 template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
448 static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
449 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
450 int64_t remM_ = 0) {
451 EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
452 else {
453 aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
454 }
455 }
456
457 template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
458 static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
459 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
460 int64_t remM_ = 0) {
461 aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
462 }
463
464 template <int64_t packetIndexOffset>
465 static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
466 // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
467 // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
468 PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
469 r.packet[0] = ymm.packet[packetIndexOffset + 0];
470 r.packet[1] = ymm.packet[packetIndexOffset + 1];
471 r.packet[2] = ymm.packet[packetIndexOffset + 2];
472 r.packet[3] = ymm.packet[packetIndexOffset + 3];
473 r.packet[4] = ymm.packet[packetIndexOffset + 4];
474 r.packet[5] = ymm.packet[packetIndexOffset + 5];
475 r.packet[6] = ymm.packet[packetIndexOffset + 6];
476 r.packet[7] = ymm.packet[packetIndexOffset + 7];
477 ptranspose(r);
478 ymm.packet[packetIndexOffset + 0] = r.packet[0];
479 ymm.packet[packetIndexOffset + 1] = r.packet[1];
480 ymm.packet[packetIndexOffset + 2] = r.packet[2];
481 ymm.packet[packetIndexOffset + 3] = r.packet[3];
482 ymm.packet[packetIndexOffset + 4] = r.packet[4];
483 ymm.packet[packetIndexOffset + 5] = r.packet[5];
484 ymm.packet[packetIndexOffset + 6] = r.packet[6];
485 ymm.packet[packetIndexOffset + 7] = r.packet[7];
486 }
487
488 template <int64_t unrollN, bool toTemp, bool remM>
489 static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
490 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
491 int64_t remM_ = 0) {
492 constexpr int64_t U3 = PacketSize * 3;
493 constexpr int64_t U2 = PacketSize * 2;
494 constexpr int64_t U1 = PacketSize * 1;
502 EIGEN_IF_CONSTEXPR(unrollN == U3) {
503 // load LxU3 B col major, transpose LxU3 row major
504 constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
505 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
506 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
507 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
508 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
509 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
510
511 EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
512 transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
513 ymm, remM_);
514 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
515 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
516 transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
517 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
518 ymm, remM_);
519 }
520 }
521 else EIGEN_IF_CONSTEXPR(unrollN == U2) {
522 // load LxU2 B col major, transpose LxU2 row major
523 constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
524 transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
525 transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
526 transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
527 EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
528 transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
529
530 EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
531 transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
532 &B_temp[maxUBlock], LDB_, ymm, remM_);
533 transB::template transposeLxL<0>(ymm);
534 transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
535 &B_temp[maxUBlock], LDB_, ymm, remM_);
536 }
537 }
538 else EIGEN_IF_CONSTEXPR(unrollN == U1) {
539 // load LxU1 B col major, transpose LxU1 row major
540 transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
541 transB::template transposeLxL<0>(ymm);
542 EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
543 transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
544 }
545 else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
546 // load Lx4 B col major, transpose Lx4 row major
547 transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
548 transB::template transposeLxL<0>(ymm);
549 transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
550 }
551 else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
552 // load Lx4 B col major, transpose Lx4 row major
553 transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
554 transB::template transposeLxL<0>(ymm);
555 transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
556 }
557 else EIGEN_IF_CONSTEXPR(unrollN == 2) {
558 // load Lx2 B col major, transpose Lx2 row major
559 transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
560 transB::template transposeLxL<0>(ymm);
561 transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
562 }
563 else EIGEN_IF_CONSTEXPR(unrollN == 1) {
564 // load Lx1 B col major, transpose Lx1 row major
565 transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
566 transB::template transposeLxL<0>(ymm);
567 transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
568 }
569 }
570};
571
584template <typename Scalar>
585class trsm {
586 public:
587 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
588 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
589
590 /***********************************
591 * Auxillary Functions for:
592 * - loadRHS
593 * - storeRHS
594 * - divRHSByDiag
595 * - updateRHS
596 * - triSolveMicroKernel
597 ************************************/
605 template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
606 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
607 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
608 constexpr int64_t counterReverse = endM * endK - counter;
609 constexpr int64_t startM = counterReverse / (endK);
610 constexpr int64_t startK = counterReverse % endK;
611
612 constexpr int64_t packetIndex = startM * endK + startK;
613 constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
614 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
615 EIGEN_IF_CONSTEXPR(krem) {
616 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
617 }
618 else {
619 RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
620 }
621 aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
622 }
623
624 template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
625 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
626 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
627 EIGEN_UNUSED_VARIABLE(B_arr);
628 EIGEN_UNUSED_VARIABLE(LDB);
629 EIGEN_UNUSED_VARIABLE(RHSInPacket);
630 EIGEN_UNUSED_VARIABLE(rem);
631 }
632
640 template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
641 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
642 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
643 constexpr int64_t counterReverse = endM * endK - counter;
644 constexpr int64_t startM = counterReverse / (endK);
645 constexpr int64_t startK = counterReverse % endK;
646
647 constexpr int64_t packetIndex = startM * endK + startK;
648 constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
649 const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
650 EIGEN_IF_CONSTEXPR(krem) {
651 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
652 }
653 else {
654 pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
655 }
656 aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
657 }
658
659 template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
660 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
661 Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
662 EIGEN_UNUSED_VARIABLE(B_arr);
663 EIGEN_UNUSED_VARIABLE(LDB);
664 EIGEN_UNUSED_VARIABLE(RHSInPacket);
665 EIGEN_UNUSED_VARIABLE(rem);
666 }
667
676 template <int64_t currM, int64_t endK, int64_t counter>
677 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
678 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
679 constexpr int64_t counterReverse = endK - counter;
680 constexpr int64_t startK = counterReverse;
681
682 constexpr int64_t packetIndex = currM * endK + startK;
683 RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
684 aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
685 }
686
687 template <int64_t currM, int64_t endK, int64_t counter>
688 static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
689 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
690 EIGEN_UNUSED_VARIABLE(RHSInPacket);
691 EIGEN_UNUSED_VARIABLE(AInPacket);
692 }
693
701 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
702 int64_t counter, int64_t currentM>
703 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
704 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
705 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
706 constexpr int64_t counterReverse = (endM - initM) * endK - counter;
707 constexpr int64_t startM = initM + counterReverse / (endK);
708 constexpr int64_t startK = counterReverse % endK;
709
710 // For each row of A, first update all corresponding RHS
711 constexpr int64_t packetIndex = startM * endK + startK;
712 EIGEN_IF_CONSTEXPR(currentM > 0) {
713 RHSInPacket.packet[packetIndex] =
714 pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
715 RHSInPacket.packet[packetIndex]);
716 }
717
718 EIGEN_IF_CONSTEXPR(startK == endK - 1) {
719 // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
720 EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
721 // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
722 // This will be used in divRHSByDiag
723 EIGEN_IF_CONSTEXPR(isFWDSolve)
724 AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
725 else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
726 }
727 else {
728 // Broadcast next off diagonal element of A
729 EIGEN_IF_CONSTEXPR(isFWDSolve)
730 AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
731 else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
732 }
733 }
734
735 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
736 A_arr, LDA, RHSInPacket, AInPacket);
737 }
738
739 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
740 int64_t counter, int64_t currentM>
741 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
742 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
743 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
744 EIGEN_UNUSED_VARIABLE(A_arr);
745 EIGEN_UNUSED_VARIABLE(LDA);
746 EIGEN_UNUSED_VARIABLE(RHSInPacket);
747 EIGEN_UNUSED_VARIABLE(AInPacket);
748 }
749
756 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
757 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
758 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
759 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
760 constexpr int64_t counterReverse = endM - counter;
761 constexpr int64_t startM = counterReverse;
762
763 constexpr int64_t currentM = startM;
764 // Divides the right-hand side in row startM, by digonal value of A
765 // broadcasted to AInPacket.packet[startM-1] in the previous iteration.
766 //
767 // Without "if constexpr" the compiler instantiates the case <-1, numK>
768 // this is handled with enable_if to prevent out-of-bound warnings
769 // from the compiler
770 EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
771 trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
772
773 // After division, the rhs corresponding to subsequent rows of A can be partially updated
774 // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
775 // to be used in the next iteration.
776 trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
777 AInPacket);
778
779 // Handle division for the RHS corresponding to the final row of A.
780 EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
781 trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
782
783 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
784 AInPacket);
785 }
786
787 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
788 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
789 Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
790 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
791 EIGEN_UNUSED_VARIABLE(A_arr);
792 EIGEN_UNUSED_VARIABLE(LDA);
793 EIGEN_UNUSED_VARIABLE(RHSInPacket);
794 EIGEN_UNUSED_VARIABLE(AInPacket);
795 }
796
797 /********************************************************
798 * Wrappers for aux_XXXX to hide counter parameter
799 ********************************************************/
800
805 template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
806 static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
807 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
808 aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
809 }
810
815 template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
816 static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
817 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
818 aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
819 }
820
824 template <int64_t currM, int64_t endK>
825 static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
826 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
827 aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
828 }
829
834 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
835 int64_t currentM>
836 static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
837 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
838 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
839 aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
840 A_arr, LDA, RHSInPacket, AInPacket);
841 }
842
849 template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
850 static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
851 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
852 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
853 static_assert(numK >= 1 && numK <= 3, "numK out of range");
854 aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
855 }
856};
857
863template <typename Scalar, bool isAdd>
864class gemm {
865 public:
866 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
867 static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
868
869 /***********************************
870 * Auxillary Functions for:
871 * - setzero
872 * - updateC
873 * - storeC
874 * - startLoadB
875 * - triSolveMicroKernel
876 ************************************/
877
885 template <int64_t endM, int64_t endN, int64_t counter>
886 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
887 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
888 constexpr int64_t counterReverse = endM * endN - counter;
889 constexpr int64_t startM = counterReverse / (endN);
890 constexpr int64_t startN = counterReverse % endN;
891
892 zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
893 aux_setzero<endM, endN, counter - 1>(zmm);
894 }
895
896 template <int64_t endM, int64_t endN, int64_t counter>
897 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
898 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
899 EIGEN_UNUSED_VARIABLE(zmm);
900 }
901
909 template <int64_t endM, int64_t endN, int64_t counter, bool rem>
910 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
911 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
912 EIGEN_UNUSED_VARIABLE(rem_);
913 constexpr int64_t counterReverse = endM * endN - counter;
914 constexpr int64_t startM = counterReverse / (endN);
915 constexpr int64_t startN = counterReverse % endN;
916
917 EIGEN_IF_CONSTEXPR(rem)
918 zmm.packet[startN * endM + startM] =
919 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
920 zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
921 else zmm.packet[startN * endM + startM] =
922 padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
923 aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
924 }
925
926 template <int64_t endM, int64_t endN, int64_t counter, bool rem>
927 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
928 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
929 EIGEN_UNUSED_VARIABLE(C_arr);
930 EIGEN_UNUSED_VARIABLE(LDC);
931 EIGEN_UNUSED_VARIABLE(zmm);
932 EIGEN_UNUSED_VARIABLE(rem_);
933 }
934
942 template <int64_t endM, int64_t endN, int64_t counter, bool rem>
943 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
944 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
945 EIGEN_UNUSED_VARIABLE(rem_);
946 constexpr int64_t counterReverse = endM * endN - counter;
947 constexpr int64_t startM = counterReverse / (endN);
948 constexpr int64_t startN = counterReverse % endN;
949
950 EIGEN_IF_CONSTEXPR(rem)
951 pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
952 remMask<PacketSize>(rem_));
953 else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
954 aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
955 }
956
957 template <int64_t endM, int64_t endN, int64_t counter, bool rem>
958 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
959 Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
960 EIGEN_UNUSED_VARIABLE(C_arr);
961 EIGEN_UNUSED_VARIABLE(LDC);
962 EIGEN_UNUSED_VARIABLE(zmm);
963 EIGEN_UNUSED_VARIABLE(rem_);
964 }
965
972 template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
973 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
974 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
975 EIGEN_UNUSED_VARIABLE(rem_);
976 constexpr int64_t counterReverse = endL - counter;
977 constexpr int64_t startL = counterReverse;
978
979 EIGEN_IF_CONSTEXPR(rem)
980 zmm.packet[unrollM * unrollN + startL] =
981 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
982 else zmm.packet[unrollM * unrollN + startL] =
983 ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
984
985 aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
986 }
987
988 template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
989 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
990 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
991 EIGEN_UNUSED_VARIABLE(B_t);
992 EIGEN_UNUSED_VARIABLE(LDB);
993 EIGEN_UNUSED_VARIABLE(zmm);
994 EIGEN_UNUSED_VARIABLE(rem_);
995 }
996
1003 template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1004 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
1005 Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1006 constexpr int64_t counterReverse = endB - counter;
1007 constexpr int64_t startB = counterReverse;
1008
1009 zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
1010
1011 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
1012 }
1013
1014 template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1015 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
1016 Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1017 EIGEN_UNUSED_VARIABLE(A_t);
1018 EIGEN_UNUSED_VARIABLE(LDA);
1019 EIGEN_UNUSED_VARIABLE(zmm);
1020 }
1021
1029 template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1030 int64_t numBCast, bool rem>
1031 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
1032 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1033 EIGEN_UNUSED_VARIABLE(rem_);
1034 if ((numLoad / endM + currK < unrollK)) {
1035 constexpr int64_t counterReverse = endM - counter;
1036 constexpr int64_t startM = counterReverse;
1037
1038 EIGEN_IF_CONSTEXPR(rem) {
1039 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1040 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
1041 }
1042 else {
1043 zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1044 ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
1045 }
1046
1047 aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1048 }
1049 }
1050
1051 template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1052 int64_t numBCast, bool rem>
1053 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
1054 Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1055 EIGEN_UNUSED_VARIABLE(B_t);
1056 EIGEN_UNUSED_VARIABLE(LDB);
1057 EIGEN_UNUSED_VARIABLE(zmm);
1058 EIGEN_UNUSED_VARIABLE(rem_);
1059 }
1060
1069 template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1070 int64_t numBCast, bool rem>
1071 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
1072 Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1073 int64_t rem_ = 0) {
1074 EIGEN_UNUSED_VARIABLE(rem_);
1075 constexpr int64_t counterReverse = endM * endN * endK - counter;
1076 constexpr int startK = counterReverse / (endM * endN);
1077 constexpr int startN = (counterReverse / (endM)) % endN;
1078 constexpr int startM = counterReverse % endM;
1079
1080 EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
1081 gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
1082 gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
1083 }
1084
1085 {
1086 // Interleave FMA and Bcast
1087 EIGEN_IF_CONSTEXPR(isAdd) {
1088 zmm.packet[startN * endM + startM] =
1089 pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1090 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1091 }
1092 else {
1093 zmm.packet[startN * endM + startM] =
1094 pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1095 zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1096 }
1097 // Bcast
1098 EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
1099 zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
1100 (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
1101 }
1102 }
1103
1104 // We have updated all accumlators, time to load next set of B's
1105 EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
1106 gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1107 }
1108 aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
1109 }
1110
1111 template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1112 int64_t numBCast, bool rem>
1113 static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
1114 Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1115 int64_t rem_ = 0) {
1116 EIGEN_UNUSED_VARIABLE(B_t);
1117 EIGEN_UNUSED_VARIABLE(A_t);
1118 EIGEN_UNUSED_VARIABLE(LDB);
1119 EIGEN_UNUSED_VARIABLE(LDA);
1120 EIGEN_UNUSED_VARIABLE(zmm);
1121 EIGEN_UNUSED_VARIABLE(rem_);
1122 }
1123
1124 /********************************************************
1125 * Wrappers for aux_XXXX to hide counter parameter
1126 ********************************************************/
1127
1128 template <int64_t endM, int64_t endN>
1129 static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1130 aux_setzero<endM, endN, endM * endN>(zmm);
1131 }
1132
1136 template <int64_t endM, int64_t endN, bool rem = false>
1137 static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
1138 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1139 int64_t rem_ = 0) {
1140 EIGEN_UNUSED_VARIABLE(rem_);
1141 aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1142 }
1143
1144 template <int64_t endM, int64_t endN, bool rem = false>
1145 static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
1146 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1147 int64_t rem_ = 0) {
1148 EIGEN_UNUSED_VARIABLE(rem_);
1149 aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1150 }
1151
1155 template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
1156 static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
1157 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1158 int64_t rem_ = 0) {
1159 EIGEN_UNUSED_VARIABLE(rem_);
1160 aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
1161 }
1162
1166 template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
1167 static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
1168 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1169 aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
1170 }
1171
1175 template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
1176 static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
1177 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1178 int64_t rem_ = 0) {
1179 EIGEN_UNUSED_VARIABLE(rem_);
1180 aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1181 }
1182
1206 template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
1207 bool rem = false>
1208 static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
1209 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1210 int64_t rem_ = 0) {
1211 EIGEN_UNUSED_VARIABLE(rem_);
1212 aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
1213 rem_);
1214 }
1215};
1216} // namespace unrolls
1217
1218#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H