11#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
15#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
16#pragma GCC push_options
17#pragma GCC target("cpu=power10,htm")
21#if !__has_builtin(__builtin_vsx_assemble_pair)
22#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
24#if !__has_builtin(__builtin_vsx_disassemble_pair)
25#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
30#include "../../InternalHeaderCheck.h"
32#include "MatrixProductMMAbfloat16.h"
38#define accColsC (accCols / 2)
40EIGEN_ALWAYS_INLINE
void bsetzeroMMA(__vector_quad* acc) { __builtin_mma_xxsetaccz(acc); }
42template <
typename DataMapper,
typename Packet,
bool full>
43EIGEN_ALWAYS_INLINE
void storeAccumulator(Index i,
const DataMapper& data,
const Packet& alpha,
const Index elements,
45 PacketBlock<Packet, 4> result;
46 __builtin_mma_disassemble_acc(&result.packet, acc);
48 PacketBlock<Packet, 4> tRes;
50 EIGEN_UNUSED_VARIABLE(elements);
51 bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes, data, i, 0);
52 bscale<Packet, 4>(tRes, result, alpha);
53 bstore<DataMapper, Packet, 4>(tRes, data, i);
55 bload_partial<DataMapper, Packet, 0, false, 4>(tRes, data, i, elements);
56 bscale<Packet, 4>(tRes, result, alpha);
57 bstore_partial<DataMapper, Packet, 4>(tRes, data, i, elements);
61template <
typename DataMapper,
typename Packet,
typename Packetc, const Index accCols, const Index accCols2>
62EIGEN_ALWAYS_INLINE
void storeComplexAccumulator(Index i,
const DataMapper& data,
const Packet& alphaReal,
63 const Packet& alphaImag,
const Packet& pMask, __vector_quad* accReal,
64 __vector_quad* accImag) {
65 constexpr bool full = (accCols2 > accColsC);
66 PacketBlock<Packet, 4> resultReal, resultImag;
67 __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
68 __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
70 PacketBlock<Packetc, 8> tRes;
71 bload<DataMapper, Packetc, accColsC, ColMajor, true, 4, full>(tRes, data, i, 0);
73 PacketBlock<Packet, 4> taccReal, taccImag;
74 bscalec<Packet, 4, (accCols != accCols2)>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask);
76 PacketBlock<Packetc, 4> acc1, acc2;
77 bcouple<Packet, Packetc, 4, full>(taccReal, taccImag, tRes, acc1, acc2);
79 bstore<DataMapper, Packetc, 4>(acc1, data, i);
81 bstore<DataMapper, Packetc, 4>(acc2, data, i + accColsC);
86template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
87EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad* acc,
const RhsPacket& a,
const LhsPacket& b) {
88 if (NegativeAccumulate) {
89 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
91 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
95template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
96EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad* acc,
const __vector_pair& a,
const Packet2d& b) {
97 if (NegativeAccumulate) {
98 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector
unsigned char)b);
100 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector
unsigned char)b);
104template <
typename Packet,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
105EIGEN_ALWAYS_INLINE
void pgercMMA(__vector_quad* accReal, __vector_quad* accImag,
const Packet& lhsV, Packet& lhsVi,
106 const RhsPacket& rhsV, RhsPacket& rhsVi) {
107 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
109 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
110 EIGEN_UNUSED_VARIABLE(lhsVi);
113 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
114 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
116 EIGEN_UNUSED_VARIABLE(rhsVi);
118 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
123template <
typename Packet>
124EIGEN_ALWAYS_INLINE Packet ploadRhs(
const __UNPACK_TYPE__(Packet) * rhs) {
125 return ploadu<Packet>(rhs);
128template <
typename Scalar,
typename Packet>
129EIGEN_ALWAYS_INLINE
void ploadRhsMMA(
const Scalar* rhs, Packet& rhsV) {
130 rhsV = ploadRhs<Packet>(rhs);
134EIGEN_ALWAYS_INLINE
void ploadRhsMMA(
const double* rhs, __vector_pair& rhsV) {
136 __builtin_vsx_assemble_pair(
137 &rhsV,
reinterpret_cast<__vector
unsigned char>(ploadRhs<Packet2d>(rhs + (
sizeof(Packet2d) /
sizeof(
double)))),
138 reinterpret_cast<__vector
unsigned char>(ploadRhs<Packet2d>(rhs)));
140 rhsV = *
reinterpret_cast<__vector_pair*
>(
const_cast<double*
>(rhs));
144EIGEN_ALWAYS_INLINE
void ploadLhsMMA(
const double* lhs, __vector_pair& lhsV) { ploadRhsMMA(lhs, lhsV); }
146#define GEMM_MULTIPLE_COLS
151#define VECTOR_PAIR_LOADS_LHS
155#ifdef GEMM_MULTIPLE_COLS
159#if EIGEN_COMP_LLVM || (__GNUC__ < 12) || defined(VECTOR_PAIR_LOADS_LHS)
166#define MICRO_MMA_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
168#define MICRO_MMA_WORK(func, type, peel) \
170 func(0, type, peel, 0, 0) func(1, type, peel, 1, 0) func(2, type, peel, 2, 0) func(3, type, peel, 3, 0) \
171 func(4, type, peel, 4, 0) func(5, type, peel, 5, 0) func(6, type, peel, 6, 0) func(7, type, peel, 7, 0) \
172 } else if (accItr == 2) { \
173 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 1, 0) func(3, type, peel, 1, 1) \
174 func(4, type, peel, 2, 0) func(5, type, peel, 2, 1) func(6, type, peel, 3, 0) func(7, type, peel, 3, 1) \
176 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 0, 2) func(3, type, peel, 0, 3) \
177 func(4, type, peel, 1, 0) func(5, type, peel, 1, 1) func(6, type, peel, 1, 2) func(7, type, peel, 1, 3) \
180#define MICRO_MMA_WORK_ONE(iter, type, peel, left, right) \
181 if (unroll_factor > left) { \
182 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV##left); \
185#ifdef VECTOR_PAIR_LOADS_LHS
186#define MICRO_MMA_WORK_TWO(iter, type, peel, left, right) \
187 if (unroll_factor > left) { \
188 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV2##left.packet[peel & 1]); \
191#define MICRO_MMA_LOAD1_TWO(lhs_ptr, left) \
192 if (unroll_factor > left) { \
193 if (MICRO_NORMAL(left)) { \
194 ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##left), plhsV##left); \
195 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##left.packet), &plhsV##left); \
196 lhs_ptr##left += accCols * 2; \
198 lhsV2##left.packet[0] = ploadLhs<Packet>(lhs_ptr##left); \
199 lhsV2##left.packet[1] = ploadLhs<Packet>(lhs_ptr##left + accCols2); \
200 lhs_ptr##left += accCols2 * 2; \
201 EIGEN_UNUSED_VARIABLE(plhsV##left); \
204 EIGEN_UNUSED_VARIABLE(lhsV2##left); \
205 EIGEN_UNUSED_VARIABLE(plhsV##left); \
208#define MICRO_MMA_LOAD_TWO(left) MICRO_MMA_LOAD1_TWO(lhs_ptr, left)
211#define MICRO_MMA_UNROLL_ITER(func, val) \
212 func(val, 0) if (accItr > 1) { \
213 func(val, 1) if (accItr > 2) { func(val, 2) func(val, 3) } \
216#define MICRO_MMA_LOAD_ONE_RHS1(peel, right) ploadRhsMMA(rhs_ptr##right + (accRows * peel), rhsV##right[peel]);
218#define MICRO_MMA_LOAD_ONE_RHS(peel) MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_ONE_RHS1, peel)
220#define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
221 if (PEEL_MMA > peel) { \
222 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
223 MICRO_MMA_LOAD_ONE_RHS(peel) \
224 MICRO_MMA_UNROLL(funcl) \
225 MICRO_MMA_WORK(funcw, type, peel) \
228#ifndef VECTOR_PAIR_LOADS_LHS
229#define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
230 type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
231 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 0) \
232 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 1) \
233 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 2) \
234 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 3) \
235 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 4) \
236 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 5) \
237 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 6) MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 7)
239#define MICRO_MMA_LOAD_TWO_RHS(peel1, right) \
240 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr##right + (accRows * peel1)), prhsV##peel1); \
241 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1);
243#define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
244 if (PEEL_MMA > peel2) { \
245 PacketBlock<Packet, 2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
246 __vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
247 if (sizeof(type) == 16) { \
248 MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_TWO_RHS, peel1) \
250 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
251 MICRO_MMA_LOAD_ONE_RHS(peel1) \
252 MICRO_MMA_LOAD_ONE_RHS(peel2) \
254 MICRO_MMA_UNROLL(funcl2) \
255 MICRO_MMA_WORK(funcw2, type, peel1) \
256 MICRO_MMA_WORK(funcw2, type, peel2) \
258 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
259 MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
262#define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
263 type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
264 __vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \
265 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 0, 1) \
266 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 2, 3) \
267 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 4, 5) \
268 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 6, 7)
271#define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
272 type rhsV0[1], rhsV1[1], rhsV2[1], rhsV3[1]; \
273 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 0)
275#define MICRO_MMA_UPDATE_RHS1(size, right) rhs_ptr##right += (accRows * size);
277#define MICRO_MMA_UPDATE_RHS(size) MICRO_MMA_UNROLL_ITER(MICRO_MMA_UPDATE_RHS1, size)
279#define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
280 MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \
281 MICRO_MMA_UPDATE_RHS(size)
283#ifndef VECTOR_PAIR_LOADS_LHS
284#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_PEEL, PEEL_MMA)
286#define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \
287 MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \
288 MICRO_MMA_UPDATE_RHS(size)
290#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA)
293#define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1)
295#define MICRO_MMA_DST_PTR_ONE(iter) \
296 if (unroll_factor * accItr > iter) { \
297 bsetzeroMMA(&accZero##iter); \
299 EIGEN_UNUSED_VARIABLE(accZero##iter); \
302#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
304#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_SRC_PTR_ONE)
306#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE)
308#define MICRO_MMA_STORE_ONE(iter, left, right) \
309 if (unroll_factor > left) { \
310 storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(left)>(row + left * accCols, res##right, pAlpha, \
311 accCols2, &accZero##iter); \
314#define MICRO_MMA_ITER_UNROLL(func) \
316 func(0, 0, 0) func(1, 1, 0) func(2, 2, 0) func(3, 3, 0) func(4, 4, 0) func(5, 5, 0) func(6, 6, 0) func(7, 7, 0) \
317 } else if (accItr == 2) { \
318 func(0, 0, 0) func(1, 0, 1) func(2, 1, 0) func(3, 1, 1) func(4, 2, 0) func(5, 2, 1) func(6, 3, 0) func(7, 3, 1) \
320 func(0, 0, 0) func(1, 0, 1) func(2, 0, 2) func(3, 0, 3) func(4, 1, 0) func(5, 1, 1) func(6, 1, 2) func(7, 1, 3) \
323#define MICRO_MMA_STORE MICRO_MMA_ITER_UNROLL(MICRO_MMA_STORE_ONE)
325#define MICRO_MMA_EXTRA_ROWS(right) \
326 gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>( \
327 res3##right, blockA, rhs_base + right * accRows * strideB, depth, strideA, offsetA, strideB, row, rows, \
328 remaining_rows, pAlpha, pMask);
330#define MICRO_MMA_EXTRA_ROWS1(val, right) MICRO_MMA_EXTRA_ROWS(right);
332template <
int unroll_factor,
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
333 const Index accRows,
const Index accCols,
bool full,
const Index accItr>
334EIGEN_ALWAYS_INLINE
void gemm_unrolled_MMA_iteration(
const DataMapper& res0,
const DataMapper& res1,
335 const DataMapper& res2,
const DataMapper& res3,
336 const Scalar* lhs_base,
const Scalar* rhs_base, Index depth,
337 Index strideA, Index strideB, Index offsetA, Index& row,
338 const Packet& pAlpha, Index accCols2) {
339 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL, *rhs_ptr3 = NULL;
340 const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
341 *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
342 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
345 rhs_ptr1 = rhs_base + (accRows * strideB);
347 EIGEN_UNUSED_VARIABLE(strideB);
348 EIGEN_UNUSED_VARIABLE(rhs_ptr1);
349 EIGEN_UNUSED_VARIABLE(res1);
352 rhs_ptr2 = rhs_base + (2 * accRows * strideB);
353 rhs_ptr3 = rhs_base + (3 * accRows * strideB);
355 EIGEN_UNUSED_VARIABLE(rhs_ptr2);
356 EIGEN_UNUSED_VARIABLE(rhs_ptr3);
357 EIGEN_UNUSED_VARIABLE(res2);
358 EIGEN_UNUSED_VARIABLE(res3);
364 Index k = 0, depth2 = depth - PEEL_MMA;
365 for (; k <= depth2; k += PEEL_MMA) {
366 EIGEN_POWER_PREFETCH(rhs_ptr);
370 for (; k < depth; k++) {
378#define MICRO_MMA_UNROLL_ITER2(N, M) \
379 gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M, accItr>( \
380 res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, strideB, offsetA, row, pAlpha, \
381 M ? remaining_rows : accCols); \
384#define MICRO_MMA_ROWS(n) \
385 while (row + n * accCols <= rows) { \
386 MICRO_MMA_UNROLL_ITER2(n, 0); \
389template <
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
const Index accRows,
390 const Index accCols,
const Index accItr>
391EIGEN_ALWAYS_INLINE
void gemmMMA_cols(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB, Index depth,
392 Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows,
393 Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask) {
394 const DataMapper res30 = res.getSubMapper(0, col);
395 const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows * 1) : res30;
396 const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows * 2) : res30;
397 const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows * 3) : res30;
399 const Scalar* rhs_base = blockB + col * strideB + accRows * offsetB;
400 const Scalar* lhs_base = blockA + accCols * offsetA;
403#define MAX_MMA_UNROLL 7
405#if MAX_MMA_UNROLL < 2
407#elif MAX_MMA_UNROLL < 4
412 MICRO_MMA_ROWS(MAX_MMA_UNROLL);
413 }
else if (accItr == 2) {
418 switch ((rows - row) / accCols) {
419#if MAX_MMA_UNROLL > 7
422 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7)
426#if MAX_MMA_UNROLL > 6
429 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6)
433#if MAX_MMA_UNROLL > 5
436 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5)
440#if MAX_MMA_UNROLL > 4
443 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4)
447#if MAX_MMA_UNROLL > 3
450 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3)
454#if MAX_MMA_UNROLL > 2
457 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2)
461#if MAX_MMA_UNROLL > 1
463 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 1)
471 if (remaining_rows > 0) {
472 MICRO_MMA_UNROLL_ITER(MICRO_MMA_EXTRA_ROWS1, 0)
476#define MICRO_MMA_COLS(n) \
477 for (; col + n * accRows <= cols; col += n * accRows) { \
478 gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols, n>( \
479 res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); \
482template <
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
const Index accRows,
484void gemmMMA(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB, Index rows, Index depth, Index cols,
485 Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
486 const Index remaining_rows = rows % accCols;
488 if (strideA == -1) strideA = depth;
489 if (strideB == -1) strideB = depth;
491 const Packet pAlpha = pset1<Packet>(alpha);
492 const Packet pMask = bmask<Packet>(remaining_rows);
494 typedef typename std::conditional_t<(
sizeof(Scalar) ==
sizeof(
float)), RhsPacket, __vector_pair> RhsPacket2;
497#ifdef GEMM_MULTIPLE_COLS
504 gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
505 col, rows, cols, remaining_rows, pAlpha, pMask);
509#define advanceRows ((LhsIsReal) ? 1 : 2)
510#define advanceCols ((RhsIsReal) ? 1 : 2)
513#ifdef GEMM_MULTIPLE_COLS
514#define PEEL_COMPLEX_MMA 4
516#define PEEL_COMPLEX_MMA 3
519#define MICRO_COMPLEX_MMA_UNROLL(func) func(0) func(1) func(2) func(3)
521#define MICRO_COMPLEX_MMA_WORK(func, type, peel) \
523 func(0, type, peel, 0, 0) func(1, type, peel, 1, 0) func(2, type, peel, 2, 0) func(3, type, peel, 3, 0) \
524 } else if (accItr == 2) { \
525 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 1, 0) func(3, type, peel, 1, 1) \
527 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 0, 2) func(3, type, peel, 0, 3) \
530#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel, left, right) \
531 if (unroll_factor > left) { \
532 pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
533 &accReal##iter, &accImag##iter, lhsV##left, lhsVi##left, rhsV##right[peel], rhsVi##right[peel]); \
536#ifdef VECTOR_PAIR_LOADS_LHS
537#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel, left, right) \
538 if (unroll_factor > left) { \
539 pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
540 &accReal##iter, &accImag##iter, lhsV2##left.packet[peel & 1], lhsVi2##left.packet[peel & 1], \
541 rhsV##right[peel], rhsVi##right[peel]); \
544#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left) \
545 if (!LhsIsReal && (unroll_factor > left)) { \
546 if (MICRO_NORMAL(left)) { \
547 ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##left + imag_delta), plhsVi##left); \
548 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##left.packet), &plhsVi##left); \
550 lhsVi2##left.packet[0] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2); \
551 lhsVi2##left.packet[1] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2 + accCols2); \
552 EIGEN_UNUSED_VARIABLE(plhsVi##left); \
555 EIGEN_UNUSED_VARIABLE(lhsVi2##left); \
556 EIGEN_UNUSED_VARIABLE(plhsVi##left); \
558 MICRO_MMA_LOAD1_TWO(lhs_ptr_real, left)
560#define MICRO_COMPLEX_MMA_LOAD_TWO(left) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left)
563#define MICRO_COMPLEX_MMA_LOAD_RHS1(peel, right) \
564 ploadRhsMMA(rhs_ptr_real##right + (accRows * peel), rhsV##right[peel]); \
566 ploadRhsMMA(rhs_ptr_imag##right + (accRows * peel), rhsVi##right[peel]); \
569#define MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_RHS1, peel)
571#define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
572 if (PEEL_COMPLEX_MMA > peel) { \
573 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
574 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
575 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \
576 MICRO_COMPLEX_MMA_UNROLL(funcl) \
577 MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
580#ifndef VECTOR_PAIR_LOADS_LHS
581#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
582 type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], \
583 rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
584 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 0) \
585 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 1) \
586 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 3)
588#define MICRO_COMPLEX_MMA_LOAD_TWO_RHS(peel1, right) \
589 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real##right + (accRows * peel1)), prhsV##peel1); \
590 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1); \
592 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag##right + (accRows * peel1)), prhsVi##peel1); \
593 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi##right[peel1]), &prhsVi##peel1); \
595 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
598#define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
599 if (PEEL_COMPLEX_MMA > peel2) { \
600 PacketBlock<Packet, 2> lhsV20, lhsV21, lhsV22, lhsV23; \
601 PacketBlock<Packet, 2> lhsVi20, lhsVi21, lhsVi22, lhsVi23; \
602 __vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
603 __vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
604 if (sizeof(type) == 16) { \
605 MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_TWO_RHS, peel1) \
607 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
608 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
609 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel1); \
610 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel2); \
612 MICRO_COMPLEX_MMA_UNROLL(funcl2) \
613 MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
614 MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \
616 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
617 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
618 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
621#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
622 type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], \
623 rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
624 __vector_pair prhsV0, prhsV2; \
625 __vector_pair prhsVi0, prhsVi2; \
626 MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 0, 1) \
627 MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 2, 3)
630#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
631 type rhsV0[1], rhsVi0[1], rhsV1[1], rhsVi1[1], rhsV2[1], rhsVi2[1], rhsV3[1], rhsVi3[1]; \
632 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 0)
634#define MICRO_COMPLEX_MMA_UPDATE_RHS1(size, right) \
635 rhs_ptr_real##right += (accRows * size); \
636 if (!RhsIsReal) rhs_ptr_imag##right += (accRows * size);
638#define MICRO_COMPLEX_MMA_UPDATE_RHS(size) MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_UPDATE_RHS1, size)
640#define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
641 MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \
642 MICRO_COMPLEX_MMA_UPDATE_RHS(size);
644#ifndef VECTOR_PAIR_LOADS_LHS
645#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL, PEEL_COMPLEX_MMA)
647#define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \
648 MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, \
649 MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \
650 MICRO_COMPLEX_MMA_UPDATE_RHS(size);
652#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA)
655#define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1)
657#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
658 if (unroll_factor * accItr > iter) { \
659 bsetzeroMMA(&accReal##iter); \
660 bsetzeroMMA(&accImag##iter); \
662 EIGEN_UNUSED_VARIABLE(accReal##iter); \
663 EIGEN_UNUSED_VARIABLE(accImag##iter); \
666#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
668#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
670#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
672#define MICRO_COMPLEX_MMA_STORE_ONE(iter, left, right) \
673 if (unroll_factor > left) { \
674 storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (left + 1)) ? accCols : accCols2>( \
675 row + left * accCols, res##right, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
678#define MICRO_COMPLEX_MMA_ITER_UNROLL(func) \
680 func(0, 0, 0) func(1, 1, 0) func(2, 2, 0) func(3, 3, 0) \
681 } else if (accItr == 2) { \
682 func(0, 0, 0) func(1, 0, 1) func(2, 1, 0) func(3, 1, 1) \
684 func(0, 0, 0) func(1, 0, 1) func(2, 0, 2) func(3, 0, 3) \
687#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_ITER_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
689#define MICRO_COMPLEX_MMA_EXTRA_ROWS(right) \
690 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
691 RhsIsReal>(res3##right, blockA, rhs_base + right * accRows * (RhsIsReal ? 1 : 2) * strideB, \
692 depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, \
695#define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right) MICRO_COMPLEX_MMA_EXTRA_ROWS(right);
697template <
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
698 typename DataMapper,
const Index accRows,
const Index accCols,
const Index accCols2,
bool ConjugateLhs,
699 bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal,
const Index accItr>
700EIGEN_ALWAYS_INLINE
void gemm_complex_unrolled_MMA_iteration(
const DataMapper& res0,
const DataMapper& res1,
701 const DataMapper& res2,
const DataMapper& res3,
702 const Scalar* lhs_base,
const Scalar* rhs_base,
703 Index depth, Index strideA, Index offsetA, Index strideB,
704 Index& row,
const Packet& pAlphaReal,
705 const Packet& pAlphaImag,
const Packet& pMask) {
706 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL, *rhs_ptr_real3 = NULL;
707 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL, *rhs_ptr_imag3 = NULL;
708 const Index imag_delta = accCols * strideA;
709 const Index imag_delta2 = accCols2 * strideA;
712 rhs_ptr_imag0 = rhs_base + accRows * strideB;
714 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag0);
718 rhs_ptr_real1 = rhs_base + (2 * accRows * strideB);
719 rhs_ptr_imag1 = rhs_base + (3 * accRows * strideB);
721 rhs_ptr_real1 = rhs_base + accRows * strideB;
722 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
725 EIGEN_UNUSED_VARIABLE(rhs_ptr_real1);
726 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
727 EIGEN_UNUSED_VARIABLE(res1);
731 rhs_ptr_real2 = rhs_base + (4 * accRows * strideB);
732 rhs_ptr_imag2 = rhs_base + (5 * accRows * strideB);
733 rhs_ptr_real3 = rhs_base + (6 * accRows * strideB);
734 rhs_ptr_imag3 = rhs_base + (7 * accRows * strideB);
736 rhs_ptr_real2 = rhs_base + (2 * accRows * strideB);
737 rhs_ptr_real3 = rhs_base + (3 * accRows * strideB);
738 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
739 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
742 EIGEN_UNUSED_VARIABLE(rhs_ptr_real2);
743 EIGEN_UNUSED_VARIABLE(rhs_ptr_real3);
744 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
745 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
746 EIGEN_UNUSED_VARIABLE(res2);
747 EIGEN_UNUSED_VARIABLE(res3);
749 const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
750 const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
751 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
753 MICRO_COMPLEX_MMA_SRC_PTR
754 MICRO_COMPLEX_MMA_DST_PTR
756 Index k = 0, depth2 = depth - PEEL_COMPLEX_MMA;
757 for (; k <= depth2; k += PEEL_COMPLEX_MMA) {
758 EIGEN_POWER_PREFETCH(rhs_ptr_real);
760 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
762 MICRO_COMPLEX_MMA_PREFETCH
763 MICRO_COMPLEX_MMA_ONE_PEEL
765 for (; k < depth; k++) {
766 MICRO_COMPLEX_MMA_ONE
768 MICRO_COMPLEX_MMA_STORE
773#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
774 gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, \
775 accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, \
776 accItr>(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, offsetA, \
777 strideB, row, pAlphaReal, pAlphaImag, pMask); \
780#define MICRO_COMPLEX_MMA_ROWS(n) \
781 while (row + n * accCols <= rows) { \
782 MICRO_COMPLEX_MMA_UNROLL_ITER2(n, 0); \
785template <
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper,
786 const Index accRows,
const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
787 bool RhsIsReal,
const Index accItr>
788EIGEN_ALWAYS_INLINE
void gemmMMA_complex_cols(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB,
789 Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB,
790 Index col, Index rows, Index remaining_rows,
const Packet& pAlphaReal,
791 const Packet& pAlphaImag,
const Packet& pMask) {
792 const DataMapper res30 = res.getSubMapper(0, col);
793 const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows * 1) : res30;
794 const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows * 2) : res30;
795 const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows * 3) : res30;
797 const Scalar* rhs_base = blockB + advanceCols * col * strideB + accRows * offsetB;
798 const Scalar* lhs_base = blockA + accCols * offsetA;
801#define MAX_COMPLEX_MMA_UNROLL 4
803#if MAX_COMPLEX_MMA_UNROLL < 2
805#elif MAX_COMPLEX_MMA_UNROLL < 4
810 MICRO_COMPLEX_MMA_ROWS(MAX_COMPLEX_MMA_UNROLL);
811 }
else if (accItr == 2) {
812 MICRO_COMPLEX_MMA_ROWS(2);
814 MICRO_COMPLEX_MMA_ROWS(1);
816 switch ((rows - row) / accCols) {
817#if MAX_COMPLEX_MMA_UNROLL > 3
820 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3)
824#if MAX_COMPLEX_MMA_UNROLL > 2
827 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2)
831#if MAX_COMPLEX_MMA_UNROLL > 1
834 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1)
841#undef MAX_COMPLEX_MMA_UNROLL
843 if (remaining_rows > 0) {
844 MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_EXTRA_ROWS1, 0)
848#define MICRO_COMPLEX_MMA_COLS(n) \
849 for (; col + n * accRows <= cols; col += n * accRows) { \
850 gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, \
851 ConjugateRhs, LhsIsReal, RhsIsReal, n>(res, blockA, blockB, depth, strideA, offsetA, strideB, \
852 offsetB, col, rows, remaining_rows, pAlphaReal, \
853 pAlphaImag, pMask); \
856template <
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Packet,
typename Packetc,
857 typename RhsPacket,
typename DataMapper,
const Index accRows,
const Index accCols,
bool ConjugateLhs,
858 bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
859void gemm_complexMMA(
const DataMapper& res,
const LhsScalar* blockAc,
const RhsScalar* blockBc, Index rows, Index depth,
860 Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
861 const Index remaining_rows = rows % accCols;
863 if (strideA == -1) strideA = depth;
864 if (strideB == -1) strideB = depth;
866 const Packet pAlphaReal = pset1<Packet>(alpha.real());
867 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
868 const Packet pMask = bmask<Packet>(remaining_rows);
870 const Scalar* blockA = (Scalar*)blockAc;
871 const Scalar* blockB = (Scalar*)blockBc;
873 typedef typename std::conditional_t<(
sizeof(Scalar) ==
sizeof(
float)), RhsPacket, __vector_pair> RhsPacket2;
876#ifdef GEMM_MULTIPLE_COLS
877 MICRO_COMPLEX_MMA_COLS(4);
878 MICRO_COMPLEX_MMA_COLS(2);
880 MICRO_COMPLEX_MMA_COLS(1);
883 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
884 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols,
885 remaining_rows, pAlphaReal, pAlphaImag, pMask);
897#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
898#pragma GCC pop_options
Namespace containing all symbols from the Eigen library.
Definition Core:137