Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
TrsmKernel.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2022 Intel Corporation
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
11#define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
12
13// IWYU pragma: private
14#include "../../InternalHeaderCheck.h"
15
16#if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
17#define EIGEN_USE_AVX512_TRSM_KERNELS 1
18#endif
19
20// TRSM kernels currently unconditionally rely on malloc with AVX512.
21// Disable them if malloc is explicitly disabled at compile-time.
22#ifdef EIGEN_NO_MALLOC
23#undef EIGEN_USE_AVX512_TRSM_KERNELS
24#define EIGEN_USE_AVX512_TRSM_KERNELS 0
25#endif
26
27#if EIGEN_USE_AVX512_TRSM_KERNELS
28#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
29#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
30#endif
31#if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
32#define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
33#endif
34#else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
35#define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
36#define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
37#endif
38
39// Need this for some std::min calls.
40#ifdef min
41#undef min
42#endif
43
44namespace Eigen {
45namespace internal {
46
47#define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
48#define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
49#define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
50#define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
51#define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
52typedef Packet16f vecFullFloat;
53typedef Packet8d vecFullDouble;
54typedef Packet8f vecHalfFloat;
55typedef Packet4d vecHalfDouble;
56
57// Compile-time unrolls are implemented here.
58// Note: this depends on macros and typedefs above.
59#include "TrsmUnrolls.inc"
60
61#if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
78#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
79#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
80#endif
81
82#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
83
84#if EIGEN_USE_AVX512_TRSM_R_KERNELS
85#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
86#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
87#endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
88#endif
89
90#if EIGEN_USE_AVX512_TRSM_L_KERNELS
91#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
92#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
93#endif
94#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
95
96#else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
97#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
98#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
99#endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
100
101template <typename Scalar>
102int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
103 const int64_t U3 = 3 * packet_traits<Scalar>::size;
104 const int64_t MaxNb = 5 * U3;
105 int64_t Nb = std::min(MaxNb, N);
106 double cutoff_d =
107 (((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
108 int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
109 return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
110}
111#else // !(EIGEN_USE_AVX512_TRSM_KERNELS) || !(EIGEN_COMP_CLANG != 0)
112#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
113#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
114#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
115#endif
116
120template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
121EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, Scalar *C_arr,
122 int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
123 EIGEN_UNUSED_VARIABLE(remN_);
124 EIGEN_UNUSED_VARIABLE(remM_);
125 using urolls = unrolls::trans<Scalar>;
126
127 constexpr int64_t U3 = urolls::PacketSize * 3;
128 constexpr int64_t U2 = urolls::PacketSize * 2;
129 constexpr int64_t U1 = urolls::PacketSize * 1;
130
131 static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize");
132 static_assert(unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
133
134 urolls::template transpose<unrollN, 0>(zmm);
135 EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose<unrollN, 2>(zmm);
136 EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose<unrollN, 1>(zmm);
137
138 static_assert((remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1");
139 EIGEN_IF_CONSTEXPR(!remN) {
140 urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
141 EIGEN_IF_CONSTEXPR(unrollN > U1) {
142 constexpr int64_t unrollN_ = std::min(unrollN - U1, U1);
143 urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
144 }
145 EIGEN_IF_CONSTEXPR(unrollN > U2) {
146 constexpr int64_t unrollN_ = std::min(unrollN - U2, U1);
147 urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
148 }
149 }
150 else {
151 EIGEN_IF_CONSTEXPR((std::is_same<Scalar, float>::value)) {
152 // Note: without "if constexpr" this section of code will also be
153 // parsed by the compiler so each of the storeC will still be instantiated.
154 // We use enable_if in aux_storeC to set it to an empty function for
155 // these cases.
156 if (remN_ == 15)
157 urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
158 else if (remN_ == 14)
159 urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
160 else if (remN_ == 13)
161 urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
162 else if (remN_ == 12)
163 urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
164 else if (remN_ == 11)
165 urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
166 else if (remN_ == 10)
167 urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
168 else if (remN_ == 9)
169 urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
170 else if (remN_ == 8)
171 urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
172 else if (remN_ == 7)
173 urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
174 else if (remN_ == 6)
175 urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
176 else if (remN_ == 5)
177 urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
178 else if (remN_ == 4)
179 urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
180 else if (remN_ == 3)
181 urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
182 else if (remN_ == 2)
183 urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
184 else if (remN_ == 1)
185 urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
186 }
187 else {
188 if (remN_ == 7)
189 urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
190 else if (remN_ == 6)
191 urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
192 else if (remN_ == 5)
193 urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
194 else if (remN_ == 4)
195 urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
196 else if (remN_ == 3)
197 urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
198 else if (remN_ == 2)
199 urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
200 else if (remN_ == 1)
201 urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
202 }
203 }
204}
205
220template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
221void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
222 int64_t LDC) {
223 using urolls = unrolls::gemm<Scalar, isAdd>;
224 constexpr int64_t U3 = urolls::PacketSize * 3;
225 constexpr int64_t U2 = urolls::PacketSize * 2;
226 constexpr int64_t U1 = urolls::PacketSize * 1;
227 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
228 int64_t N_ = (N / U3) * U3;
229 int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
230 int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
231 int64_t j = 0;
232 for (; j < N_; j += U3) {
233 constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 3;
234 int64_t i = 0;
235 for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
236 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
237 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
238 urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
239 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
240 urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
241 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
242 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
243 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
244 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
245 }
246 EIGEN_IF_CONSTEXPR(handleKRem) {
247 for (int64_t k = K_; k < K; k++) {
248 urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
249 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
250 B_t += LDB;
251 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
252 else A_t += LDA;
253 }
254 }
255 EIGEN_IF_CONSTEXPR(isCRowMajor) {
256 urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
257 urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
258 }
259 else {
260 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[i + j * LDC], LDC);
261 }
262 }
263 if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
264 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
265 Scalar *B_t = &B_arr[0 * LDB + j];
266 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
267 urolls::template setzero<3, 4>(zmm);
268 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
269 urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
270 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
271 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
272 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
273 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
274 }
275 EIGEN_IF_CONSTEXPR(handleKRem) {
276 for (int64_t k = K_; k < K; k++) {
277 urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
278 B_t, A_t, LDB, LDA, zmm);
279 B_t += LDB;
280 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
281 else A_t += LDA;
282 }
283 }
284 EIGEN_IF_CONSTEXPR(isCRowMajor) {
285 urolls::template updateC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
286 urolls::template storeC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
287 }
288 else {
289 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
290 }
291 i += 4;
292 }
293 if (M - i >= 2) {
294 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
295 Scalar *B_t = &B_arr[0 * LDB + j];
296 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
297 urolls::template setzero<3, 2>(zmm);
298 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
299 urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
300 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
301 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
302 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
303 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
304 }
305 EIGEN_IF_CONSTEXPR(handleKRem) {
306 for (int64_t k = K_; k < K; k++) {
307 urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
308 B_t, A_t, LDB, LDA, zmm);
309 B_t += LDB;
310 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
311 else A_t += LDA;
312 }
313 }
314 EIGEN_IF_CONSTEXPR(isCRowMajor) {
315 urolls::template updateC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
316 urolls::template storeC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
317 }
318 else {
319 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
320 }
321 i += 2;
322 }
323 if (M - i > 0) {
324 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
325 Scalar *B_t = &B_arr[0 * LDB + j];
326 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
327 urolls::template setzero<3, 1>(zmm);
328 {
329 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
330 urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
331 B_t, A_t, LDB, LDA, zmm);
332 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
333 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
334 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
335 }
336 EIGEN_IF_CONSTEXPR(handleKRem) {
337 for (int64_t k = K_; k < K; k++) {
338 urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
339 B_t += LDB;
340 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
341 else A_t += LDA;
342 }
343 }
344 EIGEN_IF_CONSTEXPR(isCRowMajor) {
345 urolls::template updateC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
346 urolls::template storeC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
347 }
348 else {
349 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
350 }
351 }
352 }
353 }
354 if (N - j >= U2) {
355 constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 2;
356 int64_t i = 0;
357 for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
358 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
359 EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0 * LDB + j];
360 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
361 urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
362 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
363 urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
364 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
365 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
366 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
367 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
368 }
369 EIGEN_IF_CONSTEXPR(handleKRem) {
370 for (int64_t k = K_; k < K; k++) {
371 urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
372 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
373 B_t += LDB;
374 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
375 else A_t += LDA;
376 }
377 }
378 EIGEN_IF_CONSTEXPR(isCRowMajor) {
379 urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
380 urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
381 }
382 else {
383 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[i + j * LDC], LDC);
384 }
385 }
386 if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
387 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
388 Scalar *B_t = &B_arr[0 * LDB + j];
389 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
390 urolls::template setzero<2, 4>(zmm);
391 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
392 urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
393 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
394 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
395 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
396 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
397 }
398 EIGEN_IF_CONSTEXPR(handleKRem) {
399 for (int64_t k = K_; k < K; k++) {
400 urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
401 LDA, zmm);
402 B_t += LDB;
403 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
404 else A_t += LDA;
405 }
406 }
407 EIGEN_IF_CONSTEXPR(isCRowMajor) {
408 urolls::template updateC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
409 urolls::template storeC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
410 }
411 else {
412 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
413 }
414 i += 4;
415 }
416 if (M - i >= 2) {
417 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
418 Scalar *B_t = &B_arr[0 * LDB + j];
419 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
420 urolls::template setzero<2, 2>(zmm);
421 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
422 urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
423 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
424 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
425 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
426 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
427 }
428 EIGEN_IF_CONSTEXPR(handleKRem) {
429 for (int64_t k = K_; k < K; k++) {
430 urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
431 LDA, zmm);
432 B_t += LDB;
433 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
434 else A_t += LDA;
435 }
436 }
437 EIGEN_IF_CONSTEXPR(isCRowMajor) {
438 urolls::template updateC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
439 urolls::template storeC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
440 }
441 else {
442 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
443 }
444 i += 2;
445 }
446 if (M - i > 0) {
447 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
448 Scalar *B_t = &B_arr[0 * LDB + j];
449 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
450 urolls::template setzero<2, 1>(zmm);
451 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
452 urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
453 LDA, zmm);
454 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
455 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
456 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
457 }
458 EIGEN_IF_CONSTEXPR(handleKRem) {
459 for (int64_t k = K_; k < K; k++) {
460 urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
461 B_t += LDB;
462 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
463 else A_t += LDA;
464 }
465 }
466 EIGEN_IF_CONSTEXPR(isCRowMajor) {
467 urolls::template updateC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
468 urolls::template storeC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
469 }
470 else {
471 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
472 }
473 }
474 j += U2;
475 }
476 if (N - j >= U1) {
477 constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
478 int64_t i = 0;
479 for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
480 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
481 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
482 urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
483 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
484 urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
485 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
486 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
487 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
488 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
489 }
490 EIGEN_IF_CONSTEXPR(handleKRem) {
491 for (int64_t k = K_; k < K; k++) {
492 urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
493 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
494 B_t += LDB;
495 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
496 else A_t += LDA;
497 }
498 }
499 EIGEN_IF_CONSTEXPR(isCRowMajor) {
500 urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
501 urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
502 }
503 else {
504 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[i + j * LDC], LDC);
505 }
506 }
507 if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
508 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
509 Scalar *B_t = &B_arr[0 * LDB + j];
510 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
511 urolls::template setzero<1, 4>(zmm);
512 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
513 urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
514 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
515 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
516 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
517 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
518 }
519 EIGEN_IF_CONSTEXPR(handleKRem) {
520 for (int64_t k = K_; k < K; k++) {
521 urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
522 LDA, zmm);
523 B_t += LDB;
524 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
525 else A_t += LDA;
526 }
527 }
528 EIGEN_IF_CONSTEXPR(isCRowMajor) {
529 urolls::template updateC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
530 urolls::template storeC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
531 }
532 else {
533 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
534 }
535 i += 4;
536 }
537 if (M - i >= 2) {
538 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
539 Scalar *B_t = &B_arr[0 * LDB + j];
540 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
541 urolls::template setzero<1, 2>(zmm);
542 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
543 urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
544 EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
545 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
546 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
547 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
548 }
549 EIGEN_IF_CONSTEXPR(handleKRem) {
550 for (int64_t k = K_; k < K; k++) {
551 urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
552 LDA, zmm);
553 B_t += LDB;
554 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
555 else A_t += LDA;
556 }
557 }
558 EIGEN_IF_CONSTEXPR(isCRowMajor) {
559 urolls::template updateC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
560 urolls::template storeC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
561 }
562 else {
563 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
564 }
565 i += 2;
566 }
567 if (M - i > 0) {
568 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
569 Scalar *B_t = &B_arr[0 * LDB + j];
570 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
571 urolls::template setzero<1, 1>(zmm);
572 {
573 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
574 urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
575 LDA, zmm);
576 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
577 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
578 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
579 }
580 EIGEN_IF_CONSTEXPR(handleKRem) {
581 for (int64_t k = K_; k < K; k++) {
582 urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
583 B_t += LDB;
584 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
585 else A_t += LDA;
586 }
587 }
588 EIGEN_IF_CONSTEXPR(isCRowMajor) {
589 urolls::template updateC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
590 urolls::template storeC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
591 }
592 else {
593 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
594 }
595 }
596 }
597 j += U1;
598 }
599 if (N - j > 0) {
600 constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
601 int64_t i = 0;
602 for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
603 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
604 Scalar *B_t = &B_arr[0 * LDB + j];
605 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
606 urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
607 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
608 urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
609 EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
610 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
611 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
612 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
613 }
614 EIGEN_IF_CONSTEXPR(handleKRem) {
615 for (int64_t k = K_; k < K; k++) {
616 urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
617 EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
618 B_t += LDB;
619 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
620 else A_t += LDA;
621 }
622 }
623 EIGEN_IF_CONSTEXPR(isCRowMajor) {
624 urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
625 urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
626 }
627 else {
628 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[i + j * LDC], LDC, 0, N - j);
629 }
630 }
631 if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
632 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
633 Scalar *B_t = &B_arr[0 * LDB + j];
634 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
635 urolls::template setzero<1, 4>(zmm);
636 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
637 urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
638 EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
639 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
640 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
641 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
642 }
643 EIGEN_IF_CONSTEXPR(handleKRem) {
644 for (int64_t k = K_; k < K; k++) {
645 urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
646 B_t, A_t, LDB, LDA, zmm, N - j);
647 B_t += LDB;
648 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
649 else A_t += LDA;
650 }
651 }
652 EIGEN_IF_CONSTEXPR(isCRowMajor) {
653 urolls::template updateC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
654 urolls::template storeC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
655 }
656 else {
657 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 4, N - j);
658 }
659 i += 4;
660 }
661 if (M - i >= 2) {
662 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
663 Scalar *B_t = &B_arr[0 * LDB + j];
664 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
665 urolls::template setzero<1, 2>(zmm);
666 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
667 urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
668 EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
669 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
670 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
671 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
672 }
673 EIGEN_IF_CONSTEXPR(handleKRem) {
674 for (int64_t k = K_; k < K; k++) {
675 urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
676 B_t, A_t, LDB, LDA, zmm, N - j);
677 B_t += LDB;
678 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
679 else A_t += LDA;
680 }
681 }
682 EIGEN_IF_CONSTEXPR(isCRowMajor) {
683 urolls::template updateC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
684 urolls::template storeC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
685 }
686 else {
687 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 2, N - j);
688 }
689 i += 2;
690 }
691 if (M - i > 0) {
692 Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
693 Scalar *B_t = &B_arr[0 * LDB + j];
694 PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
695 urolls::template setzero<1, 1>(zmm);
696 for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
697 urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
698 B_t, A_t, LDB, LDA, zmm, N - j);
699 B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
700 EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
701 else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
702 }
703 EIGEN_IF_CONSTEXPR(handleKRem) {
704 for (int64_t k = K_; k < K; k++) {
705 urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
706 N - j);
707 B_t += LDB;
708 EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
709 else A_t += LDA;
710 }
711 }
712 EIGEN_IF_CONSTEXPR(isCRowMajor) {
713 urolls::template updateC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
714 urolls::template storeC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
715 }
716 else {
717 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 1, N - j);
718 }
719 }
720 }
721}
722
731template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
732EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
733 static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
734 using urolls = unrolls::trsm<Scalar>;
735 constexpr int64_t U3 = urolls::PacketSize * 3;
736 constexpr int64_t U2 = urolls::PacketSize * 2;
737 constexpr int64_t U1 = urolls::PacketSize * 1;
738
739 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> RHSInPacket;
740 PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> AInPacket;
741
742 int64_t k = 0;
743 while (K - k >= U3) {
744 urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
745 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
746 AInPacket);
747 urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
748 k += U3;
749 }
750 if (K - k >= U2) {
751 urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
752 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
753 AInPacket);
754 urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
755 k += U2;
756 }
757 if (K - k >= U1) {
758 urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
759 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
760 AInPacket);
761 urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
762 k += U1;
763 }
764 if (K - k > 0) {
765 // Handle remaining number of RHS
766 urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
767 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
768 AInPacket);
769 urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
770 }
771}
772
781template <typename Scalar, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
782void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) {
783 // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
784 // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
785 using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
786 if (M == 8)
787 triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
788 else if (M == 7)
789 triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
790 else if (M == 6)
791 triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
792 else if (M == 5)
793 triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
794 else if (M == 4)
795 triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
796 else if (M == 3)
797 triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
798 else if (M == 2)
799 triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
800 else if (M == 1)
801 triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
802 return;
803}
804
812template <typename Scalar, bool toTemp = true, bool remM = false>
813EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
814 int64_t remM_ = 0) {
815 EIGEN_UNUSED_VARIABLE(remM_);
816 using urolls = unrolls::transB<Scalar>;
817 using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
818 PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> ymm;
819 constexpr int64_t U3 = urolls::PacketSize * 3;
820 constexpr int64_t U2 = urolls::PacketSize * 2;
821 constexpr int64_t U1 = urolls::PacketSize * 1;
822 int64_t K_ = K / U3 * U3;
823 int64_t k = 0;
824
825 for (; k < K_; k += U3) {
826 urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
827 B_temp += U3;
828 }
829 if (K - k >= U2) {
830 urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
831 B_temp += U2;
832 k += U2;
833 }
834 if (K - k >= U1) {
835 urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
836 B_temp += U1;
837 k += U1;
838 }
839 EIGEN_IF_CONSTEXPR(U1 > 8) {
840 // Note: without "if constexpr" this section of code will also be
841 // parsed by the compiler so there is an additional check in {load/store}BBlock
842 // to make sure the counter is not non-negative.
843 if (K - k >= 8) {
844 urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
845 B_temp += 8;
846 k += 8;
847 }
848 }
849 EIGEN_IF_CONSTEXPR(U1 > 4) {
850 // Note: without "if constexpr" this section of code will also be
851 // parsed by the compiler so there is an additional check in {load/store}BBlock
852 // to make sure the counter is not non-negative.
853 if (K - k >= 4) {
854 urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
855 B_temp += 4;
856 k += 4;
857 }
858 }
859 if (K - k >= 2) {
860 urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
861 B_temp += 2;
862 k += 2;
863 }
864 if (K - k >= 1) {
865 urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
866 B_temp += 1;
867 k += 1;
868 }
869}
870
898template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
899 bool isUnitDiag = false>
900void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
901 constexpr int64_t psize = packet_traits<Scalar>::size;
915 constexpr int64_t kB = (3 * psize) * 5; // 5*U3
916 constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
917
918 int64_t sizeBTemp = 0;
919 Scalar *B_temp = NULL;
920 EIGEN_IF_CONSTEXPR(!isBRowMajor) {
926 sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
927 }
928
929 EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
930
931 for (int64_t k = 0; k < numRHS; k += kB) {
932 int64_t bK = numRHS - k > kB ? kB : numRHS - k;
933 int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
934
935 // bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
936 // instead of bK RHS in triSolveKernelLxK.
937 int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW - 1)) / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
938 const int64_t numScalarPerCache = 64 / sizeof(Scalar);
939 // Leading dimension of B_temp, will be a multiple of the cache line size.
940 int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
941 int64_t offsetBTemp = 0;
942 for (int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
943 EIGEN_IF_CONSTEXPR(!isBRowMajor) {
944 int64_t indA_i = isFWDSolve ? i : M - 1 - i;
945 int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW);
946 int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW * LDT - offsetBTemp;
947 int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
948 // Copy values from B to B_temp.
949 copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
950 // Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
951 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
952 &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
953 // Copy values from B_temp back to B. B_temp will be reused in gemm call below.
954 copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
955
956 offsetBTemp += EIGEN_AVX_MAX_NUM_ROW * LDT;
957 }
958 else {
959 int64_t ind = isFWDSolve ? i : M - 1 - i;
960 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
961 &A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind * LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB);
962 }
963 if (i + EIGEN_AVX_MAX_NUM_ROW < M_) {
976 EIGEN_IF_CONSTEXPR(isBRowMajor) {
977 int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
978 int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
979 int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
980 int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
981 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
982 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
983 EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
984 }
985 else {
986 if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
995 int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
996 int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
997 int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
998 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
999 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1000 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1001 M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1002 offsetBTemp = 0;
1003 gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
1004 } else {
1008 int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1009 int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
1010 int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1011 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1012 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1013 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1014 EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1015 }
1016 }
1017 }
1018 }
1019 // Handle M remainder..
1020 int64_t bM = M - M_;
1021 if (bM > 0) {
1022 if (M_ > 0) {
1023 EIGEN_IF_CONSTEXPR(isBRowMajor) {
1024 int64_t indA_i = isFWDSolve ? M_ : 0;
1025 int64_t indA_j = isFWDSolve ? 0 : bM;
1026 int64_t indB_i = isFWDSolve ? 0 : bM;
1027 int64_t indB_i2 = isFWDSolve ? M_ : 0;
1028 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1029 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
1030 bK, M_, LDA, LDB, LDB);
1031 }
1032 else {
1033 int64_t indA_i = isFWDSolve ? M_ : 0;
1034 int64_t indA_j = isFWDSolve ? gemmOff : bM;
1035 int64_t indB_i = isFWDSolve ? M_ : 0;
1036 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1037 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
1038 B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
1039 M_ - gemmOff, LDA, LDT, LDB);
1040 }
1041 }
1042 EIGEN_IF_CONSTEXPR(!isBRowMajor) {
1043 int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_;
1044 int64_t indB_i = isFWDSolve ? M_ : 0;
1045 int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
1046 copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1047 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
1048 B_temp + offB_1, bM, bkL, LDA, bkL);
1049 copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1050 }
1051 else {
1052 int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
1053 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
1054 B_arr + k + ind * LDB, bM, bK, LDA, LDB);
1055 }
1056 }
1057 }
1058
1059 EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
1060}
1061
1062// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
1063#if (EIGEN_USE_AVX512_TRSM_KERNELS)
1064#if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1065template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
1066 bool Specialized>
1067struct trsmKernelR;
1068
1069template <typename Index, int Mode, int TriStorageOrder>
1070struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true> {
1071 static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1072 Index otherStride);
1073};
1074
1075template <typename Index, int Mode, int TriStorageOrder>
1076struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true> {
1077 static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1078 Index otherStride);
1079};
1080
1081template <typename Index, int Mode, int TriStorageOrder>
1082EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1083 Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1084 Index otherStride) {
1085 EIGEN_UNUSED_VARIABLE(otherIncr);
1086#ifdef EIGEN_RUNTIME_NO_MALLOC
1087 if (!is_malloc_allowed()) {
1088 trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1089 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1090 return;
1091 }
1092#endif
1093 triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1094 const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1095}
1096
1097template <typename Index, int Mode, int TriStorageOrder>
1098EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1099 Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1100 Index otherStride) {
1101 EIGEN_UNUSED_VARIABLE(otherIncr);
1102#ifdef EIGEN_RUNTIME_NO_MALLOC
1103 if (!is_malloc_allowed()) {
1104 trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1105 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1106 return;
1107 }
1108#endif
1109 triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1110 const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1111}
1112#endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1113
1114// These trsm kernels require temporary memory allocation
1115#if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
1116template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
1117 bool Specialized = true>
1118struct trsmKernelL;
1119
1120template <typename Index, int Mode, int TriStorageOrder>
1121struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true> {
1122 static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1123 Index otherStride);
1124};
1125
1126template <typename Index, int Mode, int TriStorageOrder>
1127struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true> {
1128 static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1129 Index otherStride);
1130};
1131
1132template <typename Index, int Mode, int TriStorageOrder>
1133EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1134 Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1135 Index otherStride) {
1136 EIGEN_UNUSED_VARIABLE(otherIncr);
1137#ifdef EIGEN_RUNTIME_NO_MALLOC
1138 if (!is_malloc_allowed()) {
1139 trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1140 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1141 return;
1142 }
1143#endif
1144 triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1145 const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1146}
1147
1148template <typename Index, int Mode, int TriStorageOrder>
1149EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1150 Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1151 Index otherStride) {
1152 EIGEN_UNUSED_VARIABLE(otherIncr);
1153#ifdef EIGEN_RUNTIME_NO_MALLOC
1154 if (!is_malloc_allowed()) {
1155 trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1156 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1157 return;
1158 }
1159#endif
1160 triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1161 const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1162}
1163#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
1164#endif // EIGEN_USE_AVX512_TRSM_KERNELS
1165} // namespace internal
1166} // namespace Eigen
1167#endif // EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
Namespace containing all symbols from the Eigen library.
Definition Core:137
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83