Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
MatrixProductMMA.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020 Everton Constantino ([email protected])
5// Copyright (C) 2021 Chip Kerchner ([email protected])
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
13
14// If using dynamic dispatch, set the CPU target.
15#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
16#pragma GCC push_options
17#pragma GCC target("cpu=power10,htm")
18#endif
19
20#ifdef __has_builtin
21#if !__has_builtin(__builtin_vsx_assemble_pair)
22#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
23#endif
24#if !__has_builtin(__builtin_vsx_disassemble_pair)
25#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
26#endif
27#endif
28
29// IWYU pragma: private
30#include "../../InternalHeaderCheck.h"
31
32#include "MatrixProductMMAbfloat16.h"
33
34namespace Eigen {
35
36namespace internal {
37
38#define accColsC (accCols / 2)
39
40EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc) { __builtin_mma_xxsetaccz(acc); }
41
42template <typename DataMapper, typename Packet, bool full>
43EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Index elements,
44 __vector_quad* acc) {
45 PacketBlock<Packet, 4> result;
46 __builtin_mma_disassemble_acc(&result.packet, acc);
47
48 PacketBlock<Packet, 4> tRes;
49 if (full) {
50 EIGEN_UNUSED_VARIABLE(elements);
51 bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes, data, i, 0);
52 bscale<Packet, 4>(tRes, result, alpha);
53 bstore<DataMapper, Packet, 4>(tRes, data, i);
54 } else {
55 bload_partial<DataMapper, Packet, 0, false, 4>(tRes, data, i, elements);
56 bscale<Packet, 4>(tRes, result, alpha);
57 bstore_partial<DataMapper, Packet, 4>(tRes, data, i, elements);
58 }
59}
60
61template <typename DataMapper, typename Packet, typename Packetc, const Index accCols, const Index accCols2>
62EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal,
63 const Packet& alphaImag, const Packet& pMask, __vector_quad* accReal,
64 __vector_quad* accImag) {
65 constexpr bool full = (accCols2 > accColsC);
66 PacketBlock<Packet, 4> resultReal, resultImag;
67 __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
68 __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
69
70 PacketBlock<Packetc, 8> tRes;
71 bload<DataMapper, Packetc, accColsC, ColMajor, true, 4, full>(tRes, data, i, 0);
72
73 PacketBlock<Packet, 4> taccReal, taccImag;
74 bscalec<Packet, 4, (accCols != accCols2)>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask);
75
76 PacketBlock<Packetc, 4> acc1, acc2;
77 bcouple<Packet, Packetc, 4, full>(taccReal, taccImag, tRes, acc1, acc2);
78
79 bstore<DataMapper, Packetc, 4>(acc1, data, i);
80 if (full) {
81 bstore<DataMapper, Packetc, 4>(acc2, data, i + accColsC);
82 }
83}
84
85// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
86template <typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
87EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b) {
88 if (NegativeAccumulate) {
89 __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
90 } else {
91 __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
92 }
93}
94
95template <typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
96EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b) {
97 if (NegativeAccumulate) {
98 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
99 } else {
100 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
101 }
102}
103
104template <typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
105EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, Packet& lhsVi,
106 const RhsPacket& rhsV, RhsPacket& rhsVi) {
107 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
108 if (LhsIsReal) {
109 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
110 EIGEN_UNUSED_VARIABLE(lhsVi);
111 } else {
112 if (!RhsIsReal) {
113 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
114 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
115 } else {
116 EIGEN_UNUSED_VARIABLE(rhsVi);
117 }
118 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
119 }
120}
121
122// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
123template <typename Packet>
124EIGEN_ALWAYS_INLINE Packet ploadRhs(const __UNPACK_TYPE__(Packet) * rhs) {
125 return ploadu<Packet>(rhs);
126}
127
128template <typename Scalar, typename Packet>
129EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV) {
130 rhsV = ploadRhs<Packet>(rhs);
131}
132
133template <>
134EIGEN_ALWAYS_INLINE void ploadRhsMMA(const double* rhs, __vector_pair& rhsV) {
135#if EIGEN_COMP_LLVM
136 __builtin_vsx_assemble_pair(
137 &rhsV, reinterpret_cast<__vector unsigned char>(ploadRhs<Packet2d>(rhs + (sizeof(Packet2d) / sizeof(double)))),
138 reinterpret_cast<__vector unsigned char>(ploadRhs<Packet2d>(rhs)));
139#else
140 rhsV = *reinterpret_cast<__vector_pair*>(const_cast<double*>(rhs));
141#endif
142}
143
144EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV) { ploadRhsMMA(lhs, lhsV); }
145
146#define GEMM_MULTIPLE_COLS
147
148// Disable in GCC until unnecessary register moves are fixed
149// #if (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
150#if EIGEN_COMP_LLVM
151#define VECTOR_PAIR_LOADS_LHS
152#endif
153
154// PEEL_MMA loop factor.
155#ifdef GEMM_MULTIPLE_COLS
156#define PEEL_MMA 8
157#else
158// Register spillage with GCC12+
159#if EIGEN_COMP_LLVM || (__GNUC__ < 12) || defined(VECTOR_PAIR_LOADS_LHS)
160#define PEEL_MMA 7
161#else
162#define PEEL_MMA 6
163#endif
164#endif
165
166#define MICRO_MMA_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
167
168#define MICRO_MMA_WORK(func, type, peel) \
169 if (accItr == 1) { \
170 func(0, type, peel, 0, 0) func(1, type, peel, 1, 0) func(2, type, peel, 2, 0) func(3, type, peel, 3, 0) \
171 func(4, type, peel, 4, 0) func(5, type, peel, 5, 0) func(6, type, peel, 6, 0) func(7, type, peel, 7, 0) \
172 } else if (accItr == 2) { \
173 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 1, 0) func(3, type, peel, 1, 1) \
174 func(4, type, peel, 2, 0) func(5, type, peel, 2, 1) func(6, type, peel, 3, 0) func(7, type, peel, 3, 1) \
175 } else { \
176 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 0, 2) func(3, type, peel, 0, 3) \
177 func(4, type, peel, 1, 0) func(5, type, peel, 1, 1) func(6, type, peel, 1, 2) func(7, type, peel, 1, 3) \
178 }
179
180#define MICRO_MMA_WORK_ONE(iter, type, peel, left, right) \
181 if (unroll_factor > left) { \
182 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV##left); \
183 }
184
185#ifdef VECTOR_PAIR_LOADS_LHS
186#define MICRO_MMA_WORK_TWO(iter, type, peel, left, right) \
187 if (unroll_factor > left) { \
188 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV2##left.packet[peel & 1]); \
189 }
190
191#define MICRO_MMA_LOAD1_TWO(lhs_ptr, left) \
192 if (unroll_factor > left) { \
193 if (MICRO_NORMAL(left)) { \
194 ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##left), plhsV##left); \
195 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##left.packet), &plhsV##left); \
196 lhs_ptr##left += accCols * 2; \
197 } else { \
198 lhsV2##left.packet[0] = ploadLhs<Packet>(lhs_ptr##left); \
199 lhsV2##left.packet[1] = ploadLhs<Packet>(lhs_ptr##left + accCols2); \
200 lhs_ptr##left += accCols2 * 2; \
201 EIGEN_UNUSED_VARIABLE(plhsV##left); \
202 } \
203 } else { \
204 EIGEN_UNUSED_VARIABLE(lhsV2##left); \
205 EIGEN_UNUSED_VARIABLE(plhsV##left); \
206 }
207
208#define MICRO_MMA_LOAD_TWO(left) MICRO_MMA_LOAD1_TWO(lhs_ptr, left)
209#endif
210
211#define MICRO_MMA_UNROLL_ITER(func, val) \
212 func(val, 0) if (accItr > 1) { \
213 func(val, 1) if (accItr > 2) { func(val, 2) func(val, 3) } \
214 }
215
216#define MICRO_MMA_LOAD_ONE_RHS1(peel, right) ploadRhsMMA(rhs_ptr##right + (accRows * peel), rhsV##right[peel]);
217
218#define MICRO_MMA_LOAD_ONE_RHS(peel) MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_ONE_RHS1, peel)
219
220#define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
221 if (PEEL_MMA > peel) { \
222 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
223 MICRO_MMA_LOAD_ONE_RHS(peel) \
224 MICRO_MMA_UNROLL(funcl) \
225 MICRO_MMA_WORK(funcw, type, peel) \
226 }
227
228#ifndef VECTOR_PAIR_LOADS_LHS
229#define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
230 type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
231 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 0) \
232 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 1) \
233 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 2) \
234 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 3) \
235 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 4) \
236 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 5) \
237 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 6) MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 7)
238#else
239#define MICRO_MMA_LOAD_TWO_RHS(peel1, right) \
240 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr##right + (accRows * peel1)), prhsV##peel1); \
241 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1);
242
243#define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
244 if (PEEL_MMA > peel2) { \
245 PacketBlock<Packet, 2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
246 __vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
247 if (sizeof(type) == 16) { \
248 MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_TWO_RHS, peel1) \
249 } else { \
250 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
251 MICRO_MMA_LOAD_ONE_RHS(peel1) \
252 MICRO_MMA_LOAD_ONE_RHS(peel2) \
253 } \
254 MICRO_MMA_UNROLL(funcl2) \
255 MICRO_MMA_WORK(funcw2, type, peel1) \
256 MICRO_MMA_WORK(funcw2, type, peel2) \
257 } else { \
258 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
259 MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
260 }
261
262#define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
263 type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
264 __vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \
265 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 0, 1) \
266 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 2, 3) \
267 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 4, 5) \
268 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 6, 7)
269#endif
270
271#define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
272 type rhsV0[1], rhsV1[1], rhsV2[1], rhsV3[1]; \
273 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 0)
274
275#define MICRO_MMA_UPDATE_RHS1(size, right) rhs_ptr##right += (accRows * size);
276
277#define MICRO_MMA_UPDATE_RHS(size) MICRO_MMA_UNROLL_ITER(MICRO_MMA_UPDATE_RHS1, size)
278
279#define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
280 MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \
281 MICRO_MMA_UPDATE_RHS(size)
282
283#ifndef VECTOR_PAIR_LOADS_LHS
284#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_PEEL, PEEL_MMA)
285#else
286#define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \
287 MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \
288 MICRO_MMA_UPDATE_RHS(size)
289
290#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA)
291#endif
292
293#define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1)
294
295#define MICRO_MMA_DST_PTR_ONE(iter) \
296 if (unroll_factor * accItr > iter) { \
297 bsetzeroMMA(&accZero##iter); \
298 } else { \
299 EIGEN_UNUSED_VARIABLE(accZero##iter); \
300 }
301
302#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
303
304#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_SRC_PTR_ONE)
305
306#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE)
307
308#define MICRO_MMA_STORE_ONE(iter, left, right) \
309 if (unroll_factor > left) { \
310 storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(left)>(row + left * accCols, res##right, pAlpha, \
311 accCols2, &accZero##iter); \
312 }
313
314#define MICRO_MMA_ITER_UNROLL(func) \
315 if (accItr == 1) { \
316 func(0, 0, 0) func(1, 1, 0) func(2, 2, 0) func(3, 3, 0) func(4, 4, 0) func(5, 5, 0) func(6, 6, 0) func(7, 7, 0) \
317 } else if (accItr == 2) { \
318 func(0, 0, 0) func(1, 0, 1) func(2, 1, 0) func(3, 1, 1) func(4, 2, 0) func(5, 2, 1) func(6, 3, 0) func(7, 3, 1) \
319 } else { \
320 func(0, 0, 0) func(1, 0, 1) func(2, 0, 2) func(3, 0, 3) func(4, 1, 0) func(5, 1, 1) func(6, 1, 2) func(7, 1, 3) \
321 }
322
323#define MICRO_MMA_STORE MICRO_MMA_ITER_UNROLL(MICRO_MMA_STORE_ONE)
324
325#define MICRO_MMA_EXTRA_ROWS(right) \
326 gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>( \
327 res3##right, blockA, rhs_base + right * accRows * strideB, depth, strideA, offsetA, strideB, row, rows, \
328 remaining_rows, pAlpha, pMask);
329
330#define MICRO_MMA_EXTRA_ROWS1(val, right) MICRO_MMA_EXTRA_ROWS(right);
331
332template <int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper,
333 const Index accRows, const Index accCols, bool full, const Index accItr>
334EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(const DataMapper& res0, const DataMapper& res1,
335 const DataMapper& res2, const DataMapper& res3,
336 const Scalar* lhs_base, const Scalar* rhs_base, Index depth,
337 Index strideA, Index strideB, Index offsetA, Index& row,
338 const Packet& pAlpha, Index accCols2) {
339 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL, *rhs_ptr3 = NULL;
340 const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
341 *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
342 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
343
344 if (accItr > 1) {
345 rhs_ptr1 = rhs_base + (accRows * strideB);
346 } else {
347 EIGEN_UNUSED_VARIABLE(strideB);
348 EIGEN_UNUSED_VARIABLE(rhs_ptr1);
349 EIGEN_UNUSED_VARIABLE(res1);
350 }
351 if (accItr > 2) {
352 rhs_ptr2 = rhs_base + (2 * accRows * strideB);
353 rhs_ptr3 = rhs_base + (3 * accRows * strideB);
354 } else {
355 EIGEN_UNUSED_VARIABLE(rhs_ptr2);
356 EIGEN_UNUSED_VARIABLE(rhs_ptr3);
357 EIGEN_UNUSED_VARIABLE(res2);
358 EIGEN_UNUSED_VARIABLE(res3);
359 }
360
361 MICRO_MMA_SRC_PTR
362 MICRO_MMA_DST_PTR
363
364 Index k = 0, depth2 = depth - PEEL_MMA;
365 for (; k <= depth2; k += PEEL_MMA) {
366 EIGEN_POWER_PREFETCH(rhs_ptr);
367 MICRO_MMA_PREFETCH
368 MICRO_MMA_ONE_PEEL
369 }
370 for (; k < depth; k++) {
371 MICRO_MMA_ONE
372 }
373 MICRO_MMA_STORE
374
375 MICRO_UPDATE
376}
377
378#define MICRO_MMA_UNROLL_ITER2(N, M) \
379 gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M, accItr>( \
380 res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, strideB, offsetA, row, pAlpha, \
381 M ? remaining_rows : accCols); \
382 if (M) return;
383
384#define MICRO_MMA_ROWS(n) \
385 while (row + n * accCols <= rows) { \
386 MICRO_MMA_UNROLL_ITER2(n, 0); \
387 }
388
389template <typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows,
390 const Index accCols, const Index accItr>
391EIGEN_ALWAYS_INLINE void gemmMMA_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
392 Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows,
393 Index remaining_rows, const Packet& pAlpha, const Packet& pMask) {
394 const DataMapper res30 = res.getSubMapper(0, col);
395 const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows * 1) : res30;
396 const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows * 2) : res30;
397 const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows * 3) : res30;
398
399 const Scalar* rhs_base = blockB + col * strideB + accRows * offsetB;
400 const Scalar* lhs_base = blockA + accCols * offsetA;
401 Index row = 0;
402
403#define MAX_MMA_UNROLL 7
404
405#if MAX_MMA_UNROLL < 2
406 if (1) {
407#elif MAX_MMA_UNROLL < 4
408 if (accItr <= 2) {
409#else
410 if (accItr == 1) {
411#endif
412 MICRO_MMA_ROWS(MAX_MMA_UNROLL);
413 } else if (accItr == 2) {
414 MICRO_MMA_ROWS(4);
415 } else {
416 MICRO_MMA_ROWS(2);
417 }
418 switch ((rows - row) / accCols) {
419#if MAX_MMA_UNROLL > 7
420 case 7:
421 if (accItr == 1) {
422 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7)
423 }
424 break;
425#endif
426#if MAX_MMA_UNROLL > 6
427 case 6:
428 if (accItr == 1) {
429 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6)
430 }
431 break;
432#endif
433#if MAX_MMA_UNROLL > 5
434 case 5:
435 if (accItr == 1) {
436 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5)
437 }
438 break;
439#endif
440#if MAX_MMA_UNROLL > 4
441 case 4:
442 if (accItr == 1) {
443 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4)
444 }
445 break;
446#endif
447#if MAX_MMA_UNROLL > 3
448 case 3:
449 if (accItr <= 2) {
450 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3)
451 }
452 break;
453#endif
454#if MAX_MMA_UNROLL > 2
455 case 2:
456 if (accItr <= 2) {
457 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2)
458 }
459 break;
460#endif
461#if MAX_MMA_UNROLL > 1
462 case 1:
463 MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 1)
464 break;
465#endif
466 default:
467 break;
468 }
469#undef MAX_MMA_UNROLL
470
471 if (remaining_rows > 0) {
472 MICRO_MMA_UNROLL_ITER(MICRO_MMA_EXTRA_ROWS1, 0)
473 }
474}
475
476#define MICRO_MMA_COLS(n) \
477 for (; col + n * accRows <= cols; col += n * accRows) { \
478 gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols, n>( \
479 res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); \
480 }
481
482template <typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows,
483 const Index accCols>
484void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols,
485 Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
486 const Index remaining_rows = rows % accCols;
487
488 if (strideA == -1) strideA = depth;
489 if (strideB == -1) strideB = depth;
490
491 const Packet pAlpha = pset1<Packet>(alpha);
492 const Packet pMask = bmask<Packet>(remaining_rows);
493
494 typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
495
496 Index col = 0;
497#ifdef GEMM_MULTIPLE_COLS
498 MICRO_MMA_COLS(4);
499 MICRO_MMA_COLS(2);
500#endif
501 MICRO_MMA_COLS(1);
502
503 if (col != cols) {
504 gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
505 col, rows, cols, remaining_rows, pAlpha, pMask);
506 }
507}
508
509#define advanceRows ((LhsIsReal) ? 1 : 2)
510#define advanceCols ((RhsIsReal) ? 1 : 2)
511
512// PEEL_COMPLEX_MMA loop factor.
513#ifdef GEMM_MULTIPLE_COLS
514#define PEEL_COMPLEX_MMA 4
515#else
516#define PEEL_COMPLEX_MMA 3
517#endif
518
519#define MICRO_COMPLEX_MMA_UNROLL(func) func(0) func(1) func(2) func(3)
520
521#define MICRO_COMPLEX_MMA_WORK(func, type, peel) \
522 if (accItr == 1) { \
523 func(0, type, peel, 0, 0) func(1, type, peel, 1, 0) func(2, type, peel, 2, 0) func(3, type, peel, 3, 0) \
524 } else if (accItr == 2) { \
525 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 1, 0) func(3, type, peel, 1, 1) \
526 } else { \
527 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 0, 2) func(3, type, peel, 0, 3) \
528 }
529
530#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel, left, right) \
531 if (unroll_factor > left) { \
532 pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
533 &accReal##iter, &accImag##iter, lhsV##left, lhsVi##left, rhsV##right[peel], rhsVi##right[peel]); \
534 }
535
536#ifdef VECTOR_PAIR_LOADS_LHS
537#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel, left, right) \
538 if (unroll_factor > left) { \
539 pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
540 &accReal##iter, &accImag##iter, lhsV2##left.packet[peel & 1], lhsVi2##left.packet[peel & 1], \
541 rhsV##right[peel], rhsVi##right[peel]); \
542 }
543
544#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left) \
545 if (!LhsIsReal && (unroll_factor > left)) { \
546 if (MICRO_NORMAL(left)) { \
547 ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##left + imag_delta), plhsVi##left); \
548 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##left.packet), &plhsVi##left); \
549 } else { \
550 lhsVi2##left.packet[0] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2); \
551 lhsVi2##left.packet[1] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2 + accCols2); \
552 EIGEN_UNUSED_VARIABLE(plhsVi##left); \
553 } \
554 } else { \
555 EIGEN_UNUSED_VARIABLE(lhsVi2##left); \
556 EIGEN_UNUSED_VARIABLE(plhsVi##left); \
557 } \
558 MICRO_MMA_LOAD1_TWO(lhs_ptr_real, left)
559
560#define MICRO_COMPLEX_MMA_LOAD_TWO(left) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left)
561#endif
562
563#define MICRO_COMPLEX_MMA_LOAD_RHS1(peel, right) \
564 ploadRhsMMA(rhs_ptr_real##right + (accRows * peel), rhsV##right[peel]); \
565 if (!RhsIsReal) { \
566 ploadRhsMMA(rhs_ptr_imag##right + (accRows * peel), rhsVi##right[peel]); \
567 }
568
569#define MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_RHS1, peel)
570
571#define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
572 if (PEEL_COMPLEX_MMA > peel) { \
573 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
574 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
575 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \
576 MICRO_COMPLEX_MMA_UNROLL(funcl) \
577 MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
578 }
579
580#ifndef VECTOR_PAIR_LOADS_LHS
581#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
582 type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], \
583 rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
584 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 0) \
585 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 1) \
586 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 3)
587#else
588#define MICRO_COMPLEX_MMA_LOAD_TWO_RHS(peel1, right) \
589 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real##right + (accRows * peel1)), prhsV##peel1); \
590 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1); \
591 if (!RhsIsReal) { \
592 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag##right + (accRows * peel1)), prhsVi##peel1); \
593 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi##right[peel1]), &prhsVi##peel1); \
594 } else { \
595 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
596 }
597
598#define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
599 if (PEEL_COMPLEX_MMA > peel2) { \
600 PacketBlock<Packet, 2> lhsV20, lhsV21, lhsV22, lhsV23; \
601 PacketBlock<Packet, 2> lhsVi20, lhsVi21, lhsVi22, lhsVi23; \
602 __vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
603 __vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
604 if (sizeof(type) == 16) { \
605 MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_TWO_RHS, peel1) \
606 } else { \
607 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
608 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
609 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel1); \
610 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel2); \
611 } \
612 MICRO_COMPLEX_MMA_UNROLL(funcl2) \
613 MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
614 MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \
615 } else { \
616 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
617 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
618 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
619 }
620
621#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
622 type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], \
623 rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
624 __vector_pair prhsV0, prhsV2; \
625 __vector_pair prhsVi0, prhsVi2; \
626 MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 0, 1) \
627 MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 2, 3)
628#endif
629
630#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
631 type rhsV0[1], rhsVi0[1], rhsV1[1], rhsVi1[1], rhsV2[1], rhsVi2[1], rhsV3[1], rhsVi3[1]; \
632 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 0)
633
634#define MICRO_COMPLEX_MMA_UPDATE_RHS1(size, right) \
635 rhs_ptr_real##right += (accRows * size); \
636 if (!RhsIsReal) rhs_ptr_imag##right += (accRows * size);
637
638#define MICRO_COMPLEX_MMA_UPDATE_RHS(size) MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_UPDATE_RHS1, size)
639
640#define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
641 MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \
642 MICRO_COMPLEX_MMA_UPDATE_RHS(size);
643
644#ifndef VECTOR_PAIR_LOADS_LHS
645#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL, PEEL_COMPLEX_MMA)
646#else
647#define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \
648 MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, \
649 MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \
650 MICRO_COMPLEX_MMA_UPDATE_RHS(size);
651
652#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA)
653#endif
654
655#define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1)
656
657#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
658 if (unroll_factor * accItr > iter) { \
659 bsetzeroMMA(&accReal##iter); \
660 bsetzeroMMA(&accImag##iter); \
661 } else { \
662 EIGEN_UNUSED_VARIABLE(accReal##iter); \
663 EIGEN_UNUSED_VARIABLE(accImag##iter); \
664 }
665
666#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
667
668#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
669
670#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
671
672#define MICRO_COMPLEX_MMA_STORE_ONE(iter, left, right) \
673 if (unroll_factor > left) { \
674 storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (left + 1)) ? accCols : accCols2>( \
675 row + left * accCols, res##right, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
676 }
677
678#define MICRO_COMPLEX_MMA_ITER_UNROLL(func) \
679 if (accItr == 1) { \
680 func(0, 0, 0) func(1, 1, 0) func(2, 2, 0) func(3, 3, 0) \
681 } else if (accItr == 2) { \
682 func(0, 0, 0) func(1, 0, 1) func(2, 1, 0) func(3, 1, 1) \
683 } else { \
684 func(0, 0, 0) func(1, 0, 1) func(2, 0, 2) func(3, 0, 3) \
685 }
686
687#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_ITER_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
688
689#define MICRO_COMPLEX_MMA_EXTRA_ROWS(right) \
690 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
691 RhsIsReal>(res3##right, blockA, rhs_base + right * accRows * (RhsIsReal ? 1 : 2) * strideB, \
692 depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, \
693 pAlphaImag, pMask);
694
695#define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right) MICRO_COMPLEX_MMA_EXTRA_ROWS(right);
696
697template <int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket,
698 typename DataMapper, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs,
699 bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index accItr>
700EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(const DataMapper& res0, const DataMapper& res1,
701 const DataMapper& res2, const DataMapper& res3,
702 const Scalar* lhs_base, const Scalar* rhs_base,
703 Index depth, Index strideA, Index offsetA, Index strideB,
704 Index& row, const Packet& pAlphaReal,
705 const Packet& pAlphaImag, const Packet& pMask) {
706 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL, *rhs_ptr_real3 = NULL;
707 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL, *rhs_ptr_imag3 = NULL;
708 const Index imag_delta = accCols * strideA;
709 const Index imag_delta2 = accCols2 * strideA;
710
711 if (!RhsIsReal) {
712 rhs_ptr_imag0 = rhs_base + accRows * strideB;
713 } else {
714 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag0);
715 }
716 if (accItr > 1) {
717 if (!RhsIsReal) {
718 rhs_ptr_real1 = rhs_base + (2 * accRows * strideB);
719 rhs_ptr_imag1 = rhs_base + (3 * accRows * strideB);
720 } else {
721 rhs_ptr_real1 = rhs_base + accRows * strideB;
722 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
723 }
724 } else {
725 EIGEN_UNUSED_VARIABLE(rhs_ptr_real1);
726 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
727 EIGEN_UNUSED_VARIABLE(res1);
728 }
729 if (accItr > 2) {
730 if (!RhsIsReal) {
731 rhs_ptr_real2 = rhs_base + (4 * accRows * strideB);
732 rhs_ptr_imag2 = rhs_base + (5 * accRows * strideB);
733 rhs_ptr_real3 = rhs_base + (6 * accRows * strideB);
734 rhs_ptr_imag3 = rhs_base + (7 * accRows * strideB);
735 } else {
736 rhs_ptr_real2 = rhs_base + (2 * accRows * strideB);
737 rhs_ptr_real3 = rhs_base + (3 * accRows * strideB);
738 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
739 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
740 }
741 } else {
742 EIGEN_UNUSED_VARIABLE(rhs_ptr_real2);
743 EIGEN_UNUSED_VARIABLE(rhs_ptr_real3);
744 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
745 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
746 EIGEN_UNUSED_VARIABLE(res2);
747 EIGEN_UNUSED_VARIABLE(res3);
748 }
749 const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
750 const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
751 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
752
753 MICRO_COMPLEX_MMA_SRC_PTR
754 MICRO_COMPLEX_MMA_DST_PTR
755
756 Index k = 0, depth2 = depth - PEEL_COMPLEX_MMA;
757 for (; k <= depth2; k += PEEL_COMPLEX_MMA) {
758 EIGEN_POWER_PREFETCH(rhs_ptr_real);
759 if (!RhsIsReal) {
760 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
761 }
762 MICRO_COMPLEX_MMA_PREFETCH
763 MICRO_COMPLEX_MMA_ONE_PEEL
764 }
765 for (; k < depth; k++) {
766 MICRO_COMPLEX_MMA_ONE
767 }
768 MICRO_COMPLEX_MMA_STORE
769
770 MICRO_COMPLEX_UPDATE
771}
772
773#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
774 gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, \
775 accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, \
776 accItr>(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, offsetA, \
777 strideB, row, pAlphaReal, pAlphaImag, pMask); \
778 if (M) return;
779
780#define MICRO_COMPLEX_MMA_ROWS(n) \
781 while (row + n * accCols <= rows) { \
782 MICRO_COMPLEX_MMA_UNROLL_ITER2(n, 0); \
783 }
784
785template <typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper,
786 const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal,
787 bool RhsIsReal, const Index accItr>
788EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
789 Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB,
790 Index col, Index rows, Index remaining_rows, const Packet& pAlphaReal,
791 const Packet& pAlphaImag, const Packet& pMask) {
792 const DataMapper res30 = res.getSubMapper(0, col);
793 const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows * 1) : res30;
794 const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows * 2) : res30;
795 const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows * 3) : res30;
796
797 const Scalar* rhs_base = blockB + advanceCols * col * strideB + accRows * offsetB;
798 const Scalar* lhs_base = blockA + accCols * offsetA;
799 Index row = 0;
800
801#define MAX_COMPLEX_MMA_UNROLL 4
802
803#if MAX_COMPLEX_MMA_UNROLL < 2
804 if (1) {
805#elif MAX_COMPLEX_MMA_UNROLL < 4
806 if (accItr <= 2) {
807#else
808 if (accItr == 1) {
809#endif
810 MICRO_COMPLEX_MMA_ROWS(MAX_COMPLEX_MMA_UNROLL);
811 } else if (accItr == 2) {
812 MICRO_COMPLEX_MMA_ROWS(2);
813 } else {
814 MICRO_COMPLEX_MMA_ROWS(1);
815 }
816 switch ((rows - row) / accCols) {
817#if MAX_COMPLEX_MMA_UNROLL > 3
818 case 3:
819 if (accItr == 1) {
820 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3)
821 }
822 break;
823#endif
824#if MAX_COMPLEX_MMA_UNROLL > 2
825 case 2:
826 if (accItr == 1) {
827 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2)
828 }
829 break;
830#endif
831#if MAX_COMPLEX_MMA_UNROLL > 1
832 case 1:
833 if (accItr <= 2) {
834 MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1)
835 }
836 break;
837#endif
838 default:
839 break;
840 }
841#undef MAX_COMPLEX_MMA_UNROLL
842
843 if (remaining_rows > 0) {
844 MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_EXTRA_ROWS1, 0)
845 }
846}
847
848#define MICRO_COMPLEX_MMA_COLS(n) \
849 for (; col + n * accRows <= cols; col += n * accRows) { \
850 gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, \
851 ConjugateRhs, LhsIsReal, RhsIsReal, n>(res, blockA, blockB, depth, strideA, offsetA, strideB, \
852 offsetB, col, rows, remaining_rows, pAlphaReal, \
853 pAlphaImag, pMask); \
854 }
855
856template <typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc,
857 typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs,
858 bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
859void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth,
860 Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
861 const Index remaining_rows = rows % accCols;
862
863 if (strideA == -1) strideA = depth;
864 if (strideB == -1) strideB = depth;
865
866 const Packet pAlphaReal = pset1<Packet>(alpha.real());
867 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
868 const Packet pMask = bmask<Packet>(remaining_rows);
869
870 const Scalar* blockA = (Scalar*)blockAc;
871 const Scalar* blockB = (Scalar*)blockBc;
872
873 typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
874
875 Index col = 0;
876#ifdef GEMM_MULTIPLE_COLS
877 MICRO_COMPLEX_MMA_COLS(4);
878 MICRO_COMPLEX_MMA_COLS(2);
879#endif
880 MICRO_COMPLEX_MMA_COLS(1);
881
882 if (col != cols) {
883 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
884 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols,
885 remaining_rows, pAlphaReal, pAlphaImag, pMask);
886 }
887}
888
889#undef accColsC
890#undef advanceRows
891#undef advanceCols
892
893} // end namespace internal
894
895} // end namespace Eigen
896
897#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
898#pragma GCC pop_options
899#endif
900
901#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
Namespace containing all symbols from the Eigen library.
Definition Core:137