11#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
14#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
18#if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
19#define EIGEN_ALTIVEC_DISABLE_MMA 0
23#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__has_builtin)
24#if __has_builtin(__builtin_mma_assemble_acc)
25#define EIGEN_ALTIVEC_MMA_SUPPORT
30#if defined(EIGEN_ALTIVEC_MMA_SUPPORT)
32#if !defined(EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH)
33#define EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH 0
37#if EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH && !EIGEN_COMP_LLVM
38#define EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH 1
41#define EIGEN_ALTIVEC_MMA_ONLY 1
46#include "MatrixProductCommon.h"
48#if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
49#include "MatrixProductMMA.h"
53#include "../../InternalHeaderCheck.h"
62template <
typename Scalar>
64 typedef typename packet_traits<Scalar>::type vectortype;
65 typedef PacketBlock<vectortype, 4> type;
66 typedef vectortype rhstype;
67 enum { vectorsize = packet_traits<Scalar>::size, size = 4, rows = 4 };
71struct quad_traits<double> {
72 typedef Packet2d vectortype;
73 typedef PacketBlock<vectortype, 4> type;
74 typedef PacketBlock<Packet2d, 2> rhstype;
75 enum { vectorsize = packet_traits<double>::size, size = 2, rows = 4 };
79struct quad_traits<bfloat16> {
80 typedef Packet8bf vectortype;
81 typedef PacketBlock<vectortype, 4> type;
82 typedef vectortype rhstype;
83 enum { vectorsize = packet_traits<bfloat16>::size, size = 8, rows = 4 };
90const static Packet16uc p16uc_GETREAL32 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
92const static Packet16uc p16uc_GETIMAG32 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
94const static Packet16uc p16uc_GETREAL32b = {0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
96const static Packet16uc p16uc_GETIMAG32b = {4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31};
116template <
typename Scalar,
int StorageOrder>
117EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(
118 Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt) {
119 std::complex<Scalar> v;
121 v.real(dt(j, i).
real());
122 v.imag(-dt(j, i).
imag());
124 v.real(dt(i, j).
real());
125 v.imag(dt(i, j).
imag());
127 v.real(dt(i, j).
real());
133template <
typename Scalar,
int StorageOrder,
int N>
134EIGEN_STRONG_INLINE
void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB,
const std::complex<Scalar>* _rhs,
135 Index rhsStride, Index rows, Index cols, Index k2) {
136 const Index depth = k2 + rows;
137 const_blas_data_mapper<std::complex<Scalar>,
Index, StorageOrder> rhs(_rhs, rhsStride);
138 const Index vectorSize = N * quad_traits<Scalar>::vectorsize;
139 const Index vectorDelta = vectorSize * rows;
140 Scalar* blockBf =
reinterpret_cast<Scalar*
>(blockB);
142 Index rir = 0, rii, j = 0;
143 for (; j + vectorSize <= cols; j += vectorSize) {
144 rii = rir + vectorDelta;
146 for (Index i = k2; i < depth; i++) {
147 for (Index k = 0; k < vectorSize; k++) {
148 std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j + k, rhs);
150 blockBf[rir + k] = v.real();
151 blockBf[rii + k] = v.imag();
160 for (; j < cols; j++) {
163 for (Index i = k2; i < depth; i++) {
164 std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j, rhs);
166 blockBf[rir] = v.real();
167 blockBf[rii] = v.imag();
177template <
typename Scalar,
int StorageOrder>
178EIGEN_STRONG_INLINE
void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA,
const std::complex<Scalar>* _lhs,
179 Index lhsStride, Index cols, Index rows) {
180 const Index depth = cols;
181 const_blas_data_mapper<std::complex<Scalar>,
Index, StorageOrder> lhs(_lhs, lhsStride);
182 const Index vectorSize = quad_traits<Scalar>::vectorsize;
183 const Index vectorDelta = vectorSize * depth;
184 Scalar* blockAf =
reinterpret_cast<Scalar*
>(blockA);
186 Index rir = 0, rii, j = 0;
187 for (; j + vectorSize <= rows; j += vectorSize) {
188 rii = rir + vectorDelta;
190 for (Index i = 0; i < depth; i++) {
191 for (Index k = 0; k < vectorSize; k++) {
192 std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(j + k, i, lhs);
194 blockAf[rir + k] = v.real();
195 blockAf[rii + k] = v.imag();
205 rii = rir + ((rows - j) * depth);
207 for (Index i = 0; i < depth; i++) {
209 for (; k < rows; k++) {
210 std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(k, i, lhs);
212 blockAf[rir] = v.real();
213 blockAf[rii] = v.imag();
222template <
typename Scalar,
int StorageOrder,
int N>
223EIGEN_STRONG_INLINE
void symm_pack_rhs_helper(Scalar* blockB,
const Scalar* _rhs, Index rhsStride, Index rows,
224 Index cols, Index k2) {
225 const Index depth = k2 + rows;
226 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
227 const Index vectorSize = quad_traits<Scalar>::vectorsize;
230 for (; j + N * vectorSize <= cols; j += N * vectorSize) {
232 for (; i < depth; i++) {
233 for (Index k = 0; k < N * vectorSize; k++) {
235 blockB[ri + k] = rhs(j + k, i);
237 blockB[ri + k] = rhs(i, j + k);
239 ri += N * vectorSize;
243 for (; j < cols; j++) {
244 for (Index i = k2; i < depth; i++) {
246 blockB[ri] = rhs(i, j);
248 blockB[ri] = rhs(j, i);
254template <
typename Scalar,
int StorageOrder>
255EIGEN_STRONG_INLINE
void symm_pack_lhs_helper(Scalar* blockA,
const Scalar* _lhs, Index lhsStride, Index cols,
257 const Index depth = cols;
258 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
259 const Index vectorSize = quad_traits<Scalar>::vectorsize;
262 for (; j + vectorSize <= rows; j += vectorSize) {
265 for (; i < depth; i++) {
266 for (Index k = 0; k < vectorSize; k++) {
268 blockA[ri + k] = lhs(j + k, i);
270 blockA[ri + k] = lhs(i, j + k);
277 for (Index i = 0; i < depth; i++) {
279 for (; k < rows; k++) {
281 blockA[ri] = lhs(k, i);
283 blockA[ri] = lhs(i, k);
290template <
typename Index,
int nr,
int StorageOrder>
291struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder> {
292 void operator()(std::complex<float>* blockB,
const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols,
294 symm_pack_complex_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
298template <
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
299struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder> {
300 void operator()(std::complex<float>* blockA,
const std::complex<float>* _lhs, Index lhsStride, Index cols,
302 symm_pack_complex_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
308template <
typename Index,
int nr,
int StorageOrder>
309struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder> {
310 void operator()(std::complex<double>* blockB,
const std::complex<double>* _rhs, Index rhsStride, Index rows,
311 Index cols, Index k2) {
312 symm_pack_complex_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
316template <
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
317struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder> {
318 void operator()(std::complex<double>* blockA,
const std::complex<double>* _lhs, Index lhsStride, Index cols,
320 symm_pack_complex_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
325template <
typename Index,
int nr,
int StorageOrder>
326struct symm_pack_rhs<float, Index, nr, StorageOrder> {
327 void operator()(
float* blockB,
const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2) {
328 symm_pack_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
332template <
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
333struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder> {
334 void operator()(
float* blockA,
const float* _lhs, Index lhsStride, Index cols, Index rows) {
335 symm_pack_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
340template <
typename Index,
int nr,
int StorageOrder>
341struct symm_pack_rhs<double, Index, nr, StorageOrder> {
342 void operator()(
double* blockB,
const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2) {
343 symm_pack_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
347template <
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
348struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder> {
349 void operator()(
double* blockA,
const double* _lhs, Index lhsStride, Index cols, Index rows) {
350 symm_pack_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
365template <
typename Scalar,
typename Packet,
int N>
366EIGEN_ALWAYS_INLINE
void storeBlock(Scalar* to, PacketBlock<Packet, N>& block) {
367 const Index size = 16 /
sizeof(Scalar);
368 pstore<Scalar>(to + (0 * size), block.packet[0]);
369 pstore<Scalar>(to + (1 * size), block.packet[1]);
371 pstore<Scalar>(to + (2 * size), block.packet[2]);
374 pstore<Scalar>(to + (3 * size), block.packet[3]);
379template <
typename Scalar,
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
380 bool PanelMode,
bool UseLhs>
382 template <
bool transpose>
383 EIGEN_ALWAYS_INLINE
void dhs_cblock(PacketBlock<PacketC, 8>& cblock, PacketBlock<Packet, 4>& block,
384 Packet16uc permute) {
386 block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, permute);
387 block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, permute);
388 block.packet[2] = vec_perm(cblock.packet[4].v, cblock.packet[5].v, permute);
389 block.packet[3] = vec_perm(cblock.packet[6].v, cblock.packet[7].v, permute);
391 Packet4f t0, t1, t2, t3;
392#ifdef EIGEN_VECTORIZE_VSX
393 t0 =
reinterpret_cast<Packet
>(
394 vec_mergeh(
reinterpret_cast<Packet2ul
>(block.packet[0]),
reinterpret_cast<Packet2ul
>(block.packet[1])));
395 t1 =
reinterpret_cast<Packet
>(
396 vec_mergel(
reinterpret_cast<Packet2ul
>(block.packet[0]),
reinterpret_cast<Packet2ul
>(block.packet[1])));
397 t2 =
reinterpret_cast<Packet
>(
398 vec_mergeh(
reinterpret_cast<Packet2ul
>(block.packet[2]),
reinterpret_cast<Packet2ul
>(block.packet[3])));
399 t3 =
reinterpret_cast<Packet
>(
400 vec_mergel(
reinterpret_cast<Packet2ul
>(block.packet[2]),
reinterpret_cast<Packet2ul
>(block.packet[3])));
402 t0 =
reinterpret_cast<Packet
>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_HI));
403 t1 =
reinterpret_cast<Packet
>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_LO));
404 t2 =
reinterpret_cast<Packet
>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_HI));
405 t3 =
reinterpret_cast<Packet
>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_LO));
408 block.packet[0] = t0;
409 block.packet[1] = t1;
410 block.packet[2] = t2;
411 block.packet[3] = t3;
413 block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, permute);
414 block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, permute);
415 block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, permute);
416 block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, permute);
420 EIGEN_ALWAYS_INLINE
void dhs_ccopy(Scalar* blockAt,
const DataMapper& lhs2, Index& i, Index& rir, Index& rii,
421 Index depth,
const Index vectorSize) {
422 PacketBlock<Packet, 4> blockr, blocki;
423 PacketBlock<PacketC, 8> cblock;
425 for (; i + vectorSize <= depth; i += vectorSize) {
427 bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
429 bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
432 if (((StorageOrder ==
RowMajor) && UseLhs) || (((StorageOrder ==
ColMajor) && !UseLhs))) {
433 dhs_cblock<true>(cblock, blockr, p16uc_GETREAL32b);
434 dhs_cblock<true>(cblock, blocki, p16uc_GETIMAG32b);
436 dhs_cblock<false>(cblock, blockr, p16uc_GETREAL32);
437 dhs_cblock<false>(cblock, blocki, p16uc_GETIMAG32);
441 blocki.packet[0] = -blocki.packet[0];
442 blocki.packet[1] = -blocki.packet[1];
443 blocki.packet[2] = -blocki.packet[2];
444 blocki.packet[3] = -blocki.packet[3];
447 storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
448 storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
450 rir += 4 * vectorSize;
451 rii += 4 * vectorSize;
455 EIGEN_STRONG_INLINE
void operator()(std::complex<Scalar>* blockA,
const DataMapper& lhs, Index depth, Index rows,
456 Index stride, Index offset) {
457 const Index vectorSize = quad_traits<Scalar>::vectorsize;
458 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
459 Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
460 Scalar* blockAt =
reinterpret_cast<Scalar*
>(blockA);
463 for (; j + vectorSize <= rows; j += vectorSize) {
464 const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
467 rii = rir + vectorDelta;
469 dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
471 for (; i < depth; i++) {
472 PacketBlock<Packet, 1> blockr, blocki;
473 PacketBlock<PacketC, 2> cblock;
475 if (((StorageOrder ==
ColMajor) && UseLhs) || (((StorageOrder ==
RowMajor) && !UseLhs))) {
477 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
478 cblock.packet[1] = lhs2.template loadPacket<PacketC>(2, i);
480 cblock.packet[0] = lhs2.template loadPacket<PacketC>(i, 0);
481 cblock.packet[1] = lhs2.template loadPacket<PacketC>(i, 2);
485 cblock.packet[0] = pload2(lhs2(0, i), lhs2(1, i));
486 cblock.packet[1] = pload2(lhs2(2, i), lhs2(3, i));
488 cblock.packet[0] = pload2(lhs2(i, 0), lhs2(i, 1));
489 cblock.packet[1] = pload2(lhs2(i, 2), lhs2(i, 3));
493 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
494 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
497 blocki.packet[0] = -blocki.packet[0];
500 pstore<Scalar>(blockAt + rir, blockr.packet[0]);
501 pstore<Scalar>(blockAt + rii, blocki.packet[0]);
507 rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
511 if (PanelMode) rir -= (offset * (vectorSize - 1));
513 for (; j < rows; j++) {
514 const DataMapper lhs2 = lhs.getSubMapper(0, j);
515 rii = rir + ((PanelMode) ? stride : depth);
517 for (Index i = 0; i < depth; i++) {
518 blockAt[rir] = lhs2(i, 0).real();
521 blockAt[rii] = -lhs2(i, 0).imag();
523 blockAt[rii] = lhs2(i, 0).imag();
529 rir += ((PanelMode) ? (2 * stride - depth) : depth);
533 if (PanelMode) rir += (offset * (rows - j - vectorSize));
534 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
536 for (Index i = 0; i < depth; i++) {
538 for (; k < rows; k++) {
539 blockAt[rir] = lhs(k, i).real();
542 blockAt[rii] = -lhs(k, i).imag();
544 blockAt[rii] = lhs(k, i).imag();
556template <
typename Scalar,
typename DataMapper,
typename Packet,
int StorageOrder,
bool PanelMode,
bool UseLhs>
559 EIGEN_ALWAYS_INLINE
void dhs_copy(Scalar* blockA,
const DataMapper& lhs2, Index& i, Index& ri, Index depth,
560 const Index vectorSize) {
561 PacketBlock<Packet, 4> block[n];
563 for (; i + n * vectorSize <= depth; i += n * vectorSize) {
564 for (Index k = 0; k < n; k++) {
566 bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, 0, i + k * vectorSize);
568 bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, i + k * vectorSize, 0);
572 if (((StorageOrder ==
RowMajor) && UseLhs) || ((StorageOrder ==
ColMajor) && !UseLhs)) {
573 for (Index k = 0; k < n; k++) {
574 ptranspose(block[k]);
578 for (Index k = 0; k < n; k++) {
579 storeBlock<Scalar, Packet, 4>(blockA + ri + k * 4 * vectorSize, block[k]);
582 ri += n * 4 * vectorSize;
586 EIGEN_STRONG_INLINE
void operator()(Scalar* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride,
588 const Index vectorSize = quad_traits<Scalar>::vectorsize;
591 for (; j + vectorSize <= rows; j += vectorSize) {
592 const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
595 if (PanelMode) ri += vectorSize * offset;
597 dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
598 dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
599 dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
601 for (; i < depth; i++) {
602 if (((StorageOrder ==
RowMajor) && UseLhs) || ((StorageOrder ==
ColMajor) && !UseLhs)) {
604 blockA[ri + 0] = lhs2(0, i);
605 blockA[ri + 1] = lhs2(1, i);
606 blockA[ri + 2] = lhs2(2, i);
607 blockA[ri + 3] = lhs2(3, i);
609 blockA[ri + 0] = lhs2(i, 0);
610 blockA[ri + 1] = lhs2(i, 1);
611 blockA[ri + 2] = lhs2(i, 2);
612 blockA[ri + 3] = lhs2(i, 3);
617 lhsV = lhs2.template loadPacket<Packet>(0, i);
619 lhsV = lhs2.template loadPacket<Packet>(i, 0);
621 pstore<Scalar>(blockA + ri, lhsV);
627 if (PanelMode) ri += vectorSize * (stride - offset - depth);
631 if (PanelMode) ri += offset;
633 for (; j < rows; j++) {
634 const DataMapper lhs2 = lhs.getSubMapper(0, j);
635 for (Index i = 0; i < depth; i++) {
636 blockA[ri] = lhs2(i, 0);
640 if (PanelMode) ri += stride - depth;
644 if (PanelMode) ri += offset * (rows - j);
646 for (Index i = 0; i < depth; i++) {
648 for (; k < rows; k++) {
649 blockA[ri] = lhs(k, i);
659template <
typename DataMapper,
int StorageOrder,
bool PanelMode>
660struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true> {
662 EIGEN_ALWAYS_INLINE
void dhs_copy(
double* blockA,
const DataMapper& lhs2, Index& i, Index& ri, Index depth,
663 const Index vectorSize) {
664 PacketBlock<Packet2d, 2> block[n];
666 for (; i + n * vectorSize <= depth; i += n * vectorSize) {
667 for (Index k = 0; k < n; k++) {
669 block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize);
670 block[k].packet[1] = lhs2.template loadPacket<Packet2d>(1, i + k * vectorSize);
672 block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize + 0);
673 block[k].packet[1] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize + 1);
678 for (Index k = 0; k < n; k++) {
679 ptranspose(block[k]);
683 for (Index k = 0; k < n; k++) {
684 storeBlock<double, Packet2d, 2>(blockA + ri + k * 2 * vectorSize, block[k]);
687 ri += n * 2 * vectorSize;
691 EIGEN_STRONG_INLINE
void operator()(
double* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride,
693 const Index vectorSize = quad_traits<double>::vectorsize;
696 for (; j + vectorSize <= rows; j += vectorSize) {
697 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
700 if (PanelMode) ri += vectorSize * offset;
702 dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
703 dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
704 dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
706 for (; i < depth; i++) {
708 blockA[ri + 0] = lhs2(0, i);
709 blockA[ri + 1] = lhs2(1, i);
711 Packet2d lhsV = lhs2.template loadPacket<Packet2d>(0, i);
712 pstore<double>(blockA + ri, lhsV);
718 if (PanelMode) ri += vectorSize * (stride - offset - depth);
722 if (PanelMode) ri += offset * (rows - j);
724 for (Index i = 0; i < depth; i++) {
726 for (; k < rows; k++) {
727 blockA[ri] = lhs(k, i);
736template <
typename DataMapper,
int StorageOrder,
bool PanelMode>
737struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false> {
739 EIGEN_ALWAYS_INLINE
void dhs_copy(
double* blockB,
const DataMapper& rhs2, Index& i, Index& ri, Index depth,
740 const Index vectorSize) {
741 PacketBlock<Packet2d, 2> block1[n], block2[n];
742 PacketBlock<Packet2d, 4> block3[n];
744 for (; i + n * vectorSize <= depth; i += n * vectorSize) {
745 for (Index k = 0; k < n; k++) {
747 block1[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 0);
748 block1[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 1);
749 block2[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 2);
750 block2[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 3);
752 block3[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 0, 0);
753 block3[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 0, 2);
754 block3[k].packet[2] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 1, 0);
755 block3[k].packet[3] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 1, 2);
760 for (Index k = 0; k < n; k++) {
761 ptranspose(block1[k]);
762 ptranspose(block2[k]);
766 for (Index k = 0; k < n; k++) {
768 pstore<double>(blockB + ri + k * 4 * vectorSize, block1[k].packet[0]);
769 pstore<double>(blockB + ri + k * 4 * vectorSize + 2, block2[k].packet[0]);
770 pstore<double>(blockB + ri + k * 4 * vectorSize + 4, block1[k].packet[1]);
771 pstore<double>(blockB + ri + k * 4 * vectorSize + 6, block2[k].packet[1]);
773 storeBlock<double, Packet2d, 4>(blockB + ri + k * 4 * vectorSize, block3[k]);
777 ri += n * 4 * vectorSize;
781 EIGEN_STRONG_INLINE
void operator()(
double* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride,
783 const Index vectorSize = quad_traits<double>::vectorsize;
786 for (; j + 2 * vectorSize <= cols; j += 2 * vectorSize) {
787 const DataMapper rhs2 = rhs.getSubMapper(0, j);
790 if (PanelMode) ri += offset * (2 * vectorSize);
792 dhs_copy<4>(blockB, rhs2, i, ri, depth, vectorSize);
793 dhs_copy<2>(blockB, rhs2, i, ri, depth, vectorSize);
794 dhs_copy<1>(blockB, rhs2, i, ri, depth, vectorSize);
796 for (; i < depth; i++) {
798 blockB[ri + 0] = rhs2(i, 0);
799 blockB[ri + 1] = rhs2(i, 1);
803 blockB[ri + 0] = rhs2(i, 2);
804 blockB[ri + 1] = rhs2(i, 3);
806 Packet2d rhsV = rhs2.template loadPacket<Packet2d>(i, 0);
807 pstore<double>(blockB + ri, rhsV);
811 rhsV = rhs2.template loadPacket<Packet2d>(i, 2);
812 pstore<double>(blockB + ri, rhsV);
817 if (PanelMode) ri += (2 * vectorSize) * (stride - offset - depth);
820 if (PanelMode) ri += offset;
822 for (; j < cols; j++) {
823 const DataMapper rhs2 = rhs.getSubMapper(0, j);
824 for (Index i = 0; i < depth; i++) {
825 blockB[ri] = rhs2(i, 0);
829 if (PanelMode) ri += stride - depth;
835template <
typename DataMapper,
int StorageOrder,
bool PanelMode>
836struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true> {
837 EIGEN_STRONG_INLINE
void operator()(bfloat16* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride,
839 const Index vectorSize = quad_traits<bfloat16>::vectorsize;
842 for (; j + 2 * vectorSize <= rows; j += 2 * vectorSize) {
843 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
846 if (PanelMode) ri += 2 * vectorSize * offset;
849 for (; i + 2 <= depth; i += 2) {
850 PacketBlock<Packet8bf, 4> block;
852 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
853 block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
854 block.packet[2] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
855 block.packet[3] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 1);
858 t0 = vec_mergeh(block.packet[0].m_val, block.packet[2].m_val);
859 t1 = vec_mergel(block.packet[0].m_val, block.packet[2].m_val);
860 block.packet[2] = vec_mergeh(block.packet[1].m_val, block.packet[3].m_val);
861 block.packet[3] = vec_mergel(block.packet[1].m_val, block.packet[3].m_val);
862 block.packet[0] = t0;
863 block.packet[1] = t1;
865 storeBlock<bfloat16, Packet8bf, 4>(blockA + ri, block);
867 ri += 2 * 2 * vectorSize;
870 PacketBlock<Packet8bf, 2> block;
872 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
873 block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
875 storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
877 ri += 2 * vectorSize;
880 for (; i + vectorSize <= depth; i += vectorSize) {
881 PacketBlock<Packet8bf, 8> block1, block2;
883 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
884 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block2, lhs2, 1 * vectorSize, i);
886 Packet4ui v1[8], v2[8];
888 v1[0] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block1.packet[0].m_val),
889 reinterpret_cast<Packet4ui
>(block1.packet[1].m_val));
890 v1[1] = vec_mergel(
reinterpret_cast<Packet4ui
>(block1.packet[0].m_val),
891 reinterpret_cast<Packet4ui
>(block1.packet[1].m_val));
892 v1[2] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block1.packet[2].m_val),
893 reinterpret_cast<Packet4ui
>(block1.packet[3].m_val));
894 v1[3] = vec_mergel(
reinterpret_cast<Packet4ui
>(block1.packet[2].m_val),
895 reinterpret_cast<Packet4ui
>(block1.packet[3].m_val));
896 v1[4] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block1.packet[4].m_val),
897 reinterpret_cast<Packet4ui
>(block1.packet[5].m_val));
898 v1[5] = vec_mergel(
reinterpret_cast<Packet4ui
>(block1.packet[4].m_val),
899 reinterpret_cast<Packet4ui
>(block1.packet[5].m_val));
900 v1[6] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block1.packet[6].m_val),
901 reinterpret_cast<Packet4ui
>(block1.packet[7].m_val));
902 v1[7] = vec_mergel(
reinterpret_cast<Packet4ui
>(block1.packet[6].m_val),
903 reinterpret_cast<Packet4ui
>(block1.packet[7].m_val));
904 v2[0] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block2.packet[0].m_val),
905 reinterpret_cast<Packet4ui
>(block2.packet[1].m_val));
906 v2[1] = vec_mergel(
reinterpret_cast<Packet4ui
>(block2.packet[0].m_val),
907 reinterpret_cast<Packet4ui
>(block2.packet[1].m_val));
908 v2[2] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block2.packet[2].m_val),
909 reinterpret_cast<Packet4ui
>(block2.packet[3].m_val));
910 v2[3] = vec_mergel(
reinterpret_cast<Packet4ui
>(block2.packet[2].m_val),
911 reinterpret_cast<Packet4ui
>(block2.packet[3].m_val));
912 v2[4] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block2.packet[4].m_val),
913 reinterpret_cast<Packet4ui
>(block2.packet[5].m_val));
914 v2[5] = vec_mergel(
reinterpret_cast<Packet4ui
>(block2.packet[4].m_val),
915 reinterpret_cast<Packet4ui
>(block2.packet[5].m_val));
916 v2[6] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block2.packet[6].m_val),
917 reinterpret_cast<Packet4ui
>(block2.packet[7].m_val));
918 v2[7] = vec_mergel(
reinterpret_cast<Packet4ui
>(block2.packet[6].m_val),
919 reinterpret_cast<Packet4ui
>(block2.packet[7].m_val));
921#ifdef EIGEN_VECTORIZE_VSX
922 block1.packet[0] =
reinterpret_cast<Packet8us
>(
923 vec_mergeh(
reinterpret_cast<Packet2ul
>(v1[0]),
reinterpret_cast<Packet2ul
>(v1[2])));
924 block1.packet[2] =
reinterpret_cast<Packet8us
>(
925 vec_mergel(
reinterpret_cast<Packet2ul
>(v1[0]),
reinterpret_cast<Packet2ul
>(v1[2])));
926 block1.packet[4] =
reinterpret_cast<Packet8us
>(
927 vec_mergeh(
reinterpret_cast<Packet2ul
>(v1[1]),
reinterpret_cast<Packet2ul
>(v1[3])));
928 block1.packet[6] =
reinterpret_cast<Packet8us
>(
929 vec_mergel(
reinterpret_cast<Packet2ul
>(v1[1]),
reinterpret_cast<Packet2ul
>(v1[3])));
930 block1.packet[1] =
reinterpret_cast<Packet8us
>(
931 vec_mergeh(
reinterpret_cast<Packet2ul
>(v1[4]),
reinterpret_cast<Packet2ul
>(v1[6])));
932 block1.packet[3] =
reinterpret_cast<Packet8us
>(
933 vec_mergel(
reinterpret_cast<Packet2ul
>(v1[4]),
reinterpret_cast<Packet2ul
>(v1[6])));
934 block1.packet[5] =
reinterpret_cast<Packet8us
>(
935 vec_mergeh(
reinterpret_cast<Packet2ul
>(v1[5]),
reinterpret_cast<Packet2ul
>(v1[7])));
936 block1.packet[7] =
reinterpret_cast<Packet8us
>(
937 vec_mergel(
reinterpret_cast<Packet2ul
>(v1[5]),
reinterpret_cast<Packet2ul
>(v1[7])));
938 block2.packet[0] =
reinterpret_cast<Packet8us
>(
939 vec_mergeh(
reinterpret_cast<Packet2ul
>(v2[0]),
reinterpret_cast<Packet2ul
>(v2[2])));
940 block2.packet[2] =
reinterpret_cast<Packet8us
>(
941 vec_mergel(
reinterpret_cast<Packet2ul
>(v2[0]),
reinterpret_cast<Packet2ul
>(v2[2])));
942 block2.packet[4] =
reinterpret_cast<Packet8us
>(
943 vec_mergeh(
reinterpret_cast<Packet2ul
>(v2[1]),
reinterpret_cast<Packet2ul
>(v2[3])));
944 block2.packet[6] =
reinterpret_cast<Packet8us
>(
945 vec_mergel(
reinterpret_cast<Packet2ul
>(v2[1]),
reinterpret_cast<Packet2ul
>(v2[3])));
946 block2.packet[1] =
reinterpret_cast<Packet8us
>(
947 vec_mergeh(
reinterpret_cast<Packet2ul
>(v2[4]),
reinterpret_cast<Packet2ul
>(v2[6])));
948 block2.packet[3] =
reinterpret_cast<Packet8us
>(
949 vec_mergel(
reinterpret_cast<Packet2ul
>(v2[4]),
reinterpret_cast<Packet2ul
>(v2[6])));
950 block2.packet[5] =
reinterpret_cast<Packet8us
>(
951 vec_mergeh(
reinterpret_cast<Packet2ul
>(v2[5]),
reinterpret_cast<Packet2ul
>(v2[7])));
952 block2.packet[7] =
reinterpret_cast<Packet8us
>(
953 vec_mergel(
reinterpret_cast<Packet2ul
>(v2[5]),
reinterpret_cast<Packet2ul
>(v2[7])));
955 block1.packet[0] =
reinterpret_cast<Packet8us
>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_HI));
956 block1.packet[2] =
reinterpret_cast<Packet8us
>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_LO));
957 block1.packet[4] =
reinterpret_cast<Packet8us
>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_HI));
958 block1.packet[6] =
reinterpret_cast<Packet8us
>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_LO));
959 block1.packet[1] =
reinterpret_cast<Packet8us
>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_HI));
960 block1.packet[3] =
reinterpret_cast<Packet8us
>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_LO));
961 block1.packet[5] =
reinterpret_cast<Packet8us
>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_HI));
962 block1.packet[7] =
reinterpret_cast<Packet8us
>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_LO));
963 block2.packet[0] =
reinterpret_cast<Packet8us
>(vec_perm(v2[0], v2[2], p16uc_TRANSPOSE64_HI));
964 block2.packet[2] =
reinterpret_cast<Packet8us
>(vec_perm(v2[0], v2[2], p16uc_TRANSPOSE64_LO));
965 block2.packet[4] =
reinterpret_cast<Packet8us
>(vec_perm(v2[1], v2[3], p16uc_TRANSPOSE64_HI));
966 block2.packet[6] =
reinterpret_cast<Packet8us
>(vec_perm(v2[1], v2[3], p16uc_TRANSPOSE64_LO));
967 block2.packet[1] =
reinterpret_cast<Packet8us
>(vec_perm(v2[4], v2[6], p16uc_TRANSPOSE64_HI));
968 block2.packet[3] =
reinterpret_cast<Packet8us
>(vec_perm(v2[4], v2[6], p16uc_TRANSPOSE64_LO));
969 block2.packet[5] =
reinterpret_cast<Packet8us
>(vec_perm(v2[5], v2[7], p16uc_TRANSPOSE64_HI));
970 block2.packet[7] =
reinterpret_cast<Packet8us
>(vec_perm(v2[5], v2[7], p16uc_TRANSPOSE64_LO));
973 for (Index M = 0; M < 8; M += 2) {
974 pstore<bfloat16>(blockA + ri + (0 * vectorSize) + (2 * vectorSize * M), block1.packet[M + 0]);
975 pstore<bfloat16>(blockA + ri + (1 * vectorSize) + (2 * vectorSize * M), block1.packet[M + 1]);
976 pstore<bfloat16>(blockA + ri + (2 * vectorSize) + (2 * vectorSize * M), block2.packet[M + 0]);
977 pstore<bfloat16>(blockA + ri + (3 * vectorSize) + (2 * vectorSize * M), block2.packet[M + 1]);
980 ri += 2 * vectorSize * vectorSize;
982 for (; i + 2 <= depth; i += 2) {
983 for (Index M = 0; M < 2 * vectorSize; M++) {
984 blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
985 blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
988 ri += 2 * 2 * vectorSize;
991 for (Index M = 0; M < 2 * vectorSize; M++) {
992 blockA[ri + M] = lhs2(M, i);
994 ri += 2 * vectorSize;
998 if (PanelMode) ri += 2 * vectorSize * (stride - offset - depth);
1000 for (; j + vectorSize <= rows; j += vectorSize) {
1001 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1004 if (PanelMode) ri += vectorSize * offset;
1007 for (; i + 2 <= depth; i += 2) {
1008 PacketBlock<Packet8bf, 2> block;
1010 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
1011 block.packet[1] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
1014 t0 = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1015 block.packet[1] = vec_mergel(block.packet[0].m_val, block.packet[1].m_val);
1016 block.packet[0] = t0;
1018 storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
1020 ri += 2 * vectorSize;
1023 Packet8bf lhsV = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
1024 pstore<bfloat16>(blockA + ri, lhsV);
1029 for (; i + vectorSize <= depth; i += vectorSize) {
1030 PacketBlock<Packet8bf, 8> block1;
1032 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
1037 v1[0] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block1.packet[0].m_val),
1038 reinterpret_cast<Packet4ui
>(block1.packet[1].m_val));
1039 v1[1] = vec_mergel(
reinterpret_cast<Packet4ui
>(block1.packet[0].m_val),
1040 reinterpret_cast<Packet4ui
>(block1.packet[1].m_val));
1041 v1[2] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block1.packet[2].m_val),
1042 reinterpret_cast<Packet4ui
>(block1.packet[3].m_val));
1043 v1[3] = vec_mergel(
reinterpret_cast<Packet4ui
>(block1.packet[2].m_val),
1044 reinterpret_cast<Packet4ui
>(block1.packet[3].m_val));
1045 v1[4] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block1.packet[4].m_val),
1046 reinterpret_cast<Packet4ui
>(block1.packet[5].m_val));
1047 v1[5] = vec_mergel(
reinterpret_cast<Packet4ui
>(block1.packet[4].m_val),
1048 reinterpret_cast<Packet4ui
>(block1.packet[5].m_val));
1049 v1[6] = vec_mergeh(
reinterpret_cast<Packet4ui
>(block1.packet[6].m_val),
1050 reinterpret_cast<Packet4ui
>(block1.packet[7].m_val));
1051 v1[7] = vec_mergel(
reinterpret_cast<Packet4ui
>(block1.packet[6].m_val),
1052 reinterpret_cast<Packet4ui
>(block1.packet[7].m_val));
1054#ifdef EIGEN_VECTORIZE_VSX
1055 block1.packet[0] =
reinterpret_cast<Packet8us
>(
1056 vec_mergeh(
reinterpret_cast<Packet2ul
>(v1[0]),
reinterpret_cast<Packet2ul
>(v1[2])));
1057 block1.packet[2] =
reinterpret_cast<Packet8us
>(
1058 vec_mergel(
reinterpret_cast<Packet2ul
>(v1[0]),
reinterpret_cast<Packet2ul
>(v1[2])));
1059 block1.packet[4] =
reinterpret_cast<Packet8us
>(
1060 vec_mergeh(
reinterpret_cast<Packet2ul
>(v1[1]),
reinterpret_cast<Packet2ul
>(v1[3])));
1061 block1.packet[6] =
reinterpret_cast<Packet8us
>(
1062 vec_mergel(
reinterpret_cast<Packet2ul
>(v1[1]),
reinterpret_cast<Packet2ul
>(v1[3])));
1063 block1.packet[1] =
reinterpret_cast<Packet8us
>(
1064 vec_mergeh(
reinterpret_cast<Packet2ul
>(v1[4]),
reinterpret_cast<Packet2ul
>(v1[6])));
1065 block1.packet[3] =
reinterpret_cast<Packet8us
>(
1066 vec_mergel(
reinterpret_cast<Packet2ul
>(v1[4]),
reinterpret_cast<Packet2ul
>(v1[6])));
1067 block1.packet[5] =
reinterpret_cast<Packet8us
>(
1068 vec_mergeh(
reinterpret_cast<Packet2ul
>(v1[5]),
reinterpret_cast<Packet2ul
>(v1[7])));
1069 block1.packet[7] =
reinterpret_cast<Packet8us
>(
1070 vec_mergel(
reinterpret_cast<Packet2ul
>(v1[5]),
reinterpret_cast<Packet2ul
>(v1[7])));
1072 block1.packet[0] =
reinterpret_cast<Packet8us
>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_HI));
1073 block1.packet[2] =
reinterpret_cast<Packet8us
>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_LO));
1074 block1.packet[4] =
reinterpret_cast<Packet8us
>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_HI));
1075 block1.packet[6] =
reinterpret_cast<Packet8us
>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_LO));
1076 block1.packet[1] =
reinterpret_cast<Packet8us
>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_HI));
1077 block1.packet[3] =
reinterpret_cast<Packet8us
>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_LO));
1078 block1.packet[5] =
reinterpret_cast<Packet8us
>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_HI));
1079 block1.packet[7] =
reinterpret_cast<Packet8us
>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_LO));
1082 for (Index M = 0; M < 8; M++) {
1083 pstore<bfloat16>(blockA + ri + (vectorSize * M), block1.packet[M]);
1086 ri += vectorSize * vectorSize;
1088 for (; i + 2 <= depth; i += 2) {
1089 for (Index M = 0; M < vectorSize; M++) {
1090 blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
1091 blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
1094 ri += 2 * vectorSize;
1097 for (Index M = 0; M < vectorSize; M++) {
1098 blockA[ri + M] = lhs2(M, i);
1105 if (PanelMode) ri += vectorSize * (stride - offset - depth);
1107 if (j + 4 <= rows) {
1108 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1111 if (PanelMode) ri += 4 * offset;
1113 for (; i + 2 <= depth; i += 2) {
1115 PacketBlock<Packet8bf, 2> block;
1117 block.packet[0] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
1118 block.packet[1] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 1, 4);
1120 block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1122 pstore<bfloat16>(blockA + ri, block.packet[0]);
1124 blockA[ri + 0] = lhs2(0, i + 0);
1125 blockA[ri + 1] = lhs2(0, i + 1);
1126 blockA[ri + 2] = lhs2(1, i + 0);
1127 blockA[ri + 3] = lhs2(1, i + 1);
1128 blockA[ri + 4] = lhs2(2, i + 0);
1129 blockA[ri + 5] = lhs2(2, i + 1);
1130 blockA[ri + 6] = lhs2(3, i + 0);
1131 blockA[ri + 7] = lhs2(3, i + 1);
1138 Packet8bf lhsV = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
1140 pstore_partial<bfloat16>(blockA + ri, lhsV, 4);
1142 blockA[ri + 0] = lhs2(0, i);
1143 blockA[ri + 1] = lhs2(1, i);
1144 blockA[ri + 2] = lhs2(2, i);
1145 blockA[ri + 3] = lhs2(3, i);
1151 if (PanelMode) ri += 4 * (stride - offset - depth);
1156 if (PanelMode) ri += offset * (rows - j);
1159 for (; i + 2 <= depth; i += 2) {
1161 for (; k < rows; k++) {
1162 blockA[ri + 0] = lhs(k, i + 0);
1163 blockA[ri + 1] = lhs(k, i + 1);
1168 for (; j < rows; j++) {
1169 blockA[ri] = lhs(j, i);
1178template <
typename DataMapper,
int StorageOrder,
bool PanelMode>
1179struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false> {
1180 EIGEN_STRONG_INLINE
void operator()(bfloat16* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride,
1182 const Index vectorSize = quad_traits<bfloat16>::vectorsize;
1183 Index ri = 0, j = 0;
1185 for (; j + 4 <= cols; j += 4) {
1186 const DataMapper rhs2 = rhs.getSubMapper(0, j);
1189 if (PanelMode) ri += 4 * offset;
1191 for (; i + vectorSize <= depth; i += vectorSize) {
1193 PacketBlock<Packet8bf, 4> block;
1195 bload<DataMapper, Packet8bf, 4, StorageOrder, false, 4>(block, rhs2, i, 0);
1197 Packet4ui t0, t1, t2, t3;
1199 t0 = vec_mergeh(
reinterpret_cast<Packet4ui
>(block.packet[0].m_val),
1200 reinterpret_cast<Packet4ui
>(block.packet[1].m_val));
1201 t1 = vec_mergel(
reinterpret_cast<Packet4ui
>(block.packet[0].m_val),
1202 reinterpret_cast<Packet4ui
>(block.packet[1].m_val));
1203 t2 = vec_mergeh(
reinterpret_cast<Packet4ui
>(block.packet[2].m_val),
1204 reinterpret_cast<Packet4ui
>(block.packet[3].m_val));
1205 t3 = vec_mergel(
reinterpret_cast<Packet4ui
>(block.packet[2].m_val),
1206 reinterpret_cast<Packet4ui
>(block.packet[3].m_val));
1208#ifdef EIGEN_VECTORIZE_VSX
1210 reinterpret_cast<Packet8us
>(vec_mergeh(
reinterpret_cast<Packet2ul
>(t0),
reinterpret_cast<Packet2ul
>(t2)));
1212 reinterpret_cast<Packet8us
>(vec_mergel(
reinterpret_cast<Packet2ul
>(t0),
reinterpret_cast<Packet2ul
>(t2)));
1214 reinterpret_cast<Packet8us
>(vec_mergeh(
reinterpret_cast<Packet2ul
>(t1),
reinterpret_cast<Packet2ul
>(t3)));
1216 reinterpret_cast<Packet8us
>(vec_mergel(
reinterpret_cast<Packet2ul
>(t1),
reinterpret_cast<Packet2ul
>(t3)));
1218 block.packet[0] =
reinterpret_cast<Packet8us
>(vec_perm(t0, t2, p16uc_TRANSPOSE64_HI));
1219 block.packet[1] =
reinterpret_cast<Packet8us
>(vec_perm(t0, t2, p16uc_TRANSPOSE64_LO));
1220 block.packet[2] =
reinterpret_cast<Packet8us
>(vec_perm(t1, t3, p16uc_TRANSPOSE64_HI));
1221 block.packet[3] =
reinterpret_cast<Packet8us
>(vec_perm(t1, t3, p16uc_TRANSPOSE64_LO));
1224 storeBlock<bfloat16, Packet8bf, 4>(blockB + ri, block);
1226 PacketBlock<Packet8bf, 8> block;
1228 for (
int M = 0; M < 8; M++) {
1229 block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
1232 block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1233 block.packet[1] = vec_mergeh(block.packet[2].m_val, block.packet[3].m_val);
1234 block.packet[2] = vec_mergeh(block.packet[4].m_val, block.packet[5].m_val);
1235 block.packet[3] = vec_mergeh(block.packet[6].m_val, block.packet[7].m_val);
1237 const Index size = 16 /
sizeof(bfloat16);
1239 for (
int M = 0; M < 4; M++) {
1240 pstore<bfloat16>(blockB + ri + (M * size), block.packet[M]);
1244 ri += 4 * vectorSize;
1246 for (; i + 2 <= depth; i += 2) {
1248 blockB[ri + 0] = rhs2(i + 0, 0);
1249 blockB[ri + 1] = rhs2(i + 1, 0);
1250 blockB[ri + 2] = rhs2(i + 0, 1);
1251 blockB[ri + 3] = rhs2(i + 1, 1);
1252 blockB[ri + 4] = rhs2(i + 0, 2);
1253 blockB[ri + 5] = rhs2(i + 1, 2);
1254 blockB[ri + 6] = rhs2(i + 0, 3);
1255 blockB[ri + 7] = rhs2(i + 1, 3);
1257 PacketBlock<Packet8bf, 2> block;
1259 for (
int M = 0; M < 2; M++) {
1260 block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
1263 block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1265 pstore<bfloat16>(blockB + ri, block.packet[0]);
1271 blockB[ri + 0] = rhs2(i, 0);
1272 blockB[ri + 1] = rhs2(i, 1);
1273 blockB[ri + 2] = rhs2(i, 2);
1274 blockB[ri + 3] = rhs2(i, 3);
1279 if (PanelMode) ri += 4 * (stride - offset - depth);
1283 if (PanelMode) ri += offset * (cols - j);
1286 for (; i + 2 <= depth; i += 2) {
1288 for (; k < cols; k++) {
1289 blockB[ri + 0] = rhs(i + 0, k);
1290 blockB[ri + 1] = rhs(i + 1, k);
1295 for (; j < cols; j++) {
1296 blockB[ri] = rhs(i, j);
1305template <
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
bool PanelMode>
1306struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true> {
1307 EIGEN_ALWAYS_INLINE
void dhs_ccopy(
double* blockAt,
const DataMapper& lhs2, Index& i, Index& rir, Index& rii,
1308 Index depth,
const Index vectorSize) {
1309 PacketBlock<Packet, 2> blockr, blocki;
1310 PacketBlock<PacketC, 4> cblock;
1312 for (; i + vectorSize <= depth; i += vectorSize) {
1314 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0);
1315 cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1);
1317 cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0);
1318 cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1);
1320 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v);
1321 blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v);
1323 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
1324 blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
1326 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
1327 cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i);
1329 cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1);
1330 cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1);
1332 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1333 blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
1335 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1336 blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1340 blocki.packet[0] = -blocki.packet[0];
1341 blocki.packet[1] = -blocki.packet[1];
1344 storeBlock<double, Packet, 2>(blockAt + rir, blockr);
1345 storeBlock<double, Packet, 2>(blockAt + rii, blocki);
1347 rir += 2 * vectorSize;
1348 rii += 2 * vectorSize;
1352 EIGEN_STRONG_INLINE
void operator()(std::complex<double>* blockA,
const DataMapper& lhs, Index depth, Index rows,
1353 Index stride, Index offset) {
1354 const Index vectorSize = quad_traits<double>::vectorsize;
1355 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
1356 Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
1357 double* blockAt =
reinterpret_cast<double*
>(blockA);
1360 for (; j + vectorSize <= rows; j += vectorSize) {
1361 const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1364 rii = rir + vectorDelta;
1366 dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
1368 for (; i < depth; i++) {
1369 PacketBlock<Packet, 1> blockr, blocki;
1370 PacketBlock<PacketC, 2> cblock;
1372 cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
1373 cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i);
1375 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1376 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1379 blocki.packet[0] = -blocki.packet[0];
1382 pstore<double>(blockAt + rir, blockr.packet[0]);
1383 pstore<double>(blockAt + rii, blocki.packet[0]);
1389 rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
1393 if (PanelMode) rir += (offset * (rows - j - vectorSize));
1394 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
1396 for (Index i = 0; i < depth; i++) {
1398 for (; k < rows; k++) {
1399 blockAt[rir] = lhs(k, i).real();
1402 blockAt[rii] = -lhs(k, i).imag();
1404 blockAt[rii] = lhs(k, i).imag();
1415template <
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
bool PanelMode>
1416struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false> {
1417 EIGEN_ALWAYS_INLINE
void dhs_ccopy(
double* blockBt,
const DataMapper& rhs2, Index& i, Index& rir, Index& rii,
1418 Index depth,
const Index vectorSize) {
1419 for (; i < depth; i++) {
1420 PacketBlock<PacketC, 4> cblock;
1421 PacketBlock<Packet, 2> blockr, blocki;
1423 bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
1425 blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1426 blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
1428 blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1429 blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1432 blocki.packet[0] = -blocki.packet[0];
1433 blocki.packet[1] = -blocki.packet[1];
1436 storeBlock<double, Packet, 2>(blockBt + rir, blockr);
1437 storeBlock<double, Packet, 2>(blockBt + rii, blocki);
1439 rir += 2 * vectorSize;
1440 rii += 2 * vectorSize;
1444 EIGEN_STRONG_INLINE
void operator()(std::complex<double>* blockB,
const DataMapper& rhs, Index depth, Index cols,
1445 Index stride, Index offset) {
1446 const Index vectorSize = quad_traits<double>::vectorsize;
1447 const Index vectorDelta = 2 * vectorSize * ((PanelMode) ? stride : depth);
1448 Index rir = ((PanelMode) ? (2 * vectorSize * offset) : 0), rii;
1449 double* blockBt =
reinterpret_cast<double*
>(blockB);
1452 for (; j + 2 * vectorSize <= cols; j += 2 * vectorSize) {
1453 const DataMapper rhs2 = rhs.getSubMapper(0, j);
1456 rii = rir + vectorDelta;
1458 dhs_ccopy(blockBt, rhs2, i, rir, rii, depth, vectorSize);
1460 rir += ((PanelMode) ? (2 * vectorSize * (2 * stride - depth)) : vectorDelta);
1463 if (PanelMode) rir -= (offset * (2 * vectorSize - 1));
1465 for (; j < cols; j++) {
1466 const DataMapper rhs2 = rhs.getSubMapper(0, j);
1467 rii = rir + ((PanelMode) ? stride : depth);
1469 for (Index i = 0; i < depth; i++) {
1470 blockBt[rir] = rhs2(i, 0).real();
1473 blockBt[rii] = -rhs2(i, 0).imag();
1475 blockBt[rii] = rhs2(i, 0).imag();
1481 rir += ((PanelMode) ? (2 * stride - depth) : depth);
1491template <
typename Packet,
bool NegativeAccumulate,
int N>
1492EIGEN_ALWAYS_INLINE
void pger_common(PacketBlock<Packet, N>* acc,
const Packet& lhsV,
const Packet* rhsV) {
1493 if (NegativeAccumulate) {
1494 for (
int M = 0; M < N; M++) {
1495 acc->packet[M] = vec_nmsub(lhsV, rhsV[M], acc->packet[M]);
1498 for (
int M = 0; M < N; M++) {
1499 acc->packet[M] = vec_madd(lhsV, rhsV[M], acc->packet[M]);
1504template <
int N,
typename Scalar,
typename Packet,
bool NegativeAccumulate>
1505EIGEN_ALWAYS_INLINE
void pger(PacketBlock<Packet, N>* acc,
const Scalar* lhs,
const Packet* rhsV) {
1506 Packet lhsV = pload<Packet>(lhs);
1508 pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1513template <
int N,
typename Packet,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
1514EIGEN_ALWAYS_INLINE
void pgerc_common(PacketBlock<Packet, N>* accReal, PacketBlock<Packet, N>* accImag,
1515 const Packet& lhsV, Packet& lhsVi,
const Packet* rhsV,
const Packet* rhsVi) {
1516 pger_common<Packet, false, N>(accReal, lhsV, rhsV);
1518 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1519 EIGEN_UNUSED_VARIABLE(lhsVi);
1522 pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
1523 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1525 EIGEN_UNUSED_VARIABLE(rhsVi);
1527 pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
1531template <
int N,
typename Scalar,
typename Packet,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
1532EIGEN_ALWAYS_INLINE
void pgerc(PacketBlock<Packet, N>* accReal, PacketBlock<Packet, N>* accImag,
const Scalar* lhs_ptr,
1533 const Scalar* lhs_ptr_imag,
const Packet* rhsV,
const Packet* rhsVi) {
1534 Packet lhsV = ploadLhs<Packet>(lhs_ptr);
1537 lhsVi = ploadLhs<Packet>(lhs_ptr_imag);
1539 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1541 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1544template <
typename Packet>
1545EIGEN_ALWAYS_INLINE Packet ploadLhs(
const __UNPACK_TYPE__(Packet) * lhs) {
1546 return ploadu<Packet>(lhs);
1550template <
typename Packet,
int N>
1551EIGEN_ALWAYS_INLINE
void bsetzero(PacketBlock<Packet, N>& acc) {
1552 for (
int M = 0; M < N; M++) {
1553 acc.packet[M] = pset1<Packet>((__UNPACK_TYPE__(Packet))0);
1557template <
typename Packet,
int N>
1558EIGEN_ALWAYS_INLINE
void bscalec_common(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ,
1559 const Packet& pAlpha) {
1560 for (
int M = 0; M < N; M++) {
1561 acc.packet[M] = vec_mul(accZ.packet[M], pAlpha);
1565template <
typename Packet,
int N>
1566EIGEN_ALWAYS_INLINE
void band(PacketBlock<Packet, N>& acc,
const Packet& pMask) {
1567 for (
int M = 0; M < N; M++) {
1568 acc.packet[M] = pand<Packet>(acc.packet[M], pMask);
1573template <
typename Packet,
int N,
bool mask>
1574EIGEN_ALWAYS_INLINE
void bscalec(PacketBlock<Packet, N>& aReal, PacketBlock<Packet, N>& aImag,
const Packet& bReal,
1575 const Packet& bImag, PacketBlock<Packet, N>& cReal, PacketBlock<Packet, N>& cImag,
1576 const Packet& pMask) {
1577 if (mask && (
sizeof(__UNPACK_TYPE__(Packet)) ==
sizeof(
float))) {
1578 band<Packet, N>(aReal, pMask);
1579 band<Packet, N>(aImag, pMask);
1581 EIGEN_UNUSED_VARIABLE(pMask);
1584 bscalec_common<Packet, N>(cReal, aReal, bReal);
1586 bscalec_common<Packet, N>(cImag, aImag, bReal);
1588 pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
1590 pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
1596template <
typename DataMapper,
typename Packet, const Index accCols,
int StorageOrder,
bool Complex,
int N,
bool full>
1597EIGEN_ALWAYS_INLINE
void bload(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
1600 for (
int M = 0; M < N; M++) {
1601 acc.packet[M] = res.template loadPacket<Packet>(row + M, col);
1604 for (
int M = 0; M < N; M++) {
1605 acc.packet[M + N] = res.template loadPacket<Packet>(row + M, col + accCols);
1609 for (
int M = 0; M < N; M++) {
1610 acc.packet[M] = res.template loadPacket<Packet>(row, col + M);
1612 if (Complex && full) {
1613 for (
int M = 0; M < N; M++) {
1614 acc.packet[M + N] = res.template loadPacket<Packet>(row + accCols, col + M);
1620template <
typename DataMapper,
typename Packet,
int N>
1621EIGEN_ALWAYS_INLINE
void bstore(PacketBlock<Packet, N>& acc,
const DataMapper& res, Index row) {
1622 for (
int M = 0; M < N; M++) {
1623 res.template storePacket<Packet>(row, M, acc.packet[M]);
1627#ifdef USE_PARTIAL_PACKETS
1628template <
typename DataMapper,
typename Packet, const Index accCols,
bool Complex, Index N,
bool full>
1629EIGEN_ALWAYS_INLINE
void bload_partial(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
1631 for (Index M = 0; M < N; M++) {
1632 acc.packet[M] = res.template loadPacketPartial<Packet>(row, M, elements);
1634 if (Complex && full) {
1635 for (Index M = 0; M < N; M++) {
1636 acc.packet[M + N] = res.template loadPacketPartial<Packet>(row + accCols, M, elements);
1641template <
typename DataMapper,
typename Packet, Index N>
1642EIGEN_ALWAYS_INLINE
void bstore_partial(PacketBlock<Packet, N>& acc,
const DataMapper& res, Index row, Index elements) {
1643 for (Index M = 0; M < N; M++) {
1644 res.template storePacketPartial<Packet>(row, M, acc.packet[M], elements);
1650#define USE_P10_AND_PVIPR2_0 (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
1652#define USE_P10_AND_PVIPR2_0 0
1655#if !USE_P10_AND_PVIPR2_0
1656const static Packet4i mask4[4] = {{0, 0, 0, 0}, {-1, 0, 0, 0}, {-1, -1, 0, 0}, {-1, -1, -1, 0}};
1659template <
typename Packet>
1660EIGEN_ALWAYS_INLINE Packet bmask(
const Index remaining_rows) {
1661#if USE_P10_AND_PVIPR2_0
1663 return Packet(vec_reve(vec_genwm((1 << remaining_rows) - 1)));
1665 return Packet(vec_genwm((1 << remaining_rows) - 1));
1668 return Packet(mask4[remaining_rows]);
1673EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(
const Index remaining_rows) {
1674#if USE_P10_AND_PVIPR2_0
1675 Packet2d mask2 = Packet2d(vec_gendm(remaining_rows));
1677 return preverse(mask2);
1682 Packet2l ret = {-remaining_rows, 0};
1683 return Packet2d(ret);
1687template <
typename Packet,
int N>
1688EIGEN_ALWAYS_INLINE
void bscale(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ,
const Packet& pAlpha) {
1689 for (
int M = 0; M < N; M++) {
1690 acc.packet[M] = pmadd<Packet>(pAlpha, accZ.packet[M], acc.packet[M]);
1695template <
typename Packet,
int N,
bool mask>
1696EIGEN_ALWAYS_INLINE
void bscale(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ,
const Packet& pAlpha,
1697 const Packet& pMask) {
1699 band<Packet, N>(accZ, pMask);
1701 EIGEN_UNUSED_VARIABLE(pMask);
1704 bscale<Packet, N>(acc, accZ, pAlpha);
1707template <
typename Packet,
int N,
bool real>
1708EIGEN_ALWAYS_INLINE
void pbroadcastN(
const __UNPACK_TYPE__(Packet) * ap0,
const __UNPACK_TYPE__(Packet) * ap1,
1709 const __UNPACK_TYPE__(Packet) * ap2, Packet& a0, Packet& a1, Packet& a2,
1711 a0 = pset1<Packet>(ap0[0]);
1713 a1 = pset1<Packet>(ap0[1]);
1714 a2 = pset1<Packet>(ap0[2]);
1715 a3 = pset1<Packet>(ap0[3]);
1716 EIGEN_UNUSED_VARIABLE(ap1);
1717 EIGEN_UNUSED_VARIABLE(ap2);
1720 a1 = pset1<Packet>(ap1[0]);
1722 EIGEN_UNUSED_VARIABLE(a1);
1723 EIGEN_UNUSED_VARIABLE(ap1);
1726 a2 = pset1<Packet>(ap2[0]);
1728 EIGEN_UNUSED_VARIABLE(a2);
1729 EIGEN_UNUSED_VARIABLE(ap2);
1735EIGEN_ALWAYS_INLINE
void pbroadcastN<Packet4f, 4, true>(
const float* ap0,
const float*,
const float*, Packet4f& a0,
1736 Packet4f& a1, Packet4f& a2, Packet4f& a3) {
1737 pbroadcast4<Packet4f>(ap0, a0, a1, a2, a3);
1741EIGEN_ALWAYS_INLINE
void pbroadcastN<Packet4f, 4, false>(
const float* ap0,
const float* ap1,
const float* ap2,
1742 Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) {
1743 pbroadcastN<Packet4f, 4, true>(ap0, ap1, ap2, a0, a1, a2, a3);
1747EIGEN_ALWAYS_INLINE
void pbroadcastN<Packet2d, 4, false>(
const double* ap0,
const double*,
const double*, Packet2d& a0,
1748 Packet2d& a1, Packet2d& a2, Packet2d& a3) {
1749 a1 = pload<Packet2d>(ap0);
1750 a3 = pload<Packet2d>(ap0 + 2);
1751 a0 = vec_splat(a1, 0);
1752 a1 = vec_splat(a1, 1);
1753 a2 = vec_splat(a3, 0);
1754 a3 = vec_splat(a3, 1);
1758template <
typename Packet,
typename Packetc,
int N,
bool full>
1759EIGEN_ALWAYS_INLINE
void bcouple_common(PacketBlock<Packet, N>& taccReal, PacketBlock<Packet, N>& taccImag,
1760 PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2) {
1761 for (
int M = 0; M < N; M++) {
1762 acc1.packet[M].v = vec_mergeh(taccReal.packet[M], taccImag.packet[M]);
1766 for (
int M = 0; M < N; M++) {
1767 acc2.packet[M].v = vec_mergel(taccReal.packet[M], taccImag.packet[M]);
1772template <
typename Packet,
typename Packetc,
int N,
bool full>
1773EIGEN_ALWAYS_INLINE
void bcouple(PacketBlock<Packet, N>& taccReal, PacketBlock<Packet, N>& taccImag,
1774 PacketBlock<Packetc, N * 2>& tRes, PacketBlock<Packetc, N>& acc1,
1775 PacketBlock<Packetc, N>& acc2) {
1776 bcouple_common<Packet, Packetc, N, full>(taccReal, taccImag, acc1, acc2);
1778 for (
int M = 0; M < N; M++) {
1779 acc1.packet[M] = padd<Packetc>(tRes.packet[M], acc1.packet[M]);
1783 for (
int M = 0; M < N; M++) {
1784 acc2.packet[M] = padd<Packetc>(tRes.packet[M + N], acc2.packet[M]);
1793#define MICRO_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1795#define MICRO_NORMAL_ROWS accRows == quad_traits<Scalar>::rows || accRows == 1
1797#define MICRO_NEW_ROWS ((MICRO_NORMAL_ROWS) ? accRows : 1)
1799#define MICRO_RHS(ptr, N) rhs_##ptr##N
1801#define MICRO_ZERO_PEEL(peel) \
1802 if ((PEEL_ROW > peel) && (peel != 0)) { \
1803 bsetzero<Packet, accRows>(accZero##peel); \
1805 EIGEN_UNUSED_VARIABLE(accZero##peel); \
1808#define MICRO_ADD(ptr, N) \
1809 if (MICRO_NORMAL_ROWS) { \
1810 MICRO_RHS(ptr, 0) += (accRows * N); \
1812 MICRO_RHS(ptr, 0) += N; \
1813 MICRO_RHS(ptr, 1) += N; \
1814 if (accRows == 3) { \
1815 MICRO_RHS(ptr, 2) += N; \
1819#define MICRO_ADD_ROWS(N) MICRO_ADD(ptr, N)
1821#define MICRO_BROADCAST1(peel, ptr, rhsV, real) \
1822 if (MICRO_NORMAL_ROWS) { \
1823 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + (accRows * peel), MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 0), \
1824 rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1826 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + peel, MICRO_RHS(ptr, 1) + peel, MICRO_RHS(ptr, 2) + peel, \
1827 rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1830#define MICRO_BROADCAST(peel) MICRO_BROADCAST1(peel, ptr, rhsV, true)
1832#define MICRO_BROADCAST_EXTRA1(ptr, rhsV, real) \
1833 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 1), MICRO_RHS(ptr, 2), rhsV[0], rhsV[1], \
1836#define MICRO_BROADCAST_EXTRA \
1838 MICRO_BROADCAST_EXTRA1(ptr, rhsV, true) \
1841#define MICRO_SRC2(ptr, N, M) \
1842 if (MICRO_NORMAL_ROWS) { \
1843 EIGEN_UNUSED_VARIABLE(strideB); \
1844 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 1)); \
1845 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1847 MICRO_RHS(ptr, 1) = rhs_base + N + M; \
1848 if (accRows == 3) { \
1849 MICRO_RHS(ptr, 2) = rhs_base + N * 2 + M; \
1851 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1855#define MICRO_SRC2_PTR MICRO_SRC2(ptr, strideB, 0)
1857#define MICRO_ZERO_PEEL_ROW MICRO_UNROLL(MICRO_ZERO_PEEL)
1859#define MICRO_WORK_PEEL(peel) \
1860 if (PEEL_ROW > peel) { \
1861 MICRO_BROADCAST(peel) \
1862 pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
1864 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1867#define MICRO_WORK_PEEL_ROW \
1868 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
1869 MICRO_UNROLL(MICRO_WORK_PEEL) \
1870 lhs_ptr += (remaining_rows * PEEL_ROW); \
1871 MICRO_ADD_ROWS(PEEL_ROW)
1873#define MICRO_ADD_PEEL(peel, sum) \
1874 if (PEEL_ROW > peel) { \
1875 for (Index i = 0; i < accRows; i++) { \
1876 accZero##sum.packet[i] += accZero##peel.packet[i]; \
1880#define MICRO_ADD_PEEL_ROW \
1881 MICRO_ADD_PEEL(4, 0) \
1882 MICRO_ADD_PEEL(5, 1) \
1883 MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
1885#define MICRO_PREFETCHN1(ptr, N) \
1886 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 0)); \
1887 if (N == 2 || N == 3) { \
1888 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 1)); \
1890 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 2)); \
1894#define MICRO_PREFETCHN(N) MICRO_PREFETCHN1(ptr, N)
1896#define MICRO_COMPLEX_PREFETCHN(N) \
1897 MICRO_PREFETCHN1(ptr_real, N); \
1899 MICRO_PREFETCHN1(ptr_imag, N); \
1902template <
typename Scalar,
typename Packet, const Index accRows, const Index remaining_rows>
1903EIGEN_ALWAYS_INLINE
void MICRO_EXTRA_ROW(
const Scalar*& lhs_ptr,
const Scalar*& rhs_ptr0,
const Scalar*& rhs_ptr1,
1904 const Scalar*& rhs_ptr2, PacketBlock<Packet, accRows>& accZero) {
1905 MICRO_BROADCAST_EXTRA
1906 pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1907 lhs_ptr += remaining_rows;
1910template <
typename Scalar,
typename Packet,
typename DataMapper,
const Index accRows,
const Index accCols,
1911 const Index remaining_rows>
1912EIGEN_ALWAYS_INLINE
void gemm_unrolled_row_iteration(
const DataMapper& res,
const Scalar* lhs_base,
1913 const Scalar* rhs_base, Index depth, Index strideA, Index offsetA,
1914 Index strideB, Index row, Index rows,
const Packet& pAlpha,
1915 const Packet& pMask) {
1916 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
1917 const Scalar* lhs_ptr = lhs_base + row * strideA + remaining_rows * offsetA;
1918 PacketBlock<Packet, accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
1921 bsetzero<Packet, accRows>(accZero0);
1923 Index remaining_depth = depth & -quad_traits<Scalar>::rows;
1925 if (remaining_depth >= PEEL_ROW) {
1928 MICRO_PREFETCHN(accRows)
1929 EIGEN_POWER_PREFETCH(lhs_ptr);
1931 }
while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
1934 for (; k < depth; k++) {
1935 MICRO_EXTRA_ROW<Scalar, Packet, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
1938#ifdef USE_PARTIAL_PACKETS
1939 EIGEN_UNUSED_VARIABLE(rows);
1940 EIGEN_UNUSED_VARIABLE(pMask);
1941 bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row, remaining_rows);
1942 bscale<Packet, accRows>(acc, accZero0, pAlpha);
1943 bstore_partial<DataMapper, Packet, accRows>(acc, res, row, remaining_rows);
1945 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row, 0);
1946 if ((accRows == 1) || (rows >= accCols)) {
1947 bscale<Packet, accRows, true>(acc, accZero0, pAlpha, pMask);
1948 bstore<DataMapper, Packet, accRows>(acc, res, row);
1950 bscale<Packet, accRows, false>(acc, accZero0, pAlpha, pMask);
1951 for (Index j = 0; j < accRows; j++) {
1952 for (Index i = 0; i < remaining_rows; i++) {
1953 res(row + i, j) = acc.packet[j][i];
1960#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
1963 MICRO_EXTRA_UNROLL(1) \
1966 if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1967 MICRO_EXTRA_UNROLL(2) \
1971 if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1972 MICRO_EXTRA_UNROLL(3) \
1977#define MICRO_EXTRA_ROWS(N) \
1978 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, accRows, accCols, N>( \
1979 res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
1981template <
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols>
1982EIGEN_ALWAYS_INLINE
void gemm_extra_row(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
1983 Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows,
1984 Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask) {
1985 MICRO_EXTRA(MICRO_EXTRA_ROWS, remaining_rows,
false)
1988#define MICRO_UNROLL_WORK(func, func2, peel) \
1989 MICRO_UNROLL(func2); \
1990 func(0, peel) func(1, peel) func(2, peel) func(3, peel) func(4, peel) func(5, peel) func(6, peel) func(7, peel)
1992#define MICRO_WORK_ONE(iter, peel) \
1993 if (unroll_factor > iter) { \
1994 pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
1997#define MICRO_TYPE_PEEL4(func, func2, peel) \
1998 if (PEEL > peel) { \
1999 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
2000 MICRO_BROADCAST(peel) \
2001 MICRO_UNROLL_WORK(func, func2, peel) \
2003 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2006#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2007 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
2008 func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3) func(func1, func2, 4) \
2009 func(func1, func2, 5) func(func1, func2, 6) func(func1, func2, 7)
2011#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
2013 func(func1, func2, 0)
2015#define MICRO_UNROLL_TYPE(MICRO_TYPE, size) \
2016 MICRO_TYPE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE) \
2017 MICRO_ADD_ROWS(size)
2019#define MICRO_ONE_PEEL4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_PEEL, PEEL)
2021#define MICRO_ONE4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_ONE, 1)
2023#define MICRO_DST_PTR_ONE(iter) \
2024 if (unroll_factor > iter) { \
2025 bsetzero<Packet, accRows>(accZero##iter); \
2027 EIGEN_UNUSED_VARIABLE(accZero##iter); \
2030#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
2032#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
2034#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
2036#ifdef USE_PARTIAL_PACKETS
2037#define MICRO_STORE_ONE(iter) \
2038 if (unroll_factor > iter) { \
2039 if (MICRO_NORMAL_PARTIAL(iter)) { \
2040 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2041 bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2042 bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2044 bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row + iter * accCols, accCols2); \
2045 bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2046 bstore_partial<DataMapper, Packet, accRows>(acc, res, row + iter * accCols, accCols2); \
2050#define MICRO_STORE_ONE(iter) \
2051 if (unroll_factor > iter) { \
2052 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2053 bscale<Packet, accRows, !(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
2054 bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2058#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
2060#ifdef USE_PARTIAL_PACKETS
2061template <
int unroll_factor,
typename Scalar,
typename Packet,
typename DataMapper,
const Index accRows,
2062 const Index accCols,
bool full>
2064template <
int unroll_factor,
typename Scalar,
typename Packet,
typename DataMapper,
const Index accRows,
2065 const Index accCols,
const Index accCols2>
2067EIGEN_ALWAYS_INLINE
void gemm_unrolled_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
2068 Index depth, Index strideA, Index offsetA, Index strideB, Index& row,
2069 const Packet& pAlpha,
2070#ifdef USE_PARTIAL_PACKETS
2076 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
2077 const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
2078 *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
2079 PacketBlock<Packet, accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
2080 PacketBlock<Packet, accRows> acc;
2087 for (; k + PEEL <= depth; k += PEEL) {
2088 MICRO_PREFETCHN(accRows)
2092 for (; k < depth; k++) {
2100#ifdef USE_PARTIAL_PACKETS
2101#define MICRO_UNROLL_ITER2(N, M) \
2102 gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, !M>( \
2103 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, M ? remaining_rows : accCols); \
2106#define MICRO_UNROLL_ITER2(N, M) \
2107 gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, M ? M : accCols>( \
2108 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
2112template <
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols>
2113EIGEN_ALWAYS_INLINE
void gemm_cols(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB, Index depth,
2114 Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows,
2115 Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask) {
2116 const DataMapper res3 = res.getSubMapper(0, col);
2118 const Scalar* rhs_base = blockB + col * strideB + MICRO_NEW_ROWS * offsetB;
2119 const Scalar* lhs_base = blockA + accCols * offsetA;
2123 while (row + MAX_UNROLL * accCols <= rows) {
2124 MICRO_UNROLL_ITER2(MAX_UNROLL, 0);
2126 switch ((rows - row) / accCols) {
2129 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 7)
2134 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 6)
2139 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 5)
2144 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 4)
2149 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 3)
2154 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 2)
2159 MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 1)
2167 if (remaining_rows > 0) {
2168 gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA,
2169 strideB, row, rows, remaining_rows, pAlpha, pMask);
2173#define MICRO_EXTRA_COLS(N) \
2174 gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, \
2175 col, rows, remaining_rows, pAlpha, pMask);
2177template <
typename Scalar,
typename Packet,
typename DataMapper, const Index accCols>
2178EIGEN_ALWAYS_INLINE
void gemm_extra_cols(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB, Index depth,
2179 Index strideA, Index offsetA, Index strideB, Index offsetB, Index col,
2180 Index rows, Index cols, Index remaining_rows,
const Packet& pAlpha,
2181 const Packet& pMask) {
2182 MICRO_EXTRA(MICRO_EXTRA_COLS, cols - col,
true)
2188template <
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
const Index accRows,
2189 const Index accCols>
2190EIGEN_STRONG_INLINE
void gemm(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB, Index rows,
2191 Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA,
2193 const Index remaining_rows = rows % accCols;
2195 if (strideA == -1) strideA = depth;
2196 if (strideB == -1) strideB = depth;
2198 const Packet pAlpha = pset1<Packet>(alpha);
2199 const Packet pMask = bmask<Packet>(remaining_rows);
2202 for (; col + accRows <= cols; col += accRows) {
2203 gemm_cols<Scalar, Packet, DataMapper, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB,
2204 offsetB, col, rows, remaining_rows, pAlpha, pMask);
2208 gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
2209 col, rows, cols, remaining_rows, pAlpha, pMask);
2213#define accColsC (accCols / 2)
2214#define advanceRows ((LhsIsReal) ? 1 : 2)
2215#define advanceCols ((RhsIsReal) ? 1 : 2)
2218#define PEEL_COMPLEX 3
2219#define PEEL_COMPLEX_ROW 3
2221#define MICRO_COMPLEX_UNROLL(func) func(0) func(1) func(2) func(3)
2223#define MICRO_COMPLEX_ZERO_PEEL(peel) \
2224 if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
2225 bsetzero<Packet, accRows>(accReal##peel); \
2226 bsetzero<Packet, accRows>(accImag##peel); \
2228 EIGEN_UNUSED_VARIABLE(accReal##peel); \
2229 EIGEN_UNUSED_VARIABLE(accImag##peel); \
2232#define MICRO_COMPLEX_ADD_ROWS(N, used) \
2233 MICRO_ADD(ptr_real, N) \
2235 MICRO_ADD(ptr_imag, N) \
2236 } else if (used) { \
2237 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2238 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2239 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2242#define MICRO_COMPLEX_BROADCAST(peel) \
2243 MICRO_BROADCAST1(peel, ptr_real, rhsV, false) \
2245 MICRO_BROADCAST1(peel, ptr_imag, rhsVi, false) \
2247 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2250#define MICRO_COMPLEX_BROADCAST_EXTRA \
2251 Packet rhsV[4], rhsVi[4]; \
2252 MICRO_BROADCAST_EXTRA1(ptr_real, rhsV, false) \
2254 MICRO_BROADCAST_EXTRA1(ptr_imag, rhsVi, false) \
2256 EIGEN_UNUSED_VARIABLE(rhsVi); \
2258 MICRO_COMPLEX_ADD_ROWS(1, true)
2260#define MICRO_COMPLEX_SRC2_PTR \
2261 MICRO_SRC2(ptr_real, strideB* advanceCols, 0) \
2263 MICRO_RHS(ptr_imag, 0) = rhs_base + MICRO_NEW_ROWS * strideB; \
2264 MICRO_SRC2(ptr_imag, strideB* advanceCols, strideB) \
2266 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2267 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2268 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2271#define MICRO_COMPLEX_ZERO_PEEL_ROW MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_ZERO_PEEL)
2273#define MICRO_COMPLEX_WORK_PEEL(peel) \
2274 if (PEEL_COMPLEX_ROW > peel) { \
2275 MICRO_COMPLEX_BROADCAST(peel) \
2276 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2277 &accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), \
2278 lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
2280 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2281 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2284#define MICRO_COMPLEX_ADD_COLS(size) \
2285 lhs_ptr_real += (remaining_rows * size); \
2287 lhs_ptr_imag += (remaining_rows * size); \
2289 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2291#define MICRO_COMPLEX_WORK_PEEL_ROW \
2292 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
2293 Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
2294 MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_WORK_PEEL) \
2295 MICRO_COMPLEX_ADD_COLS(PEEL_COMPLEX_ROW) \
2296 MICRO_COMPLEX_ADD_ROWS(PEEL_COMPLEX_ROW, false)
2298#define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
2299 if (PEEL_COMPLEX_ROW > peel) { \
2300 for (Index i = 0; i < accRows; i++) { \
2301 accReal##sum.packet[i] += accReal##peel.packet[i]; \
2302 accImag##sum.packet[i] += accImag##peel.packet[i]; \
2306#define MICRO_COMPLEX_ADD_PEEL_ROW \
2307 MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) MICRO_COMPLEX_ADD_PEEL(1, 0)
2309template <
typename Scalar,
typename Packet,
const Index accRows,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
2310 bool RhsIsReal,
const Index remaining_rows>
2311EIGEN_ALWAYS_INLINE
void MICRO_COMPLEX_EXTRA_ROW(
const Scalar*& lhs_ptr_real,
const Scalar*& lhs_ptr_imag,
2312 const Scalar*& rhs_ptr_real0,
const Scalar*& rhs_ptr_real1,
2313 const Scalar*& rhs_ptr_real2,
const Scalar*& rhs_ptr_imag0,
2314 const Scalar*& rhs_ptr_imag1,
const Scalar*& rhs_ptr_imag2,
2315 PacketBlock<Packet, accRows>& accReal,
2316 PacketBlock<Packet, accRows>& accImag) {
2317 MICRO_COMPLEX_BROADCAST_EXTRA
2318 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real,
2319 lhs_ptr_imag, rhsV, rhsVi);
2320 MICRO_COMPLEX_ADD_COLS(1)
2323template <
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
const Index accRows,
2324 const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal,
2325 const Index remaining_rows>
2326EIGEN_ALWAYS_INLINE
void gemm_unrolled_complex_row_iteration(
const DataMapper& res,
const Scalar* lhs_base,
2327 const Scalar* rhs_base, Index depth, Index strideA,
2328 Index offsetA, Index strideB, Index row, Index rows,
2329 const Packet& pAlphaReal,
const Packet& pAlphaImag,
2330 const Packet& pMask) {
2331 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2332 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2333 const Scalar* lhs_ptr_real = lhs_base + advanceRows * row * strideA + remaining_rows * offsetA;
2334 const Scalar* lhs_ptr_imag = NULL;
2336 lhs_ptr_imag = lhs_ptr_real + remaining_rows * strideA;
2338 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2339 PacketBlock<Packet, accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
2340 PacketBlock<Packet, accRows> taccReal, taccImag;
2341 PacketBlock<Packetc, accRows> acc0, acc1;
2342 PacketBlock<Packetc, accRows * 2> tRes;
2344 MICRO_COMPLEX_SRC2_PTR
2346 bsetzero<Packet, accRows>(accReal0);
2347 bsetzero<Packet, accRows>(accImag0);
2349 Index remaining_depth = depth & -quad_traits<Scalar>::rows;
2351 if (remaining_depth >= PEEL_COMPLEX_ROW) {
2352 MICRO_COMPLEX_ZERO_PEEL_ROW
2354 MICRO_COMPLEX_PREFETCHN(accRows)
2355 EIGEN_POWER_PREFETCH(lhs_ptr_real);
2357 EIGEN_POWER_PREFETCH(lhs_ptr_imag);
2359 MICRO_COMPLEX_WORK_PEEL_ROW
2360 }
while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth);
2361 MICRO_COMPLEX_ADD_PEEL_ROW
2363 for (; k < depth; k++) {
2364 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(
2365 lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1,
2366 rhs_ptr_imag2, accReal0, accImag0);
2369 constexpr bool full = (remaining_rows > accColsC);
2370 bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row, 0);
2371 if ((accRows == 1) || (rows >= accCols)) {
2372 bscalec<Packet, accRows, true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2373 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2374 bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
2376 bstore<DataMapper, Packetc, accRows>(acc1, res, row + accColsC);
2379 bscalec<Packet, accRows, false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2380 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2382 if ((
sizeof(Scalar) ==
sizeof(
float)) && (remaining_rows == 1)) {
2383 for (Index j = 0; j < accRows; j++) {
2384 res(row + 0, j) = pfirst<Packetc>(acc0.packet[j]);
2387 bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
2389 for (Index j = 0; j < accRows; j++) {
2390 res(row + accColsC, j) = pfirst<Packetc>(acc1.packet[j]);
2397#define MICRO_COMPLEX_EXTRA_ROWS(N) \
2398 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, \
2399 ConjugateRhs, LhsIsReal, RhsIsReal, N>( \
2400 res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
2402template <
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
const Index accRows,
2403 const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2404EIGEN_ALWAYS_INLINE
void gemm_complex_extra_row(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
2405 Index depth, Index strideA, Index offsetA, Index strideB, Index row,
2406 Index rows, Index remaining_rows,
const Packet& pAlphaReal,
2407 const Packet& pAlphaImag,
const Packet& pMask) {
2408 MICRO_EXTRA(MICRO_COMPLEX_EXTRA_ROWS, remaining_rows,
false)
2411#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2412 MICRO_COMPLEX_UNROLL(func2); \
2413 func(0, peel) func(1, peel) func(2, peel) func(3, peel)
2415#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
2416 if (unroll_factor > iter) { \
2417 pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2418 &accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2421#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
2422 if (PEEL_COMPLEX > peel) { \
2423 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
2424 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
2425 MICRO_COMPLEX_BROADCAST(peel) \
2426 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2428 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2429 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2432#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2433 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
2434 Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
2435 func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3)
2437#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2438 Packet rhsV0[M], rhsVi0[M]; \
2439 func(func1, func2, 0)
2441#define MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_TYPE, size) \
2442 MICRO_COMPLEX_TYPE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE) \
2443 MICRO_COMPLEX_ADD_ROWS(size, false)
2445#define MICRO_COMPLEX_ONE_PEEL4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_PEEL, PEEL_COMPLEX)
2447#define MICRO_COMPLEX_ONE4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_ONE, 1)
2449#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2450 if (unroll_factor > iter) { \
2451 bsetzero<Packet, accRows>(accReal##iter); \
2452 bsetzero<Packet, accRows>(accImag##iter); \
2454 EIGEN_UNUSED_VARIABLE(accReal##iter); \
2455 EIGEN_UNUSED_VARIABLE(accImag##iter); \
2458#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2460#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2462#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2464#define MICRO_COMPLEX_STORE_ONE(iter) \
2465 if (unroll_factor > iter) { \
2466 constexpr bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
2467 bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter * accCols, 0); \
2468 bscalec<Packet, accRows, !(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, \
2470 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
2471 bstore<DataMapper, Packetc, accRows>(acc0, res, row + iter * accCols + 0); \
2473 bstore<DataMapper, Packetc, accRows>(acc1, res, row + iter * accCols + accColsC); \
2477#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2479template <
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
2480 const Index accRows,
const Index accCols,
const Index accCols2,
bool ConjugateLhs,
bool ConjugateRhs,
2481 bool LhsIsReal,
bool RhsIsReal>
2482EIGEN_ALWAYS_INLINE
void gemm_complex_unrolled_iteration(
const DataMapper& res,
const Scalar* lhs_base,
2483 const Scalar* rhs_base, Index depth, Index strideA,
2484 Index offsetA, Index strideB, Index& row,
2485 const Packet& pAlphaReal,
const Packet& pAlphaImag,
2486 const Packet& pMask) {
2487 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2488 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2489 const Index imag_delta = accCols * strideA;
2490 const Index imag_delta2 = accCols2 * strideA;
2491 const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
2492 const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
2493 PacketBlock<Packet, accRows> accReal0, accImag0, accReal1, accImag1;
2494 PacketBlock<Packet, accRows> accReal2, accImag2, accReal3, accImag3;
2495 PacketBlock<Packet, accRows> taccReal, taccImag;
2496 PacketBlock<Packetc, accRows> acc0, acc1;
2497 PacketBlock<Packetc, accRows * 2> tRes;
2499 MICRO_COMPLEX_SRC2_PTR
2500 MICRO_COMPLEX_SRC_PTR
2501 MICRO_COMPLEX_DST_PTR
2504 for (; k + PEEL_COMPLEX <= depth; k += PEEL_COMPLEX) {
2505 MICRO_COMPLEX_PREFETCHN(accRows)
2506 MICRO_COMPLEX_PREFETCH
2507 MICRO_COMPLEX_ONE_PEEL4
2509 for (; k < depth; k++) {
2514 MICRO_COMPLEX_UPDATE
2517#define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
2518 gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, accRows, accCols, \
2519 M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2520 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
2523template <
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
const Index accRows,
2524 const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2525EIGEN_ALWAYS_INLINE
void gemm_complex_cols(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB,
2526 Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB,
2527 Index col, Index rows, Index remaining_rows,
const Packet& pAlphaReal,
2528 const Packet& pAlphaImag,
const Packet& pMask) {
2529 const DataMapper res3 = res.getSubMapper(0, col);
2531 const Scalar* rhs_base = blockB + advanceCols * col * strideB + MICRO_NEW_ROWS * offsetB;
2532 const Scalar* lhs_base = blockA + accCols * offsetA;
2535#define MAX_COMPLEX_UNROLL 4
2536 while (row + MAX_COMPLEX_UNROLL * accCols <= rows) {
2537 MICRO_COMPLEX_UNROLL_ITER2(MAX_COMPLEX_UNROLL, 0);
2539 switch ((rows - row) / accCols) {
2540#if MAX_COMPLEX_UNROLL > 4
2542 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 4)
2545#if MAX_COMPLEX_UNROLL > 3
2547 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 3)
2550#if MAX_COMPLEX_UNROLL > 2
2552 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 2)
2555#if MAX_COMPLEX_UNROLL > 1
2557 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 1)
2563#undef MAX_COMPLEX_UNROLL
2565 if (remaining_rows > 0) {
2566 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2567 RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows,
2568 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2572#define MICRO_COMPLEX_EXTRA_COLS(N) \
2573 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
2574 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, \
2575 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2577template <
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
const Index accCols,
2578 bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2579EIGEN_ALWAYS_INLINE
void gemm_complex_extra_cols(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB,
2580 Index depth, Index strideA, Index offsetA, Index strideB,
2581 Index offsetB, Index col, Index rows, Index cols, Index remaining_rows,
2582 const Packet& pAlphaReal,
const Packet& pAlphaImag,
2583 const Packet& pMask) {
2584 MICRO_EXTRA(MICRO_COMPLEX_EXTRA_COLS, cols - col,
true)
2587template <
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Packet,
typename Packetc,
2588 typename RhsPacket,
typename DataMapper,
const Index accRows,
const Index accCols,
bool ConjugateLhs,
2589 bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2590EIGEN_STRONG_INLINE
void gemm_complex(
const DataMapper& res,
const LhsScalar* blockAc,
const RhsScalar* blockBc,
2591 Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB,
2592 Index offsetA, Index offsetB) {
2593 const Index remaining_rows = rows % accCols;
2595 if (strideA == -1) strideA = depth;
2596 if (strideB == -1) strideB = depth;
2598 const Packet pAlphaReal = pset1<Packet>(alpha.real());
2599 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
2600 const Packet pMask = bmask<Packet>(remaining_rows);
2602 const Scalar* blockA = (Scalar*)blockAc;
2603 const Scalar* blockB = (Scalar*)blockBc;
2606 for (; col + accRows <= cols; col += accRows) {
2607 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2608 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows,
2609 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2613 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2614 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols,
2615 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2623EIGEN_ALWAYS_INLINE
bool supportsMMA() {
2624#if defined(EIGEN_ALTIVEC_MMA_ONLY)
2626#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && defined(__BUILTIN_CPU_SUPPORTS__)
2627 return __builtin_cpu_supports(
"arch_3_1") && __builtin_cpu_supports(
"mma");
2633EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc,
const Packet4f pAlpha,
float* result) {
2634 Packet4f result_block = ploadu<Packet4f>(result);
2635 return pmadd(acc, pAlpha, result_block);
2638template <
bool lhsExtraRows>
2639EIGEN_ALWAYS_INLINE
void storeF32(
float*& result, Packet4f result_block, Index rows, Index extra_rows) {
2641 pstoreu_partial(result, result_block, extra_rows);
2643 pstoreu(result, result_block);
2648template <
bool rhsExtraCols,
bool lhsExtraRows>
2649EIGEN_ALWAYS_INLINE
void storeResults(Packet4f (&acc)[4], Index rows,
const Packet4f pAlpha,
float* result,
2650 Index extra_cols, Index extra_rows) {
2654 Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
2655 storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
2656 }
while (++x < extra_cols);
2658 Packet4f result_block[4];
2659 float* result2 = result;
2661 result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
2666 storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
2671EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data) {
2672 Packet8us z = pset1<Packet8us>(0);
2674 return reinterpret_cast<Packet4f
>(vec_mergeh(data, z));
2676 return reinterpret_cast<Packet4f
>(vec_mergeh(z, data));
2680EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data) {
2681 Packet8us z = pset1<Packet8us>(0);
2683 return reinterpret_cast<Packet4f
>(vec_mergel(data, z));
2685 return reinterpret_cast<Packet4f
>(vec_mergel(z, data));
2689template <Index N, Index M>
2690EIGEN_ALWAYS_INLINE
void storeConvertTwoBF16(
float* to, PacketBlock<Packet8bf, (N + 7) / 8>& block, Index extra = 0) {
2692 pstoreu_partial(to + 0, oneConvertBF16Hi(block.packet[0].m_val), extra);
2693 }
else if (N >= (M * 8 + 4)) {
2694 pstoreu(to + 0, oneConvertBF16Hi(block.packet[M].m_val));
2696 pstoreu(to + 4, oneConvertBF16Lo(block.packet[M].m_val));
2702EIGEN_ALWAYS_INLINE
void storeConvertBlockBF16(
float* to, PacketBlock<Packet8bf, (N + 7) / 8>& block, Index extra) {
2703 storeConvertTwoBF16<N, 0>(to + 0, block, extra);
2705 storeConvertTwoBF16<N, 1>(to + 8, block);
2708 storeConvertTwoBF16<N, 2>(to + 16, block);
2709 storeConvertTwoBF16<N, 3>(to + 24, block);
2713template <
bool non_unit_str
ide, Index delta>
2714EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc) {
2715 if (non_unit_stride) {
2716 return pgather<bfloat16, Packet8bf>(src + delta * resInc, resInc);
2718 return ploadu<Packet8bf>(src + delta);
2722static Packet16uc p16uc_MERGE16_32_1 = {0, 1, 16, 17, 2, 3, 18, 19, 0, 1, 16, 17, 2, 3, 18, 19};
2723static Packet16uc p16uc_MERGE16_32_2 = {4, 5, 20, 21, 6, 7, 22, 23, 4, 5, 20, 21, 6, 7, 22, 23};
2724static Packet16uc p16uc_MERGE16_32_3 = {8, 9, 24, 25, 10, 11, 26, 27, 8, 9, 24, 25, 10, 11, 26, 27};
2725static Packet16uc p16uc_MERGE16_32_4 = {12, 13, 28, 29, 14, 15, 30, 31, 12, 13, 28, 29, 14, 15, 30, 31};
2727static Packet16uc p16uc_MERGE16_32_5 = {0, 1, 16, 17, 16, 17, 16, 17, 0, 1, 16, 17, 16, 17, 16, 17};
2728static Packet16uc p16uc_MERGE16_32_6 = {2, 3, 18, 19, 18, 19, 18, 19, 2, 3, 18, 19, 18, 19, 18, 19};
2729static Packet16uc p16uc_MERGE16_32_7 = {4, 5, 20, 21, 20, 21, 20, 21, 4, 5, 20, 21, 20, 21, 20, 21};
2730static Packet16uc p16uc_MERGE16_32_8 = {6, 7, 22, 23, 22, 23, 22, 23, 6, 7, 22, 23, 22, 23, 22, 23};
2732EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask) {
2733 Packet8us z = pset1<Packet8us>(0);
2735 return reinterpret_cast<Packet4f
>(vec_perm(data, z, mask));
2737 return reinterpret_cast<Packet4f
>(vec_perm(z, data, mask));
2741template <
bool lhsExtraRows,
bool odd, Index size>
2742EIGEN_ALWAYS_INLINE
void convertArrayPointerBF16toF32DupOne(
float* result, Index rows,
const bfloat16* src,
2744 Packet4f dup[4 * 4];
2747 for (Index i = 0; i < size; i++) {
2748 data[i] = ploadu<Packet8bf>(src + rows * i);
2751 for (Index i = 0, j = 0; i < size; i++, j += 4) {
2752 dup[j + 0] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_5 : p16uc_MERGE16_32_1);
2753 dup[j + 1] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_6 : p16uc_MERGE16_32_2);
2754 dup[j + 2] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_7 : p16uc_MERGE16_32_3);
2755 dup[j + 3] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_8 : p16uc_MERGE16_32_4);
2758 for (Index j = 0; j < 4 * size; j += 4) {
2760 Packet4f z = pset1<Packet4f>(
float(0));
2763 pstoreu(result + (j + i) * 4, dup[j + i]);
2764 }
while (++i < extra_rows);
2766 pstoreu(result + (j + i) * 4, z);
2769 for (Index i = 0; i < 4; i++) {
2770 pstoreu(result + (j + i) * 4, dup[j + i]);
2776template <
bool lhsExtraRows>
2777EIGEN_ALWAYS_INLINE
void convertArrayPointerBF16toF32Dup(
float* result, Index cols, Index rows,
const bfloat16* src,
2778 Index delta, Index extra_rows) {
2781 for (; col + 4 * 2 <= cols; col += 4 * 2, result += 4 * 4 * 4, src += 4 * rows) {
2782 convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 4>(result, rows, src, extra_rows);
2784 for (; col + 2 <= cols; col += 2, result += 4 * 4, src += rows) {
2785 convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 1>(result, rows, src, extra_rows);
2788 convertArrayPointerBF16toF32DupOne<lhsExtraRows, true, 1>(result, rows, src - delta, extra_rows);
2792template <const Index size,
bool non_unit_str
ide>
2793EIGEN_ALWAYS_INLINE
void convertPointerBF16toF32(Index& i,
float* result, Index rows, bfloat16*& src, Index resInc) {
2794 constexpr Index extra = ((size < 4) ? 4 : size);
2795 while (i + size <= rows) {
2796 PacketBlock<Packet8bf, (size + 7) / 8> r32;
2797 r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
2799 r32.packet[1] = loadBF16fromResult<non_unit_stride, 8>(src, resInc);
2802 r32.packet[2] = loadBF16fromResult<non_unit_stride, 16>(src, resInc);
2803 r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
2805 storeConvertBlockBF16<size>(result + i, r32, rows & 3);
2807 src += extra * resInc;
2808 if (size != 32)
break;
2812template <
bool non_unit_str
ide>
2813EIGEN_ALWAYS_INLINE
void convertArrayPointerBF16toF32(
float* result, Index cols, Index rows, bfloat16* src,
2815 for (Index col = 0; col < cols; col++, src += (rows * resInc), result += rows) {
2817 bfloat16* src2 = src;
2818 convertPointerBF16toF32<32, non_unit_stride>(i, result, rows, src2, resInc);
2819 convertPointerBF16toF32<16, non_unit_stride>(i, result, rows, src2, resInc);
2820 convertPointerBF16toF32<8, non_unit_stride>(i, result, rows, src2, resInc);
2821 convertPointerBF16toF32<4, non_unit_stride>(i, result, rows, src2, resInc);
2822 convertPointerBF16toF32<1, non_unit_stride>(i, result, rows, src2, resInc);
2826template <Index num_acc, Index size = 4>
2827EIGEN_ALWAYS_INLINE
void zeroAccumulators(Packet4f (&acc)[num_acc][size]) {
2828 Packet4f z = pset1<Packet4f>(
float(0));
2830 for (Index k = 0; k < num_acc; k++) {
2831 for (Index j = 0; j < size; j++) {
2837template <Index num_acc>
2838EIGEN_ALWAYS_INLINE
void tranposeResults(Packet4f (&acc)[num_acc][4]) {
2839 for (Index i = 0; i < num_acc; i++) {
2840 Packet4ui t0, t1, t2, t3;
2841 t0 = vec_mergeh(
reinterpret_cast<Packet4ui
>(acc[i][0]),
reinterpret_cast<Packet4ui
>(acc[i][2]));
2842 t1 = vec_mergel(
reinterpret_cast<Packet4ui
>(acc[i][0]),
reinterpret_cast<Packet4ui
>(acc[i][2]));
2843 t2 = vec_mergeh(
reinterpret_cast<Packet4ui
>(acc[i][1]),
reinterpret_cast<Packet4ui
>(acc[i][3]));
2844 t3 = vec_mergel(
reinterpret_cast<Packet4ui
>(acc[i][1]),
reinterpret_cast<Packet4ui
>(acc[i][3]));
2845 acc[i][0] =
reinterpret_cast<Packet4f
>(vec_mergeh(t0, t2));
2846 acc[i][1] =
reinterpret_cast<Packet4f
>(vec_mergel(t0, t2));
2847 acc[i][2] =
reinterpret_cast<Packet4f
>(vec_mergeh(t1, t3));
2848 acc[i][3] =
reinterpret_cast<Packet4f
>(vec_mergel(t1, t3));
2852template <Index num_acc>
2853EIGEN_ALWAYS_INLINE
void addResults(Packet4f (&acc)[num_acc][4]) {
2854 for (Index i = 0, j = 0; j < num_acc; i++, j += 2) {
2855 for (Index x = 0, y = 0; x < 2; x++, y += 2) {
2856 for (Index w = 0, z = 0; w < 2; w++, z += 2) {
2857 acc[i][y + w] = acc[j + x][z + 0] + acc[j + x][z + 1];
2863template <Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows, Index num_rhs>
2864EIGEN_ALWAYS_INLINE
void outputResultsVSX(Packet4f (&acc)[num_acc][4], Index rows,
const Packet4f pAlpha,
float* result,
2865 const Index extra_cols, Index extra_rows) {
2866 tranposeResults<num_acc>(acc);
2867 addResults<num_acc>(acc);
2869 constexpr Index real_rhs = ((num_rhs / 2) - (rhsExtraCols ? 1 : 0));
2871 for (Index i = 0; i < real_rhs; i++, result += 4 * rows, k++) {
2872 storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
2875 storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
2880EIGEN_ALWAYS_INLINE
void loadTwoRhsFloat32(
const float* block, Index strideB, Index i, Packet4f& dhs0, Packet4f& dhs1) {
2881 dhs0 = ploadu<Packet4f>(block + strideB * i + 0);
2883 Packet4f dhs2 = pset1<Packet4f>(
float(0));
2884 dhs1 = vec_mergel(dhs0, dhs2);
2885 dhs0 = vec_mergeh(dhs0, dhs2);
2887 dhs1 = ploadu<Packet4f>(block + strideB * i + 4);
2891template <Index num_acc,
bool zero,
bool rhsExtraCols, Index num_rhs>
2892EIGEN_ALWAYS_INLINE
void KLoop(
const float* indexA,
const float* indexB, Packet4f (&acc)[num_acc][4], Index strideB,
2893 Index k, Index offsetB, Index extra_cols) {
2894 constexpr Index num_lhs = 4;
2895 Packet4f lhs[num_lhs], rhs[num_rhs];
2897 constexpr Index real_rhs = (num_rhs - (rhsExtraCols ? 2 : 0));
2898 for (Index i = 0; i < real_rhs; i += 2) {
2899 loadTwoRhsFloat32<zero>(indexB + k * 4, strideB, i, rhs[i + 0], rhs[i + 1]);
2902 loadTwoRhsFloat32<zero>(indexB + k * extra_cols - offsetB, strideB, real_rhs, rhs[real_rhs + 0], rhs[real_rhs + 1]);
2905 indexA += 2 * k * 4;
2906 for (Index j = 0; j < num_lhs; j++) {
2907 lhs[j] = ploadu<Packet4f>(indexA + j * 4);
2910 for (Index j = 0; j < num_rhs; j++) {
2911 for (Index i = 0; i < num_lhs; i++) {
2912 acc[j][i] = pmadd(rhs[j], lhs[i], acc[j][i]);
2917template <const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
2918EIGEN_ALWAYS_INLINE
void colVSXLoopBodyIter(Index depth, Index rows,
const Packet4f pAlpha,
const float* indexA,
2919 const float* indexB, Index strideB, Index offsetB,
float* result,
2920 const Index extra_cols,
const Index extra_rows) {
2921 constexpr Index num_rhs = num_acc;
2923 Packet4f acc[num_acc][4];
2925 zeroAccumulators<num_acc>(acc);
2928 for (k = 0; k + 2 <= depth; k += 2) {
2929 KLoop<num_acc, false, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
2932 KLoop<num_acc, true, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
2935 outputResultsVSX<num_acc, rhsExtraCols, lhsExtraRows, num_rhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
2939#define MAX_BFLOAT16_ACC_VSX 4
2941template <const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
2942void colVSXLoopBody(Index& col, Index depth, Index cols, Index rows,
const Packet4f pAlpha,
const float* indexA,
2943 const float* indexB, Index strideB, Index offsetB,
float* result) {
2944 constexpr Index step = (num_acc * 4);
2945 const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
2946 const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
2947 constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC_VSX);
2950 colVSXLoopBodyIter<num_acc * 2, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB,
2951 result, extra_cols, extra_rows);
2953 indexB += strideB * (num_acc * 2);
2954 result += rows * step;
2955 }
while (multiIters && (step <= cols - (col += step)));
2958template <const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
2959EIGEN_ALWAYS_INLINE
void colVSXLoopBodyExtraN(Index col, Index depth, Index cols, Index rows,
const Packet4f pAlpha,
2960 const float* indexA,
const float* blockB, Index strideB, Index offsetB,
2962 if (MAX_BFLOAT16_ACC_VSX > num_acc) {
2963 colVSXLoopBody<num_acc + (rhsExtraCols ? 1 : 0), rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA,
2964 blockB, strideB, offsetB, result);
2968template <
bool rhsExtraCols,
bool lhsExtraRows>
2969void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows,
const Packet4f pAlpha,
const float* indexA,
2970 const float* blockB, Index strideB, Index offsetB,
float* result) {
2971 switch ((cols - col) >> 2) {
2973 colVSXLoopBodyExtraN<3, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2977 colVSXLoopBodyExtraN<2, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2981 colVSXLoopBodyExtraN<1, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2986 colVSXLoopBody<1, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
2992template <Index size,
bool lhsExtraRows = false>
2993EIGEN_ALWAYS_INLINE
void colVSXLoops(Index depth, Index cols, Index rows,
const Packet4f pAlpha,
const bfloat16* indexA,
2994 const float* indexA2,
const float* blockB2, Index strideA, Index strideB,
2995 Index offsetB,
float* result2) {
2996 Index delta_rows = 2 * (lhsExtraRows ? (rows & 3) : size);
2997 for (Index row = 0; row < size; row += 4) {
2998 convertArrayPointerBF16toF32Dup<lhsExtraRows>(
const_cast<float*
>(indexA2), strideA, delta_rows, indexA, row,
3001 const float* blockB = blockB2;
3002 float* result = result2 + row;
3005 if (cols >= (MAX_BFLOAT16_ACC_VSX * 4)) {
3006 colVSXLoopBody<MAX_BFLOAT16_ACC_VSX, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB,
3007 strideB, 0, result);
3008 blockB += (strideB >> 1) * col;
3009 result += rows * col;
3012 colVSXLoopBodyExtra<true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, offsetB,
3015 colVSXLoopBodyExtra<false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, 0, result);
3020template <Index size>
3021EIGEN_ALWAYS_INLINE
void calcVSXColLoops(
const bfloat16*& indexA,
const float* indexA2, Index& row, Index depth,
3022 Index cols, Index rows,
const Packet4f pAlpha,
const float* indexB,
3023 Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix,
3025 if ((size == 16) || (rows & size)) {
3026 indexA += size * offsetA;
3027 colVSXLoops<size>(depth, cols, rows, pAlpha, indexA, indexA2, indexB, strideA, strideB, offsetB, result + row);
3029 indexA += bigSuffix * size / 16;
3033template <const Index size,
typename DataMapper>
3034EIGEN_ALWAYS_INLINE
void convertBF16toF32(Index& i,
float* result, Index rows,
const DataMapper& src) {
3035 constexpr Index extra = ((size < 4) ? 4 : size);
3036 while (i + size <= rows) {
3037 PacketBlock<Packet8bf, (size + 7) / 8> r32;
3038 r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
3040 r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
3043 r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
3044 r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
3046 storeConvertBlockBF16<size>(result + i, r32, rows & 3);
3048 if (size != 32)
break;
3052template <
typename DataMapper>
3053EIGEN_ALWAYS_INLINE
void convertArrayBF16toF32(
float* result, Index cols, Index rows,
const DataMapper& src) {
3054 typedef typename DataMapper::LinearMapper LinearMapper;
3055 for (Index j = 0; j < cols; j++, result += rows) {
3056 const LinearMapper src2 = src.getLinearMapper(0, j);
3058 convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
3059 convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
3060 convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
3061 convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
3062 convertBF16toF32<1, LinearMapper>(i, result, rows, src2);
3066EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(
const float* res) {
3067 return F32ToBf16Both(ploadu<Packet4f>(res + 0), ploadu<Packet4f>(res + 4));
3070template <
typename DataMapper, const Index size>
3071EIGEN_ALWAYS_INLINE
void convertArrayF32toBF16ColVSX(
float* result, Index col, Index rows,
const DataMapper& res) {
3072 const DataMapper res2 = res.getSubMapper(0, col);
3074 float* result2 = result + col * rows;
3075 for (row = 0; row + 8 <= rows; row += 8, result2 += 8) {
3077 PacketBlock<Packet8bf, size> block;
3078 for (Index j = 0; j < size; j++) {
3079 block.packet[j] = convertF32toBF16VSX(result2 + j * rows);
3081 res2.template storePacketBlock<Packet8bf, size>(row, 0, block);
3085 for (Index j = 0; j < size; j++) {
3086 Packet8bf fp16 = convertF32toBF16VSX(result2 + j * rows);
3087 res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
3092template <
typename DataMapper>
3093EIGEN_ALWAYS_INLINE
void convertArrayF32toBF16VSX(
float* result, Index cols, Index rows,
const DataMapper& res) {
3095 for (col = 0; col + 4 <= cols; col += 4) {
3096 convertArrayF32toBF16ColVSX<DataMapper, 4>(result, col, rows, res);
3099 switch (cols - col) {
3101 convertArrayF32toBF16ColVSX<DataMapper, 1>(result, col, rows, res);
3104 convertArrayF32toBF16ColVSX<DataMapper, 2>(result, col, rows, res);
3107 convertArrayF32toBF16ColVSX<DataMapper, 3>(result, col, rows, res);
3112template <
typename DataMapper>
3113void gemmbfloat16(
const DataMapper& res,
const bfloat16* indexA,
const bfloat16* indexB, Index rows, Index depth,
3114 Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3115 float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
3116 const Packet4f pAlpha = pset1<Packet4f>(falpha);
3118 if (strideA == -1) strideA = depth;
3119 if (strideB == -1) strideB = depth;
3121 ei_declare_aligned_stack_constructed_variable(
float, result, cols* rows, 0);
3122 ei_declare_aligned_stack_constructed_variable(
float, indexB2, strideB* cols, 0);
3123 ei_declare_aligned_stack_constructed_variable(
float, indexA2, ((strideA + 1) & -2) * 4 * 2, 0);
3125 convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
3126 convertArrayPointerBF16toF32(indexB2, cols, strideB,
const_cast<bfloat16*
>(indexB));
3128 Index bigSuffix = 2 * 8 * (strideA - offsetA);
3129 float* indexBF32 = indexB2 + 4 * offsetB;
3135 while (row + 16 <= rows) {
3136 calcVSXColLoops<16>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3140 calcVSXColLoops<8>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3143 calcVSXColLoops<4>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3148 colVSXLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexA2, indexBF32, strideA, strideB, offsetB,
3153 convertArrayF32toBF16VSX<DataMapper>(result, cols, rows, res);
3156#undef MAX_BFLOAT16_ACC_VSX
3158#include "MatrixVectorProduct.h"
3163template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3164struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode> {
3165 void operator()(
double* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3168template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3169void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3170 double* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3171 dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
3172 pack(blockA, lhs, depth, rows, stride, offset);
3175template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3176struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode> {
3177 void operator()(
double* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3180template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3181void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3182 double* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3183 dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
3184 pack(blockA, lhs, depth, rows, stride, offset);
3187#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3188template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3189struct gemm_pack_rhs<double, Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode> {
3190 void operator()(
double* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3193template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3194void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3195 double* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3196 dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
3197 pack(blockB, rhs, depth, cols, stride, offset);
3200template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3201struct gemm_pack_rhs<double, Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode> {
3202 void operator()(
double* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3205template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3206void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3207 double* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3208 dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
3209 pack(blockB, rhs, depth, cols, stride, offset);
3212template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3213struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode> {
3214 void operator()(bfloat16* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3217template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3218void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3219 bfloat16* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3220 dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, false> pack;
3221 pack(blockB, rhs, depth, cols, stride, offset);
3224template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3225struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode> {
3226 void operator()(bfloat16* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3229template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3230void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3231 bfloat16* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3232 dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, false> pack;
3233 pack(blockB, rhs, depth, cols, stride, offset);
3237template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3238struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode> {
3239 void operator()(bfloat16* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3242template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3243void gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3244 bfloat16* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3245 dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, true> pack;
3246 pack(blockA, lhs, depth, rows, stride, offset);
3249template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3250struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode> {
3251 void operator()(bfloat16* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3254template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3255void gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3256 bfloat16* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3257 dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, true> pack;
3258 pack(blockA, lhs, depth, rows, stride, offset);
3261template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3262struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode> {
3263 void operator()(
float* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3266template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3267void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3268 float* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3269 dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
3270 pack(blockA, lhs, depth, rows, stride, offset);
3273template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3274struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode> {
3275 void operator()(
float* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3278template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3279void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3280 float* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3281 dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
3282 pack(blockA, lhs, depth, rows, stride, offset);
3285template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3286struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode> {
3287 void operator()(std::complex<float>* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3291template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3292void gemm_pack_lhs<std::complex<float>,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate,
3293 PanelMode>::operator()(std::complex<float>* blockA,
const DataMapper& lhs, Index depth, Index rows,
3294 Index stride, Index offset) {
3295 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
3296 pack(blockA, lhs, depth, rows, stride, offset);
3299template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3300struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode> {
3301 void operator()(std::complex<float>* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3305template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3306void gemm_pack_lhs<std::complex<float>,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate,
3307 PanelMode>::operator()(std::complex<float>* blockA,
const DataMapper& lhs, Index depth, Index rows,
3308 Index stride, Index offset) {
3309 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
3310 pack(blockA, lhs, depth, rows, stride, offset);
3313#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3314template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3315struct gemm_pack_rhs<float, Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode> {
3316 void operator()(
float* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3319template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3320void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3321 float* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3322 dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
3323 pack(blockB, rhs, depth, cols, stride, offset);
3326template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3327struct gemm_pack_rhs<float, Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode> {
3328 void operator()(
float* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3331template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3332void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3333 float* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3334 dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
3335 pack(blockB, rhs, depth, cols, stride, offset);
3339template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3340struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode> {
3341 void operator()(std::complex<float>* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3345template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3346void gemm_pack_rhs<std::complex<float>,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>::operator()(
3347 std::complex<float>* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3348 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
3349 pack(blockB, rhs, depth, cols, stride, offset);
3352template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3353struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode> {
3354 void operator()(std::complex<float>* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3358template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3359void gemm_pack_rhs<std::complex<float>,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>::operator()(
3360 std::complex<float>* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3361 dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
3362 pack(blockB, rhs, depth, cols, stride, offset);
3365template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3366struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate, PanelMode> {
3367 void operator()(std::complex<double>* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3371template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3372void gemm_pack_lhs<std::complex<double>,
Index, DataMapper, Pack1, Pack2, Packet,
RowMajor, Conjugate,
3373 PanelMode>::operator()(std::complex<double>* blockA,
const DataMapper& lhs, Index depth, Index rows,
3374 Index stride, Index offset) {
3375 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
3376 pack(blockA, lhs, depth, rows, stride, offset);
3379template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3380struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate, PanelMode> {
3381 void operator()(std::complex<double>* blockA,
const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3385template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3386void gemm_pack_lhs<std::complex<double>,
Index, DataMapper, Pack1, Pack2, Packet,
ColMajor, Conjugate,
3387 PanelMode>::operator()(std::complex<double>* blockA,
const DataMapper& lhs, Index depth, Index rows,
3388 Index stride, Index offset) {
3389 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
3390 pack(blockA, lhs, depth, rows, stride, offset);
3393template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3394struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode> {
3395 void operator()(std::complex<double>* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3399template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3400void gemm_pack_rhs<std::complex<double>,
Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>::operator()(
3401 std::complex<double>* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3402 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
3403 pack(blockB, rhs, depth, cols, stride, offset);
3406template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3407struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode> {
3408 void operator()(std::complex<double>* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3412template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3413void gemm_pack_rhs<std::complex<double>,
Index, DataMapper, nr,
RowMajor, Conjugate, PanelMode>::operator()(
3414 std::complex<double>* blockB,
const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3415 dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
3416 pack(blockB, rhs, depth, cols, stride, offset);
3420template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3421struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3422 typedef typename quad_traits<float>::vectortype Packet;
3423 typedef typename quad_traits<float>::rhstype RhsPacket;
3425 void operator()(
const DataMapper& res,
const float* blockA,
const float* blockB, Index rows, Index depth, Index cols,
3426 float alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3429template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3430void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3431 const DataMapper& res,
const float* blockA,
const float* blockB, Index rows, Index depth, Index cols,
float alpha,
3432 Index strideA, Index strideB, Index offsetA, Index offsetB) {
3433 const Index accRows = quad_traits<float>::rows;
3434 const Index accCols = quad_traits<float>::size;
3437#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3438 (supportsMMA()) ? &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols> :
3440 &
Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
3441 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3444template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3445struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3446 typedef Packet4f Packet;
3447 typedef Packet2cf Packetc;
3448 typedef Packet4f RhsPacket;
3450 void operator()(
const DataMapper& res,
const std::complex<float>* blockA,
const std::complex<float>* blockB,
3451 Index rows, Index depth, Index cols, std::complex<float> alpha, Index strideA = -1,
3452 Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3455template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3456void gebp_kernel<std::complex<float>, std::complex<float>,
Index, DataMapper, mr, nr, ConjugateLhs,
3457 ConjugateRhs>::operator()(
const DataMapper& res,
const std::complex<float>* blockA,
3458 const std::complex<float>* blockB, Index rows, Index depth, Index cols,
3459 std::complex<float> alpha, Index strideA, Index strideB, Index offsetA,
3461 const Index accRows = quad_traits<float>::rows;
3462 const Index accCols = quad_traits<float>::size;
3463 static void (*gemm_function)(
const DataMapper&,
const std::complex<float>*,
const std::complex<float>*,
Index,
Index,
3465#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3466 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>,
3467 float, Packet, Packetc, RhsPacket, DataMapper, accRows,
3468 accCols, ConjugateLhs, ConjugateRhs,
false,
false>
3471 &
Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>,
3472 float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3473 ConjugateLhs, ConjugateRhs, false, false>;
3474 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3477template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3478struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3479 typedef Packet4f Packet;
3480 typedef Packet2cf Packetc;
3481 typedef Packet4f RhsPacket;
3483 void operator()(
const DataMapper& res,
const float* blockA,
const std::complex<float>* blockB, Index rows,
3484 Index depth, Index cols, std::complex<float> alpha, Index strideA = -1, Index strideB = -1,
3485 Index offsetA = 0, Index offsetB = 0);
3488template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3489void gebp_kernel<float, std::complex<float>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3490 const DataMapper& res,
const float* blockA,
const std::complex<float>* blockB, Index rows, Index depth, Index cols,
3491 std::complex<float> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3492 const Index accRows = quad_traits<float>::rows;
3493 const Index accCols = quad_traits<float>::size;
3494 static void (*gemm_function)(
const DataMapper&,
const float*,
const std::complex<float>*,
Index,
Index,
Index,
3496#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3497 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<
float, std::complex<float>, std::complex<float>,
float,
3498 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3499 ConjugateLhs, ConjugateRhs,
true,
false>
3502 &
Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet,
3503 Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3504 ConjugateRhs, true, false>;
3505 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3508template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3509struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3510 typedef Packet4f Packet;
3511 typedef Packet2cf Packetc;
3512 typedef Packet4f RhsPacket;
3514 void operator()(
const DataMapper& res,
const std::complex<float>* blockA,
const float* blockB, Index rows,
3515 Index depth, Index cols, std::complex<float> alpha, Index strideA = -1, Index strideB = -1,
3516 Index offsetA = 0, Index offsetB = 0);
3519template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3520void gebp_kernel<std::complex<float>, float,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3521 const DataMapper& res,
const std::complex<float>* blockA,
const float* blockB, Index rows, Index depth, Index cols,
3522 std::complex<float> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3523 const Index accRows = quad_traits<float>::rows;
3524 const Index accCols = quad_traits<float>::size;
3525 static void (*gemm_function)(
const DataMapper&,
const std::complex<float>*,
const float*,
Index,
Index,
Index,
3527#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3528 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<float>,
float, std::complex<float>,
float,
3529 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3530 ConjugateLhs, ConjugateRhs,
false,
true>
3533 &
Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet,
3534 Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3535 ConjugateRhs, false, true>;
3536 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3539template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3540struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3541 typedef typename quad_traits<double>::vectortype Packet;
3542 typedef typename quad_traits<double>::rhstype RhsPacket;
3544 void operator()(
const DataMapper& res,
const double* blockA,
const double* blockB, Index rows, Index depth,
3545 Index cols,
double alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
3549template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3550void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3551 const DataMapper& res,
const double* blockA,
const double* blockB, Index rows, Index depth, Index cols,
3552 double alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3553 const Index accRows = quad_traits<double>::rows;
3554 const Index accCols = quad_traits<double>::size;
3555 static void (*gemm_function)(
const DataMapper&,
const double*,
const double*,
Index,
Index,
Index, double,
Index,
3557#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3558 (supportsMMA()) ? &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols> :
3560 &
Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
3561 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3564template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3565struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3566 typedef quad_traits<double>::vectortype Packet;
3567 typedef Packet1cd Packetc;
3568 typedef quad_traits<double>::rhstype RhsPacket;
3570 void operator()(
const DataMapper& res,
const std::complex<double>* blockA,
const std::complex<double>* blockB,
3571 Index rows, Index depth, Index cols, std::complex<double> alpha, Index strideA = -1,
3572 Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3575template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3576void gebp_kernel<std::complex<double>, std::complex<double>,
Index, DataMapper, mr, nr, ConjugateLhs,
3577 ConjugateRhs>::operator()(
const DataMapper& res,
const std::complex<double>* blockA,
3578 const std::complex<double>* blockB, Index rows, Index depth, Index cols,
3579 std::complex<double> alpha, Index strideA, Index strideB, Index offsetA,
3581 const Index accRows = quad_traits<double>::rows;
3582 const Index accCols = quad_traits<double>::size;
3583 static void (*gemm_function)(
const DataMapper&,
const std::complex<double>*,
const std::complex<double>*,
Index,
3585#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3587 ? &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>,
double,
3588 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3589 ConjugateRhs,
false,
false>
3592 &
Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double,
3593 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3594 ConjugateRhs, false, false>;
3595 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3598template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3599struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3600 typedef quad_traits<double>::vectortype Packet;
3601 typedef Packet1cd Packetc;
3602 typedef quad_traits<double>::rhstype RhsPacket;
3604 void operator()(
const DataMapper& res,
const std::complex<double>* blockA,
const double* blockB, Index rows,
3605 Index depth, Index cols, std::complex<double> alpha, Index strideA = -1, Index strideB = -1,
3606 Index offsetA = 0, Index offsetB = 0);
3609template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3610void gebp_kernel<std::complex<double>, double,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3611 const DataMapper& res,
const std::complex<double>* blockA,
const double* blockB, Index rows, Index depth,
3612 Index cols, std::complex<double> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3613 const Index accRows = quad_traits<double>::rows;
3614 const Index accCols = quad_traits<double>::size;
3615 static void (*gemm_function)(
const DataMapper&,
const std::complex<double>*,
const double*,
Index,
Index,
Index,
3617#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3618 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<double>,
double, std::complex<double>,
double,
3619 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3620 ConjugateLhs, ConjugateRhs,
false,
true>
3623 &
Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet,
3624 Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3625 ConjugateRhs, false, true>;
3626 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3629template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3630struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3631 typedef quad_traits<double>::vectortype Packet;
3632 typedef Packet1cd Packetc;
3633 typedef quad_traits<double>::rhstype RhsPacket;
3635 void operator()(
const DataMapper& res,
const double* blockA,
const std::complex<double>* blockB, Index rows,
3636 Index depth, Index cols, std::complex<double> alpha, Index strideA = -1, Index strideB = -1,
3637 Index offsetA = 0, Index offsetB = 0);
3640template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3641void gebp_kernel<double, std::complex<double>,
Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3642 const DataMapper& res,
const double* blockA,
const std::complex<double>* blockB, Index rows, Index depth,
3643 Index cols, std::complex<double> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3644 const Index accRows = quad_traits<double>::rows;
3645 const Index accCols = quad_traits<double>::size;
3646 static void (*gemm_function)(
const DataMapper&,
const double*,
const std::complex<double>*,
Index,
Index,
Index,
3648#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3649 (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<
double, std::complex<double>, std::complex<double>,
double,
3650 Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3651 ConjugateLhs, ConjugateRhs,
true,
false>
3654 &
Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet,
3655 Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3656 ConjugateRhs, true, false>;
3657 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3660template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3661struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3662 typedef typename quad_traits<bfloat16>::vectortype Packet;
3663 typedef typename quad_traits<bfloat16>::rhstype RhsPacket;
3665 void operator()(
const DataMapper& res,
const bfloat16* blockA,
const bfloat16* blockB, Index rows, Index depth,
3666 Index cols, bfloat16 alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
3670template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3671void gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3672 const DataMapper& res,
const bfloat16* blockA,
const bfloat16* blockB, Index rows, Index depth, Index cols,
3673 bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3674 static void (*gemm_function)(
const DataMapper&,
const bfloat16*,
const bfloat16*,
Index,
Index,
Index, bfloat16,
3676#ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3677 (supportsMMA()) ? &Eigen::internal::gemmMMAbfloat16<DataMapper> :
3679 &
Eigen::internal::gemmbfloat16<DataMapper>;
3680 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_imag_op< typename Derived::Scalar >, const Derived > imag(const Eigen::ArrayBase< Derived > &x)