Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
MatrixProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020 Everton Constantino ([email protected])
5// Copyright (C) 2021 Chip Kerchner ([email protected])
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
13
14#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
16#endif
17
18#if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
19#define EIGEN_ALTIVEC_DISABLE_MMA 0
20#endif
21
22// Check for MMA builtin support.
23#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__has_builtin)
24#if __has_builtin(__builtin_mma_assemble_acc)
25#define EIGEN_ALTIVEC_MMA_SUPPORT
26#endif
27#endif
28
29// Check if and how we should actually use MMA if supported.
30#if defined(EIGEN_ALTIVEC_MMA_SUPPORT)
31
32#if !defined(EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH)
33#define EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH 0
34#endif
35
36// Check if we want to enable dynamic dispatch. Not supported by LLVM.
37#if EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH && !EIGEN_COMP_LLVM
38#define EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH 1
39// Otherwise, use MMA by default if available.
40#elif defined(__MMA__)
41#define EIGEN_ALTIVEC_MMA_ONLY 1
42#endif
43
44#endif // EIGEN_ALTIVEC_MMA_SUPPORT
45
46#include "MatrixProductCommon.h"
47
48#if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
49#include "MatrixProductMMA.h"
50#endif
51
52// IWYU pragma: private
53#include "../../InternalHeaderCheck.h"
54
55namespace Eigen {
56
57namespace internal {
58
59/**************************
60 * Constants and typedefs *
61 **************************/
62template <typename Scalar>
63struct quad_traits {
64 typedef typename packet_traits<Scalar>::type vectortype;
65 typedef PacketBlock<vectortype, 4> type;
66 typedef vectortype rhstype;
67 enum { vectorsize = packet_traits<Scalar>::size, size = 4, rows = 4 };
68};
69
70template <>
71struct quad_traits<double> {
72 typedef Packet2d vectortype;
73 typedef PacketBlock<vectortype, 4> type;
74 typedef PacketBlock<Packet2d, 2> rhstype;
75 enum { vectorsize = packet_traits<double>::size, size = 2, rows = 4 };
76};
77
78template <>
79struct quad_traits<bfloat16> {
80 typedef Packet8bf vectortype;
81 typedef PacketBlock<vectortype, 4> type;
82 typedef vectortype rhstype;
83 enum { vectorsize = packet_traits<bfloat16>::size, size = 8, rows = 4 };
84};
85
86// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
87// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
88// are responsible to extract from convert between Eigen's and MatrixProduct approach.
89
90const static Packet16uc p16uc_GETREAL32 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
91
92const static Packet16uc p16uc_GETIMAG32 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
93
94const static Packet16uc p16uc_GETREAL32b = {0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
95
96const static Packet16uc p16uc_GETIMAG32b = {4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31};
97
98/*********************************************
99 * Single precision real and complex packing *
100 * *******************************************/
101
116template <typename Scalar, int StorageOrder>
117EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(
118 Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt) {
119 std::complex<Scalar> v;
120 if (i < j) {
121 v.real(dt(j, i).real());
122 v.imag(-dt(j, i).imag());
123 } else if (i > j) {
124 v.real(dt(i, j).real());
125 v.imag(dt(i, j).imag());
126 } else {
127 v.real(dt(i, j).real());
128 v.imag((Scalar)0.0);
129 }
130 return v;
131}
132
133template <typename Scalar, int StorageOrder, int N>
134EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs,
135 Index rhsStride, Index rows, Index cols, Index k2) {
136 const Index depth = k2 + rows;
137 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
138 const Index vectorSize = N * quad_traits<Scalar>::vectorsize;
139 const Index vectorDelta = vectorSize * rows;
140 Scalar* blockBf = reinterpret_cast<Scalar*>(blockB);
141
142 Index rir = 0, rii, j = 0;
143 for (; j + vectorSize <= cols; j += vectorSize) {
144 rii = rir + vectorDelta;
145
146 for (Index i = k2; i < depth; i++) {
147 for (Index k = 0; k < vectorSize; k++) {
148 std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j + k, rhs);
149
150 blockBf[rir + k] = v.real();
151 blockBf[rii + k] = v.imag();
152 }
153 rir += vectorSize;
154 rii += vectorSize;
155 }
156
157 rir += vectorDelta;
158 }
159
160 for (; j < cols; j++) {
161 rii = rir + rows;
162
163 for (Index i = k2; i < depth; i++) {
164 std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j, rhs);
165
166 blockBf[rir] = v.real();
167 blockBf[rii] = v.imag();
168
169 rir += 1;
170 rii += 1;
171 }
172
173 rir += rows;
174 }
175}
176
177template <typename Scalar, int StorageOrder>
178EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs,
179 Index lhsStride, Index cols, Index rows) {
180 const Index depth = cols;
181 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
182 const Index vectorSize = quad_traits<Scalar>::vectorsize;
183 const Index vectorDelta = vectorSize * depth;
184 Scalar* blockAf = reinterpret_cast<Scalar*>(blockA);
185
186 Index rir = 0, rii, j = 0;
187 for (; j + vectorSize <= rows; j += vectorSize) {
188 rii = rir + vectorDelta;
189
190 for (Index i = 0; i < depth; i++) {
191 for (Index k = 0; k < vectorSize; k++) {
192 std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(j + k, i, lhs);
193
194 blockAf[rir + k] = v.real();
195 blockAf[rii + k] = v.imag();
196 }
197 rir += vectorSize;
198 rii += vectorSize;
199 }
200
201 rir += vectorDelta;
202 }
203
204 if (j < rows) {
205 rii = rir + ((rows - j) * depth);
206
207 for (Index i = 0; i < depth; i++) {
208 Index k = j;
209 for (; k < rows; k++) {
210 std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(k, i, lhs);
211
212 blockAf[rir] = v.real();
213 blockAf[rii] = v.imag();
214
215 rir += 1;
216 rii += 1;
217 }
218 }
219 }
220}
221
222template <typename Scalar, int StorageOrder, int N>
223EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows,
224 Index cols, Index k2) {
225 const Index depth = k2 + rows;
226 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
227 const Index vectorSize = quad_traits<Scalar>::vectorsize;
228
229 Index ri = 0, j = 0;
230 for (; j + N * vectorSize <= cols; j += N * vectorSize) {
231 Index i = k2;
232 for (; i < depth; i++) {
233 for (Index k = 0; k < N * vectorSize; k++) {
234 if (i <= j + k)
235 blockB[ri + k] = rhs(j + k, i);
236 else
237 blockB[ri + k] = rhs(i, j + k);
238 }
239 ri += N * vectorSize;
240 }
241 }
242
243 for (; j < cols; j++) {
244 for (Index i = k2; i < depth; i++) {
245 if (j <= i)
246 blockB[ri] = rhs(i, j);
247 else
248 blockB[ri] = rhs(j, i);
249 ri += 1;
250 }
251 }
252}
253
254template <typename Scalar, int StorageOrder>
255EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols,
256 Index rows) {
257 const Index depth = cols;
258 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
259 const Index vectorSize = quad_traits<Scalar>::vectorsize;
260
261 Index ri = 0, j = 0;
262 for (; j + vectorSize <= rows; j += vectorSize) {
263 Index i = 0;
264
265 for (; i < depth; i++) {
266 for (Index k = 0; k < vectorSize; k++) {
267 if (i <= j + k)
268 blockA[ri + k] = lhs(j + k, i);
269 else
270 blockA[ri + k] = lhs(i, j + k);
271 }
272 ri += vectorSize;
273 }
274 }
275
276 if (j < rows) {
277 for (Index i = 0; i < depth; i++) {
278 Index k = j;
279 for (; k < rows; k++) {
280 if (i <= k)
281 blockA[ri] = lhs(k, i);
282 else
283 blockA[ri] = lhs(i, k);
284 ri += 1;
285 }
286 }
287 }
288}
289
290template <typename Index, int nr, int StorageOrder>
291struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder> {
292 void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols,
293 Index k2) {
294 symm_pack_complex_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
295 }
296};
297
298template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
299struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder> {
300 void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols,
301 Index rows) {
302 symm_pack_complex_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
303 }
304};
305
306// *********** symm_pack std::complex<float64> ***********
307
308template <typename Index, int nr, int StorageOrder>
309struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder> {
310 void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows,
311 Index cols, Index k2) {
312 symm_pack_complex_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
313 }
314};
315
316template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
317struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder> {
318 void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols,
319 Index rows) {
320 symm_pack_complex_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
321 }
322};
323
324// *********** symm_pack float32 ***********
325template <typename Index, int nr, int StorageOrder>
326struct symm_pack_rhs<float, Index, nr, StorageOrder> {
327 void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2) {
328 symm_pack_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
329 }
330};
331
332template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
333struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder> {
334 void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows) {
335 symm_pack_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
336 }
337};
338
339// *********** symm_pack float64 ***********
340template <typename Index, int nr, int StorageOrder>
341struct symm_pack_rhs<double, Index, nr, StorageOrder> {
342 void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2) {
343 symm_pack_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
344 }
345};
346
347template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
348struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder> {
349 void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows) {
350 symm_pack_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
351 }
352};
353
365template <typename Scalar, typename Packet, int N>
366EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet, N>& block) {
367 const Index size = 16 / sizeof(Scalar);
368 pstore<Scalar>(to + (0 * size), block.packet[0]);
369 pstore<Scalar>(to + (1 * size), block.packet[1]);
370 if (N > 2) {
371 pstore<Scalar>(to + (2 * size), block.packet[2]);
372 }
373 if (N > 3) {
374 pstore<Scalar>(to + (3 * size), block.packet[3]);
375 }
376}
377
378// General template for lhs & rhs complex packing.
379template <typename Scalar, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate,
380 bool PanelMode, bool UseLhs>
381struct dhs_cpack {
382 template <bool transpose>
383 EIGEN_ALWAYS_INLINE void dhs_cblock(PacketBlock<PacketC, 8>& cblock, PacketBlock<Packet, 4>& block,
384 Packet16uc permute) {
385 if (transpose) {
386 block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, permute);
387 block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, permute);
388 block.packet[2] = vec_perm(cblock.packet[4].v, cblock.packet[5].v, permute);
389 block.packet[3] = vec_perm(cblock.packet[6].v, cblock.packet[7].v, permute);
390
391 Packet4f t0, t1, t2, t3;
392#ifdef EIGEN_VECTORIZE_VSX
393 t0 = reinterpret_cast<Packet>(
394 vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
395 t1 = reinterpret_cast<Packet>(
396 vec_mergel(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
397 t2 = reinterpret_cast<Packet>(
398 vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
399 t3 = reinterpret_cast<Packet>(
400 vec_mergel(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
401#else
402 t0 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_HI));
403 t1 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_LO));
404 t2 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_HI));
405 t3 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_LO));
406#endif
407
408 block.packet[0] = t0;
409 block.packet[1] = t1;
410 block.packet[2] = t2;
411 block.packet[3] = t3;
412 } else {
413 block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, permute);
414 block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, permute);
415 block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, permute);
416 block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, permute);
417 }
418 }
419
420 EIGEN_ALWAYS_INLINE void dhs_ccopy(Scalar* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii,
421 Index depth, const Index vectorSize) {
422 PacketBlock<Packet, 4> blockr, blocki;
423 PacketBlock<PacketC, 8> cblock;
424
425 for (; i + vectorSize <= depth; i += vectorSize) {
426 if (UseLhs) {
427 bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
428 } else {
429 bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
430 }
431
432 if (((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs))) {
433 dhs_cblock<true>(cblock, blockr, p16uc_GETREAL32b);
434 dhs_cblock<true>(cblock, blocki, p16uc_GETIMAG32b);
435 } else {
436 dhs_cblock<false>(cblock, blockr, p16uc_GETREAL32);
437 dhs_cblock<false>(cblock, blocki, p16uc_GETIMAG32);
438 }
439
440 if (Conjugate) {
441 blocki.packet[0] = -blocki.packet[0];
442 blocki.packet[1] = -blocki.packet[1];
443 blocki.packet[2] = -blocki.packet[2];
444 blocki.packet[3] = -blocki.packet[3];
445 }
446
447 storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
448 storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
449
450 rir += 4 * vectorSize;
451 rii += 4 * vectorSize;
452 }
453 }
454
455 EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows,
456 Index stride, Index offset) {
457 const Index vectorSize = quad_traits<Scalar>::vectorsize;
458 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
459 Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
460 Scalar* blockAt = reinterpret_cast<Scalar*>(blockA);
461 Index j = 0;
462
463 for (; j + vectorSize <= rows; j += vectorSize) {
464 const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
465 Index i = 0;
466
467 rii = rir + vectorDelta;
468
469 dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
470
471 for (; i < depth; i++) {
472 PacketBlock<Packet, 1> blockr, blocki;
473 PacketBlock<PacketC, 2> cblock;
474
475 if (((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs))) {
476 if (UseLhs) {
477 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
478 cblock.packet[1] = lhs2.template loadPacket<PacketC>(2, i);
479 } else {
480 cblock.packet[0] = lhs2.template loadPacket<PacketC>(i, 0);
481 cblock.packet[1] = lhs2.template loadPacket<PacketC>(i, 2);
482 }
483 } else {
484 if (UseLhs) {
485 cblock.packet[0] = pload2(lhs2(0, i), lhs2(1, i));
486 cblock.packet[1] = pload2(lhs2(2, i), lhs2(3, i));
487 } else {
488 cblock.packet[0] = pload2(lhs2(i, 0), lhs2(i, 1));
489 cblock.packet[1] = pload2(lhs2(i, 2), lhs2(i, 3));
490 }
491 }
492
493 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
494 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
495
496 if (Conjugate) {
497 blocki.packet[0] = -blocki.packet[0];
498 }
499
500 pstore<Scalar>(blockAt + rir, blockr.packet[0]);
501 pstore<Scalar>(blockAt + rii, blocki.packet[0]);
502
503 rir += vectorSize;
504 rii += vectorSize;
505 }
506
507 rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
508 }
509
510 if (!UseLhs) {
511 if (PanelMode) rir -= (offset * (vectorSize - 1));
512
513 for (; j < rows; j++) {
514 const DataMapper lhs2 = lhs.getSubMapper(0, j);
515 rii = rir + ((PanelMode) ? stride : depth);
516
517 for (Index i = 0; i < depth; i++) {
518 blockAt[rir] = lhs2(i, 0).real();
519
520 if (Conjugate)
521 blockAt[rii] = -lhs2(i, 0).imag();
522 else
523 blockAt[rii] = lhs2(i, 0).imag();
524
525 rir += 1;
526 rii += 1;
527 }
528
529 rir += ((PanelMode) ? (2 * stride - depth) : depth);
530 }
531 } else {
532 if (j < rows) {
533 if (PanelMode) rir += (offset * (rows - j - vectorSize));
534 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
535
536 for (Index i = 0; i < depth; i++) {
537 Index k = j;
538 for (; k < rows; k++) {
539 blockAt[rir] = lhs(k, i).real();
540
541 if (Conjugate)
542 blockAt[rii] = -lhs(k, i).imag();
543 else
544 blockAt[rii] = lhs(k, i).imag();
545
546 rir += 1;
547 rii += 1;
548 }
549 }
550 }
551 }
552 }
553};
554
555// General template for lhs & rhs packing.
556template <typename Scalar, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
557struct dhs_pack {
558 template <Index n>
559 EIGEN_ALWAYS_INLINE void dhs_copy(Scalar* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth,
560 const Index vectorSize) {
561 PacketBlock<Packet, 4> block[n];
562
563 for (; i + n * vectorSize <= depth; i += n * vectorSize) {
564 for (Index k = 0; k < n; k++) {
565 if (UseLhs) {
566 bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, 0, i + k * vectorSize);
567 } else {
568 bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, i + k * vectorSize, 0);
569 }
570 }
571
572 if (((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) {
573 for (Index k = 0; k < n; k++) {
574 ptranspose(block[k]);
575 }
576 }
577
578 for (Index k = 0; k < n; k++) {
579 storeBlock<Scalar, Packet, 4>(blockA + ri + k * 4 * vectorSize, block[k]);
580 }
581
582 ri += n * 4 * vectorSize;
583 }
584 }
585
586 EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
587 Index offset) {
588 const Index vectorSize = quad_traits<Scalar>::vectorsize;
589 Index ri = 0, j = 0;
590
591 for (; j + vectorSize <= rows; j += vectorSize) {
592 const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
593 Index i = 0;
594
595 if (PanelMode) ri += vectorSize * offset;
596
597 dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
598 dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
599 dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
600
601 for (; i < depth; i++) {
602 if (((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) {
603 if (UseLhs) {
604 blockA[ri + 0] = lhs2(0, i);
605 blockA[ri + 1] = lhs2(1, i);
606 blockA[ri + 2] = lhs2(2, i);
607 blockA[ri + 3] = lhs2(3, i);
608 } else {
609 blockA[ri + 0] = lhs2(i, 0);
610 blockA[ri + 1] = lhs2(i, 1);
611 blockA[ri + 2] = lhs2(i, 2);
612 blockA[ri + 3] = lhs2(i, 3);
613 }
614 } else {
615 Packet lhsV;
616 if (UseLhs) {
617 lhsV = lhs2.template loadPacket<Packet>(0, i);
618 } else {
619 lhsV = lhs2.template loadPacket<Packet>(i, 0);
620 }
621 pstore<Scalar>(blockA + ri, lhsV);
622 }
623
624 ri += vectorSize;
625 }
626
627 if (PanelMode) ri += vectorSize * (stride - offset - depth);
628 }
629
630 if (!UseLhs) {
631 if (PanelMode) ri += offset;
632
633 for (; j < rows; j++) {
634 const DataMapper lhs2 = lhs.getSubMapper(0, j);
635 for (Index i = 0; i < depth; i++) {
636 blockA[ri] = lhs2(i, 0);
637 ri += 1;
638 }
639
640 if (PanelMode) ri += stride - depth;
641 }
642 } else {
643 if (j < rows) {
644 if (PanelMode) ri += offset * (rows - j);
645
646 for (Index i = 0; i < depth; i++) {
647 Index k = j;
648 for (; k < rows; k++) {
649 blockA[ri] = lhs(k, i);
650 ri += 1;
651 }
652 }
653 }
654 }
655 }
656};
657
658// General template for lhs packing, float64 specialization.
659template <typename DataMapper, int StorageOrder, bool PanelMode>
660struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true> {
661 template <Index n>
662 EIGEN_ALWAYS_INLINE void dhs_copy(double* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth,
663 const Index vectorSize) {
664 PacketBlock<Packet2d, 2> block[n];
665
666 for (; i + n * vectorSize <= depth; i += n * vectorSize) {
667 for (Index k = 0; k < n; k++) {
668 if (StorageOrder == RowMajor) {
669 block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize);
670 block[k].packet[1] = lhs2.template loadPacket<Packet2d>(1, i + k * vectorSize);
671 } else {
672 block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize + 0);
673 block[k].packet[1] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize + 1);
674 }
675 }
676
677 if (StorageOrder == RowMajor) {
678 for (Index k = 0; k < n; k++) {
679 ptranspose(block[k]);
680 }
681 }
682
683 for (Index k = 0; k < n; k++) {
684 storeBlock<double, Packet2d, 2>(blockA + ri + k * 2 * vectorSize, block[k]);
685 }
686
687 ri += n * 2 * vectorSize;
688 }
689 }
690
691 EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
692 Index offset) {
693 const Index vectorSize = quad_traits<double>::vectorsize;
694 Index ri = 0, j = 0;
695
696 for (; j + vectorSize <= rows; j += vectorSize) {
697 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
698 Index i = 0;
699
700 if (PanelMode) ri += vectorSize * offset;
701
702 dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
703 dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
704 dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
705
706 for (; i < depth; i++) {
707 if (StorageOrder == RowMajor) {
708 blockA[ri + 0] = lhs2(0, i);
709 blockA[ri + 1] = lhs2(1, i);
710 } else {
711 Packet2d lhsV = lhs2.template loadPacket<Packet2d>(0, i);
712 pstore<double>(blockA + ri, lhsV);
713 }
714
715 ri += vectorSize;
716 }
717
718 if (PanelMode) ri += vectorSize * (stride - offset - depth);
719 }
720
721 if (j < rows) {
722 if (PanelMode) ri += offset * (rows - j);
723
724 for (Index i = 0; i < depth; i++) {
725 Index k = j;
726 for (; k < rows; k++) {
727 blockA[ri] = lhs(k, i);
728 ri += 1;
729 }
730 }
731 }
732 }
733};
734
735// General template for rhs packing, float64 specialization.
736template <typename DataMapper, int StorageOrder, bool PanelMode>
737struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false> {
738 template <Index n>
739 EIGEN_ALWAYS_INLINE void dhs_copy(double* blockB, const DataMapper& rhs2, Index& i, Index& ri, Index depth,
740 const Index vectorSize) {
741 PacketBlock<Packet2d, 2> block1[n], block2[n];
742 PacketBlock<Packet2d, 4> block3[n];
743
744 for (; i + n * vectorSize <= depth; i += n * vectorSize) {
745 for (Index k = 0; k < n; k++) {
746 if (StorageOrder == ColMajor) {
747 block1[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 0);
748 block1[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 1);
749 block2[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 2);
750 block2[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 3);
751 } else {
752 block3[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 0, 0); //[a1 a2]
753 block3[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 0, 2); //[a3 a4]
754 block3[k].packet[2] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 1, 0); //[b1 b2]
755 block3[k].packet[3] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 1, 2); //[b3 b4]
756 }
757 }
758
759 if (StorageOrder == ColMajor) {
760 for (Index k = 0; k < n; k++) {
761 ptranspose(block1[k]);
762 ptranspose(block2[k]);
763 }
764 }
765
766 for (Index k = 0; k < n; k++) {
767 if (StorageOrder == ColMajor) {
768 pstore<double>(blockB + ri + k * 4 * vectorSize, block1[k].packet[0]);
769 pstore<double>(blockB + ri + k * 4 * vectorSize + 2, block2[k].packet[0]);
770 pstore<double>(blockB + ri + k * 4 * vectorSize + 4, block1[k].packet[1]);
771 pstore<double>(blockB + ri + k * 4 * vectorSize + 6, block2[k].packet[1]);
772 } else {
773 storeBlock<double, Packet2d, 4>(blockB + ri + k * 4 * vectorSize, block3[k]);
774 }
775 }
776
777 ri += n * 4 * vectorSize;
778 }
779 }
780
781 EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride,
782 Index offset) {
783 const Index vectorSize = quad_traits<double>::vectorsize;
784 Index ri = 0, j = 0;
785
786 for (; j + 2 * vectorSize <= cols; j += 2 * vectorSize) {
787 const DataMapper rhs2 = rhs.getSubMapper(0, j);
788 Index i = 0;
789
790 if (PanelMode) ri += offset * (2 * vectorSize);
791
792 dhs_copy<4>(blockB, rhs2, i, ri, depth, vectorSize);
793 dhs_copy<2>(blockB, rhs2, i, ri, depth, vectorSize);
794 dhs_copy<1>(blockB, rhs2, i, ri, depth, vectorSize);
795
796 for (; i < depth; i++) {
797 if (StorageOrder == ColMajor) {
798 blockB[ri + 0] = rhs2(i, 0);
799 blockB[ri + 1] = rhs2(i, 1);
800
801 ri += vectorSize;
802
803 blockB[ri + 0] = rhs2(i, 2);
804 blockB[ri + 1] = rhs2(i, 3);
805 } else {
806 Packet2d rhsV = rhs2.template loadPacket<Packet2d>(i, 0);
807 pstore<double>(blockB + ri, rhsV);
808
809 ri += vectorSize;
810
811 rhsV = rhs2.template loadPacket<Packet2d>(i, 2);
812 pstore<double>(blockB + ri, rhsV);
813 }
814 ri += vectorSize;
815 }
816
817 if (PanelMode) ri += (2 * vectorSize) * (stride - offset - depth);
818 }
819
820 if (PanelMode) ri += offset;
821
822 for (; j < cols; j++) {
823 const DataMapper rhs2 = rhs.getSubMapper(0, j);
824 for (Index i = 0; i < depth; i++) {
825 blockB[ri] = rhs2(i, 0);
826 ri += 1;
827 }
828
829 if (PanelMode) ri += stride - depth;
830 }
831 }
832};
833
834// General template for lhs packing, bfloat16 specialization.
835template <typename DataMapper, int StorageOrder, bool PanelMode>
836struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true> {
837 EIGEN_STRONG_INLINE void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
838 Index offset) {
839 const Index vectorSize = quad_traits<bfloat16>::vectorsize;
840 Index ri = 0, j = 0;
841
842 for (; j + 2 * vectorSize <= rows; j += 2 * vectorSize) {
843 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
844 Index i = 0;
845
846 if (PanelMode) ri += 2 * vectorSize * offset;
847
848 if (StorageOrder == ColMajor) {
849 for (; i + 2 <= depth; i += 2) {
850 PacketBlock<Packet8bf, 4> block;
851
852 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
853 block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
854 block.packet[2] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
855 block.packet[3] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 1);
856
857 Packet8bf t0, t1;
858 t0 = vec_mergeh(block.packet[0].m_val, block.packet[2].m_val);
859 t1 = vec_mergel(block.packet[0].m_val, block.packet[2].m_val);
860 block.packet[2] = vec_mergeh(block.packet[1].m_val, block.packet[3].m_val);
861 block.packet[3] = vec_mergel(block.packet[1].m_val, block.packet[3].m_val);
862 block.packet[0] = t0;
863 block.packet[1] = t1;
864
865 storeBlock<bfloat16, Packet8bf, 4>(blockA + ri, block);
866
867 ri += 2 * 2 * vectorSize;
868 }
869 if (depth & 1) {
870 PacketBlock<Packet8bf, 2> block;
871
872 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
873 block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
874
875 storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
876
877 ri += 2 * vectorSize;
878 }
879 } else {
880 for (; i + vectorSize <= depth; i += vectorSize) {
881 PacketBlock<Packet8bf, 8> block1, block2;
882
883 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
884 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block2, lhs2, 1 * vectorSize, i);
885
886 Packet4ui v1[8], v2[8];
887
888 v1[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
889 reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
890 v1[1] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
891 reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
892 v1[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
893 reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
894 v1[3] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
895 reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
896 v1[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
897 reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
898 v1[5] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
899 reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
900 v1[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
901 reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
902 v1[7] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
903 reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
904 v2[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[0].m_val),
905 reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
906 v2[1] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[0].m_val),
907 reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
908 v2[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[2].m_val),
909 reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
910 v2[3] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[2].m_val),
911 reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
912 v2[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[4].m_val),
913 reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
914 v2[5] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[4].m_val),
915 reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
916 v2[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[6].m_val),
917 reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
918 v2[7] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[6].m_val),
919 reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
920
921#ifdef EIGEN_VECTORIZE_VSX
922 block1.packet[0] = reinterpret_cast<Packet8us>(
923 vec_mergeh(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
924 block1.packet[2] = reinterpret_cast<Packet8us>(
925 vec_mergel(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
926 block1.packet[4] = reinterpret_cast<Packet8us>(
927 vec_mergeh(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
928 block1.packet[6] = reinterpret_cast<Packet8us>(
929 vec_mergel(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
930 block1.packet[1] = reinterpret_cast<Packet8us>(
931 vec_mergeh(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
932 block1.packet[3] = reinterpret_cast<Packet8us>(
933 vec_mergel(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
934 block1.packet[5] = reinterpret_cast<Packet8us>(
935 vec_mergeh(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
936 block1.packet[7] = reinterpret_cast<Packet8us>(
937 vec_mergel(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
938 block2.packet[0] = reinterpret_cast<Packet8us>(
939 vec_mergeh(reinterpret_cast<Packet2ul>(v2[0]), reinterpret_cast<Packet2ul>(v2[2])));
940 block2.packet[2] = reinterpret_cast<Packet8us>(
941 vec_mergel(reinterpret_cast<Packet2ul>(v2[0]), reinterpret_cast<Packet2ul>(v2[2])));
942 block2.packet[4] = reinterpret_cast<Packet8us>(
943 vec_mergeh(reinterpret_cast<Packet2ul>(v2[1]), reinterpret_cast<Packet2ul>(v2[3])));
944 block2.packet[6] = reinterpret_cast<Packet8us>(
945 vec_mergel(reinterpret_cast<Packet2ul>(v2[1]), reinterpret_cast<Packet2ul>(v2[3])));
946 block2.packet[1] = reinterpret_cast<Packet8us>(
947 vec_mergeh(reinterpret_cast<Packet2ul>(v2[4]), reinterpret_cast<Packet2ul>(v2[6])));
948 block2.packet[3] = reinterpret_cast<Packet8us>(
949 vec_mergel(reinterpret_cast<Packet2ul>(v2[4]), reinterpret_cast<Packet2ul>(v2[6])));
950 block2.packet[5] = reinterpret_cast<Packet8us>(
951 vec_mergeh(reinterpret_cast<Packet2ul>(v2[5]), reinterpret_cast<Packet2ul>(v2[7])));
952 block2.packet[7] = reinterpret_cast<Packet8us>(
953 vec_mergel(reinterpret_cast<Packet2ul>(v2[5]), reinterpret_cast<Packet2ul>(v2[7])));
954#else
955 block1.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_HI));
956 block1.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_LO));
957 block1.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_HI));
958 block1.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_LO));
959 block1.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_HI));
960 block1.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_LO));
961 block1.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_HI));
962 block1.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_LO));
963 block2.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v2[0], v2[2], p16uc_TRANSPOSE64_HI));
964 block2.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v2[0], v2[2], p16uc_TRANSPOSE64_LO));
965 block2.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v2[1], v2[3], p16uc_TRANSPOSE64_HI));
966 block2.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v2[1], v2[3], p16uc_TRANSPOSE64_LO));
967 block2.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v2[4], v2[6], p16uc_TRANSPOSE64_HI));
968 block2.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v2[4], v2[6], p16uc_TRANSPOSE64_LO));
969 block2.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v2[5], v2[7], p16uc_TRANSPOSE64_HI));
970 block2.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v2[5], v2[7], p16uc_TRANSPOSE64_LO));
971#endif
972
973 for (Index M = 0; M < 8; M += 2) {
974 pstore<bfloat16>(blockA + ri + (0 * vectorSize) + (2 * vectorSize * M), block1.packet[M + 0]);
975 pstore<bfloat16>(blockA + ri + (1 * vectorSize) + (2 * vectorSize * M), block1.packet[M + 1]);
976 pstore<bfloat16>(blockA + ri + (2 * vectorSize) + (2 * vectorSize * M), block2.packet[M + 0]);
977 pstore<bfloat16>(blockA + ri + (3 * vectorSize) + (2 * vectorSize * M), block2.packet[M + 1]);
978 }
979
980 ri += 2 * vectorSize * vectorSize;
981 }
982 for (; i + 2 <= depth; i += 2) {
983 for (Index M = 0; M < 2 * vectorSize; M++) {
984 blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
985 blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
986 }
987
988 ri += 2 * 2 * vectorSize;
989 }
990 if (depth & 1) {
991 for (Index M = 0; M < 2 * vectorSize; M++) {
992 blockA[ri + M] = lhs2(M, i);
993 }
994 ri += 2 * vectorSize;
995 }
996 }
997
998 if (PanelMode) ri += 2 * vectorSize * (stride - offset - depth);
999 }
1000 for (; j + vectorSize <= rows; j += vectorSize) {
1001 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1002 Index i = 0;
1003
1004 if (PanelMode) ri += vectorSize * offset;
1005
1006 if (StorageOrder == ColMajor) {
1007 for (; i + 2 <= depth; i += 2) {
1008 PacketBlock<Packet8bf, 2> block;
1009
1010 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
1011 block.packet[1] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
1012
1013 Packet8bf t0;
1014 t0 = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1015 block.packet[1] = vec_mergel(block.packet[0].m_val, block.packet[1].m_val);
1016 block.packet[0] = t0;
1017
1018 storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
1019
1020 ri += 2 * vectorSize;
1021 }
1022 if (depth & 1) {
1023 Packet8bf lhsV = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
1024 pstore<bfloat16>(blockA + ri, lhsV);
1025
1026 ri += vectorSize;
1027 }
1028 } else {
1029 for (; i + vectorSize <= depth; i += vectorSize) {
1030 PacketBlock<Packet8bf, 8> block1;
1031
1032 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
1033
1034 Packet4ui v1[8];
1035
1036 // This is transposing and interleaving data
1037 v1[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
1038 reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
1039 v1[1] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
1040 reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
1041 v1[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
1042 reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
1043 v1[3] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
1044 reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
1045 v1[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
1046 reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
1047 v1[5] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
1048 reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
1049 v1[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
1050 reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
1051 v1[7] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
1052 reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
1053
1054#ifdef EIGEN_VECTORIZE_VSX
1055 block1.packet[0] = reinterpret_cast<Packet8us>(
1056 vec_mergeh(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
1057 block1.packet[2] = reinterpret_cast<Packet8us>(
1058 vec_mergel(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
1059 block1.packet[4] = reinterpret_cast<Packet8us>(
1060 vec_mergeh(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
1061 block1.packet[6] = reinterpret_cast<Packet8us>(
1062 vec_mergel(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
1063 block1.packet[1] = reinterpret_cast<Packet8us>(
1064 vec_mergeh(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
1065 block1.packet[3] = reinterpret_cast<Packet8us>(
1066 vec_mergel(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
1067 block1.packet[5] = reinterpret_cast<Packet8us>(
1068 vec_mergeh(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
1069 block1.packet[7] = reinterpret_cast<Packet8us>(
1070 vec_mergel(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
1071#else
1072 block1.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_HI));
1073 block1.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_LO));
1074 block1.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_HI));
1075 block1.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_LO));
1076 block1.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_HI));
1077 block1.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_LO));
1078 block1.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_HI));
1079 block1.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_LO));
1080#endif
1081
1082 for (Index M = 0; M < 8; M++) {
1083 pstore<bfloat16>(blockA + ri + (vectorSize * M), block1.packet[M]);
1084 }
1085
1086 ri += vectorSize * vectorSize;
1087 }
1088 for (; i + 2 <= depth; i += 2) {
1089 for (Index M = 0; M < vectorSize; M++) {
1090 blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
1091 blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
1092 }
1093
1094 ri += 2 * vectorSize;
1095 }
1096 if (depth & 1) {
1097 for (Index M = 0; M < vectorSize; M++) {
1098 blockA[ri + M] = lhs2(M, i);
1099 }
1100
1101 ri += vectorSize;
1102 }
1103 }
1104
1105 if (PanelMode) ri += vectorSize * (stride - offset - depth);
1106 }
1107 if (j + 4 <= rows) {
1108 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1109 Index i = 0;
1110
1111 if (PanelMode) ri += 4 * offset;
1112
1113 for (; i + 2 <= depth; i += 2) {
1114 if (StorageOrder == ColMajor) {
1115 PacketBlock<Packet8bf, 2> block;
1116
1117 block.packet[0] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
1118 block.packet[1] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 1, 4);
1119
1120 block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1121
1122 pstore<bfloat16>(blockA + ri, block.packet[0]);
1123 } else {
1124 blockA[ri + 0] = lhs2(0, i + 0);
1125 blockA[ri + 1] = lhs2(0, i + 1);
1126 blockA[ri + 2] = lhs2(1, i + 0);
1127 blockA[ri + 3] = lhs2(1, i + 1);
1128 blockA[ri + 4] = lhs2(2, i + 0);
1129 blockA[ri + 5] = lhs2(2, i + 1);
1130 blockA[ri + 6] = lhs2(3, i + 0);
1131 blockA[ri + 7] = lhs2(3, i + 1);
1132 }
1133
1134 ri += 2 * 4;
1135 }
1136 if (depth & 1) {
1137 if (StorageOrder == ColMajor) {
1138 Packet8bf lhsV = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
1139
1140 pstore_partial<bfloat16>(blockA + ri, lhsV, 4);
1141 } else {
1142 blockA[ri + 0] = lhs2(0, i);
1143 blockA[ri + 1] = lhs2(1, i);
1144 blockA[ri + 2] = lhs2(2, i);
1145 blockA[ri + 3] = lhs2(3, i);
1146 }
1147
1148 ri += 4;
1149 }
1150
1151 if (PanelMode) ri += 4 * (stride - offset - depth);
1152 j += 4;
1153 }
1154
1155 if (j < rows) {
1156 if (PanelMode) ri += offset * (rows - j);
1157
1158 Index i = 0;
1159 for (; i + 2 <= depth; i += 2) {
1160 Index k = j;
1161 for (; k < rows; k++) {
1162 blockA[ri + 0] = lhs(k, i + 0);
1163 blockA[ri + 1] = lhs(k, i + 1);
1164 ri += 2;
1165 }
1166 }
1167 if (depth & 1) {
1168 for (; j < rows; j++) {
1169 blockA[ri] = lhs(j, i);
1170 ri += 1;
1171 }
1172 }
1173 }
1174 }
1175};
1176
1177// General template for rhs packing, bfloat16 specialization.
1178template <typename DataMapper, int StorageOrder, bool PanelMode>
1179struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false> {
1180 EIGEN_STRONG_INLINE void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride,
1181 Index offset) {
1182 const Index vectorSize = quad_traits<bfloat16>::vectorsize;
1183 Index ri = 0, j = 0;
1184
1185 for (; j + 4 <= cols; j += 4) {
1186 const DataMapper rhs2 = rhs.getSubMapper(0, j);
1187 Index i = 0;
1188
1189 if (PanelMode) ri += 4 * offset;
1190
1191 for (; i + vectorSize <= depth; i += vectorSize) {
1192 if (StorageOrder == ColMajor) {
1193 PacketBlock<Packet8bf, 4> block;
1194
1195 bload<DataMapper, Packet8bf, 4, StorageOrder, false, 4>(block, rhs2, i, 0);
1196
1197 Packet4ui t0, t1, t2, t3;
1198
1199 t0 = vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[0].m_val),
1200 reinterpret_cast<Packet4ui>(block.packet[1].m_val));
1201 t1 = vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val),
1202 reinterpret_cast<Packet4ui>(block.packet[1].m_val));
1203 t2 = vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val),
1204 reinterpret_cast<Packet4ui>(block.packet[3].m_val));
1205 t3 = vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val),
1206 reinterpret_cast<Packet4ui>(block.packet[3].m_val));
1207
1208#ifdef EIGEN_VECTORIZE_VSX
1209 block.packet[0] =
1210 reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(t0), reinterpret_cast<Packet2ul>(t2)));
1211 block.packet[1] =
1212 reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(t0), reinterpret_cast<Packet2ul>(t2)));
1213 block.packet[2] =
1214 reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(t1), reinterpret_cast<Packet2ul>(t3)));
1215 block.packet[3] =
1216 reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(t1), reinterpret_cast<Packet2ul>(t3)));
1217#else
1218 block.packet[0] = reinterpret_cast<Packet8us>(vec_perm(t0, t2, p16uc_TRANSPOSE64_HI));
1219 block.packet[1] = reinterpret_cast<Packet8us>(vec_perm(t0, t2, p16uc_TRANSPOSE64_LO));
1220 block.packet[2] = reinterpret_cast<Packet8us>(vec_perm(t1, t3, p16uc_TRANSPOSE64_HI));
1221 block.packet[3] = reinterpret_cast<Packet8us>(vec_perm(t1, t3, p16uc_TRANSPOSE64_LO));
1222#endif
1223
1224 storeBlock<bfloat16, Packet8bf, 4>(blockB + ri, block);
1225 } else {
1226 PacketBlock<Packet8bf, 8> block;
1227
1228 for (int M = 0; M < 8; M++) {
1229 block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
1230 }
1231
1232 block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1233 block.packet[1] = vec_mergeh(block.packet[2].m_val, block.packet[3].m_val);
1234 block.packet[2] = vec_mergeh(block.packet[4].m_val, block.packet[5].m_val);
1235 block.packet[3] = vec_mergeh(block.packet[6].m_val, block.packet[7].m_val);
1236
1237 const Index size = 16 / sizeof(bfloat16);
1238
1239 for (int M = 0; M < 4; M++) {
1240 pstore<bfloat16>(blockB + ri + (M * size), block.packet[M]);
1241 }
1242 }
1243
1244 ri += 4 * vectorSize;
1245 }
1246 for (; i + 2 <= depth; i += 2) {
1247 if (StorageOrder == ColMajor) {
1248 blockB[ri + 0] = rhs2(i + 0, 0);
1249 blockB[ri + 1] = rhs2(i + 1, 0);
1250 blockB[ri + 2] = rhs2(i + 0, 1);
1251 blockB[ri + 3] = rhs2(i + 1, 1);
1252 blockB[ri + 4] = rhs2(i + 0, 2);
1253 blockB[ri + 5] = rhs2(i + 1, 2);
1254 blockB[ri + 6] = rhs2(i + 0, 3);
1255 blockB[ri + 7] = rhs2(i + 1, 3);
1256 } else {
1257 PacketBlock<Packet8bf, 2> block;
1258
1259 for (int M = 0; M < 2; M++) {
1260 block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
1261 }
1262
1263 block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1264
1265 pstore<bfloat16>(blockB + ri, block.packet[0]);
1266 }
1267
1268 ri += 4 * 2;
1269 }
1270 if (depth & 1) {
1271 blockB[ri + 0] = rhs2(i, 0);
1272 blockB[ri + 1] = rhs2(i, 1);
1273 blockB[ri + 2] = rhs2(i, 2);
1274 blockB[ri + 3] = rhs2(i, 3);
1275
1276 ri += 4;
1277 }
1278
1279 if (PanelMode) ri += 4 * (stride - offset - depth);
1280 }
1281
1282 if (j < cols) {
1283 if (PanelMode) ri += offset * (cols - j);
1284
1285 Index i = 0;
1286 for (; i + 2 <= depth; i += 2) {
1287 Index k = j;
1288 for (; k < cols; k++) {
1289 blockB[ri + 0] = rhs(i + 0, k);
1290 blockB[ri + 1] = rhs(i + 1, k);
1291 ri += 2;
1292 }
1293 }
1294 if (depth & 1) {
1295 for (; j < cols; j++) {
1296 blockB[ri] = rhs(i, j);
1297 ri += 1;
1298 }
1299 }
1300 }
1301 }
1302};
1303
1304// General template for lhs complex packing, float64 specialization.
1305template <typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
1306struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true> {
1307 EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii,
1308 Index depth, const Index vectorSize) {
1309 PacketBlock<Packet, 2> blockr, blocki;
1310 PacketBlock<PacketC, 4> cblock;
1311
1312 for (; i + vectorSize <= depth; i += vectorSize) {
1313 if (StorageOrder == ColMajor) {
1314 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0); //[a1 a1i]
1315 cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
1316
1317 cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0); //[a2 a2i]
1318 cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i]
1319
1320 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2]
1321 blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2]
1322
1323 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
1324 blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
1325 } else {
1326 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i); //[a1 a1i]
1327 cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i); //[a2 a2i]
1328
1329 cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
1330 cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i
1331
1332 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2]
1333 blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2]
1334
1335 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1336 blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1337 }
1338
1339 if (Conjugate) {
1340 blocki.packet[0] = -blocki.packet[0];
1341 blocki.packet[1] = -blocki.packet[1];
1342 }
1343
1344 storeBlock<double, Packet, 2>(blockAt + rir, blockr);
1345 storeBlock<double, Packet, 2>(blockAt + rii, blocki);
1346
1347 rir += 2 * vectorSize;
1348 rii += 2 * vectorSize;
1349 }
1350 }
1351
1352 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
1353 Index stride, Index offset) {
1354 const Index vectorSize = quad_traits<double>::vectorsize;
1355 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
1356 Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
1357 double* blockAt = reinterpret_cast<double*>(blockA);
1358 Index j = 0;
1359
1360 for (; j + vectorSize <= rows; j += vectorSize) {
1361 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1362 Index i = 0;
1363
1364 rii = rir + vectorDelta;
1365
1366 dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
1367
1368 for (; i < depth; i++) {
1369 PacketBlock<Packet, 1> blockr, blocki;
1370 PacketBlock<PacketC, 2> cblock;
1371
1372 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
1373 cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i);
1374
1375 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1376 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1377
1378 if (Conjugate) {
1379 blocki.packet[0] = -blocki.packet[0];
1380 }
1381
1382 pstore<double>(blockAt + rir, blockr.packet[0]);
1383 pstore<double>(blockAt + rii, blocki.packet[0]);
1384
1385 rir += vectorSize;
1386 rii += vectorSize;
1387 }
1388
1389 rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
1390 }
1391
1392 if (j < rows) {
1393 if (PanelMode) rir += (offset * (rows - j - vectorSize));
1394 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
1395
1396 for (Index i = 0; i < depth; i++) {
1397 Index k = j;
1398 for (; k < rows; k++) {
1399 blockAt[rir] = lhs(k, i).real();
1400
1401 if (Conjugate)
1402 blockAt[rii] = -lhs(k, i).imag();
1403 else
1404 blockAt[rii] = lhs(k, i).imag();
1405
1406 rir += 1;
1407 rii += 1;
1408 }
1409 }
1410 }
1411 }
1412};
1413
1414// General template for rhs complex packing, float64 specialization.
1415template <typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
1416struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false> {
1417 EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockBt, const DataMapper& rhs2, Index& i, Index& rir, Index& rii,
1418 Index depth, const Index vectorSize) {
1419 for (; i < depth; i++) {
1420 PacketBlock<PacketC, 4> cblock;
1421 PacketBlock<Packet, 2> blockr, blocki;
1422
1423 bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
1424
1425 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1426 blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
1427
1428 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1429 blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1430
1431 if (Conjugate) {
1432 blocki.packet[0] = -blocki.packet[0];
1433 blocki.packet[1] = -blocki.packet[1];
1434 }
1435
1436 storeBlock<double, Packet, 2>(blockBt + rir, blockr);
1437 storeBlock<double, Packet, 2>(blockBt + rii, blocki);
1438
1439 rir += 2 * vectorSize;
1440 rii += 2 * vectorSize;
1441 }
1442 }
1443
1444 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols,
1445 Index stride, Index offset) {
1446 const Index vectorSize = quad_traits<double>::vectorsize;
1447 const Index vectorDelta = 2 * vectorSize * ((PanelMode) ? stride : depth);
1448 Index rir = ((PanelMode) ? (2 * vectorSize * offset) : 0), rii;
1449 double* blockBt = reinterpret_cast<double*>(blockB);
1450 Index j = 0;
1451
1452 for (; j + 2 * vectorSize <= cols; j += 2 * vectorSize) {
1453 const DataMapper rhs2 = rhs.getSubMapper(0, j);
1454 Index i = 0;
1455
1456 rii = rir + vectorDelta;
1457
1458 dhs_ccopy(blockBt, rhs2, i, rir, rii, depth, vectorSize);
1459
1460 rir += ((PanelMode) ? (2 * vectorSize * (2 * stride - depth)) : vectorDelta);
1461 }
1462
1463 if (PanelMode) rir -= (offset * (2 * vectorSize - 1));
1464
1465 for (; j < cols; j++) {
1466 const DataMapper rhs2 = rhs.getSubMapper(0, j);
1467 rii = rir + ((PanelMode) ? stride : depth);
1468
1469 for (Index i = 0; i < depth; i++) {
1470 blockBt[rir] = rhs2(i, 0).real();
1471
1472 if (Conjugate)
1473 blockBt[rii] = -rhs2(i, 0).imag();
1474 else
1475 blockBt[rii] = rhs2(i, 0).imag();
1476
1477 rir += 1;
1478 rii += 1;
1479 }
1480
1481 rir += ((PanelMode) ? (2 * stride - depth) : depth);
1482 }
1483 }
1484};
1485
1486/**************
1487 * GEMM utils *
1488 **************/
1489
1490// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
1491template <typename Packet, bool NegativeAccumulate, int N>
1492EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet, N>* acc, const Packet& lhsV, const Packet* rhsV) {
1493 if (NegativeAccumulate) {
1494 for (int M = 0; M < N; M++) {
1495 acc->packet[M] = vec_nmsub(lhsV, rhsV[M], acc->packet[M]);
1496 }
1497 } else {
1498 for (int M = 0; M < N; M++) {
1499 acc->packet[M] = vec_madd(lhsV, rhsV[M], acc->packet[M]);
1500 }
1501 }
1502}
1503
1504template <int N, typename Scalar, typename Packet, bool NegativeAccumulate>
1505EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet, N>* acc, const Scalar* lhs, const Packet* rhsV) {
1506 Packet lhsV = pload<Packet>(lhs);
1507
1508 pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1509}
1510
1511// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types
1512// real * complex and complex * real.
1513template <int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1514EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet, N>* accReal, PacketBlock<Packet, N>* accImag,
1515 const Packet& lhsV, Packet& lhsVi, const Packet* rhsV, const Packet* rhsVi) {
1516 pger_common<Packet, false, N>(accReal, lhsV, rhsV);
1517 if (LhsIsReal) {
1518 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1519 EIGEN_UNUSED_VARIABLE(lhsVi);
1520 } else {
1521 if (!RhsIsReal) {
1522 pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
1523 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1524 } else {
1525 EIGEN_UNUSED_VARIABLE(rhsVi);
1526 }
1527 pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
1528 }
1529}
1530
1531template <int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1532EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet, N>* accReal, PacketBlock<Packet, N>* accImag, const Scalar* lhs_ptr,
1533 const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) {
1534 Packet lhsV = ploadLhs<Packet>(lhs_ptr);
1535 Packet lhsVi;
1536 if (!LhsIsReal)
1537 lhsVi = ploadLhs<Packet>(lhs_ptr_imag);
1538 else
1539 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1540
1541 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1542}
1543
1544template <typename Packet>
1545EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet) * lhs) {
1546 return ploadu<Packet>(lhs);
1547}
1548
1549// Zero the accumulator on PacketBlock.
1550template <typename Packet, int N>
1551EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet, N>& acc) {
1552 for (int M = 0; M < N; M++) {
1553 acc.packet[M] = pset1<Packet>((__UNPACK_TYPE__(Packet))0);
1554 }
1555}
1556
1557template <typename Packet, int N>
1558EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ,
1559 const Packet& pAlpha) {
1560 for (int M = 0; M < N; M++) {
1561 acc.packet[M] = vec_mul(accZ.packet[M], pAlpha);
1562 }
1563}
1564
1565template <typename Packet, int N>
1566EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet, N>& acc, const Packet& pMask) {
1567 for (int M = 0; M < N; M++) {
1568 acc.packet[M] = pand<Packet>(acc.packet[M], pMask);
1569 }
1570}
1571
1572// Complex version of PacketBlock scaling.
1573template <typename Packet, int N, bool mask>
1574EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet, N>& aReal, PacketBlock<Packet, N>& aImag, const Packet& bReal,
1575 const Packet& bImag, PacketBlock<Packet, N>& cReal, PacketBlock<Packet, N>& cImag,
1576 const Packet& pMask) {
1577 if (mask && (sizeof(__UNPACK_TYPE__(Packet)) == sizeof(float))) {
1578 band<Packet, N>(aReal, pMask);
1579 band<Packet, N>(aImag, pMask);
1580 } else {
1581 EIGEN_UNUSED_VARIABLE(pMask);
1582 }
1583
1584 bscalec_common<Packet, N>(cReal, aReal, bReal);
1585
1586 bscalec_common<Packet, N>(cImag, aImag, bReal);
1587
1588 pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
1589
1590 pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
1591}
1592
1593// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
1594//
1595// full = operate (load) on the entire PacketBlock or only half
1596template <typename DataMapper, typename Packet, const Index accCols, int StorageOrder, bool Complex, int N, bool full>
1597EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
1598 Index col) {
1599 if (StorageOrder == RowMajor) {
1600 for (int M = 0; M < N; M++) {
1601 acc.packet[M] = res.template loadPacket<Packet>(row + M, col);
1602 }
1603 if (Complex) {
1604 for (int M = 0; M < N; M++) {
1605 acc.packet[M + N] = res.template loadPacket<Packet>(row + M, col + accCols);
1606 }
1607 }
1608 } else {
1609 for (int M = 0; M < N; M++) {
1610 acc.packet[M] = res.template loadPacket<Packet>(row, col + M);
1611 }
1612 if (Complex && full) {
1613 for (int M = 0; M < N; M++) {
1614 acc.packet[M + N] = res.template loadPacket<Packet>(row + accCols, col + M);
1615 }
1616 }
1617 }
1618}
1619
1620template <typename DataMapper, typename Packet, int N>
1621EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet, N>& acc, const DataMapper& res, Index row) {
1622 for (int M = 0; M < N; M++) {
1623 res.template storePacket<Packet>(row, M, acc.packet[M]);
1624 }
1625}
1626
1627#ifdef USE_PARTIAL_PACKETS
1628template <typename DataMapper, typename Packet, const Index accCols, bool Complex, Index N, bool full>
1629EIGEN_ALWAYS_INLINE void bload_partial(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
1630 Index elements) {
1631 for (Index M = 0; M < N; M++) {
1632 acc.packet[M] = res.template loadPacketPartial<Packet>(row, M, elements);
1633 }
1634 if (Complex && full) {
1635 for (Index M = 0; M < N; M++) {
1636 acc.packet[M + N] = res.template loadPacketPartial<Packet>(row + accCols, M, elements);
1637 }
1638 }
1639}
1640
1641template <typename DataMapper, typename Packet, Index N>
1642EIGEN_ALWAYS_INLINE void bstore_partial(PacketBlock<Packet, N>& acc, const DataMapper& res, Index row, Index elements) {
1643 for (Index M = 0; M < N; M++) {
1644 res.template storePacketPartial<Packet>(row, M, acc.packet[M], elements);
1645 }
1646}
1647#endif
1648
1649#ifdef _ARCH_PWR10
1650#define USE_P10_AND_PVIPR2_0 (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
1651#else
1652#define USE_P10_AND_PVIPR2_0 0
1653#endif
1654
1655#if !USE_P10_AND_PVIPR2_0
1656const static Packet4i mask4[4] = {{0, 0, 0, 0}, {-1, 0, 0, 0}, {-1, -1, 0, 0}, {-1, -1, -1, 0}};
1657#endif
1658
1659template <typename Packet>
1660EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows) {
1661#if USE_P10_AND_PVIPR2_0
1662#ifdef _BIG_ENDIAN
1663 return Packet(vec_reve(vec_genwm((1 << remaining_rows) - 1)));
1664#else
1665 return Packet(vec_genwm((1 << remaining_rows) - 1));
1666#endif
1667#else
1668 return Packet(mask4[remaining_rows]);
1669#endif
1670}
1671
1672template <>
1673EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const Index remaining_rows) {
1674#if USE_P10_AND_PVIPR2_0
1675 Packet2d mask2 = Packet2d(vec_gendm(remaining_rows));
1676#ifdef _BIG_ENDIAN
1677 return preverse(mask2);
1678#else
1679 return mask2;
1680#endif
1681#else
1682 Packet2l ret = {-remaining_rows, 0};
1683 return Packet2d(ret);
1684#endif
1685}
1686
1687template <typename Packet, int N>
1688EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ, const Packet& pAlpha) {
1689 for (int M = 0; M < N; M++) {
1690 acc.packet[M] = pmadd<Packet>(pAlpha, accZ.packet[M], acc.packet[M]);
1691 }
1692}
1693
1694// Scale the PacketBlock vectors by alpha.
1695template <typename Packet, int N, bool mask>
1696EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ, const Packet& pAlpha,
1697 const Packet& pMask) {
1698 if (mask) {
1699 band<Packet, N>(accZ, pMask);
1700 } else {
1701 EIGEN_UNUSED_VARIABLE(pMask);
1702 }
1703
1704 bscale<Packet, N>(acc, accZ, pAlpha);
1705}
1706
1707template <typename Packet, int N, bool real>
1708EIGEN_ALWAYS_INLINE void pbroadcastN(const __UNPACK_TYPE__(Packet) * ap0, const __UNPACK_TYPE__(Packet) * ap1,
1709 const __UNPACK_TYPE__(Packet) * ap2, Packet& a0, Packet& a1, Packet& a2,
1710 Packet& a3) {
1711 a0 = pset1<Packet>(ap0[0]);
1712 if (N == 4) {
1713 a1 = pset1<Packet>(ap0[1]);
1714 a2 = pset1<Packet>(ap0[2]);
1715 a3 = pset1<Packet>(ap0[3]);
1716 EIGEN_UNUSED_VARIABLE(ap1);
1717 EIGEN_UNUSED_VARIABLE(ap2);
1718 } else {
1719 if (N > 1) {
1720 a1 = pset1<Packet>(ap1[0]);
1721 } else {
1722 EIGEN_UNUSED_VARIABLE(a1);
1723 EIGEN_UNUSED_VARIABLE(ap1);
1724 }
1725 if (N > 2) {
1726 a2 = pset1<Packet>(ap2[0]);
1727 } else {
1728 EIGEN_UNUSED_VARIABLE(a2);
1729 EIGEN_UNUSED_VARIABLE(ap2);
1730 }
1731 }
1732}
1733
1734template <>
1735EIGEN_ALWAYS_INLINE void pbroadcastN<Packet4f, 4, true>(const float* ap0, const float*, const float*, Packet4f& a0,
1736 Packet4f& a1, Packet4f& a2, Packet4f& a3) {
1737 pbroadcast4<Packet4f>(ap0, a0, a1, a2, a3);
1738}
1739
1740template <>
1741EIGEN_ALWAYS_INLINE void pbroadcastN<Packet4f, 4, false>(const float* ap0, const float* ap1, const float* ap2,
1742 Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) {
1743 pbroadcastN<Packet4f, 4, true>(ap0, ap1, ap2, a0, a1, a2, a3);
1744}
1745
1746template <>
1747EIGEN_ALWAYS_INLINE void pbroadcastN<Packet2d, 4, false>(const double* ap0, const double*, const double*, Packet2d& a0,
1748 Packet2d& a1, Packet2d& a2, Packet2d& a3) {
1749 a1 = pload<Packet2d>(ap0);
1750 a3 = pload<Packet2d>(ap0 + 2);
1751 a0 = vec_splat(a1, 0);
1752 a1 = vec_splat(a1, 1);
1753 a2 = vec_splat(a3, 0);
1754 a3 = vec_splat(a3, 1);
1755}
1756
1757// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
1758template <typename Packet, typename Packetc, int N, bool full>
1759EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet, N>& taccReal, PacketBlock<Packet, N>& taccImag,
1760 PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2) {
1761 for (int M = 0; M < N; M++) {
1762 acc1.packet[M].v = vec_mergeh(taccReal.packet[M], taccImag.packet[M]);
1763 }
1764
1765 if (full) {
1766 for (int M = 0; M < N; M++) {
1767 acc2.packet[M].v = vec_mergel(taccReal.packet[M], taccImag.packet[M]);
1768 }
1769 }
1770}
1771
1772template <typename Packet, typename Packetc, int N, bool full>
1773EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet, N>& taccReal, PacketBlock<Packet, N>& taccImag,
1774 PacketBlock<Packetc, N * 2>& tRes, PacketBlock<Packetc, N>& acc1,
1775 PacketBlock<Packetc, N>& acc2) {
1776 bcouple_common<Packet, Packetc, N, full>(taccReal, taccImag, acc1, acc2);
1777
1778 for (int M = 0; M < N; M++) {
1779 acc1.packet[M] = padd<Packetc>(tRes.packet[M], acc1.packet[M]);
1780 }
1781
1782 if (full) {
1783 for (int M = 0; M < N; M++) {
1784 acc2.packet[M] = padd<Packetc>(tRes.packet[M + N], acc2.packet[M]);
1785 }
1786 }
1787}
1788
1789// PEEL loop factor.
1790#define PEEL 7
1791#define PEEL_ROW 7
1792
1793#define MICRO_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1794
1795#define MICRO_NORMAL_ROWS accRows == quad_traits<Scalar>::rows || accRows == 1
1796
1797#define MICRO_NEW_ROWS ((MICRO_NORMAL_ROWS) ? accRows : 1)
1798
1799#define MICRO_RHS(ptr, N) rhs_##ptr##N
1800
1801#define MICRO_ZERO_PEEL(peel) \
1802 if ((PEEL_ROW > peel) && (peel != 0)) { \
1803 bsetzero<Packet, accRows>(accZero##peel); \
1804 } else { \
1805 EIGEN_UNUSED_VARIABLE(accZero##peel); \
1806 }
1807
1808#define MICRO_ADD(ptr, N) \
1809 if (MICRO_NORMAL_ROWS) { \
1810 MICRO_RHS(ptr, 0) += (accRows * N); \
1811 } else { \
1812 MICRO_RHS(ptr, 0) += N; \
1813 MICRO_RHS(ptr, 1) += N; \
1814 if (accRows == 3) { \
1815 MICRO_RHS(ptr, 2) += N; \
1816 } \
1817 }
1818
1819#define MICRO_ADD_ROWS(N) MICRO_ADD(ptr, N)
1820
1821#define MICRO_BROADCAST1(peel, ptr, rhsV, real) \
1822 if (MICRO_NORMAL_ROWS) { \
1823 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + (accRows * peel), MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 0), \
1824 rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1825 } else { \
1826 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + peel, MICRO_RHS(ptr, 1) + peel, MICRO_RHS(ptr, 2) + peel, \
1827 rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1828 }
1829
1830#define MICRO_BROADCAST(peel) MICRO_BROADCAST1(peel, ptr, rhsV, true)
1831
1832#define MICRO_BROADCAST_EXTRA1(ptr, rhsV, real) \
1833 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 1), MICRO_RHS(ptr, 2), rhsV[0], rhsV[1], \
1834 rhsV[2], rhsV[3]);
1835
1836#define MICRO_BROADCAST_EXTRA \
1837 Packet rhsV[4]; \
1838 MICRO_BROADCAST_EXTRA1(ptr, rhsV, true) \
1839 MICRO_ADD_ROWS(1)
1840
1841#define MICRO_SRC2(ptr, N, M) \
1842 if (MICRO_NORMAL_ROWS) { \
1843 EIGEN_UNUSED_VARIABLE(strideB); \
1844 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 1)); \
1845 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1846 } else { \
1847 MICRO_RHS(ptr, 1) = rhs_base + N + M; \
1848 if (accRows == 3) { \
1849 MICRO_RHS(ptr, 2) = rhs_base + N * 2 + M; \
1850 } else { \
1851 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1852 } \
1853 }
1854
1855#define MICRO_SRC2_PTR MICRO_SRC2(ptr, strideB, 0)
1856
1857#define MICRO_ZERO_PEEL_ROW MICRO_UNROLL(MICRO_ZERO_PEEL)
1858
1859#define MICRO_WORK_PEEL(peel) \
1860 if (PEEL_ROW > peel) { \
1861 MICRO_BROADCAST(peel) \
1862 pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
1863 } else { \
1864 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1865 }
1866
1867#define MICRO_WORK_PEEL_ROW \
1868 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
1869 MICRO_UNROLL(MICRO_WORK_PEEL) \
1870 lhs_ptr += (remaining_rows * PEEL_ROW); \
1871 MICRO_ADD_ROWS(PEEL_ROW)
1872
1873#define MICRO_ADD_PEEL(peel, sum) \
1874 if (PEEL_ROW > peel) { \
1875 for (Index i = 0; i < accRows; i++) { \
1876 accZero##sum.packet[i] += accZero##peel.packet[i]; \
1877 } \
1878 }
1879
1880#define MICRO_ADD_PEEL_ROW \
1881 MICRO_ADD_PEEL(4, 0) \
1882 MICRO_ADD_PEEL(5, 1) \
1883 MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
1884
1885#define MICRO_PREFETCHN1(ptr, N) \
1886 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 0)); \
1887 if (N == 2 || N == 3) { \
1888 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 1)); \
1889 if (N == 3) { \
1890 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 2)); \
1891 } \
1892 }
1893
1894#define MICRO_PREFETCHN(N) MICRO_PREFETCHN1(ptr, N)
1895
1896#define MICRO_COMPLEX_PREFETCHN(N) \
1897 MICRO_PREFETCHN1(ptr_real, N); \
1898 if (!RhsIsReal) { \
1899 MICRO_PREFETCHN1(ptr_imag, N); \
1900 }
1901
1902template <typename Scalar, typename Packet, const Index accRows, const Index remaining_rows>
1903EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(const Scalar*& lhs_ptr, const Scalar*& rhs_ptr0, const Scalar*& rhs_ptr1,
1904 const Scalar*& rhs_ptr2, PacketBlock<Packet, accRows>& accZero) {
1905 MICRO_BROADCAST_EXTRA
1906 pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1907 lhs_ptr += remaining_rows;
1908}
1909
1910template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols,
1911 const Index remaining_rows>
1912EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(const DataMapper& res, const Scalar* lhs_base,
1913 const Scalar* rhs_base, Index depth, Index strideA, Index offsetA,
1914 Index strideB, Index row, Index rows, const Packet& pAlpha,
1915 const Packet& pMask) {
1916 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
1917 const Scalar* lhs_ptr = lhs_base + row * strideA + remaining_rows * offsetA;
1918 PacketBlock<Packet, accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
1919
1920 MICRO_SRC2_PTR
1921 bsetzero<Packet, accRows>(accZero0);
1922
1923 Index remaining_depth = depth & -quad_traits<Scalar>::rows;
1924 Index k = 0;
1925 if (remaining_depth >= PEEL_ROW) {
1926 MICRO_ZERO_PEEL_ROW
1927 do {
1928 MICRO_PREFETCHN(accRows)
1929 EIGEN_POWER_PREFETCH(lhs_ptr);
1930 MICRO_WORK_PEEL_ROW
1931 } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
1932 MICRO_ADD_PEEL_ROW
1933 }
1934 for (; k < depth; k++) {
1935 MICRO_EXTRA_ROW<Scalar, Packet, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
1936 }
1937
1938#ifdef USE_PARTIAL_PACKETS
1939 EIGEN_UNUSED_VARIABLE(rows);
1940 EIGEN_UNUSED_VARIABLE(pMask);
1941 bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row, remaining_rows);
1942 bscale<Packet, accRows>(acc, accZero0, pAlpha);
1943 bstore_partial<DataMapper, Packet, accRows>(acc, res, row, remaining_rows);
1944#else
1945 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row, 0);
1946 if ((accRows == 1) || (rows >= accCols)) {
1947 bscale<Packet, accRows, true>(acc, accZero0, pAlpha, pMask);
1948 bstore<DataMapper, Packet, accRows>(acc, res, row);
1949 } else {
1950 bscale<Packet, accRows, false>(acc, accZero0, pAlpha, pMask);
1951 for (Index j = 0; j < accRows; j++) {
1952 for (Index i = 0; i < remaining_rows; i++) {
1953 res(row + i, j) = acc.packet[j][i];
1954 }
1955 }
1956 }
1957#endif
1958}
1959
1960#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
1961 switch (value) { \
1962 default: \
1963 MICRO_EXTRA_UNROLL(1) \
1964 break; \
1965 case 2: \
1966 if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1967 MICRO_EXTRA_UNROLL(2) \
1968 } \
1969 break; \
1970 case 3: \
1971 if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1972 MICRO_EXTRA_UNROLL(3) \
1973 } \
1974 break; \
1975 }
1976
1977#define MICRO_EXTRA_ROWS(N) \
1978 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, accRows, accCols, N>( \
1979 res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
1980
1981template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
1982EIGEN_ALWAYS_INLINE void gemm_extra_row(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
1983 Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows,
1984 Index remaining_rows, const Packet& pAlpha, const Packet& pMask) {
1985 MICRO_EXTRA(MICRO_EXTRA_ROWS, remaining_rows, false)
1986}
1987
1988#define MICRO_UNROLL_WORK(func, func2, peel) \
1989 MICRO_UNROLL(func2); \
1990 func(0, peel) func(1, peel) func(2, peel) func(3, peel) func(4, peel) func(5, peel) func(6, peel) func(7, peel)
1991
1992#define MICRO_WORK_ONE(iter, peel) \
1993 if (unroll_factor > iter) { \
1994 pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
1995 }
1996
1997#define MICRO_TYPE_PEEL4(func, func2, peel) \
1998 if (PEEL > peel) { \
1999 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
2000 MICRO_BROADCAST(peel) \
2001 MICRO_UNROLL_WORK(func, func2, peel) \
2002 } else { \
2003 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2004 }
2005
2006#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2007 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
2008 func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3) func(func1, func2, 4) \
2009 func(func1, func2, 5) func(func1, func2, 6) func(func1, func2, 7)
2010
2011#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
2012 Packet rhsV0[M]; \
2013 func(func1, func2, 0)
2014
2015#define MICRO_UNROLL_TYPE(MICRO_TYPE, size) \
2016 MICRO_TYPE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE) \
2017 MICRO_ADD_ROWS(size)
2018
2019#define MICRO_ONE_PEEL4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_PEEL, PEEL)
2020
2021#define MICRO_ONE4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_ONE, 1)
2022
2023#define MICRO_DST_PTR_ONE(iter) \
2024 if (unroll_factor > iter) { \
2025 bsetzero<Packet, accRows>(accZero##iter); \
2026 } else { \
2027 EIGEN_UNUSED_VARIABLE(accZero##iter); \
2028 }
2029
2030#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
2031
2032#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
2033
2034#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
2035
2036#ifdef USE_PARTIAL_PACKETS
2037#define MICRO_STORE_ONE(iter) \
2038 if (unroll_factor > iter) { \
2039 if (MICRO_NORMAL_PARTIAL(iter)) { \
2040 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2041 bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2042 bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2043 } else { \
2044 bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row + iter * accCols, accCols2); \
2045 bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2046 bstore_partial<DataMapper, Packet, accRows>(acc, res, row + iter * accCols, accCols2); \
2047 } \
2048 }
2049#else
2050#define MICRO_STORE_ONE(iter) \
2051 if (unroll_factor > iter) { \
2052 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2053 bscale<Packet, accRows, !(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
2054 bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2055 }
2056#endif
2057
2058#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
2059
2060#ifdef USE_PARTIAL_PACKETS
2061template <int unroll_factor, typename Scalar, typename Packet, typename DataMapper, const Index accRows,
2062 const Index accCols, bool full>
2063#else
2064template <int unroll_factor, typename Scalar, typename Packet, typename DataMapper, const Index accRows,
2065 const Index accCols, const Index accCols2>
2066#endif
2067EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
2068 Index depth, Index strideA, Index offsetA, Index strideB, Index& row,
2069 const Packet& pAlpha,
2070#ifdef USE_PARTIAL_PACKETS
2071 Index accCols2
2072#else
2073 const Packet& pMask
2074#endif
2075) {
2076 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
2077 const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
2078 *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
2079 PacketBlock<Packet, accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
2080 PacketBlock<Packet, accRows> acc;
2081
2082 MICRO_SRC2_PTR
2083 MICRO_SRC_PTR
2084 MICRO_DST_PTR
2085
2086 Index k = 0;
2087 for (; k + PEEL <= depth; k += PEEL) {
2088 MICRO_PREFETCHN(accRows)
2089 MICRO_PREFETCH
2090 MICRO_ONE_PEEL4
2091 }
2092 for (; k < depth; k++) {
2093 MICRO_ONE4
2094 }
2095 MICRO_STORE
2096
2097 MICRO_UPDATE
2098}
2099
2100#ifdef USE_PARTIAL_PACKETS
2101#define MICRO_UNROLL_ITER2(N, M) \
2102 gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, !M>( \
2103 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, M ? remaining_rows : accCols); \
2104 if (M) return;
2105#else
2106#define MICRO_UNROLL_ITER2(N, M) \
2107 gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, M ? M : accCols>( \
2108 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
2109 if (M) return;
2110#endif
2111
2112template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
2113EIGEN_ALWAYS_INLINE void gemm_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
2114 Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows,
2115 Index remaining_rows, const Packet& pAlpha, const Packet& pMask) {
2116 const DataMapper res3 = res.getSubMapper(0, col);
2117
2118 const Scalar* rhs_base = blockB + col * strideB + MICRO_NEW_ROWS * offsetB;
2119 const Scalar* lhs_base = blockA + accCols * offsetA;
2120 Index row = 0;
2121
2122#define MAX_UNROLL 7
2123 while (row + MAX_UNROLL * accCols <= rows) {
2124 MICRO_UNROLL_ITER2(MAX_UNROLL, 0);
2125 }
2126 switch ((rows - row) / accCols) {
2127#if MAX_UNROLL > 7
2128 case 7:
2129 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 7)
2130 break;
2131#endif
2132#if MAX_UNROLL > 6
2133 case 6:
2134 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 6)
2135 break;
2136#endif
2137#if MAX_UNROLL > 5
2138 case 5:
2139 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 5)
2140 break;
2141#endif
2142#if MAX_UNROLL > 4
2143 case 4:
2144 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 4)
2145 break;
2146#endif
2147#if MAX_UNROLL > 3
2148 case 3:
2149 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 3)
2150 break;
2151#endif
2152#if MAX_UNROLL > 2
2153 case 2:
2154 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 2)
2155 break;
2156#endif
2157#if MAX_UNROLL > 1
2158 case 1:
2159 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 1)
2160 break;
2161#endif
2162 default:
2163 break;
2164 }
2165#undef MAX_UNROLL
2166
2167 if (remaining_rows > 0) {
2168 gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA,
2169 strideB, row, rows, remaining_rows, pAlpha, pMask);
2170 }
2171}
2172
2173#define MICRO_EXTRA_COLS(N) \
2174 gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, \
2175 col, rows, remaining_rows, pAlpha, pMask);
2176
2177template <typename Scalar, typename Packet, typename DataMapper, const Index accCols>
2178EIGEN_ALWAYS_INLINE void gemm_extra_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
2179 Index strideA, Index offsetA, Index strideB, Index offsetB, Index col,
2180 Index rows, Index cols, Index remaining_rows, const Packet& pAlpha,
2181 const Packet& pMask) {
2182 MICRO_EXTRA(MICRO_EXTRA_COLS, cols - col, true)
2183}
2184
2185/****************
2186 * GEMM kernels *
2187 * **************/
2188template <typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows,
2189 const Index accCols>
2190EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows,
2191 Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA,
2192 Index offsetB) {
2193 const Index remaining_rows = rows % accCols;
2194
2195 if (strideA == -1) strideA = depth;
2196 if (strideB == -1) strideB = depth;
2197
2198 const Packet pAlpha = pset1<Packet>(alpha);
2199 const Packet pMask = bmask<Packet>(remaining_rows);
2200
2201 Index col = 0;
2202 for (; col + accRows <= cols; col += accRows) {
2203 gemm_cols<Scalar, Packet, DataMapper, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB,
2204 offsetB, col, rows, remaining_rows, pAlpha, pMask);
2205 }
2206
2207 if (col != cols) {
2208 gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
2209 col, rows, cols, remaining_rows, pAlpha, pMask);
2210 }
2211}
2212
2213#define accColsC (accCols / 2)
2214#define advanceRows ((LhsIsReal) ? 1 : 2)
2215#define advanceCols ((RhsIsReal) ? 1 : 2)
2216
2217// PEEL_COMPLEX loop factor.
2218#define PEEL_COMPLEX 3
2219#define PEEL_COMPLEX_ROW 3
2220
2221#define MICRO_COMPLEX_UNROLL(func) func(0) func(1) func(2) func(3)
2222
2223#define MICRO_COMPLEX_ZERO_PEEL(peel) \
2224 if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
2225 bsetzero<Packet, accRows>(accReal##peel); \
2226 bsetzero<Packet, accRows>(accImag##peel); \
2227 } else { \
2228 EIGEN_UNUSED_VARIABLE(accReal##peel); \
2229 EIGEN_UNUSED_VARIABLE(accImag##peel); \
2230 }
2231
2232#define MICRO_COMPLEX_ADD_ROWS(N, used) \
2233 MICRO_ADD(ptr_real, N) \
2234 if (!RhsIsReal) { \
2235 MICRO_ADD(ptr_imag, N) \
2236 } else if (used) { \
2237 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2238 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2239 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2240 }
2241
2242#define MICRO_COMPLEX_BROADCAST(peel) \
2243 MICRO_BROADCAST1(peel, ptr_real, rhsV, false) \
2244 if (!RhsIsReal) { \
2245 MICRO_BROADCAST1(peel, ptr_imag, rhsVi, false) \
2246 } else { \
2247 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2248 }
2249
2250#define MICRO_COMPLEX_BROADCAST_EXTRA \
2251 Packet rhsV[4], rhsVi[4]; \
2252 MICRO_BROADCAST_EXTRA1(ptr_real, rhsV, false) \
2253 if (!RhsIsReal) { \
2254 MICRO_BROADCAST_EXTRA1(ptr_imag, rhsVi, false) \
2255 } else { \
2256 EIGEN_UNUSED_VARIABLE(rhsVi); \
2257 } \
2258 MICRO_COMPLEX_ADD_ROWS(1, true)
2259
2260#define MICRO_COMPLEX_SRC2_PTR \
2261 MICRO_SRC2(ptr_real, strideB* advanceCols, 0) \
2262 if (!RhsIsReal) { \
2263 MICRO_RHS(ptr_imag, 0) = rhs_base + MICRO_NEW_ROWS * strideB; \
2264 MICRO_SRC2(ptr_imag, strideB* advanceCols, strideB) \
2265 } else { \
2266 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2267 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2268 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2269 }
2270
2271#define MICRO_COMPLEX_ZERO_PEEL_ROW MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_ZERO_PEEL)
2272
2273#define MICRO_COMPLEX_WORK_PEEL(peel) \
2274 if (PEEL_COMPLEX_ROW > peel) { \
2275 MICRO_COMPLEX_BROADCAST(peel) \
2276 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2277 &accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), \
2278 lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
2279 } else { \
2280 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2281 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2282 }
2283
2284#define MICRO_COMPLEX_ADD_COLS(size) \
2285 lhs_ptr_real += (remaining_rows * size); \
2286 if (!LhsIsReal) \
2287 lhs_ptr_imag += (remaining_rows * size); \
2288 else \
2289 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2290
2291#define MICRO_COMPLEX_WORK_PEEL_ROW \
2292 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
2293 Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
2294 MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_WORK_PEEL) \
2295 MICRO_COMPLEX_ADD_COLS(PEEL_COMPLEX_ROW) \
2296 MICRO_COMPLEX_ADD_ROWS(PEEL_COMPLEX_ROW, false)
2297
2298#define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
2299 if (PEEL_COMPLEX_ROW > peel) { \
2300 for (Index i = 0; i < accRows; i++) { \
2301 accReal##sum.packet[i] += accReal##peel.packet[i]; \
2302 accImag##sum.packet[i] += accImag##peel.packet[i]; \
2303 } \
2304 }
2305
2306#define MICRO_COMPLEX_ADD_PEEL_ROW \
2307 MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) MICRO_COMPLEX_ADD_PEEL(1, 0)
2308
2309template <typename Scalar, typename Packet, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal,
2310 bool RhsIsReal, const Index remaining_rows>
2311EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(const Scalar*& lhs_ptr_real, const Scalar*& lhs_ptr_imag,
2312 const Scalar*& rhs_ptr_real0, const Scalar*& rhs_ptr_real1,
2313 const Scalar*& rhs_ptr_real2, const Scalar*& rhs_ptr_imag0,
2314 const Scalar*& rhs_ptr_imag1, const Scalar*& rhs_ptr_imag2,
2315 PacketBlock<Packet, accRows>& accReal,
2316 PacketBlock<Packet, accRows>& accImag) {
2317 MICRO_COMPLEX_BROADCAST_EXTRA
2318 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real,
2319 lhs_ptr_imag, rhsV, rhsVi);
2320 MICRO_COMPLEX_ADD_COLS(1)
2321}
2322
2323template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2324 const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal,
2325 const Index remaining_rows>
2326EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(const DataMapper& res, const Scalar* lhs_base,
2327 const Scalar* rhs_base, Index depth, Index strideA,
2328 Index offsetA, Index strideB, Index row, Index rows,
2329 const Packet& pAlphaReal, const Packet& pAlphaImag,
2330 const Packet& pMask) {
2331 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2332 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2333 const Scalar* lhs_ptr_real = lhs_base + advanceRows * row * strideA + remaining_rows * offsetA;
2334 const Scalar* lhs_ptr_imag = NULL;
2335 if (!LhsIsReal)
2336 lhs_ptr_imag = lhs_ptr_real + remaining_rows * strideA;
2337 else
2338 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2339 PacketBlock<Packet, accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
2340 PacketBlock<Packet, accRows> taccReal, taccImag;
2341 PacketBlock<Packetc, accRows> acc0, acc1;
2342 PacketBlock<Packetc, accRows * 2> tRes;
2343
2344 MICRO_COMPLEX_SRC2_PTR
2345
2346 bsetzero<Packet, accRows>(accReal0);
2347 bsetzero<Packet, accRows>(accImag0);
2348
2349 Index remaining_depth = depth & -quad_traits<Scalar>::rows;
2350 Index k = 0;
2351 if (remaining_depth >= PEEL_COMPLEX_ROW) {
2352 MICRO_COMPLEX_ZERO_PEEL_ROW
2353 do {
2354 MICRO_COMPLEX_PREFETCHN(accRows)
2355 EIGEN_POWER_PREFETCH(lhs_ptr_real);
2356 if (!LhsIsReal) {
2357 EIGEN_POWER_PREFETCH(lhs_ptr_imag);
2358 }
2359 MICRO_COMPLEX_WORK_PEEL_ROW
2360 } while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth);
2361 MICRO_COMPLEX_ADD_PEEL_ROW
2362 }
2363 for (; k < depth; k++) {
2364 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(
2365 lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1,
2366 rhs_ptr_imag2, accReal0, accImag0);
2367 }
2368
2369 constexpr bool full = (remaining_rows > accColsC);
2370 bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row, 0);
2371 if ((accRows == 1) || (rows >= accCols)) {
2372 bscalec<Packet, accRows, true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2373 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2374 bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
2375 if (full) {
2376 bstore<DataMapper, Packetc, accRows>(acc1, res, row + accColsC);
2377 }
2378 } else {
2379 bscalec<Packet, accRows, false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2380 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2381
2382 if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) {
2383 for (Index j = 0; j < accRows; j++) {
2384 res(row + 0, j) = pfirst<Packetc>(acc0.packet[j]);
2385 }
2386 } else {
2387 bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
2388 if (full) {
2389 for (Index j = 0; j < accRows; j++) {
2390 res(row + accColsC, j) = pfirst<Packetc>(acc1.packet[j]);
2391 }
2392 }
2393 }
2394 }
2395}
2396
2397#define MICRO_COMPLEX_EXTRA_ROWS(N) \
2398 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, \
2399 ConjugateRhs, LhsIsReal, RhsIsReal, N>( \
2400 res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
2401
2402template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2403 const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2404EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
2405 Index depth, Index strideA, Index offsetA, Index strideB, Index row,
2406 Index rows, Index remaining_rows, const Packet& pAlphaReal,
2407 const Packet& pAlphaImag, const Packet& pMask) {
2408 MICRO_EXTRA(MICRO_COMPLEX_EXTRA_ROWS, remaining_rows, false)
2409}
2410
2411#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2412 MICRO_COMPLEX_UNROLL(func2); \
2413 func(0, peel) func(1, peel) func(2, peel) func(3, peel)
2414
2415#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
2416 if (unroll_factor > iter) { \
2417 pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2418 &accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2419 }
2420
2421#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
2422 if (PEEL_COMPLEX > peel) { \
2423 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
2424 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
2425 MICRO_COMPLEX_BROADCAST(peel) \
2426 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2427 } else { \
2428 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2429 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2430 }
2431
2432#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2433 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
2434 Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
2435 func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3)
2436
2437#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2438 Packet rhsV0[M], rhsVi0[M]; \
2439 func(func1, func2, 0)
2440
2441#define MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_TYPE, size) \
2442 MICRO_COMPLEX_TYPE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE) \
2443 MICRO_COMPLEX_ADD_ROWS(size, false)
2444
2445#define MICRO_COMPLEX_ONE_PEEL4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_PEEL, PEEL_COMPLEX)
2446
2447#define MICRO_COMPLEX_ONE4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_ONE, 1)
2448
2449#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2450 if (unroll_factor > iter) { \
2451 bsetzero<Packet, accRows>(accReal##iter); \
2452 bsetzero<Packet, accRows>(accImag##iter); \
2453 } else { \
2454 EIGEN_UNUSED_VARIABLE(accReal##iter); \
2455 EIGEN_UNUSED_VARIABLE(accImag##iter); \
2456 }
2457
2458#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2459
2460#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2461
2462#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2463
2464#define MICRO_COMPLEX_STORE_ONE(iter) \
2465 if (unroll_factor > iter) { \
2466 constexpr bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
2467 bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter * accCols, 0); \
2468 bscalec<Packet, accRows, !(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, \
2469 taccImag, pMask); \
2470 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
2471 bstore<DataMapper, Packetc, accRows>(acc0, res, row + iter * accCols + 0); \
2472 if (full) { \
2473 bstore<DataMapper, Packetc, accRows>(acc1, res, row + iter * accCols + accColsC); \
2474 } \
2475 }
2476
2477#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2478
2479template <int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper,
2480 const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs,
2481 bool LhsIsReal, bool RhsIsReal>
2482EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(const DataMapper& res, const Scalar* lhs_base,
2483 const Scalar* rhs_base, Index depth, Index strideA,
2484 Index offsetA, Index strideB, Index& row,
2485 const Packet& pAlphaReal, const Packet& pAlphaImag,
2486 const Packet& pMask) {
2487 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2488 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2489 const Index imag_delta = accCols * strideA;
2490 const Index imag_delta2 = accCols2 * strideA;
2491 const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
2492 const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
2493 PacketBlock<Packet, accRows> accReal0, accImag0, accReal1, accImag1;
2494 PacketBlock<Packet, accRows> accReal2, accImag2, accReal3, accImag3;
2495 PacketBlock<Packet, accRows> taccReal, taccImag;
2496 PacketBlock<Packetc, accRows> acc0, acc1;
2497 PacketBlock<Packetc, accRows * 2> tRes;
2498
2499 MICRO_COMPLEX_SRC2_PTR
2500 MICRO_COMPLEX_SRC_PTR
2501 MICRO_COMPLEX_DST_PTR
2502
2503 Index k = 0;
2504 for (; k + PEEL_COMPLEX <= depth; k += PEEL_COMPLEX) {
2505 MICRO_COMPLEX_PREFETCHN(accRows)
2506 MICRO_COMPLEX_PREFETCH
2507 MICRO_COMPLEX_ONE_PEEL4
2508 }
2509 for (; k < depth; k++) {
2510 MICRO_COMPLEX_ONE4
2511 }
2512 MICRO_COMPLEX_STORE
2513
2514 MICRO_COMPLEX_UPDATE
2515}
2516
2517#define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
2518 gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, accRows, accCols, \
2519 M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2520 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
2521 if (M) return;
2522
2523template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2524 const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2525EIGEN_ALWAYS_INLINE void gemm_complex_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
2526 Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB,
2527 Index col, Index rows, Index remaining_rows, const Packet& pAlphaReal,
2528 const Packet& pAlphaImag, const Packet& pMask) {
2529 const DataMapper res3 = res.getSubMapper(0, col);
2530
2531 const Scalar* rhs_base = blockB + advanceCols * col * strideB + MICRO_NEW_ROWS * offsetB;
2532 const Scalar* lhs_base = blockA + accCols * offsetA;
2533 Index row = 0;
2534
2535#define MAX_COMPLEX_UNROLL 4
2536 while (row + MAX_COMPLEX_UNROLL * accCols <= rows) {
2537 MICRO_COMPLEX_UNROLL_ITER2(MAX_COMPLEX_UNROLL, 0);
2538 }
2539 switch ((rows - row) / accCols) {
2540#if MAX_COMPLEX_UNROLL > 4
2541 case 4:
2542 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 4)
2543 break;
2544#endif
2545#if MAX_COMPLEX_UNROLL > 3
2546 case 3:
2547 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 3)
2548 break;
2549#endif
2550#if MAX_COMPLEX_UNROLL > 2
2551 case 2:
2552 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 2)
2553 break;
2554#endif
2555#if MAX_COMPLEX_UNROLL > 1
2556 case 1:
2557 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 1)
2558 break;
2559#endif
2560 default:
2561 break;
2562 }
2563#undef MAX_COMPLEX_UNROLL
2564
2565 if (remaining_rows > 0) {
2566 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2567 RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows,
2568 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2569 }
2570}
2571
2572#define MICRO_COMPLEX_EXTRA_COLS(N) \
2573 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
2574 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, \
2575 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2576
2577template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols,
2578 bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2579EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
2580 Index depth, Index strideA, Index offsetA, Index strideB,
2581 Index offsetB, Index col, Index rows, Index cols, Index remaining_rows,
2582 const Packet& pAlphaReal, const Packet& pAlphaImag,
2583 const Packet& pMask) {
2584 MICRO_EXTRA(MICRO_COMPLEX_EXTRA_COLS, cols - col, true)
2585}
2586
2587template <typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc,
2588 typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs,
2589 bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2590EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc,
2591 Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB,
2592 Index offsetA, Index offsetB) {
2593 const Index remaining_rows = rows % accCols;
2594
2595 if (strideA == -1) strideA = depth;
2596 if (strideB == -1) strideB = depth;
2597
2598 const Packet pAlphaReal = pset1<Packet>(alpha.real());
2599 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
2600 const Packet pMask = bmask<Packet>(remaining_rows);
2601
2602 const Scalar* blockA = (Scalar*)blockAc;
2603 const Scalar* blockB = (Scalar*)blockBc;
2604
2605 Index col = 0;
2606 for (; col + accRows <= cols; col += accRows) {
2607 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2608 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows,
2609 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2610 }
2611
2612 if (col != cols) {
2613 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2614 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols,
2615 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2616 }
2617}
2618
2619#undef accColsC
2620#undef advanceCols
2621#undef advanceRows
2622
2623EIGEN_ALWAYS_INLINE bool supportsMMA() {
2624#if defined(EIGEN_ALTIVEC_MMA_ONLY)
2625 return true;
2626#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && defined(__BUILTIN_CPU_SUPPORTS__)
2627 return __builtin_cpu_supports("arch_3_1") && __builtin_cpu_supports("mma");
2628#else
2629 return false; // No dynamic dispatch for LLVM or older GCC
2630#endif
2631}
2632
2633EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result) {
2634 Packet4f result_block = ploadu<Packet4f>(result);
2635 return pmadd(acc, pAlpha, result_block);
2636}
2637
2638template <bool lhsExtraRows>
2639EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows) {
2640 if (lhsExtraRows) {
2641 pstoreu_partial(result, result_block, extra_rows);
2642 } else {
2643 pstoreu(result, result_block);
2644 }
2645 result += rows;
2646}
2647
2648template <bool rhsExtraCols, bool lhsExtraRows>
2649EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result,
2650 Index extra_cols, Index extra_rows) {
2651 Index x = 0;
2652 if (rhsExtraCols) {
2653 do {
2654 Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
2655 storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
2656 } while (++x < extra_cols);
2657 } else {
2658 Packet4f result_block[4];
2659 float* result2 = result;
2660 do {
2661 result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
2662 result += rows;
2663 } while (++x < 4);
2664 x = 0;
2665 do {
2666 storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
2667 } while (++x < 4);
2668 }
2669}
2670
2671EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data) {
2672 Packet8us z = pset1<Packet8us>(0);
2673#ifdef _BIG_ENDIAN
2674 return reinterpret_cast<Packet4f>(vec_mergeh(data, z));
2675#else
2676 return reinterpret_cast<Packet4f>(vec_mergeh(z, data));
2677#endif
2678}
2679
2680EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data) {
2681 Packet8us z = pset1<Packet8us>(0);
2682#ifdef _BIG_ENDIAN
2683 return reinterpret_cast<Packet4f>(vec_mergel(data, z));
2684#else
2685 return reinterpret_cast<Packet4f>(vec_mergel(z, data));
2686#endif
2687}
2688
2689template <Index N, Index M>
2690EIGEN_ALWAYS_INLINE void storeConvertTwoBF16(float* to, PacketBlock<Packet8bf, (N + 7) / 8>& block, Index extra = 0) {
2691 if (N < 4) {
2692 pstoreu_partial(to + 0, oneConvertBF16Hi(block.packet[0].m_val), extra);
2693 } else if (N >= (M * 8 + 4)) {
2694 pstoreu(to + 0, oneConvertBF16Hi(block.packet[M].m_val));
2695 if (N >= 8) {
2696 pstoreu(to + 4, oneConvertBF16Lo(block.packet[M].m_val));
2697 }
2698 }
2699}
2700
2701template <Index N>
2702EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock<Packet8bf, (N + 7) / 8>& block, Index extra) {
2703 storeConvertTwoBF16<N, 0>(to + 0, block, extra);
2704 if (N >= 16) {
2705 storeConvertTwoBF16<N, 1>(to + 8, block);
2706 }
2707 if (N >= 32) {
2708 storeConvertTwoBF16<N, 2>(to + 16, block);
2709 storeConvertTwoBF16<N, 3>(to + 24, block);
2710 }
2711}
2712
2713template <bool non_unit_stride, Index delta>
2714EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc) {
2715 if (non_unit_stride) {
2716 return pgather<bfloat16, Packet8bf>(src + delta * resInc, resInc);
2717 } else {
2718 return ploadu<Packet8bf>(src + delta);
2719 }
2720}
2721
2722static Packet16uc p16uc_MERGE16_32_1 = {0, 1, 16, 17, 2, 3, 18, 19, 0, 1, 16, 17, 2, 3, 18, 19};
2723static Packet16uc p16uc_MERGE16_32_2 = {4, 5, 20, 21, 6, 7, 22, 23, 4, 5, 20, 21, 6, 7, 22, 23};
2724static Packet16uc p16uc_MERGE16_32_3 = {8, 9, 24, 25, 10, 11, 26, 27, 8, 9, 24, 25, 10, 11, 26, 27};
2725static Packet16uc p16uc_MERGE16_32_4 = {12, 13, 28, 29, 14, 15, 30, 31, 12, 13, 28, 29, 14, 15, 30, 31};
2726
2727static Packet16uc p16uc_MERGE16_32_5 = {0, 1, 16, 17, 16, 17, 16, 17, 0, 1, 16, 17, 16, 17, 16, 17};
2728static Packet16uc p16uc_MERGE16_32_6 = {2, 3, 18, 19, 18, 19, 18, 19, 2, 3, 18, 19, 18, 19, 18, 19};
2729static Packet16uc p16uc_MERGE16_32_7 = {4, 5, 20, 21, 20, 21, 20, 21, 4, 5, 20, 21, 20, 21, 20, 21};
2730static Packet16uc p16uc_MERGE16_32_8 = {6, 7, 22, 23, 22, 23, 22, 23, 6, 7, 22, 23, 22, 23, 22, 23};
2731
2732EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask) {
2733 Packet8us z = pset1<Packet8us>(0);
2734#ifdef _BIG_ENDIAN
2735 return reinterpret_cast<Packet4f>(vec_perm(data, z, mask));
2736#else
2737 return reinterpret_cast<Packet4f>(vec_perm(z, data, mask));
2738#endif
2739}
2740
2741template <bool lhsExtraRows, bool odd, Index size>
2742EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float* result, Index rows, const bfloat16* src,
2743 Index extra_rows) {
2744 Packet4f dup[4 * 4];
2745 Packet8bf data[4];
2746
2747 for (Index i = 0; i < size; i++) {
2748 data[i] = ploadu<Packet8bf>(src + rows * i);
2749 }
2750
2751 for (Index i = 0, j = 0; i < size; i++, j += 4) {
2752 dup[j + 0] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_5 : p16uc_MERGE16_32_1);
2753 dup[j + 1] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_6 : p16uc_MERGE16_32_2);
2754 dup[j + 2] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_7 : p16uc_MERGE16_32_3);
2755 dup[j + 3] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_8 : p16uc_MERGE16_32_4);
2756 }
2757
2758 for (Index j = 0; j < 4 * size; j += 4) {
2759 if (lhsExtraRows) {
2760 Packet4f z = pset1<Packet4f>(float(0));
2761 Index i = 0;
2762 do {
2763 pstoreu(result + (j + i) * 4, dup[j + i]);
2764 } while (++i < extra_rows);
2765 do {
2766 pstoreu(result + (j + i) * 4, z);
2767 } while (++i < 4);
2768 } else {
2769 for (Index i = 0; i < 4; i++) {
2770 pstoreu(result + (j + i) * 4, dup[j + i]);
2771 }
2772 }
2773 }
2774}
2775
2776template <bool lhsExtraRows>
2777EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float* result, Index cols, Index rows, const bfloat16* src,
2778 Index delta, Index extra_rows) {
2779 Index col = 0;
2780 src += delta * 2;
2781 for (; col + 4 * 2 <= cols; col += 4 * 2, result += 4 * 4 * 4, src += 4 * rows) {
2782 convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 4>(result, rows, src, extra_rows);
2783 }
2784 for (; col + 2 <= cols; col += 2, result += 4 * 4, src += rows) {
2785 convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 1>(result, rows, src, extra_rows);
2786 }
2787 if (cols & 1) {
2788 convertArrayPointerBF16toF32DupOne<lhsExtraRows, true, 1>(result, rows, src - delta, extra_rows);
2789 }
2790}
2791
2792template <const Index size, bool non_unit_stride>
2793EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float* result, Index rows, bfloat16*& src, Index resInc) {
2794 constexpr Index extra = ((size < 4) ? 4 : size);
2795 while (i + size <= rows) {
2796 PacketBlock<Packet8bf, (size + 7) / 8> r32;
2797 r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
2798 if (size >= 16) {
2799 r32.packet[1] = loadBF16fromResult<non_unit_stride, 8>(src, resInc);
2800 }
2801 if (size >= 32) {
2802 r32.packet[2] = loadBF16fromResult<non_unit_stride, 16>(src, resInc);
2803 r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
2804 }
2805 storeConvertBlockBF16<size>(result + i, r32, rows & 3);
2806 i += extra;
2807 src += extra * resInc;
2808 if (size != 32) break;
2809 }
2810}
2811
2812template <bool non_unit_stride>
2813EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float* result, Index cols, Index rows, bfloat16* src,
2814 Index resInc) {
2815 for (Index col = 0; col < cols; col++, src += (rows * resInc), result += rows) {
2816 Index i = 0;
2817 bfloat16* src2 = src;
2818 convertPointerBF16toF32<32, non_unit_stride>(i, result, rows, src2, resInc);
2819 convertPointerBF16toF32<16, non_unit_stride>(i, result, rows, src2, resInc);
2820 convertPointerBF16toF32<8, non_unit_stride>(i, result, rows, src2, resInc);
2821 convertPointerBF16toF32<4, non_unit_stride>(i, result, rows, src2, resInc);
2822 convertPointerBF16toF32<1, non_unit_stride>(i, result, rows, src2, resInc);
2823 }
2824}
2825
2826template <Index num_acc, Index size = 4>
2827EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f (&acc)[num_acc][size]) {
2828 Packet4f z = pset1<Packet4f>(float(0));
2829
2830 for (Index k = 0; k < num_acc; k++) {
2831 for (Index j = 0; j < size; j++) {
2832 acc[k][j] = z;
2833 }
2834 }
2835}
2836
2837template <Index num_acc>
2838EIGEN_ALWAYS_INLINE void tranposeResults(Packet4f (&acc)[num_acc][4]) {
2839 for (Index i = 0; i < num_acc; i++) {
2840 Packet4ui t0, t1, t2, t3;
2841 t0 = vec_mergeh(reinterpret_cast<Packet4ui>(acc[i][0]), reinterpret_cast<Packet4ui>(acc[i][2]));
2842 t1 = vec_mergel(reinterpret_cast<Packet4ui>(acc[i][0]), reinterpret_cast<Packet4ui>(acc[i][2]));
2843 t2 = vec_mergeh(reinterpret_cast<Packet4ui>(acc[i][1]), reinterpret_cast<Packet4ui>(acc[i][3]));
2844 t3 = vec_mergel(reinterpret_cast<Packet4ui>(acc[i][1]), reinterpret_cast<Packet4ui>(acc[i][3]));
2845 acc[i][0] = reinterpret_cast<Packet4f>(vec_mergeh(t0, t2));
2846 acc[i][1] = reinterpret_cast<Packet4f>(vec_mergel(t0, t2));
2847 acc[i][2] = reinterpret_cast<Packet4f>(vec_mergeh(t1, t3));
2848 acc[i][3] = reinterpret_cast<Packet4f>(vec_mergel(t1, t3));
2849 }
2850}
2851
2852template <Index num_acc>
2853EIGEN_ALWAYS_INLINE void addResults(Packet4f (&acc)[num_acc][4]) {
2854 for (Index i = 0, j = 0; j < num_acc; i++, j += 2) {
2855 for (Index x = 0, y = 0; x < 2; x++, y += 2) {
2856 for (Index w = 0, z = 0; w < 2; w++, z += 2) {
2857 acc[i][y + w] = acc[j + x][z + 0] + acc[j + x][z + 1];
2858 }
2859 }
2860 }
2861}
2862
2863template <Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs>
2864EIGEN_ALWAYS_INLINE void outputResultsVSX(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result,
2865 const Index extra_cols, Index extra_rows) {
2866 tranposeResults<num_acc>(acc);
2867 addResults<num_acc>(acc);
2868
2869 constexpr Index real_rhs = ((num_rhs / 2) - (rhsExtraCols ? 1 : 0));
2870 Index k = 0;
2871 for (Index i = 0; i < real_rhs; i++, result += 4 * rows, k++) {
2872 storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
2873 }
2874 if (rhsExtraCols) {
2875 storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
2876 }
2877}
2878
2879template <bool zero>
2880EIGEN_ALWAYS_INLINE void loadTwoRhsFloat32(const float* block, Index strideB, Index i, Packet4f& dhs0, Packet4f& dhs1) {
2881 dhs0 = ploadu<Packet4f>(block + strideB * i + 0);
2882 if (zero) {
2883 Packet4f dhs2 = pset1<Packet4f>(float(0));
2884 dhs1 = vec_mergel(dhs0, dhs2);
2885 dhs0 = vec_mergeh(dhs0, dhs2);
2886 } else {
2887 dhs1 = ploadu<Packet4f>(block + strideB * i + 4);
2888 }
2889}
2890
2891template <Index num_acc, bool zero, bool rhsExtraCols, Index num_rhs>
2892EIGEN_ALWAYS_INLINE void KLoop(const float* indexA, const float* indexB, Packet4f (&acc)[num_acc][4], Index strideB,
2893 Index k, Index offsetB, Index extra_cols) {
2894 constexpr Index num_lhs = 4;
2895 Packet4f lhs[num_lhs], rhs[num_rhs];
2896
2897 constexpr Index real_rhs = (num_rhs - (rhsExtraCols ? 2 : 0));
2898 for (Index i = 0; i < real_rhs; i += 2) {
2899 loadTwoRhsFloat32<zero>(indexB + k * 4, strideB, i, rhs[i + 0], rhs[i + 1]);
2900 }
2901 if (rhsExtraCols) {
2902 loadTwoRhsFloat32<zero>(indexB + k * extra_cols - offsetB, strideB, real_rhs, rhs[real_rhs + 0], rhs[real_rhs + 1]);
2903 }
2904
2905 indexA += 2 * k * 4;
2906 for (Index j = 0; j < num_lhs; j++) {
2907 lhs[j] = ploadu<Packet4f>(indexA + j * 4);
2908 }
2909
2910 for (Index j = 0; j < num_rhs; j++) {
2911 for (Index i = 0; i < num_lhs; i++) {
2912 acc[j][i] = pmadd(rhs[j], lhs[i], acc[j][i]);
2913 }
2914 }
2915}
2916
2917template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2918EIGEN_ALWAYS_INLINE void colVSXLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const float* indexA,
2919 const float* indexB, Index strideB, Index offsetB, float* result,
2920 const Index extra_cols, const Index extra_rows) {
2921 constexpr Index num_rhs = num_acc;
2922
2923 Packet4f acc[num_acc][4];
2924
2925 zeroAccumulators<num_acc>(acc);
2926
2927 Index k;
2928 for (k = 0; k + 2 <= depth; k += 2) {
2929 KLoop<num_acc, false, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
2930 }
2931 if (depth & 1) {
2932 KLoop<num_acc, true, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
2933 }
2934
2935 outputResultsVSX<num_acc, rhsExtraCols, lhsExtraRows, num_rhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
2936}
2937
2938// No more than 4 (uses 2X the accumulators or 8X the number of VSX registers)
2939#define MAX_BFLOAT16_ACC_VSX 4
2940
2941template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2942void colVSXLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA,
2943 const float* indexB, Index strideB, Index offsetB, float* result) {
2944 constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
2945 const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
2946 const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
2947 constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC_VSX);
2948
2949 do {
2950 colVSXLoopBodyIter<num_acc * 2, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB,
2951 result, extra_cols, extra_rows);
2952
2953 indexB += strideB * (num_acc * 2);
2954 result += rows * step;
2955 } while (multiIters && (step <= cols - (col += step)));
2956}
2957
2958template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2959EIGEN_ALWAYS_INLINE void colVSXLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha,
2960 const float* indexA, const float* blockB, Index strideB, Index offsetB,
2961 float* result) {
2962 if (MAX_BFLOAT16_ACC_VSX > num_acc) {
2963 colVSXLoopBody<num_acc + (rhsExtraCols ? 1 : 0), rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA,
2964 blockB, strideB, offsetB, result);
2965 }
2966}
2967
2968template <bool rhsExtraCols, bool lhsExtraRows>
2969void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA,
2970 const float* blockB, Index strideB, Index offsetB, float* result) {
2971 switch ((cols - col) >> 2) {
2972 case 3:
2973 colVSXLoopBodyExtraN<3, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2974 offsetB, result);
2975 break;
2976 case 2:
2977 colVSXLoopBodyExtraN<2, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2978 offsetB, result);
2979 break;
2980 case 1:
2981 colVSXLoopBodyExtraN<1, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2982 offsetB, result);
2983 break;
2984 default:
2985 if (rhsExtraCols) {
2986 colVSXLoopBody<1, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
2987 }
2988 break;
2989 }
2990}
2991
2992template <Index size, bool lhsExtraRows = false>
2993EIGEN_ALWAYS_INLINE void colVSXLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
2994 const float* indexA2, const float* blockB2, Index strideA, Index strideB,
2995 Index offsetB, float* result2) {
2996 Index delta_rows = 2 * (lhsExtraRows ? (rows & 3) : size);
2997 for (Index row = 0; row < size; row += 4) {
2998 convertArrayPointerBF16toF32Dup<lhsExtraRows>(const_cast<float*>(indexA2), strideA, delta_rows, indexA, row,
2999 rows & 3);
3000
3001 const float* blockB = blockB2;
3002 float* result = result2 + row;
3003
3004 Index col = 0;
3005 if (cols >= (MAX_BFLOAT16_ACC_VSX * 4)) {
3006 colVSXLoopBody<MAX_BFLOAT16_ACC_VSX, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB,
3007 strideB, 0, result);
3008 blockB += (strideB >> 1) * col;
3009 result += rows * col;
3010 }
3011 if (cols & 3) {
3012 colVSXLoopBodyExtra<true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, offsetB,
3013 result);
3014 } else {
3015 colVSXLoopBodyExtra<false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, 0, result);
3016 }
3017 }
3018}
3019
3020template <Index size>
3021EIGEN_ALWAYS_INLINE void calcVSXColLoops(const bfloat16*& indexA, const float* indexA2, Index& row, Index depth,
3022 Index cols, Index rows, const Packet4f pAlpha, const float* indexB,
3023 Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix,
3024 float* result) {
3025 if ((size == 16) || (rows & size)) {
3026 indexA += size * offsetA;
3027 colVSXLoops<size>(depth, cols, rows, pAlpha, indexA, indexA2, indexB, strideA, strideB, offsetB, result + row);
3028 row += size;
3029 indexA += bigSuffix * size / 16;
3030 }
3031}
3032
3033template <const Index size, typename DataMapper>
3034EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float* result, Index rows, const DataMapper& src) {
3035 constexpr Index extra = ((size < 4) ? 4 : size);
3036 while (i + size <= rows) {
3037 PacketBlock<Packet8bf, (size + 7) / 8> r32;
3038 r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
3039 if (size >= 16) {
3040 r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
3041 }
3042 if (size >= 32) {
3043 r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
3044 r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
3045 }
3046 storeConvertBlockBF16<size>(result + i, r32, rows & 3);
3047 i += extra;
3048 if (size != 32) break;
3049 }
3050}
3051
3052template <typename DataMapper>
3053EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float* result, Index cols, Index rows, const DataMapper& src) {
3054 typedef typename DataMapper::LinearMapper LinearMapper;
3055 for (Index j = 0; j < cols; j++, result += rows) {
3056 const LinearMapper src2 = src.getLinearMapper(0, j);
3057 Index i = 0;
3058 convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
3059 convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
3060 convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
3061 convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
3062 convertBF16toF32<1, LinearMapper>(i, result, rows, src2);
3063 }
3064}
3065
3066EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float* res) {
3067 return F32ToBf16Both(ploadu<Packet4f>(res + 0), ploadu<Packet4f>(res + 4));
3068}
3069
3070template <typename DataMapper, const Index size>
3071EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float* result, Index col, Index rows, const DataMapper& res) {
3072 const DataMapper res2 = res.getSubMapper(0, col);
3073 Index row;
3074 float* result2 = result + col * rows;
3075 for (row = 0; row + 8 <= rows; row += 8, result2 += 8) {
3076 // get and save block
3077 PacketBlock<Packet8bf, size> block;
3078 for (Index j = 0; j < size; j++) {
3079 block.packet[j] = convertF32toBF16VSX(result2 + j * rows);
3080 }
3081 res2.template storePacketBlock<Packet8bf, size>(row, 0, block);
3082 }
3083 // extra rows
3084 if (row < rows) {
3085 for (Index j = 0; j < size; j++) {
3086 Packet8bf fp16 = convertF32toBF16VSX(result2 + j * rows);
3087 res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
3088 }
3089 }
3090}
3091
3092template <typename DataMapper>
3093EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float* result, Index cols, Index rows, const DataMapper& res) {
3094 Index col;
3095 for (col = 0; col + 4 <= cols; col += 4) {
3096 convertArrayF32toBF16ColVSX<DataMapper, 4>(result, col, rows, res);
3097 }
3098 // extra cols
3099 switch (cols - col) {
3100 case 1:
3101 convertArrayF32toBF16ColVSX<DataMapper, 1>(result, col, rows, res);
3102 break;
3103 case 2:
3104 convertArrayF32toBF16ColVSX<DataMapper, 2>(result, col, rows, res);
3105 break;
3106 case 3:
3107 convertArrayF32toBF16ColVSX<DataMapper, 3>(result, col, rows, res);
3108 break;
3109 }
3110}
3111
3112template <typename DataMapper>
3113void gemmbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth,
3114 Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3115 float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
3116 const Packet4f pAlpha = pset1<Packet4f>(falpha);
3117
3118 if (strideA == -1) strideA = depth;
3119 if (strideB == -1) strideB = depth;
3120
3121 ei_declare_aligned_stack_constructed_variable(float, result, cols* rows, 0);
3122 ei_declare_aligned_stack_constructed_variable(float, indexB2, strideB* cols, 0);
3123 ei_declare_aligned_stack_constructed_variable(float, indexA2, ((strideA + 1) & -2) * 4 * 2, 0);
3124
3125 convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
3126 convertArrayPointerBF16toF32(indexB2, cols, strideB, const_cast<bfloat16*>(indexB));
3127
3128 Index bigSuffix = 2 * 8 * (strideA - offsetA);
3129 float* indexBF32 = indexB2 + 4 * offsetB;
3130 offsetB *= 3;
3131 strideB *= 2;
3132
3133 Index row = 0;
3134 // LHS (8x16) block
3135 while (row + 16 <= rows) {
3136 calcVSXColLoops<16>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3137 bigSuffix, result);
3138 }
3139 // LHS (8x8) block
3140 calcVSXColLoops<8>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3141 bigSuffix, result);
3142 // LHS (8x4) block
3143 calcVSXColLoops<4>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3144 bigSuffix, result);
3145 // extra rows
3146 if (rows & 3) {
3147 // This index is the beginning of remaining block.
3148 colVSXLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexA2, indexBF32, strideA, strideB, offsetB,
3149 result + row);
3150 }
3151
3152 // Convert back to bfloat16
3153 convertArrayF32toBF16VSX<DataMapper>(result, cols, rows, res);
3154}
3155
3156#undef MAX_BFLOAT16_ACC_VSX
3157
3158#include "MatrixVectorProduct.h"
3159
3160/************************************
3161 * ppc64le template specializations *
3162 * **********************************/
3163template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3164struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3165 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3166};
3167
3168template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3169void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3170 double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3171 dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
3172 pack(blockA, lhs, depth, rows, stride, offset);
3173}
3174
3175template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3176struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3177 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3178};
3179
3180template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3181void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3182 double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3183 dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
3184 pack(blockA, lhs, depth, rows, stride, offset);
3185}
3186
3187#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3188template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3189struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3190 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3191};
3192
3193template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3194void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3195 double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3196 dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
3197 pack(blockB, rhs, depth, cols, stride, offset);
3198}
3199
3200template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3201struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3202 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3203};
3204
3205template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3206void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3207 double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3208 dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
3209 pack(blockB, rhs, depth, cols, stride, offset);
3210}
3211
3212template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3213struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3214 void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3215};
3216
3217template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3218void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3219 bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3220 dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, false> pack;
3221 pack(blockB, rhs, depth, cols, stride, offset);
3222}
3223
3224template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3225struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3226 void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3227};
3228
3229template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3230void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3231 bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3232 dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, false> pack;
3233 pack(blockB, rhs, depth, cols, stride, offset);
3234}
3235#endif
3236
3237template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3238struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3239 void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3240};
3241
3242template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3243void gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3244 bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3245 dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, true> pack;
3246 pack(blockA, lhs, depth, rows, stride, offset);
3247}
3248
3249template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3250struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3251 void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3252};
3253
3254template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3255void gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3256 bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3257 dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, true> pack;
3258 pack(blockA, lhs, depth, rows, stride, offset);
3259}
3260
3261template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3262struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3263 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3264};
3265
3266template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3267void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3268 float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3269 dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
3270 pack(blockA, lhs, depth, rows, stride, offset);
3271}
3272
3273template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3274struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3275 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3276};
3277
3278template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3279void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3280 float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3281 dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
3282 pack(blockA, lhs, depth, rows, stride, offset);
3283}
3284
3285template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3286struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3287 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3288 Index offset = 0);
3289};
3290
3291template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3292void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate,
3293 PanelMode>::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows,
3294 Index stride, Index offset) {
3295 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
3296 pack(blockA, lhs, depth, rows, stride, offset);
3297}
3298
3299template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3300struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3301 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3302 Index offset = 0);
3303};
3304
3305template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3306void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate,
3307 PanelMode>::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows,
3308 Index stride, Index offset) {
3309 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
3310 pack(blockA, lhs, depth, rows, stride, offset);
3311}
3312
3313#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3314template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3315struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3316 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3317};
3318
3319template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3320void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3321 float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3322 dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
3323 pack(blockB, rhs, depth, cols, stride, offset);
3324}
3325
3326template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3327struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3328 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3329};
3330
3331template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3332void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3333 float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3334 dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
3335 pack(blockB, rhs, depth, cols, stride, offset);
3336}
3337#endif
3338
3339template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3340struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3341 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3342 Index offset = 0);
3343};
3344
3345template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3346void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3347 std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3348 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
3349 pack(blockB, rhs, depth, cols, stride, offset);
3350}
3351
3352template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3353struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3354 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3355 Index offset = 0);
3356};
3357
3358template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3359void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3360 std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3361 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
3362 pack(blockB, rhs, depth, cols, stride, offset);
3363}
3364
3365template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3366struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3367 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3368 Index offset = 0);
3369};
3370
3371template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3372void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate,
3373 PanelMode>::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
3374 Index stride, Index offset) {
3375 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
3376 pack(blockA, lhs, depth, rows, stride, offset);
3377}
3378
3379template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3380struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3381 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3382 Index offset = 0);
3383};
3384
3385template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3386void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate,
3387 PanelMode>::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
3388 Index stride, Index offset) {
3389 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
3390 pack(blockA, lhs, depth, rows, stride, offset);
3391}
3392
3393template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3394struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3395 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3396 Index offset = 0);
3397};
3398
3399template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3400void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3401 std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3402 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
3403 pack(blockB, rhs, depth, cols, stride, offset);
3404}
3405
3406template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3407struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3408 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3409 Index offset = 0);
3410};
3411
3412template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3413void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3414 std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3415 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
3416 pack(blockB, rhs, depth, cols, stride, offset);
3417}
3418
3419// ********* gebp specializations *********
3420template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3421struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3422 typedef typename quad_traits<float>::vectortype Packet;
3423 typedef typename quad_traits<float>::rhstype RhsPacket;
3424
3425 void operator()(const DataMapper& res, const float* blockA, const float* blockB, Index rows, Index depth, Index cols,
3426 float alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3427};
3428
3429template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3430void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3431 const DataMapper& res, const float* blockA, const float* blockB, Index rows, Index depth, Index cols, float alpha,
3432 Index strideA, Index strideB, Index offsetA, Index offsetB) {
3433 const Index accRows = quad_traits<float>::rows;
3434 const Index accCols = quad_traits<float>::size;
3435 static void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index,
3436 Index, Index) =
3437#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3438 (supportsMMA()) ? &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols> :
3439#endif
3440 &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
3441 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3442}
3443
3444template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3445struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3446 typedef Packet4f Packet;
3447 typedef Packet2cf Packetc;
3448 typedef Packet4f RhsPacket;
3449
3450 void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
3451 Index rows, Index depth, Index cols, std::complex<float> alpha, Index strideA = -1,
3452 Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3453};
3454
3455template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3456void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs,
3457 ConjugateRhs>::operator()(const DataMapper& res, const std::complex<float>* blockA,
3458 const std::complex<float>* blockB, Index rows, Index depth, Index cols,
3459 std::complex<float> alpha, Index strideA, Index strideB, Index offsetA,
3460 Index offsetB) {
3461 const Index accRows = quad_traits<float>::rows;
3462 const Index accCols = quad_traits<float>::size;
3463 static void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*, Index, Index,
3464 Index, std::complex<float>, Index, Index, Index, Index) =
3465#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3466 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>,
3467 float, Packet, Packetc, RhsPacket, DataMapper, accRows,
3468 accCols, ConjugateLhs, ConjugateRhs, false, false>
3469 :
3470#endif
3471 &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>,
3472 float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3473 ConjugateLhs, ConjugateRhs, false, false>;
3474 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3475}
3476
3477template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3478struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3479 typedef Packet4f Packet;
3480 typedef Packet2cf Packetc;
3481 typedef Packet4f RhsPacket;
3482
3483 void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB, Index rows,
3484 Index depth, Index cols, std::complex<float> alpha, Index strideA = -1, Index strideB = -1,
3485 Index offsetA = 0, Index offsetB = 0);
3486};
3487
3488template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3489void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3490 const DataMapper& res, const float* blockA, const std::complex<float>* blockB, Index rows, Index depth, Index cols,
3491 std::complex<float> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3492 const Index accRows = quad_traits<float>::rows;
3493 const Index accCols = quad_traits<float>::size;
3494 static void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*, Index, Index, Index,
3495 std::complex<float>, Index, Index, Index, Index) =
3496#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3497 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float,
3498 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3499 ConjugateLhs, ConjugateRhs, true, false>
3500 :
3501#endif
3502 &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet,
3503 Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3504 ConjugateRhs, true, false>;
3505 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3506}
3507
3508template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3509struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3510 typedef Packet4f Packet;
3511 typedef Packet2cf Packetc;
3512 typedef Packet4f RhsPacket;
3513
3514 void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB, Index rows,
3515 Index depth, Index cols, std::complex<float> alpha, Index strideA = -1, Index strideB = -1,
3516 Index offsetA = 0, Index offsetB = 0);
3517};
3518
3519template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3520void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3521 const DataMapper& res, const std::complex<float>* blockA, const float* blockB, Index rows, Index depth, Index cols,
3522 std::complex<float> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3523 const Index accRows = quad_traits<float>::rows;
3524 const Index accCols = quad_traits<float>::size;
3525 static void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*, Index, Index, Index,
3526 std::complex<float>, Index, Index, Index, Index) =
3527#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3528 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float,
3529 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3530 ConjugateLhs, ConjugateRhs, false, true>
3531 :
3532#endif
3533 &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet,
3534 Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3535 ConjugateRhs, false, true>;
3536 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3537}
3538
3539template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3540struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3541 typedef typename quad_traits<double>::vectortype Packet;
3542 typedef typename quad_traits<double>::rhstype RhsPacket;
3543
3544 void operator()(const DataMapper& res, const double* blockA, const double* blockB, Index rows, Index depth,
3545 Index cols, double alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
3546 Index offsetB = 0);
3547};
3548
3549template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3550void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3551 const DataMapper& res, const double* blockA, const double* blockB, Index rows, Index depth, Index cols,
3552 double alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3553 const Index accRows = quad_traits<double>::rows;
3554 const Index accCols = quad_traits<double>::size;
3555 static void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index,
3556 Index, Index, Index) =
3557#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3558 (supportsMMA()) ? &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols> :
3559#endif
3560 &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
3561 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3562}
3563
3564template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3565struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3566 typedef quad_traits<double>::vectortype Packet;
3567 typedef Packet1cd Packetc;
3568 typedef quad_traits<double>::rhstype RhsPacket;
3569
3570 void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
3571 Index rows, Index depth, Index cols, std::complex<double> alpha, Index strideA = -1,
3572 Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3573};
3574
3575template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3576void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs,
3577 ConjugateRhs>::operator()(const DataMapper& res, const std::complex<double>* blockA,
3578 const std::complex<double>* blockB, Index rows, Index depth, Index cols,
3579 std::complex<double> alpha, Index strideA, Index strideB, Index offsetA,
3580 Index offsetB) {
3581 const Index accRows = quad_traits<double>::rows;
3582 const Index accCols = quad_traits<double>::size;
3583 static void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*, Index,
3584 Index, Index, std::complex<double>, Index, Index, Index, Index) =
3585#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3586 (supportsMMA())
3587 ? &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double,
3588 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3589 ConjugateRhs, false, false>
3590 :
3591#endif
3592 &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double,
3593 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3594 ConjugateRhs, false, false>;
3595 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3596}
3597
3598template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3599struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3600 typedef quad_traits<double>::vectortype Packet;
3601 typedef Packet1cd Packetc;
3602 typedef quad_traits<double>::rhstype RhsPacket;
3603
3604 void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB, Index rows,
3605 Index depth, Index cols, std::complex<double> alpha, Index strideA = -1, Index strideB = -1,
3606 Index offsetA = 0, Index offsetB = 0);
3607};
3608
3609template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3610void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3611 const DataMapper& res, const std::complex<double>* blockA, const double* blockB, Index rows, Index depth,
3612 Index cols, std::complex<double> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3613 const Index accRows = quad_traits<double>::rows;
3614 const Index accCols = quad_traits<double>::size;
3615 static void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*, Index, Index, Index,
3616 std::complex<double>, Index, Index, Index, Index) =
3617#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3618 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double,
3619 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3620 ConjugateLhs, ConjugateRhs, false, true>
3621 :
3622#endif
3623 &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet,
3624 Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3625 ConjugateRhs, false, true>;
3626 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3627}
3628
3629template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3630struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3631 typedef quad_traits<double>::vectortype Packet;
3632 typedef Packet1cd Packetc;
3633 typedef quad_traits<double>::rhstype RhsPacket;
3634
3635 void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB, Index rows,
3636 Index depth, Index cols, std::complex<double> alpha, Index strideA = -1, Index strideB = -1,
3637 Index offsetA = 0, Index offsetB = 0);
3638};
3639
3640template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3641void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3642 const DataMapper& res, const double* blockA, const std::complex<double>* blockB, Index rows, Index depth,
3643 Index cols, std::complex<double> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3644 const Index accRows = quad_traits<double>::rows;
3645 const Index accCols = quad_traits<double>::size;
3646 static void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*, Index, Index, Index,
3647 std::complex<double>, Index, Index, Index, Index) =
3648#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3649 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double,
3650 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3651 ConjugateLhs, ConjugateRhs, true, false>
3652 :
3653#endif
3654 &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet,
3655 Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3656 ConjugateRhs, true, false>;
3657 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3658}
3659
3660template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3661struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3662 typedef typename quad_traits<bfloat16>::vectortype Packet;
3663 typedef typename quad_traits<bfloat16>::rhstype RhsPacket;
3664
3665 void operator()(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth,
3666 Index cols, bfloat16 alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
3667 Index offsetB = 0);
3668};
3669
3670template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3671void gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3672 const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols,
3673 bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3674 static void (*gemm_function)(const DataMapper&, const bfloat16*, const bfloat16*, Index, Index, Index, bfloat16,
3675 Index, Index, Index, Index) =
3676#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3677 (supportsMMA()) ? &Eigen::internal::gemmMMAbfloat16<DataMapper> :
3678#endif
3679 &Eigen::internal::gemmbfloat16<DataMapper>;
3680 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3681}
3682} // end namespace internal
3683
3684} // end namespace Eigen
3685
3686#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_imag_op< typename Derived::Scalar >, const Derived > imag(const Eigen::ArrayBase< Derived > &x)