10#ifndef EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
11#define EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
14#include "../../InternalHeaderCheck.h"
16#if defined(__MMA__) && !EIGEN_ALTIVEC_DISABLE_MMA
17#if EIGEN_COMP_LLVM || (__GNUC__ > 10 || __GNUC_MINOR__ >= 3)
21#if !EIGEN_COMP_LLVM && (__GNUC__ < 11)
23#define GCC_ONE_VECTORPAIR_BUG
31#ifdef EIGEN_POWER_USE_GEMV_PREFETCH
32#define EIGEN_POWER_GEMV_PREFETCH(p) prefetch(p)
34#define EIGEN_POWER_GEMV_PREFETCH(p)
38#if !__has_builtin(__builtin_vsx_assemble_pair)
39#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
41#if !__has_builtin(__builtin_vsx_disassemble_pair)
42#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
47#define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
48 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
51#if (__GNUC_MINOR__ > 3)
52#define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
53 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
55#define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
56 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
59#define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
60 __builtin_vsx_build_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
64#define GEMV_IS_COMPLEX_COMPLEX ((sizeof(LhsPacket) == 16) && (sizeof(RhsPacket) == 16))
65#define GEMV_IS_FLOAT (ResPacketSize == (16 / sizeof(float)))
66#define GEMV_IS_SCALAR (sizeof(ResPacket) != 16)
67#define GEMV_IS_COMPLEX_FLOAT (ResPacketSize == (16 / sizeof(std::complex<float>)))
70template <
typename ResPacket,
typename ResScalar>
71EIGEN_ALWAYS_INLINE
void storeMaddData(ResScalar* res, ResPacket& palpha, ResPacket& data) {
72 pstoreu(res, pmadd(data, palpha, ploadu<ResPacket>(res)));
75template <
typename ResScalar>
76EIGEN_ALWAYS_INLINE
void storeMaddData(ResScalar* res, ResScalar& alpha, ResScalar& data) {
77 *res += (alpha * data);
80#define GEMV_UNROLL(func, N) func(0, N) func(1, N) func(2, N) func(3, N) func(4, N) func(5, N) func(6, N) func(7, N)
82#define GEMV_UNROLL_HALF(func, N) func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
84#define GEMV_GETN(N) (((N) * ResPacketSize) >> 2)
86#define GEMV_LOADPACKET_COL(iter) lhs.template load<LhsPacket, LhsAlignment>(i + ((iter) * LhsPacketSize), j)
89#define GEMV_UNROLL3(func, N, which) \
90 func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) func(4, N, which) func(5, N, which) \
91 func(6, N, which) func(7, N, which)
93#define GEMV_UNUSED_VAR(iter, N, which) \
94 if (GEMV_GETN(N) <= iter) { \
95 EIGEN_UNUSED_VARIABLE(which##iter); \
98#define GEMV_UNUSED_EXTRA_VAR(iter, N, which) \
100 EIGEN_UNUSED_VARIABLE(which##iter); \
103#define GEMV_UNUSED_EXTRA(N, which) GEMV_UNROLL3(GEMV_UNUSED_EXTRA_VAR, N, which)
105#define GEMV_UNUSED(N, which) GEMV_UNROLL3(GEMV_UNUSED_VAR, N, which)
107#define GEMV_INIT_MMA(iter, N) \
108 if (GEMV_GETN(N) > iter) { \
109 __builtin_mma_xxsetaccz(&e##iter); \
113#define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
114 GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_COL(iter2), GEMV_LOADPACKET_COL((iter2) + 1));
116#define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
117 const LhsScalar& src##iter1 = lhs(i + ((iter1 * 32) / sizeof(LhsScalar)), j); \
118 b##iter1 = *reinterpret_cast<__vector_pair*>(const_cast<LhsScalar*>(&src##iter1));
121#define GEMV_LOAD1A_COL_MMA(iter, N) \
122 if (GEMV_GETN(N) > iter) { \
123 if (GEMV_IS_FLOAT) { \
124 g##iter = GEMV_LOADPACKET_COL(iter); \
125 EIGEN_UNUSED_VARIABLE(b##iter); \
127 GEMV_LOADPAIR_COL_MMA(iter, iter << 1) \
128 EIGEN_UNUSED_VARIABLE(g##iter); \
131 EIGEN_UNUSED_VARIABLE(b##iter); \
132 EIGEN_UNUSED_VARIABLE(g##iter); \
135#define GEMV_WORK1A_COL_MMA(iter, N) \
136 if (GEMV_GETN(N) > iter) { \
137 if (GEMV_IS_FLOAT) { \
138 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, a0, g##iter); \
140 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, b##iter, a0); \
144#define GEMV_LOAD1B_COL_MMA(iter1, iter2, iter3, N) \
145 if (GEMV_GETN(N) > iter1) { \
146 if (GEMV_IS_FLOAT) { \
147 GEMV_LOADPAIR_COL_MMA(iter2, iter2) \
148 EIGEN_UNUSED_VARIABLE(b##iter3); \
150 GEMV_LOADPAIR_COL_MMA(iter2, iter2 << 1) \
151 GEMV_LOADPAIR_COL_MMA(iter3, iter3 << 1) \
154 EIGEN_UNUSED_VARIABLE(b##iter2); \
155 EIGEN_UNUSED_VARIABLE(b##iter3); \
157 EIGEN_UNUSED_VARIABLE(g##iter2); \
158 EIGEN_UNUSED_VARIABLE(g##iter3);
160#define GEMV_WORK1B_COL_MMA(iter1, iter2, iter3, N) \
161 if (GEMV_GETN(N) > iter1) { \
162 if (GEMV_IS_FLOAT) { \
164 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(h), &b##iter2); \
165 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, a0, h[0]); \
166 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, a0, h[1]); \
168 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, b##iter2, a0); \
169 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, b##iter3, a0); \
174#define GEMV_LOAD_COL_MMA(N) \
175 if (GEMV_GETN(N) > 1) { \
176 GEMV_UNROLL_HALF(GEMV_LOAD1B_COL_MMA, (N >> 1)) \
178 GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N) \
181#define GEMV_WORK_COL_MMA(N) \
182 if (GEMV_GETN(N) > 1) { \
183 GEMV_UNROLL_HALF(GEMV_WORK1B_COL_MMA, (N >> 1)) \
185 GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N) \
188#define GEMV_LOAD_COL_MMA(N) GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N)
190#define GEMV_WORK_COL_MMA(N) GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N)
193#define GEMV_DISASSEMBLE_MMA(iter, N) \
194 if (GEMV_GETN(N) > iter) { \
195 __builtin_mma_disassemble_acc(&result##iter.packet, &e##iter); \
196 if (!GEMV_IS_FLOAT) { \
197 result##iter.packet[0][1] = result##iter.packet[1][0]; \
198 result##iter.packet[2][1] = result##iter.packet[3][0]; \
202#define GEMV_LOADPAIR2_COL_MMA(iter1, iter2) \
203 b##iter1 = *reinterpret_cast<__vector_pair*>(res + i + ((iter2) * ResPacketSize));
205#define GEMV_LOAD2_COL_MMA(iter1, iter2, iter3, N) \
206 if (GEMV_GETN(N) > iter1) { \
207 if (GEMV_IS_FLOAT) { \
208 GEMV_LOADPAIR2_COL_MMA(iter2, iter2); \
209 EIGEN_UNUSED_VARIABLE(b##iter3); \
211 GEMV_LOADPAIR2_COL_MMA(iter2, iter2 << 1); \
212 GEMV_LOADPAIR2_COL_MMA(iter3, iter3 << 1); \
215 EIGEN_UNUSED_VARIABLE(b##iter2); \
216 EIGEN_UNUSED_VARIABLE(b##iter3); \
220#define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
221 ResPacket f##iter2[2]; \
222 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(f##iter2), &b##iter2); \
223 f##iter2[0] = pmadd(result##iter2.packet[0], palpha, f##iter2[0]); \
224 f##iter2[1] = pmadd(result##iter3.packet[(iter2 == iter3) ? 2 : 0], palpha, f##iter2[1]); \
225 GEMV_BUILDPAIR_MMA(b##iter2, f##iter2[0], f##iter2[1]);
227#define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
228 if (GEMV_IS_FLOAT) { \
229 __asm__("xvmaddasp %0,%x1,%x3\n\txvmaddasp %L0,%x2,%x3" \
231 : "wa"(result##iter3.packet[0]), "wa"(result##iter2.packet[0]), "wa"(palpha)); \
233 __asm__("xvmaddadp %0,%x1,%x3\n\txvmaddadp %L0,%x2,%x3" \
235 : "wa"(result##iter2.packet[2]), "wa"(result##iter2.packet[0]), "wa"(palpha)); \
239#define GEMV_WORK2_COL_MMA(iter1, iter2, iter3, N) \
240 if (GEMV_GETN(N) > iter1) { \
241 if (GEMV_IS_FLOAT) { \
242 GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter2); \
244 GEMV_WORKPAIR2_COL_MMA(iter2, iter2, iter2 << 1); \
245 GEMV_WORKPAIR2_COL_MMA(iter3, iter3, iter3 << 1); \
249#define GEMV_STOREPAIR2_COL_MMA(iter1, iter2) \
250 *reinterpret_cast<__vector_pair*>(res + i + ((iter2) * ResPacketSize)) = b##iter1;
252#define GEMV_STORE_COL_MMA(iter, N) \
253 if (GEMV_GETN(N) > iter) { \
254 if (GEMV_IS_FLOAT) { \
255 storeMaddData<ResPacket, ResScalar>(res + i + (iter * ResPacketSize), palpha, result##iter.packet[0]); \
257 GEMV_LOADPAIR2_COL_MMA(iter, iter << 1) \
258 GEMV_WORKPAIR2_COL_MMA(iter, iter, iter << 1) \
259 GEMV_STOREPAIR2_COL_MMA(iter, iter << 1) \
263#define GEMV_STORE2_COL_MMA(iter1, iter2, iter3, N) \
264 if (GEMV_GETN(N) > iter1) { \
265 if (GEMV_IS_FLOAT) { \
266 GEMV_STOREPAIR2_COL_MMA(iter2, iter2); \
268 GEMV_STOREPAIR2_COL_MMA(iter2, iter2 << 1) \
269 GEMV_STOREPAIR2_COL_MMA(iter3, iter3 << 1) \
273#define GEMV_PROCESS_COL_ONE_MMA(N) \
274 GEMV_UNROLL(GEMV_INIT_MMA, N) \
276 __vector_pair b0, b1, b2, b3, b4, b5, b6, b7; \
278 LhsPacket g0, g1, g2, g3, g4, g5, g6, g7; \
279 RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
280 GEMV_UNROLL(GEMV_PREFETCH, N) \
281 GEMV_LOAD_COL_MMA(N) \
282 GEMV_WORK_COL_MMA(N) \
283 } while (++j < jend); \
284 GEMV_UNROLL(GEMV_DISASSEMBLE_MMA, N) \
285 if (GEMV_GETN(N) <= 1) { \
286 GEMV_UNROLL(GEMV_STORE_COL_MMA, N) \
288 GEMV_UNROLL_HALF(GEMV_LOAD2_COL_MMA, (N >> 1)) \
289 GEMV_UNROLL_HALF(GEMV_WORK2_COL_MMA, (N >> 1)) \
290 GEMV_UNROLL_HALF(GEMV_STORE2_COL_MMA, (N >> 1)) \
292 i += (ResPacketSize * N);
295#define GEMV_INIT(iter, N) \
297 c##iter = pset1<ResPacket>(ResScalar(0)); \
299 EIGEN_UNUSED_VARIABLE(c##iter); \
302#ifdef EIGEN_POWER_USE_GEMV_PREFETCH
303#define GEMV_PREFETCH(iter, N) \
304 if (GEMV_GETN(N) > ((iter >> 1) + ((N >> 1) * (iter & 1)))) { \
305 lhs.prefetch(i + (iter * LhsPacketSize) + prefetch_dist, j); \
308#define GEMV_PREFETCH(iter, N)
311#define GEMV_WORK_COL(iter, N) \
313 c##iter = pcj.pmadd(GEMV_LOADPACKET_COL(iter), a0, c##iter); \
316#define GEMV_STORE_COL(iter, N) \
318 pstoreu(res + i + (iter * ResPacketSize), \
319 pmadd(c##iter, palpha, ploadu<ResPacket>(res + i + (iter * ResPacketSize)))); \
323#define GEMV_PROCESS_COL_ONE(N) \
324 GEMV_UNROLL(GEMV_INIT, N) \
327 RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
328 GEMV_UNROLL(GEMV_PREFETCH, N) \
329 GEMV_UNROLL(GEMV_WORK_COL, N) \
330 } while (++j < jend); \
331 GEMV_UNROLL(GEMV_STORE_COL, N) \
332 i += (ResPacketSize * N);
335#define GEMV_PROCESS_COL(N) GEMV_PROCESS_COL_ONE_MMA(N)
337#define GEMV_PROCESS_COL(N) GEMV_PROCESS_COL_ONE(N)
342template <
typename LhsPacket,
typename RhsPacket,
bool accumulate>
343EIGEN_ALWAYS_INLINE
void pger_vecMMA_acc(__vector_quad* acc,
const RhsPacket& a,
const LhsPacket& b) {
345 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
347 __builtin_mma_xvf32ger(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
352template <
typename LhsPacket,
typename RhsPacket,
bool accumulate>
353EIGEN_ALWAYS_INLINE
void pger_vecMMA_acc(__vector_quad* acc, __vector_pair& a,
const LhsPacket& b) {
355 __builtin_mma_xvf64gerpp(acc, a, (__vector
unsigned char)b);
357 __builtin_mma_xvf64ger(acc, a, (__vector
unsigned char)b);
362template <
typename LhsScalar,
typename LhsMapper,
typename RhsScalar,
typename RhsMapper,
typename ResScalar>
363EIGEN_STRONG_INLINE
void gemv_col(Index rows, Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs, ResScalar* res,
364 Index resIncr, ResScalar alpha) {
365 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
367 typedef typename Traits::LhsPacket LhsPacket;
368 typedef typename Traits::RhsPacket RhsPacket;
369 typedef typename Traits::ResPacket ResPacket;
371 EIGEN_UNUSED_VARIABLE(resIncr);
372 eigen_internal_assert(resIncr == 1);
379 conj_helper<LhsScalar, RhsScalar, false, false> cj;
380 conj_helper<LhsPacket, RhsPacket, false, false> pcj;
382 const Index lhsStride = lhs.stride();
386 ResPacketSize = Traits::ResPacketSize,
387 LhsPacketSize = Traits::LhsPacketSize,
388 RhsPacketSize = Traits::RhsPacketSize,
391#ifndef GCC_ONE_VECTORPAIR_BUG
392 const Index n8 = rows - 8 * ResPacketSize + 1;
393 const Index n4 = rows - 4 * ResPacketSize + 1;
394 const Index n2 = rows - 2 * ResPacketSize + 1;
396 const Index n1 = rows - 1 * ResPacketSize + 1;
397#ifdef EIGEN_POWER_USE_GEMV_PREFETCH
398 const Index prefetch_dist = 64 * LhsPacketSize;
402 const Index block_cols = cols < 128 ? cols : (lhsStride *
sizeof(LhsScalar) < 16000 ? 16 : 8);
403 ResPacket palpha = pset1<ResPacket>(alpha);
405 for (Index j2 = 0; j2 < cols; j2 += block_cols) {
406 Index jend = numext::mini(j2 + block_cols, cols);
408 ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
410 __vector_quad e0, e1, e2, e3, e4, e5, e6, e7;
411 PacketBlock<ResPacket, 4> result0, result1, result2, result3, result4, result5, result6, result7;
413 GEMV_UNUSED(8, result)
414 GEMV_UNUSED_EXTRA(1, c)
416#ifndef GCC_ONE_VECTORPAIR_BUG
431 GEMV_PROCESS_COL_ONE(1)
433 for (; i < rows; ++i) {
437 d0 += cj.pmul(lhs(i, j), rhs2(j, 0));
438 }
while (++j < jend);
439 res[i] += alpha * d0;
444template <
bool extraRows>
445EIGEN_ALWAYS_INLINE
void outputVecCol(Packet4f acc,
float* result, Packet4f pAlpha, Index extra_rows) {
446 Packet4f d0 = ploadu<Packet4f>(result);
447 d0 = pmadd(acc, pAlpha, d0);
449 pstoreu_partial(result, d0, extra_rows);
455template <Index num_acc,
bool extraRows, Index size>
456EIGEN_ALWAYS_INLINE
void outputVecColResults(Packet4f (&acc)[num_acc][size],
float* result, Packet4f pAlpha,
458 constexpr Index real_acc = (num_acc - (extraRows ? 1 : 0));
459 for (Index k = 0; k < real_acc; k++) {
460 outputVecCol<false>(acc[k][0], result + k * 4, pAlpha, extra_rows);
463 outputVecCol<true>(acc[real_acc][0], result + real_acc * 4, pAlpha, extra_rows);
467static Packet16uc p16uc_MERGE16_32_V1 = {0, 1, 16, 17, 0, 1, 16, 17, 0, 1, 16, 17, 0, 1, 16, 17};
468static Packet16uc p16uc_MERGE16_32_V2 = {2, 3, 18, 19, 2, 3, 18, 19, 2, 3, 18, 19, 2, 3, 18, 19};
470template <Index num_acc,
typename LhsMapper,
bool zero>
471EIGEN_ALWAYS_INLINE
void loadVecLoopVSX(Index k, LhsMapper& lhs, Packet4f (&a0)[num_acc][2]) {
472 Packet8bf c0 = lhs.template loadPacket<Packet8bf>(k * 4, 0);
475 b1 = lhs.template loadPacket<Packet8bf>(k * 4, 1);
477 a0[k + 0][1] = oneConvertBF16Hi(b1.m_val);
479 a0[k + 0][0] = oneConvertBF16Hi(c0.m_val);
481 if (num_acc > (k + 1)) {
482 a0[k + 1][0] = oneConvertBF16Lo(c0.m_val);
484 a0[k + 1][1] = oneConvertBF16Lo(b1.m_val);
489template <Index num_acc,
bool zero>
490EIGEN_ALWAYS_INLINE
void multVecVSX(Packet4f (&acc)[num_acc][2], Packet4f (&a0)[num_acc][2], Packet4f (&b0)[2]) {
491 for (Index k = 0; k < num_acc; k++) {
492 for (Index i = 0; i < (zero ? 1 : 2); i++) {
493 acc[k][i] = pmadd(b0[i], a0[k][i], acc[k][i]);
498template <
typename RhsMapper,
bool linear>
499struct loadColData_impl {
501 static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j) {
502 const Index n = unpacket_traits<Packet8bf>::size;
503 EIGEN_ALIGN16 bfloat16 to[n];
505 for (Index i = 0; i < n; i++) {
506 to[i] = rhs(j + i, 0);
508 return pload<Packet8bf>(to);
512template <
typename RhsMapper>
513struct loadColData_impl<RhsMapper, true> {
515 static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j) {
516 return rhs.template loadPacket<Packet8bf>(j + 0, 0);
520template <
typename RhsMapper,
bool linear>
521EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j) {
522 return loadColData_impl<RhsMapper, linear>::run(rhs, j);
525template <Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool zero,
bool linear>
526EIGEN_ALWAYS_INLINE
void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2]) {
527 Packet4f a0[num_acc][2], b0[2];
528 Packet8bf b2 = loadColData<RhsMapper, linear>(rhs, j);
530 b0[0] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V1);
532 b0[1] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V2);
535 using LhsSubMapper =
typename LhsMapper::SubMapper;
537 LhsSubMapper lhs2 = lhs.getSubMapper(0, j);
538 for (Index k = 0; k < num_acc; k += 2) {
539 loadVecLoopVSX<num_acc, LhsSubMapper, zero>(k, lhs2, a0);
542 multVecVSX<num_acc, zero>(acc, a0, b0);
545template <Index num_acc>
546EIGEN_ALWAYS_INLINE
void addResultsVSX(Packet4f (&acc)[num_acc][2]) {
547 for (Index i = 0; i < num_acc; i++) {
548 acc[i][0] = acc[i][0] + acc[i][1];
553#define MAX_BFLOAT16_VEC_ACC_VSX 8
555template <const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
556void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
const Packet4f pAlpha,
558 constexpr Index step = (num_acc * 4);
559 const Index extra_rows = (extraRows) ? (rows & 3) : 0;
560 constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC_VSX);
563 Packet4f acc[num_acc][2];
565 zeroAccumulators<num_acc, 2>(acc);
567 using LhsSubMapper =
typename LhsMapper::SubMapper;
569 LhsSubMapper lhs2 = lhs.getSubMapper(row, 0);
570 for (Index j = 0; j + 2 <= cend; j += 2) {
571 vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, false, linear>(j, lhs2, rhs, acc);
574 vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
577 addResultsVSX<num_acc>(acc);
579 outputVecColResults<num_acc, extraRows, 2>(acc, result, pAlpha, extra_rows);
582 }
while (multiIters && (step <= rows - (row += step)));
585template <const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
586EIGEN_ALWAYS_INLINE
void colVSXVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
587 const Packet4f pAlpha,
float* result) {
588 if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) {
589 colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs,
590 rhs, pAlpha, result);
594template <
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
595EIGEN_ALWAYS_INLINE
void colVSXVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
596 const Packet4f pAlpha,
float* result) {
597 switch ((rows - row) >> 2) {
599 colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
602 colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
605 colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
608 colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
611 colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
614 colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
617 colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
621 colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
627template <
typename LhsMapper,
typename RhsMapper,
bool linear>
628EIGEN_ALWAYS_INLINE
void calcVSXVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
629 const Packet4f pAlpha,
float* result) {
631 if (rows >= (MAX_BFLOAT16_VEC_ACC_VSX * 4)) {
632 colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs,
637 colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
639 colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
643template <const Index size,
bool inc, Index delta>
644EIGEN_ALWAYS_INLINE
void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra) {
647 pscatter_partial(dst + delta * resInc, data, resInc, extra);
649 pscatter(dst + delta * resInc, data, resInc);
653 pstoreu_partial(dst + delta, data, extra);
655 pstoreu(dst + delta, data);
660template <const Index size,
bool inc = false>
661EIGEN_ALWAYS_INLINE
void convertPointerF32toBF16VSX(Index& i,
float* result, Index rows, bfloat16*& dst,
663 constexpr Index extra = ((size < 8) ? 8 : size);
664 while (i + size <= rows) {
665 PacketBlock<Packet8bf, (size + 7) / 8> r32;
666 r32.packet[0] = convertF32toBF16VSX(result + i + 0);
668 r32.packet[1] = convertF32toBF16VSX(result + i + 8);
671 r32.packet[2] = convertF32toBF16VSX(result + i + 16);
672 r32.packet[3] = convertF32toBF16VSX(result + i + 24);
674 storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc, rows & 7);
676 storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
679 storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
680 storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
683 dst += extra * resInc;
684 if (size != 32)
break;
688template <
bool inc = false>
689EIGEN_ALWAYS_INLINE
void convertArrayPointerF32toBF16VSX(
float* result, Index rows, bfloat16* dst, Index resInc = 1) {
691 convertPointerF32toBF16VSX<32, inc>(i, result, rows, dst, resInc);
692 convertPointerF32toBF16VSX<16, inc>(i, result, rows, dst, resInc);
693 convertPointerF32toBF16VSX<8, inc>(i, result, rows, dst, resInc);
694 convertPointerF32toBF16VSX<1, inc>(i, result, rows, dst, resInc);
697template <
typename RhsMapper,
typename LhsMapper,
typename =
void>
698struct UseStride : std::false_type {
699 static EIGEN_ALWAYS_INLINE
void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha,
701 using RhsSubMapper =
typename RhsMapper::SubMapper;
703 RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
704 calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
708template <
typename RhsMapper,
typename LhsMapper>
709struct UseStride<RhsMapper, LhsMapper,
710 std::enable_if_t<std::is_member_function_pointer<decltype(&RhsMapper::stride)>::value>>
712 static EIGEN_ALWAYS_INLINE
void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha,
714 using RhsSubMapper =
typename RhsMapper::SubMapper;
716 RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
717 if (rhs.stride() == 1) {
718 calcVSXVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2, rows, lhs, rhs2, pAlpha, result);
720 calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
725template <
typename LhsMapper,
typename RhsMapper>
726void gemv_bfloat16_col(Index rows, Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs, bfloat16* res,
727 Index resIncr, bfloat16 alpha) {
728 EIGEN_UNUSED_VARIABLE(resIncr);
729 eigen_internal_assert(resIncr == 1);
736 const Index lhsStride = lhs.stride();
739 const Index block_cols = cols < 128 ? cols : (lhsStride *
sizeof(bfloat16) < 16000 ? 16 : 8);
740 float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
741 Packet4f pAlpha = pset1<Packet4f>(falpha);
743 ei_declare_aligned_stack_constructed_variable(
float, result, rows, 0);
745 convertArrayPointerBF16toF32(result, 1, rows, res);
747 for (Index j2 = 0; j2 < cols; j2 += block_cols) {
748 Index jend = numext::mini(j2 + block_cols, cols);
750 using LhsSubMapper =
typename LhsMapper::SubMapper;
752 LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
753 UseStride<RhsMapper, LhsSubMapper>::run(j2, jend, rows, lhs2, rhs2, pAlpha, result);
756 convertArrayPointerF32toBF16VSX(result, rows, res);
759template <Index num_acc, Index size>
760EIGEN_ALWAYS_INLINE
void outputVecResults(Packet4f (&acc)[num_acc][size],
float* result, Packet4f pAlpha) {
761 constexpr Index extra = num_acc & 3;
763 for (Index k = 0; k < num_acc; k += 4) {
764 Packet4f d0 = ploadu<Packet4f>(result + k);
765 d0 = pmadd(acc[k + 0][0], pAlpha, d0);
767 if (num_acc > (k + 3)) {
768 pstoreu(result + k, d0);
771 pstoreu_partial(result + k, d0, extra);
773 memcpy((
void*)(result + k), (
void*)(&d0),
sizeof(
float) * extra);
779template <Index num_acc>
780EIGEN_ALWAYS_INLINE
void preduxVecResults2VSX(Packet4f (&acc)[num_acc][2], Index k) {
781 if (num_acc > (k + 1)) {
782 acc[k][1] = vec_mergel(acc[k + 0][0], acc[k + 1][0]);
783 acc[k][0] = vec_mergeh(acc[k + 0][0], acc[k + 1][0]);
784 acc[k][0] = acc[k][0] + acc[k][1];
785 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8);
787 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8);
789 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
791 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
796template <Index num_acc>
797EIGEN_ALWAYS_INLINE
void preduxVecResultsVSX(Packet4f (&acc)[num_acc][2]) {
798 for (Index k = 0; k < num_acc; k += 4) {
799 preduxVecResults2VSX<num_acc>(acc, k + 0);
800 if (num_acc > (k + 2)) {
801 preduxVecResults2VSX<num_acc>(acc, k + 2);
802#ifdef EIGEN_VECTORIZE_VSX
803 acc[k + 0][0] =
reinterpret_cast<Packet4f
>(
804 vec_mergeh(
reinterpret_cast<Packet2ul
>(acc[k + 0][0]),
reinterpret_cast<Packet2ul
>(acc[k + 2][0])));
806 acc[k + 0][0] =
reinterpret_cast<Packet4f
>(vec_perm(acc[k + 0][0], acc[k + 2][0], p16uc_TRANSPOSE64_HI));
813EIGEN_ALWAYS_INLINE Packet8us loadPacketPartialZero(Packet8us data, Index extra_cols) {
814 Packet16uc shift = pset1<Packet16uc>(8 * 2 * (8 - extra_cols));
816 return reinterpret_cast<Packet8us
>(vec_slo(vec_sro(
reinterpret_cast<Packet16uc
>(data), shift), shift));
818 return reinterpret_cast<Packet8us
>(vec_sro(vec_slo(
reinterpret_cast<Packet16uc
>(data), shift), shift));
823template <Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extra>
824EIGEN_ALWAYS_INLINE
void multVSXVecLoop(Packet4f (&acc)[num_acc][2],
const LhsMapper& lhs, RhsMapper& rhs, Index j,
826 Packet4f a0[num_acc][2], b0[2];
830 b1 = rhs.template loadPacketPartial<Packet8bf>(j, extra_cols);
832 b1 = loadPacketPartialZero(b1.m_val, extra_cols);
835 b1 = rhs.template loadPacket<Packet8bf>(j);
837 b0[0] = oneConvertBF16Hi(b1.m_val);
838 b0[1] = oneConvertBF16Lo(b1.m_val);
840 const LhsMapper lhs2 = lhs.getSubMapper(0, j);
841 for (Index k = 0; k < num_acc; k++) {
843 a1 = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
845 a1 = loadPacketPartialZero(a1.m_val, extra_cols);
848 a1 = lhs2.template loadPacket<Packet8bf>(k, 0);
850 a0[k][0] = oneConvertBF16Hi(a1.m_val);
851 a0[k][1] = oneConvertBF16Lo(a1.m_val);
854 multVecVSX<num_acc, false>(acc, a0, b0);
857template <Index num_acc,
typename LhsMapper,
typename RhsMapper>
858EIGEN_ALWAYS_INLINE
void vecVSXLoop(Index cols,
const LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2],
861 for (; j + 8 <= cols; j += 8) {
862 multVSXVecLoop<num_acc, LhsMapper, RhsMapper, false>(acc, lhs, rhs, j, extra_cols);
866 multVSXVecLoop<num_acc, LhsMapper, RhsMapper, true>(acc, lhs, rhs, j, extra_cols);
870template <const Index num_acc,
typename LhsMapper,
typename RhsMapper>
871void colVSXVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
const Packet4f pAlpha,
873 constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC_VSX);
874 const Index extra_cols = (cols & 7);
877 Packet4f acc[num_acc][2];
879 zeroAccumulators<num_acc, 2>(acc);
881 const LhsMapper lhs2 = lhs.getSubMapper(row, 0);
882 vecVSXLoop<num_acc, LhsMapper, RhsMapper>(cols, lhs2, rhs, acc, extra_cols);
884 addResultsVSX<num_acc>(acc);
886 preduxVecResultsVSX<num_acc>(acc);
888 outputVecResults<num_acc, 2>(acc, result, pAlpha);
891 }
while (multiIters && (num_acc <= rows - (row += num_acc)));
894template <const Index num_acc,
typename LhsMapper,
typename RhsMapper>
895EIGEN_ALWAYS_INLINE
void colVSXVecLoopBodyExtraN(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
896 const Packet4f pAlpha,
float* result) {
897 if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) {
898 colVSXVecLoopBody<num_acc, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
902template <
typename LhsMapper,
typename RhsMapper>
903EIGEN_ALWAYS_INLINE
void colVSXVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
904 const Packet4f pAlpha,
float* result) {
905 switch (rows - row) {
907 colVSXVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
910 colVSXVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
913 colVSXVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
916 colVSXVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
919 colVSXVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
922 colVSXVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
925 colVSXVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
930template <
typename LhsMapper,
typename RhsMapper>
931EIGEN_ALWAYS_INLINE
void calcVSXVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
const Packet4f pAlpha,
934 if (rows >= MAX_BFLOAT16_VEC_ACC_VSX) {
935 colVSXVecLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
938 colVSXVecLoopBodyExtra<LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
941template <
typename LhsMapper,
typename RhsMapper>
942EIGEN_STRONG_INLINE
void gemv_bfloat16_row(Index rows, Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs,
943 bfloat16* res, Index resIncr, bfloat16 alpha) {
944 typedef typename RhsMapper::LinearMapper LinearMapper;
949 LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
951 eigen_internal_assert(rhs.stride() == 1);
953 float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
954 const Packet4f pAlpha = pset1<Packet4f>(falpha);
956 ei_declare_aligned_stack_constructed_variable(
float, result, rows, 0);
958 convertArrayPointerBF16toF32(result, 1, rows, res);
960 convertArrayPointerBF16toF32<true>(result, 1, rows, res, resIncr);
962 calcVSXVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
964 convertArrayPointerF32toBF16VSX(result, rows, res);
966 convertArrayPointerF32toBF16VSX<true>(result, rows, res, resIncr);
970#undef MAX_BFLOAT16_VEC_ACC_VSX
972const Packet16uc p16uc_COMPLEX32_XORFLIP = {0x44, 0x55, 0x66, 0x77, 0x00, 0x11, 0x22, 0x33,
973 0xcc, 0xdd, 0xee, 0xff, 0x88, 0x99, 0xaa, 0xbb};
974const Packet16uc p16uc_COMPLEX64_XORFLIP = {0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
975 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77};
978const Packet16uc p16uc_COMPLEX32_CONJ_XOR = {0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00,
979 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00};
980const Packet16uc p16uc_COMPLEX64_CONJ_XOR = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
981 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
982const Packet16uc p16uc_COMPLEX32_CONJ_XOR2 = {0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
983 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
984const Packet16uc p16uc_COMPLEX64_CONJ_XOR2 = {0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
985 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
986const Packet16uc p16uc_COMPLEX32_NEGATE = {0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00,
987 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00};
988const Packet16uc p16uc_COMPLEX64_NEGATE = {0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
989 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
991const Packet16uc p16uc_COMPLEX32_CONJ_XOR = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80,
992 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
993const Packet16uc p16uc_COMPLEX64_CONJ_XOR = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
994 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
995const Packet16uc p16uc_COMPLEX32_CONJ_XOR2 = {0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00,
996 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00};
997const Packet16uc p16uc_COMPLEX64_CONJ_XOR2 = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80,
998 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
999const Packet16uc p16uc_COMPLEX32_NEGATE = {0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x80,
1000 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x80};
1001const Packet16uc p16uc_COMPLEX64_NEGATE = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80,
1002 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
1006#define COMPLEX_DELTA 0
1008#define COMPLEX_DELTA 2
1012EIGEN_ALWAYS_INLINE Packet2cf pconj2(
const Packet2cf& a) {
1013 return Packet2cf(pxor(a.v,
reinterpret_cast<Packet4f
>(p16uc_COMPLEX32_CONJ_XOR)));
1016EIGEN_ALWAYS_INLINE Packet1cd pconj2(
const Packet1cd& a) {
1017 return Packet1cd(pxor(a.v,
reinterpret_cast<Packet2d
>(p16uc_COMPLEX64_CONJ_XOR)));
1021EIGEN_ALWAYS_INLINE Packet2cf pconjinv(
const Packet2cf& a) {
1022#ifdef __POWER8_VECTOR__
1023 return Packet2cf(Packet4f(vec_neg(Packet2d(a.v))));
1025 return Packet2cf(pxor(a.v,
reinterpret_cast<Packet4f
>(p16uc_COMPLEX32_CONJ_XOR2)));
1029EIGEN_ALWAYS_INLINE Packet1cd pconjinv(
const Packet1cd& a) {
1030 return Packet1cd(pxor(a.v,
reinterpret_cast<Packet2d
>(p16uc_COMPLEX64_CONJ_XOR2)));
1033#if defined(_ARCH_PWR8) && (!EIGEN_COMP_LLVM || __clang_major__ >= 12)
1038EIGEN_ALWAYS_INLINE Packet2cf pcplxflipconj(Packet2cf a) {
1040 return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_CONJ_XOR, p16uc_COMPLEX32_XORFLIP)));
1042 return pcplxflip(pconj2(a));
1046EIGEN_ALWAYS_INLINE Packet1cd pcplxflipconj(Packet1cd a) {
1048 return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_CONJ_XOR, p16uc_COMPLEX64_XORFLIP)));
1050 return pcplxflip(pconj2(a));
1055EIGEN_ALWAYS_INLINE Packet2cf pcplxconjflip(Packet2cf a) {
1057 return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_CONJ_XOR2, p16uc_COMPLEX32_XORFLIP)));
1059 return pconj2(pcplxflip(a));
1063EIGEN_ALWAYS_INLINE Packet1cd pcplxconjflip(Packet1cd a) {
1065 return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_CONJ_XOR2, p16uc_COMPLEX64_XORFLIP)));
1067 return pconj2(pcplxflip(a));
1072EIGEN_ALWAYS_INLINE Packet2cf pnegate2(Packet2cf a) {
1073#ifdef __POWER8_VECTOR__
1074 return Packet2cf(vec_neg(a.v));
1076 return Packet2cf(pxor(a.v,
reinterpret_cast<Packet4f
>(p16uc_COMPLEX32_NEGATE)));
1080EIGEN_ALWAYS_INLINE Packet1cd pnegate2(Packet1cd a) {
1081#ifdef __POWER8_VECTOR__
1082 return Packet1cd(vec_neg(a.v));
1084 return Packet1cd(pxor(a.v,
reinterpret_cast<Packet2d
>(p16uc_COMPLEX64_NEGATE)));
1089EIGEN_ALWAYS_INLINE Packet2cf pcplxflipnegate(Packet2cf a) {
1091 return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_NEGATE, p16uc_COMPLEX32_XORFLIP)));
1093 return pcplxflip(pnegate2(a));
1097EIGEN_ALWAYS_INLINE Packet1cd pcplxflipnegate(Packet1cd a) {
1099 return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_NEGATE, p16uc_COMPLEX64_XORFLIP)));
1101 return pcplxflip(pnegate2(a));
1106EIGEN_ALWAYS_INLINE Packet2cf pcplxflip2(Packet2cf a) {
1107 return Packet2cf(Packet4f(vec_perm(Packet16uc(a.v), Packet16uc(a.v), p16uc_COMPLEX32_XORFLIP)));
1110EIGEN_ALWAYS_INLINE Packet1cd pcplxflip2(Packet1cd a) {
1111#ifdef EIGEN_VECTORIZE_VSX
1112 return Packet1cd(__builtin_vsx_xxpermdi(a.v, a.v, 2));
1114 return Packet1cd(Packet2d(vec_perm(Packet16uc(a.v), Packet16uc(a.v), p16uc_COMPLEX64_XORFLIP)));
1119EIGEN_ALWAYS_INLINE Packet4f pload_complex_half(std::complex<float>* src) {
1121#ifdef EIGEN_VECTORIZE_VSX
1123 __asm__(
"lxsdx %x0,%y1" :
"=wa"(t) :
"Z"(*src));
1125 *
reinterpret_cast<std::complex<float>*
>(
reinterpret_cast<float*
>(&t) + COMPLEX_DELTA) = *src;
1131template <
typename RhsScalar>
1132EIGEN_ALWAYS_INLINE
void pload_realimag(RhsScalar* src, Packet4f& r, Packet4f& i) {
1134 __asm__(
"lxvwsx %x0,%y1" :
"=wa"(r) :
"Z"(*(reinterpret_cast<float*>(src) + 0)));
1135 __asm__(
"lxvwsx %x0,%y1" :
"=wa"(i) :
"Z"(*(reinterpret_cast<float*>(src) + 1)));
1137 Packet4f t = pload_complex_half(src);
1138 r = vec_splat(t, COMPLEX_DELTA + 0);
1139 i = vec_splat(t, COMPLEX_DELTA + 1);
1143template <
typename RhsScalar>
1144EIGEN_ALWAYS_INLINE
void pload_realimag(RhsScalar* src, Packet2d& r, Packet2d& i) {
1145#ifdef EIGEN_VECTORIZE_VSX
1146 __asm__(
"lxvdsx %x0,%y1" :
"=wa"(r) :
"Z"(*(reinterpret_cast<double*>(src) + 0)));
1147 __asm__(
"lxvdsx %x0,%y1" :
"=wa"(i) :
"Z"(*(reinterpret_cast<double*>(src) + 1)));
1149 Packet2d t = ploadu<Packet2d>(
reinterpret_cast<double*
>(src));
1150 r = vec_splat(t, 0);
1151 i = vec_splat(t, 1);
1155#ifndef __POWER8_VECTOR__
1156const Packet16uc p16uc_MERGEE = {0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13,
1157 0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B};
1159const Packet16uc p16uc_MERGEO = {0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17,
1160 0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F};
1164template <
typename RhsScalar>
1165EIGEN_ALWAYS_INLINE
void pload_realimag_row(RhsScalar* src, Packet4f& r, Packet4f& i) {
1166 Packet4f t = ploadu<Packet4f>(
reinterpret_cast<float*
>(src));
1167#ifdef __POWER8_VECTOR__
1168 r = vec_mergee(t, t);
1169 i = vec_mergeo(t, t);
1171 r = vec_perm(t, t, p16uc_MERGEE);
1172 i = vec_perm(t, t, p16uc_MERGEO);
1176template <
typename RhsScalar>
1177EIGEN_ALWAYS_INLINE
void pload_realimag_row(RhsScalar* src, Packet2d& r, Packet2d& i) {
1178 return pload_realimag(src, r, i);
1182EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine(std::complex<float>* src) {
1183#ifdef EIGEN_VECTORIZE_VSX
1185 __asm__(
"lxvdsx %x0,%y1" :
"=wa"(ret) :
"Z"(*(reinterpret_cast<double*>(src) + 0)));
1188 return Packet4f(ploaddup<Packet2d>(
reinterpret_cast<double*
>(src)));
1192EIGEN_ALWAYS_INLINE Packet2d pload_realimag_combine(std::complex<double>* src) {
return ploadu<Packet1cd>(src).v; }
1195EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine_row(std::complex<float>* src) {
return ploadu<Packet2cf>(src).v; }
1197EIGEN_ALWAYS_INLINE Packet2d pload_realimag_combine_row(std::complex<double>* src) {
return ploadu<Packet1cd>(src).v; }
1200template <
typename ResPacket>
1201EIGEN_ALWAYS_INLINE Packet4f pload_complex(std::complex<float>* src) {
1202 if (GEMV_IS_SCALAR) {
1203 return pload_complex_half(src);
1205 return ploadu<Packet4f>(
reinterpret_cast<float*
>(src));
1209template <
typename ResPacket>
1210EIGEN_ALWAYS_INLINE Packet2d pload_complex(std::complex<double>* src) {
1211 return ploadu<Packet2d>(
reinterpret_cast<double*
>(src));
1215template <
typename ResPacket>
1216EIGEN_ALWAYS_INLINE Packet4f pload_complex(Packet2cf* src) {
1220template <
typename ResPacket>
1221EIGEN_ALWAYS_INLINE Packet2d pload_complex(Packet1cd* src) {
1226EIGEN_ALWAYS_INLINE Packet4f pload_complex_full(std::complex<float>* src) {
1227 return Packet4f(ploaddup<Packet2d>(
reinterpret_cast<double*
>(src)));
1230EIGEN_ALWAYS_INLINE Packet2d pload_complex_full(std::complex<double>* src) {
return ploadu<Packet1cd>(src).v; }
1233EIGEN_ALWAYS_INLINE Packet4f pload_complex_full_row(std::complex<float>* src) {
return ploadu<Packet2cf>(src).v; }
1235EIGEN_ALWAYS_INLINE Packet2d pload_complex_full_row(std::complex<double>* src) {
return pload_complex_full(src); }
1238EIGEN_ALWAYS_INLINE Packet4f pload_real(
float* src) {
return pset1<Packet4f>(*src); }
1240EIGEN_ALWAYS_INLINE Packet2d pload_real(
double* src) {
return pset1<Packet2d>(*src); }
1242EIGEN_ALWAYS_INLINE Packet4f pload_real(Packet4f& src) {
return src; }
1244EIGEN_ALWAYS_INLINE Packet2d pload_real(Packet2d& src) {
return src; }
1247EIGEN_ALWAYS_INLINE Packet4f pload_real_full(
float* src) {
1248 Packet4f ret = ploadu<Packet4f>(src);
1249 return vec_mergeh(ret, ret);
1252EIGEN_ALWAYS_INLINE Packet2d pload_real_full(
double* src) {
return pload_real(src); }
1254EIGEN_ALWAYS_INLINE Packet4f pload_real_full(std::complex<float>* src) {
1255 return pload_complex_full(src);
1258EIGEN_ALWAYS_INLINE Packet2d pload_real_full(std::complex<double>* src) {
1259 return pload_complex_full(src);
1263template <
typename ResPacket>
1264EIGEN_ALWAYS_INLINE Packet4f pload_real_row(
float* src) {
1265 if (GEMV_IS_SCALAR) {
1266 return pload_real_full(src);
1268 return ploadu<Packet4f>(src);
1272template <
typename ResPacket>
1273EIGEN_ALWAYS_INLINE Packet2d pload_real_row(
double* src) {
1274 return pload_real(src);
1277EIGEN_ALWAYS_INLINE Packet2cf padd(Packet2cf& a, std::complex<float>& b) {
1278 EIGEN_UNUSED_VARIABLE(b);
1282EIGEN_ALWAYS_INLINE Packet1cd padd(Packet1cd& a, std::complex<double>& b) {
1283 EIGEN_UNUSED_VARIABLE(b);
1288template <
typename Scalar,
typename ResScalar>
1289EIGEN_ALWAYS_INLINE Scalar pset1_realimag(ResScalar& alpha,
int which,
int conj) {
1290 return (which) ? ((
conj) ? -alpha.real() : alpha.real()) : ((
conj) ? -alpha.
imag() : alpha.
imag());
1294template <
typename Scalar,
typename ResScalar,
typename ResPacket,
int which>
1295EIGEN_ALWAYS_INLINE Packet2cf pset1_complex(std::complex<float>& alpha) {
1297 ret.v[COMPLEX_DELTA + 0] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x01), (which & 0x04));
1298 ret.v[COMPLEX_DELTA + 1] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x02), (which & 0x08));
1299 ret.v[2 - COMPLEX_DELTA] = ret.v[COMPLEX_DELTA + 0];
1300 ret.v[3 - COMPLEX_DELTA] = ret.v[COMPLEX_DELTA + 1];
1304template <
typename Scalar,
typename ResScalar,
typename ResPacket,
int which>
1305EIGEN_ALWAYS_INLINE Packet1cd pset1_complex(std::complex<double>& alpha) {
1307 ret.v[0] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x01), (which & 0x04));
1308 ret.v[1] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x02), (which & 0x08));
1313template <
typename Packet>
1314EIGEN_ALWAYS_INLINE Packet pset_zero() {
1315 return pset1<Packet>(__UNPACK_TYPE__(Packet)(0));
1319EIGEN_ALWAYS_INLINE Packet2cf pset_zero<Packet2cf>() {
1320 return Packet2cf(pset1<Packet4f>(
float(0)));
1324EIGEN_ALWAYS_INLINE Packet1cd pset_zero<Packet1cd>() {
1325 return Packet1cd(pset1<Packet2d>(
double(0)));
1329template <
typename Packet,
typename LhsPacket,
typename RhsPacket>
1330EIGEN_ALWAYS_INLINE Packet pset_init(Packet& c1) {
1331 if (GEMV_IS_COMPLEX_COMPLEX) {
1332 EIGEN_UNUSED_VARIABLE(c1);
1333 return pset_zero<Packet>();
1339template <
typename PResPacket,
typename ResPacket,
typename ResScalar,
typename Scalar>
1341 alpha_store(ResScalar& alpha) {
1342 separate.r = pset1_complex<Scalar, ResScalar, ResPacket, 0x3>(alpha);
1343 separate.i = pset1_complex<Scalar, ResScalar, ResPacket, 0x0>(alpha);
1352template <
typename ScalarPacket,
typename AlphaData>
1353EIGEN_ALWAYS_INLINE ScalarPacket pmadd_complex(ScalarPacket& c0, ScalarPacket& c2, ScalarPacket& c4, AlphaData& b0) {
1354 return pmadd(c2, b0.separate.i.v, pmadd(c0, b0.separate.r.v, c4));
1358template <
typename Scalar,
typename ScalarPacket,
typename PResPacket,
typename ResPacket,
typename ResScalar,
1360EIGEN_ALWAYS_INLINE
void pstoreu_pmadd_complex(PResPacket& c0, AlphaData& b0, ResScalar* res) {
1361 PResPacket c2 = pcplxflipconj(c0);
1362 if (GEMV_IS_SCALAR) {
1363 ScalarPacket c4 = ploadu<ScalarPacket>(
reinterpret_cast<Scalar*
>(res));
1364 ScalarPacket c3 = pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0);
1365 pstoreu(
reinterpret_cast<Scalar*
>(res), c3);
1367 ScalarPacket c4 = pload_complex<ResPacket>(res);
1368 PResPacket c3 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
1373template <
typename ScalarPacket,
typename PResPacket,
typename ResPacket,
typename ResScalar,
typename AlphaData,
1374 Index ResPacketSize, Index iter2>
1375EIGEN_ALWAYS_INLINE
void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, AlphaData& b0, ResScalar* res) {
1376 PResPacket c2 = pcplxflipconj(c0);
1377 PResPacket c3 = pcplxflipconj(c1);
1378#if !defined(_ARCH_PWR10)
1379 ScalarPacket c4 = pload_complex<ResPacket>(res + (iter2 * ResPacketSize));
1380 ScalarPacket c5 = pload_complex<ResPacket>(res + ((iter2 + 1) * ResPacketSize));
1381 PResPacket c6 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
1382 PResPacket c7 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c5, b0));
1383 pstoreu(res + (iter2 * ResPacketSize), c6);
1384 pstoreu(res + ((iter2 + 1) * ResPacketSize), c7);
1386 __vector_pair a = *
reinterpret_cast<__vector_pair*
>(res + (iter2 * ResPacketSize));
1389 __builtin_vsx_disassemble_pair(
reinterpret_cast<void*
>(c6), &a);
1390 c6[0] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c6[0].v, b0));
1391 c6[1] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c6[1].v, b0));
1392 GEMV_BUILDPAIR_MMA(a, c6[0].v, c6[1].v);
1394 if (GEMV_IS_COMPLEX_FLOAT) {
1395 __asm__(
"xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" :
"+&d"(a) :
"wa"(b0.separate.r.v),
"wa"(c0.v),
"wa"(c1.v));
1396 __asm__(
"xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" :
"+&d"(a) :
"wa"(b0.separate.i.v),
"wa"(c2.v),
"wa"(c3.v));
1398 __asm__(
"xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" :
"+&d"(a) :
"wa"(b0.separate.r.v),
"wa"(c0.v),
"wa"(c1.v));
1399 __asm__(
"xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" :
"+&d"(a) :
"wa"(b0.separate.i.v),
"wa"(c2.v),
"wa"(c3.v));
1402 *
reinterpret_cast<__vector_pair*
>(res + (iter2 * ResPacketSize)) = a;
1407template <
typename Scalar,
typename LhsScalar,
typename LhsMapper,
typename LhsPacket>
1408EIGEN_ALWAYS_INLINE LhsPacket loadLhsPacket(LhsMapper& lhs, Index i, Index j) {
1409 if (
sizeof(Scalar) ==
sizeof(LhsScalar)) {
1410 const LhsScalar& src = lhs(i + 0, j);
1411 return LhsPacket(pload_real_full(
const_cast<LhsScalar*
>(&src)));
1413 return lhs.template load<LhsPacket, Unaligned>(i + 0, j);
1417template <
typename ComplexPacket,
typename RealPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1418EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_complex(RealPacket& a, RealPacket& b, RealPacket& c) {
1419 if (ConjugateLhs && ConjugateRhs) {
1420 return vec_madd(a, pconj2(ComplexPacket(b)).v, c);
1421 }
else if (Negate && !ConjugateLhs && ConjugateRhs) {
1422 return vec_nmsub(a, b, c);
1424 return vec_madd(a, b, c);
1429template <
typename ComplexPacket,
typename RealPacket,
bool Conjugate>
1430EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_real(RealPacket& a, RealPacket& b, RealPacket& c) {
1432 return vec_madd(a, pconj2(ComplexPacket(b)).v, c);
1434 return vec_madd(a, b, c);
1438template <
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
bool ConjugateLhs,
1439 bool ConjugateRhs,
int StorageOrder>
1440EIGEN_ALWAYS_INLINE
void gemv_mult_generic(LhsPacket& a0, RhsScalar* b, PResPacket& c0) {
1441 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
1443 if (StorageOrder == ColMajor) {
1444 b0 = pset1<RhsPacket>(*b);
1446 b0 = ploadu<RhsPacket>(b);
1448 c0 = pcj.pmadd(a0, b0, c0);
1452template <
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
1453 typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1454EIGEN_ALWAYS_INLINE
void gemv_mult_complex_complex(LhsPacket& a0, RhsScalar* b, PResPacket& c0, ResPacket& c1) {
1455 ScalarPacket br, bi;
1456 if (StorageOrder == ColMajor) {
1457 pload_realimag<RhsScalar>(b, br, bi);
1459 pload_realimag_row<RhsScalar>(b, br, bi);
1461 if (ConjugateLhs && !ConjugateRhs) a0 = pconj2(a0);
1462 LhsPacket a1 = pcplxflipconj(a0);
1463 ScalarPacket cr = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, false>(a0.v, br, c0.v);
1464 ScalarPacket ci = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, true>(a1.v, bi, c1.v);
1466 c0 = PResPacket(cr);
1470template <
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
1471 typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1472EIGEN_ALWAYS_INLINE
void gemv_mult_real_complex(LhsPacket& a0, RhsScalar* b, PResPacket& c0) {
1474 if (StorageOrder == ColMajor) {
1475 b0 = pload_complex_full(b);
1477 b0 = pload_complex_full_row(b);
1479 ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateRhs>(a0, b0, c0.v);
1480 c0 = PResPacket(cri);
1484template <
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
1485 typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1486EIGEN_ALWAYS_INLINE
void gemv_mult_complex_real(LhsPacket& a0, RhsScalar* b, PResPacket& c0) {
1487 ScalarPacket a1 = pload_complex<ResPacket>(&a0);
1489 if (StorageOrder == ColMajor) {
1492 b0 = pload_real_row<ResPacket>(b);
1494 ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateLhs>(a1, b0, c0.v);
1495 c0 = PResPacket(cri);
1498#define GEMV_MULT_COMPLEX_COMPLEX(LhsType, RhsType, ResType) \
1499 template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1500 typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1501 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, ResType& c1) { \
1502 gemv_mult_complex_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1503 ConjugateRhs, StorageOrder>(a0, b, c0, c1); \
1506GEMV_MULT_COMPLEX_COMPLEX(Packet2cf, std::complex<float>, Packet2cf)
1507GEMV_MULT_COMPLEX_COMPLEX(Packet1cd, std::complex<double>, Packet1cd)
1509#define GEMV_MULT_REAL_COMPLEX(LhsType, RhsType, ResType) \
1510 template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1511 typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1512 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, RhsType&) { \
1513 gemv_mult_real_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1514 ConjugateRhs, StorageOrder>(a0, b, c0); \
1517GEMV_MULT_REAL_COMPLEX(
float, std::complex<float>, Packet2cf)
1518GEMV_MULT_REAL_COMPLEX(
double, std::complex<double>, Packet1cd)
1519GEMV_MULT_REAL_COMPLEX(Packet4f, std::complex<float>, Packet2cf)
1520GEMV_MULT_REAL_COMPLEX(Packet2d, std::complex<double>, Packet1cd)
1522#define GEMV_MULT_COMPLEX_REAL(LhsType, RhsType, ResType1, ResType2) \
1523 template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1524 typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1525 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType1& c0, ResType2&) { \
1526 gemv_mult_complex_real<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1527 ConjugateRhs, StorageOrder>(a0, b, c0); \
1530GEMV_MULT_COMPLEX_REAL(Packet2cf,
float, Packet2cf, std::complex<float>)
1531GEMV_MULT_COMPLEX_REAL(Packet1cd,
double, Packet1cd, std::complex<double>)
1532GEMV_MULT_COMPLEX_REAL(std::complex<float>,
float, Packet2cf, std::complex<float>)
1533GEMV_MULT_COMPLEX_REAL(std::complex<double>,
double, Packet1cd, std::complex<double>)
1537template <
typename T>
1538EIGEN_ALWAYS_INLINE T convertReal(T a) {
1542EIGEN_ALWAYS_INLINE Packet4f convertReal(Packet2cf a) {
return a.v; }
1544EIGEN_ALWAYS_INLINE Packet2d convertReal(Packet1cd a) {
return a.v; }
1547template <
typename T>
1548EIGEN_ALWAYS_INLINE T convertComplex(T a) {
1552EIGEN_ALWAYS_INLINE Packet2cf convertComplex(Packet4f a) {
return Packet2cf(a); }
1554EIGEN_ALWAYS_INLINE Packet1cd convertComplex(Packet2d a) {
return Packet1cd(a); }
1557template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename ResPacket>
1558EIGEN_ALWAYS_INLINE
void pload_complex_MMA(SLhsPacket& a) {
1559 a = SLhsPacket(pload_complex<ResPacket>(&a));
1562template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename ResPacket>
1563EIGEN_ALWAYS_INLINE
void pload_complex_MMA(__vector_pair&) {
1568template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1569EIGEN_ALWAYS_INLINE
void pger_vecMMA(__vector_quad* acc, RhsPacket& a, LhsPacket& b) {
1570 if (NegativeAccumulate) {
1571 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
1573 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
1578template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1579EIGEN_ALWAYS_INLINE
void pger_vecMMA(__vector_quad* acc, __vector_pair& a, Packet2d& b) {
1580 if (NegativeAccumulate) {
1581 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector
unsigned char)b);
1583 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector
unsigned char)b);
1587template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1588EIGEN_ALWAYS_INLINE
void pger_vecMMA(__vector_quad*, __vector_pair&, Packet4f&) {
1593template <
typename RealPacket,
typename LhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1594EIGEN_ALWAYS_INLINE
void pmadd_complex_complex_MMA(LhsPacket& a, RealPacket& b, __vector_quad* c) {
1595 if (ConjugateLhs && ConjugateRhs) {
1596 RealPacket b2 = pconj2(convertComplex(b)).v;
1597 return pger_vecMMA<RealPacket, RealPacket, false>(c, b2, a.v);
1598 }
else if (Negate && !ConjugateLhs && ConjugateRhs) {
1599 return pger_vecMMA<RealPacket, RealPacket, true>(c, b, a.v);
1601 return pger_vecMMA<RealPacket, RealPacket, false>(c, b, a.v);
1605template <
typename RealPacket,
typename LhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1606EIGEN_ALWAYS_INLINE
void pmadd_complex_complex_MMA(__vector_pair& a, RealPacket& b, __vector_quad* c) {
1607 if (ConjugateLhs && ConjugateRhs) {
1608 RealPacket b2 = pconj2(convertComplex(b)).v;
1609 return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b2);
1610 }
else if (Negate && !ConjugateLhs && ConjugateRhs) {
1611 return pger_vecMMA<RealPacket, __vector_pair, true>(c, a, b);
1613 return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b);
1618template <
typename RealPacket,
typename LhsPacket,
bool Conjugate,
int StorageOrder>
1619EIGEN_ALWAYS_INLINE
void pmadd_complex_real_MMA(LhsPacket& a, RealPacket& b, __vector_quad* c) {
1620 RealPacket a2 = convertReal(a);
1622 RealPacket b2 = pconj2(convertComplex(b)).v;
1623 if (StorageOrder == ColMajor) {
1624 return pger_vecMMA<RealPacket, RealPacket, false>(c, b2, a2);
1626 return pger_vecMMA<RealPacket, RealPacket, false>(c, a2, b2);
1629 if (StorageOrder == ColMajor) {
1630 return pger_vecMMA<RealPacket, RealPacket, false>(c, b, a2);
1632 return pger_vecMMA<RealPacket, RealPacket, false>(c, a2, b);
1638template <
typename RealPacket,
typename LhsPacket,
bool Conjugate,
int StorageOrder>
1639EIGEN_ALWAYS_INLINE
void pmadd_complex_real_MMA(__vector_pair& a, RealPacket& b, __vector_quad* c) {
1641 RealPacket b2 = pconj2(convertComplex(b)).v;
1642 return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b2);
1644 return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b);
1649template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
1650 bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1651EIGEN_ALWAYS_INLINE
void gemv_mult_complex_complex_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0) {
1653 if (StorageOrder == ColMajor) {
1654 b0 = pload_realimag_combine(b);
1656 b0 = pload_realimag_combine_row(b);
1658 pmadd_complex_complex_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ConjugateRhs, false>(a0, b0, c0);
1662template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
1663 bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1664EIGEN_ALWAYS_INLINE
void gemv_mult_complex_real_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0) {
1665 pload_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, ResPacket>(a0);
1667 if (StorageOrder == ColMajor) {
1670 b0 = pload_real_row<ResPacket>(b);
1672 pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ColMajor>(a0, b0, c0);
1676template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
1677 bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1678EIGEN_ALWAYS_INLINE
void gemv_mult_real_complex_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0) {
1680 if (StorageOrder == ColMajor) {
1681 b0 = pload_complex_full(b);
1683 b0 = pload_complex_full_row(b);
1685 pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateRhs,
1686 (
sizeof(RhsScalar) ==
sizeof(std::complex<float>)) ? StorageOrder :
ColMajor>(a0, b0, c0);
1689#define GEMV_MULT_COMPLEX_COMPLEX_MMA(LhsType, RhsType) \
1690 template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1691 typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1692 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1693 gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, \
1694 ConjugateRhs, StorageOrder>(a0, b, c0); \
1697GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet2cf, std::complex<float>)
1698GEMV_MULT_COMPLEX_COMPLEX_MMA(__vector_pair, std::complex<float>)
1699GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet1cd, std::complex<double>)
1702template <
typename ScalarPacket,
typename LhsScalar,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
1703 typename RhsPacket,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1704EIGEN_ALWAYS_INLINE
void gemv_mult_complex_MMA(__vector_pair& a0, std::complex<double>* b, __vector_quad* c0) {
1705 if (
sizeof(LhsScalar) == 16) {
1706 gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs,
1707 StorageOrder>(a0, b, c0);
1709 gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs,
1710 StorageOrder>(a0, b, c0);
1714#define GEMV_MULT_REAL_COMPLEX_MMA(LhsType, RhsType) \
1715 template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1716 typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1717 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1718 gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, \
1719 StorageOrder>(a0, b, c0); \
1722GEMV_MULT_REAL_COMPLEX_MMA(Packet4f, std::complex<float>)
1723GEMV_MULT_REAL_COMPLEX_MMA(Packet2d, std::complex<double>)
1725#define GEMV_MULT_COMPLEX_REAL_MMA(LhsType, RhsType) \
1726 template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1727 typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1728 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1729 gemv_mult_complex_real_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, \
1730 StorageOrder>(a0, b, c0); \
1733GEMV_MULT_COMPLEX_REAL_MMA(Packet2cf,
float)
1734GEMV_MULT_COMPLEX_REAL_MMA(Packet1cd,
double)
1735GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair,
float)
1736GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair,
double)
1739template <
typename Scalar,
typename ScalarPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
1741EIGEN_ALWAYS_INLINE
void disassembleResults2(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1742 __builtin_mma_disassemble_acc(&result0.packet, c0);
1743 if (
sizeof(LhsPacket) == 16) {
1744 if (
sizeof(RhsPacket) == 16) {
1745 ScalarPacket tmp0, tmp2;
1746 tmp2 = vec_mergeh(result0.packet[2], result0.packet[3]);
1747 tmp0 = vec_mergeh(result0.packet[0], result0.packet[1]);
1748 result0.packet[3] = vec_mergel(result0.packet[3], result0.packet[2]);
1749 result0.packet[1] = vec_mergel(result0.packet[1], result0.packet[0]);
1750 result0.packet[2] = tmp2;
1751 result0.packet[0] = tmp0;
1754 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1755 result0.packet[2] = pconj2(convertComplex(result0.packet[2])).v;
1756 }
else if (ConjugateRhs) {
1757 result0.packet[1] = pconj2(convertComplex(result0.packet[1])).v;
1758 result0.packet[3] = pconj2(convertComplex(result0.packet[3])).v;
1760 result0.packet[1] = pconjinv(convertComplex(result0.packet[1])).v;
1761 result0.packet[3] = pconjinv(convertComplex(result0.packet[3])).v;
1763 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1764 result0.packet[2] = vec_add(result0.packet[2], result0.packet[3]);
1766 result0.packet[0][1] = result0.packet[1][1];
1767 result0.packet[2][1] = result0.packet[3][1];
1772template <
typename Scalar,
typename ScalarPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
1774EIGEN_ALWAYS_INLINE
void disassembleResults4(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1775 __builtin_mma_disassemble_acc(&result0.packet, c0);
1776 if (GEMV_IS_COMPLEX_COMPLEX) {
1778 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1779 result0.packet[1] = pcplxflip2(convertComplex(result0.packet[1])).v;
1782 result0.packet[1] = pcplxconjflip(convertComplex(result0.packet[1])).v;
1784 result0.packet[1] = pcplxflipconj(convertComplex(result0.packet[1])).v;
1787 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1788 }
else if (
sizeof(LhsPacket) ==
sizeof(std::complex<float>)) {
1790 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1793 result0.packet[0] = vec_mergee(result0.packet[0], result0.packet[1]);
1797template <
typename Scalar,
typename ScalarPacket,
int ResPacketSize,
typename LhsPacket,
typename RhsPacket,
1798 bool ConjugateLhs,
bool ConjugateRhs>
1799EIGEN_ALWAYS_INLINE
void disassembleResults(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1800 if (!GEMV_IS_COMPLEX_FLOAT) {
1801 disassembleResults2<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1803 disassembleResults4<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1808#define GEMV_GETN_COMPLEX(N) (((N) * ResPacketSize) >> 1)
1810#define GEMV_LOADPACKET_COL_COMPLEX(iter) \
1811 loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + ((iter) * ResPacketSize), j)
1813#define GEMV_LOADPACKET_COL_COMPLEX_DATA(iter) convertReal(GEMV_LOADPACKET_COL_COMPLEX(iter))
1816#define GEMV_INIT_COL_COMPLEX_MMA(iter, N) \
1817 if (GEMV_GETN_COMPLEX(N) > iter) { \
1818 __builtin_mma_xxsetaccz(&e0##iter); \
1822#define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1823 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), \
1824 GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1)); \
1825 EIGEN_UNUSED_VARIABLE(f##iter1);
1827#define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1828 if (sizeof(LhsPacket) == 16) { \
1829 const LhsScalar& src = lhs(i + ((32 * iter1) / sizeof(LhsScalar)), j); \
1830 a##iter1 = *reinterpret_cast<__vector_pair*>(const_cast<LhsScalar*>(&src)); \
1831 EIGEN_UNUSED_VARIABLE(f##iter1); \
1833 f##iter1 = lhs.template load<PLhsPacket, Unaligned>(i + ((iter2) * ResPacketSize), j); \
1834 GEMV_BUILDPAIR_MMA(a##iter1, vec_splat(convertReal(f##iter1), 0), vec_splat(convertReal(f##iter1), 1)); \
1838#define GEMV_LOAD1_COL_COMPLEX_MMA(iter, N) \
1839 if (GEMV_GETN_COMPLEX(N) > iter) { \
1840 if (GEMV_IS_COMPLEX_FLOAT) { \
1841 f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1842 EIGEN_UNUSED_VARIABLE(a##iter); \
1844 GEMV_LOADPAIR_COL_COMPLEX_MMA(iter, iter << 1) \
1847 EIGEN_UNUSED_VARIABLE(a##iter); \
1848 EIGEN_UNUSED_VARIABLE(f##iter); \
1851#define GEMV_WORK1_COL_COMPLEX_MMA(iter, N) \
1852 if (GEMV_GETN_COMPLEX(N) > iter) { \
1853 if (GEMV_IS_COMPLEX_FLOAT) { \
1854 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1855 ConjugateLhs, ConjugateRhs, ColMajor>(f##iter, b, &e0##iter); \
1857 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1858 ConjugateLhs, ConjugateRhs, ColMajor>(a##iter, b, &e0##iter); \
1862#define GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter1, iter2) \
1863 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1));
1865#define GEMV_LOAD2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1866 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1867 if (GEMV_IS_COMPLEX_FLOAT) { \
1868 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2); \
1869 EIGEN_UNUSED_VARIABLE(a##iter3) \
1871 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2 << 1); \
1872 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter3, iter3 << 1); \
1875 EIGEN_UNUSED_VARIABLE(a##iter2); \
1876 EIGEN_UNUSED_VARIABLE(a##iter3); \
1878 EIGEN_UNUSED_VARIABLE(f##iter2); \
1879 EIGEN_UNUSED_VARIABLE(f##iter3);
1881#define GEMV_WORK2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1882 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1883 if (GEMV_IS_COMPLEX_FLOAT) { \
1885 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(g), &a##iter2); \
1886 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1887 ConjugateLhs, ConjugateRhs, ColMajor>(g[0], b, &e0##iter2); \
1888 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1889 ConjugateLhs, ConjugateRhs, ColMajor>(g[1], b, &e0##iter3); \
1891 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1892 ConjugateLhs, ConjugateRhs, ColMajor>(a##iter2, b, &e0##iter2); \
1893 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1894 ConjugateLhs, ConjugateRhs, ColMajor>(a##iter3, b, &e0##iter3); \
1899#define GEMV_LOAD_COL_COMPLEX_MMA(N) \
1900 if (GEMV_GETN_COMPLEX(N) > 1) { \
1901 GEMV_UNROLL_HALF(GEMV_LOAD2_COL_COMPLEX_MMA, (N >> 1)) \
1903 GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N) \
1906#define GEMV_WORK_COL_COMPLEX_MMA(N) \
1907 if (GEMV_GETN_COMPLEX(N) > 1) { \
1908 GEMV_UNROLL_HALF(GEMV_WORK2_COL_COMPLEX_MMA, (N >> 1)) \
1910 GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N) \
1913#define GEMV_LOAD_COL_COMPLEX_MMA(N) GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N)
1915#define GEMV_WORK_COL_COMPLEX_MMA(N) GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N)
1918#define GEMV_DISASSEMBLE_COMPLEX_MMA(iter) \
1919 disassembleResults<Scalar, ScalarPacket, ResPacketSize, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>( \
1920 &e0##iter, result0##iter);
1922#define GEMV_STORE_COL_COMPLEX_MMA(iter, N) \
1923 if (GEMV_GETN_COMPLEX(N) > iter) { \
1924 GEMV_DISASSEMBLE_COMPLEX_MMA(iter); \
1925 c0##iter = PResPacket(result0##iter.packet[0]); \
1926 if (GEMV_IS_COMPLEX_FLOAT) { \
1927 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1928 c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
1930 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1931 c0##iter, alpha_data, res + i + ((iter << 1) * ResPacketSize)); \
1932 c0##iter = PResPacket(result0##iter.packet[2]); \
1933 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1934 c0##iter, alpha_data, res + i + (((iter << 1) + 1) * ResPacketSize)); \
1938#define GEMV_STORE2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1939 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1940 GEMV_DISASSEMBLE_COMPLEX_MMA(iter2); \
1941 GEMV_DISASSEMBLE_COMPLEX_MMA(iter3); \
1942 c0##iter2 = PResPacket(result0##iter2.packet[0]); \
1943 if (GEMV_IS_COMPLEX_FLOAT) { \
1944 c0##iter3 = PResPacket(result0##iter3.packet[0]); \
1945 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2>( \
1946 c0##iter2, c0##iter3, alpha_data, res + i); \
1948 c0##iter3 = PResPacket(result0##iter2.packet[2]); \
1949 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2 << 1>( \
1950 c0##iter2, c0##iter3, alpha_data, res + i); \
1951 c0##iter2 = PResPacket(result0##iter3.packet[0]); \
1952 c0##iter3 = PResPacket(result0##iter3.packet[2]); \
1953 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter3 << 1>( \
1954 c0##iter2, c0##iter3, alpha_data, res + i); \
1958#define GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
1959 GEMV_UNROLL(GEMV_INIT_COL_COMPLEX_MMA, N) \
1962 const RhsScalar& b1 = rhs2(j, 0); \
1963 RhsScalar* b = const_cast<RhsScalar*>(&b1); \
1964 GEMV_UNROLL(GEMV_PREFETCH, N) \
1965 GEMV_LOAD_COL_COMPLEX_MMA(N) \
1966 GEMV_WORK_COL_COMPLEX_MMA(N) \
1967 } while (++j < jend); \
1968 if (GEMV_GETN(N) <= 2) { \
1969 GEMV_UNROLL(GEMV_STORE_COL_COMPLEX_MMA, N) \
1971 GEMV_UNROLL_HALF(GEMV_STORE2_COL_COMPLEX_MMA, (N >> 1)) \
1973 i += (ResPacketSize * N);
1976#define GEMV_INIT_COMPLEX(iter, N) \
1978 c0##iter = pset_zero<PResPacket>(); \
1979 c1##iter = pset_init<ResPacket, LhsPacket, RhsPacket>(c1##iter); \
1981 EIGEN_UNUSED_VARIABLE(c0##iter); \
1982 EIGEN_UNUSED_VARIABLE(c1##iter); \
1985#define GEMV_WORK_COL_COMPLEX(iter, N) \
1987 f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1988 gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1989 ConjugateRhs, ColMajor>(f##iter, b, c0##iter, c1##iter); \
1991 EIGEN_UNUSED_VARIABLE(f##iter); \
1994#define GEMV_STORE_COL_COMPLEX(iter, N) \
1996 if (GEMV_IS_COMPLEX_COMPLEX) { \
1997 c0##iter = padd(c0##iter, c1##iter); \
1999 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
2000 c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
2004#define GEMV_PROCESS_COL_COMPLEX_ONE(N) \
2005 GEMV_UNROLL(GEMV_INIT_COMPLEX, N) \
2008 const RhsScalar& b1 = rhs2(j, 0); \
2009 RhsScalar* b = const_cast<RhsScalar*>(&b1); \
2010 GEMV_UNROLL(GEMV_PREFETCH, N) \
2011 GEMV_UNROLL(GEMV_WORK_COL_COMPLEX, N) \
2012 } while (++j < jend); \
2013 GEMV_UNROLL(GEMV_STORE_COL_COMPLEX, N) \
2014 i += (ResPacketSize * N);
2016#if defined(USE_GEMV_MMA) && (EIGEN_COMP_LLVM || defined(USE_SLOWER_GEMV_MMA))
2017#define USE_GEMV_COL_COMPLEX_MMA
2020#ifdef USE_GEMV_COL_COMPLEX_MMA
2021#define GEMV_PROCESS_COL_COMPLEX(N) GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N)
2023#if defined(USE_GEMV_MMA) && (__GNUC__ > 10)
2024#define GEMV_PROCESS_COL_COMPLEX(N) \
2025 if (sizeof(Scalar) != sizeof(LhsPacket)) { \
2026 GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
2028 GEMV_PROCESS_COL_COMPLEX_ONE(N) \
2031#define GEMV_PROCESS_COL_COMPLEX(N) GEMV_PROCESS_COL_COMPLEX_ONE(N)
2035template <
typename Scalar,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
bool LhsIsReal,
2036 typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
bool RhsIsReal,
typename ResScalar>
2037EIGEN_STRONG_INLINE
void gemv_complex_col(Index rows, Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs,
2038 ResScalar* res, Index resIncr, ResScalar alpha) {
2039 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2041 typedef typename Traits::LhsPacket LhsPacket;
2042 typedef typename Traits::RhsPacket RhsPacket;
2043 typedef typename Traits::ResPacket ResPacket;
2045 typedef typename packet_traits<Scalar>::type ScalarPacket;
2046 typedef typename packet_traits<LhsScalar>::type PLhsPacket;
2047 typedef typename packet_traits<ResScalar>::type PResPacket;
2048 typedef gemv_traits<ResPacket, ResPacket> PTraits;
2050 EIGEN_UNUSED_VARIABLE(resIncr);
2051 eigen_internal_assert(resIncr == 1);
2055 LhsMapper lhs(alhs);
2056 RhsMapper rhs2(rhs);
2058 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2060 const Index lhsStride = lhs.stride();
2064 ResPacketSize = PTraits::ResPacketSize,
2065 LhsPacketSize = PTraits::LhsPacketSize,
2066 RhsPacketSize = PTraits::RhsPacketSize,
2068#ifdef EIGEN_POWER_USE_GEMV_PREFETCH
2069 const Index prefetch_dist = 64 * LhsPacketSize;
2072#ifndef GCC_ONE_VECTORPAIR_BUG
2073 const Index n8 = rows - 8 * ResPacketSize + 1;
2074 const Index n4 = rows - 4 * ResPacketSize + 1;
2075 const Index n2 = rows - 2 * ResPacketSize + 1;
2077 const Index n1 = rows - 1 * ResPacketSize + 1;
2080 const Index block_cols = cols < 128 ? cols : (lhsStride *
sizeof(LhsScalar) < 16000 ? 16 : 8);
2082 typedef alpha_store<PResPacket, ResPacket, ResScalar, Scalar> AlphaData;
2083 AlphaData alpha_data(alpha);
2085 for (Index j2 = 0; j2 < cols; j2 += block_cols) {
2086 Index jend = numext::mini(j2 + block_cols, cols);
2088 PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2089 ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2090 PLhsPacket f0, f1, f2, f3, f4, f5, f6, f7;
2092 __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2093 __vector_pair a0, a1, a2, a3, a4, a5, a6, a7;
2094 PacketBlock<ScalarPacket, 4> result00, result01, result02, result03, result04, result05, result06, result07;
2096 GEMV_UNUSED(8, result0)
2099#if !defined(GCC_ONE_VECTORPAIR_BUG) && defined(USE_GEMV_COL_COMPLEX_MMA)
2100 if (GEMV_IS_COMPLEX_COMPLEX || !GEMV_IS_COMPLEX_FLOAT)
2103#ifndef GCC_ONE_VECTORPAIR_BUG
2106 GEMV_PROCESS_COL_COMPLEX(8)
2110 GEMV_PROCESS_COL_COMPLEX(4)
2113 GEMV_PROCESS_COL_COMPLEX(2)
2120 GEMV_PROCESS_COL_COMPLEX_ONE(1)
2122 for (; i < rows; ++i) {
2126 d0 += cj.pmul(lhs(i, j), rhs2(j, 0));
2127 }
while (++j < jend);
2128 res[i] += alpha * d0;
2133template <
typename Scalar,
int N>
2139static Packet16uc p16uc_ELEMENT_3 = {0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f,
2140 0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f};
2143template <
typename ResScalar,
typename ResPacket>
2144EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0, __vector_quad* acc1) {
2145 PacketBlock<ResPacket, 4> result0, result1;
2146 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2147 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2148 result0.packet[0] = vec_mergeh(result0.packet[0], result1.packet[0]);
2149 result0.packet[1] = vec_mergeo(result0.packet[1], result1.packet[1]);
2150 result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
2151 result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
2153 vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
2154 return *
reinterpret_cast<ScalarBlock<ResScalar, 2>*
>(&result0.packet[0]);
2158EIGEN_ALWAYS_INLINE ScalarBlock<double, 2> predux_real<double, Packet2d>(__vector_quad* acc0, __vector_quad* acc1) {
2159 PacketBlock<Packet2d, 4> result0, result1;
2160 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2161 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2163 vec_add(vec_mergeh(result0.packet[0], result1.packet[0]), vec_mergel(result0.packet[1], result1.packet[1]));
2164 return *
reinterpret_cast<ScalarBlock<double, 2>*
>(&result0.packet[0]);
2168template <
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
2169EIGEN_ALWAYS_INLINE ScalarBlock<std::complex<float>, 2> addComplexResults(PacketBlock<Packet4f, 4>& result0,
2170 PacketBlock<Packet4f, 4>& result1) {
2171 ScalarBlock<std::complex<float>, 2> cc0;
2172 result0.packet[0] =
reinterpret_cast<Packet4f
>(
2173 vec_mergeh(
reinterpret_cast<Packet2d
>(result0.packet[0]),
reinterpret_cast<Packet2d
>(result1.packet[0])));
2174 result0.packet[2] =
reinterpret_cast<Packet4f
>(
2175 vec_mergel(
reinterpret_cast<Packet2d
>(result0.packet[2]),
reinterpret_cast<Packet2d
>(result1.packet[2])));
2176 result0.packet[0] = vec_add(result0.packet[0], result0.packet[2]);
2177 if (GEMV_IS_COMPLEX_COMPLEX) {
2178 result0.packet[1] =
reinterpret_cast<Packet4f
>(
2179 vec_mergeh(
reinterpret_cast<Packet2d
>(result0.packet[1]),
reinterpret_cast<Packet2d
>(result1.packet[1])));
2180 result0.packet[3] =
reinterpret_cast<Packet4f
>(
2181 vec_mergel(
reinterpret_cast<Packet2d
>(result0.packet[3]),
reinterpret_cast<Packet2d
>(result1.packet[3])));
2182 result0.packet[1] = vec_add(result0.packet[1], result0.packet[3]);
2184 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
2185 result0.packet[1] = pcplxflip2(convertComplex(result0.packet[1])).v;
2186 }
else if (ConjugateRhs) {
2187 result0.packet[1] = pcplxconjflip(convertComplex(result0.packet[1])).v;
2189 result0.packet[1] = pcplxflipconj(convertComplex(result0.packet[1])).v;
2191 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
2193 if (ConjugateLhs && (
sizeof(LhsPacket) ==
sizeof(std::complex<float>))) {
2194 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
2197 cc0.scalar[0].real(result0.packet[0][0]);
2198 cc0.scalar[0].imag(result0.packet[0][1]);
2199 cc0.scalar[1].real(result0.packet[0][2]);
2200 cc0.scalar[1].imag(result0.packet[0][3]);
2204template <
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
2205EIGEN_ALWAYS_INLINE ScalarBlock<std::complex<double>, 2> addComplexResults(PacketBlock<Packet2d, 4>&,
2206 PacketBlock<Packet2d, 4>&) {
2207 ScalarBlock<std::complex<double>, 2> cc0;
2208 EIGEN_UNUSED_VARIABLE(cc0);
2213template <
typename ResScalar,
typename ResPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
2215EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(__vector_quad* acc0, __vector_quad* acc1) {
2216 PacketBlock<ResPacket, 4> result0, result1;
2217 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2218 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2219 return addComplexResults<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(result0, result1);
2222template <
typename ResScalar,
typename ResPacket>
2223EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0) {
2224 PacketBlock<ResPacket, 4> result0;
2225 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2227 vec_add(vec_mergeh(result0.packet[0], result0.packet[2]), vec_mergel(result0.packet[1], result0.packet[3]));
2228 return *
reinterpret_cast<ScalarBlock<ResScalar, 2>*
>(&result0.packet[0]);
2231template <
typename ResScalar,
typename ResPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
2233EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(__vector_quad* acc0) {
2234 ScalarBlock<ResScalar, 2> cc0;
2235 PacketBlock<ResPacket, 4> result0;
2236 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2237 if (GEMV_IS_COMPLEX_COMPLEX) {
2239 result0.packet[1] = pconjinv(convertComplex(result0.packet[1])).v;
2240 result0.packet[3] = pconjinv(convertComplex(result0.packet[3])).v;
2241 }
else if (ConjugateRhs) {
2242 result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
2243 result0.packet[2] = pconj2(convertComplex(result0.packet[2])).v;
2245 result0.packet[1] = pconj2(convertComplex(result0.packet[1])).v;
2246 result0.packet[3] = pconj2(convertComplex(result0.packet[3])).v;
2248 result0.packet[0] = vec_add(result0.packet[0], __builtin_vsx_xxpermdi(result0.packet[1], result0.packet[1], 2));
2249 result0.packet[2] = vec_add(result0.packet[2], __builtin_vsx_xxpermdi(result0.packet[3], result0.packet[3], 2));
2251 result0.packet[0] = __builtin_vsx_xxpermdi(result0.packet[0], result0.packet[1], 1);
2252 result0.packet[2] = __builtin_vsx_xxpermdi(result0.packet[2], result0.packet[3], 1);
2254 cc0.scalar[0].real(result0.packet[0][0]);
2255 cc0.scalar[0].imag(result0.packet[0][1]);
2256 cc0.scalar[1].real(result0.packet[2][0]);
2257 cc0.scalar[1].imag(result0.packet[2][1]);
2262template <
typename ResScalar,
typename ResPacket>
2263EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(ResPacket& a, ResPacket& b) {
2264 ScalarBlock<ResScalar, 2> cc0;
2265 cc0.scalar[0] = predux(a);
2266 cc0.scalar[1] = predux(b);
2270template <
typename ResScalar,
typename ResPacket>
2271EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(ResPacket& a, ResPacket& b) {
2272 return predux_real<ResScalar, ResPacket>(a, b);
2275#define GEMV_UNROLL_ROW(func, N) func(0, N) func(1, N) func(2, N) func(3, N) func(4, N) func(5, N) func(6, N) func(7, N)
2277#define GEMV_UNROLL_ROW_HALF(func, N) func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
2279#define GEMV_LOADPACKET_ROW(iter) lhs.template load<LhsPacket, Unaligned>(i + (iter), j)
2282#define GEMV_UNROLL3_ROW(func, N, which) \
2283 func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) func(4, N, which) func(5, N, which) \
2284 func(6, N, which) func(7, N, which)
2286#define GEMV_UNUSED_ROW(N, which) GEMV_UNROLL3_ROW(GEMV_UNUSED_VAR, N, which)
2288#define GEMV_INIT_ROW(iter, N) \
2289 if (GEMV_GETN(N) > iter) { \
2290 __builtin_mma_xxsetaccz(&c##iter); \
2293#define GEMV_LOADPAIR_ROW(iter1, iter2) \
2294 GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_ROW(iter2), GEMV_LOADPACKET_ROW((iter2) + 1));
2296#define GEMV_WORK_ROW(iter, N) \
2297 if (GEMV_GETN(N) > iter) { \
2298 if (GEMV_IS_FLOAT) { \
2299 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, a0, GEMV_LOADPACKET_ROW(iter)); \
2301 __vector_pair b##iter; \
2302 GEMV_LOADPAIR_ROW(iter, iter << 1) \
2303 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, b##iter, a0); \
2307#define GEMV_PREDUX2(iter1, iter2, iter3, N) \
2309 if (GEMV_IS_FLOAT) { \
2310 cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter2, &c##iter3); \
2312 cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter1); \
2315 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2318#define GEMV_INIT_ROW(iter, N) \
2320 c##iter = pset1<ResPacket>(ResScalar(0)); \
2322 EIGEN_UNUSED_VARIABLE(c##iter); \
2325#define GEMV_WORK_ROW(iter, N) \
2327 c##iter = pcj.pmadd(GEMV_LOADPACKET_ROW(iter), a0, c##iter); \
2330#define GEMV_PREDUX2(iter1, iter2, iter3, N) \
2332 cc##iter1 = predux_real<ResScalar, ResPacket>(c##iter2, c##iter3); \
2334 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2338#define GEMV_MULT(iter1, iter2, iter3, N) \
2340 cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), a0); \
2341 cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), a0); \
2344#define GEMV_STORE_ROW(iter1, iter2, iter3, N) \
2346 storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2347 storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2351#define GEMV_PROCESS_ROW(N) \
2352 for (; i < n##N; i += N) { \
2353 GEMV_UNROLL_ROW(GEMV_INIT_ROW, N) \
2355 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2356 RhsPacket a0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2357 GEMV_UNROLL_ROW(GEMV_WORK_ROW, N) \
2359 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX2, (N >> 1)) \
2360 for (; j < cols; ++j) { \
2361 RhsScalar a0 = rhs2(j); \
2362 GEMV_UNROLL_ROW_HALF(GEMV_MULT, (N >> 1)) \
2364 GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \
2367template <
typename LhsScalar,
typename LhsMapper,
typename RhsScalar,
typename RhsMapper,
typename ResScalar>
2368EIGEN_STRONG_INLINE
void gemv_row(Index rows, Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs, ResScalar* res,
2369 Index resIncr, ResScalar alpha) {
2370 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2372 typedef typename Traits::LhsPacket LhsPacket;
2373 typedef typename Traits::RhsPacket RhsPacket;
2374 typedef typename Traits::ResPacket ResPacket;
2378 LhsMapper lhs(alhs);
2379 typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2381 eigen_internal_assert(rhs.stride() == 1);
2382 conj_helper<LhsScalar, RhsScalar, false, false> cj;
2383 conj_helper<LhsPacket, RhsPacket, false, false> pcj;
2387#ifndef GCC_ONE_VECTORPAIR_BUG
2388 const Index n8 = lhs.stride() *
sizeof(LhsScalar) > 32000 ? (rows - 7) : (rows - 7);
2389 const Index n4 = rows - 3;
2390 const Index n2 = rows - 1;
2396 ResPacketSize = Traits::ResPacketSize,
2397 LhsPacketSize = Traits::LhsPacketSize,
2398 RhsPacketSize = Traits::RhsPacketSize,
2403 __vector_quad c0, c1, c2, c3, c4, c5, c6, c7;
2404 GEMV_UNUSED_ROW(8, c)
2406 ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
2408#ifndef GCC_ONE_VECTORPAIR_BUG
2409 ScalarBlock<ResScalar, 2> cc0, cc1, cc2, cc3;
2414 for (; i < rows; ++i) {
2415 ResPacket d0 = pset1<ResPacket>(ResScalar(0));
2417 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) {
2418 RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j);
2420 d0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, d0);
2422 ResScalar dd0 = predux(d0);
2423 for (; j < cols; ++j) {
2424 dd0 += cj.pmul(lhs(i, j), rhs2(j));
2426 res[i * resIncr] += alpha * dd0;
2430#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar) \
2431 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2432 struct general_matrix_vector_product<Index, Scalar, LhsMapper, ColMajor, ConjugateLhs, Scalar, RhsMapper, \
2433 ConjugateRhs, Version> { \
2434 typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2436 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2437 const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2438 ResScalar alpha) { \
2439 gemv_col<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2443#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar) \
2444 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2445 struct general_matrix_vector_product<Index, Scalar, LhsMapper, RowMajor, ConjugateLhs, Scalar, RhsMapper, \
2446 ConjugateRhs, Version> { \
2447 typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2449 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2450 const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2451 ResScalar alpha) { \
2452 gemv_row<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2456EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(
float)
2457EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(
double)
2458EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(
float)
2459EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(
double)
2462#define gemv_bf16_col gemvMMA_bfloat16_col
2463#define gemv_bf16_row gemvMMA_bfloat16_row
2465#define gemv_bf16_col gemv_bfloat16_col
2466#define gemv_bf16_row gemv_bfloat16_row
2469#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \
2470 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2471 struct general_matrix_vector_product<Index, bfloat16, LhsMapper, ColMajor, ConjugateLhs, bfloat16, RhsMapper, \
2472 ConjugateRhs, Version> { \
2473 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2474 const RhsMapper& rhs, bfloat16* res, Index resIncr, \
2476 gemv_bf16_col<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2480#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() \
2481 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2482 struct general_matrix_vector_product<Index, bfloat16, LhsMapper, RowMajor, ConjugateLhs, bfloat16, RhsMapper, \
2483 ConjugateRhs, Version> { \
2484 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2485 const RhsMapper& rhs, bfloat16* res, Index resIncr, \
2487 gemv_bf16_row<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2491EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16()
2492EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16()
2494template <typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
2495EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1,
2497 if (GEMV_IS_COMPLEX_COMPLEX) {
2501 return predux_complex<ResScalar, PResPacket>(a0, b0);
2504#define GEMV_LOADPACKET_ROW_COMPLEX(iter) loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + (iter), j)
2506#define GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter) convertReal(GEMV_LOADPACKET_ROW_COMPLEX(iter))
2508#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(which, N) \
2510 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2511 const RhsScalar& b1 = rhs2(j); \
2512 RhsScalar* b = const_cast<RhsScalar*>(&b1); \
2513 GEMV_UNROLL_ROW(which, N) \
2516#define GEMV_PROCESS_END_ROW_COMPLEX(N) \
2517 for (; j < cols; ++j) { \
2518 RhsScalar b0 = rhs2(j); \
2519 GEMV_UNROLL_ROW_HALF(GEMV_MULT_COMPLEX, (N >> 1)) \
2521 GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW_COMPLEX, (N >> 1))
2524#define GEMV_INIT_ROW_COMPLEX_MMA(iter, N) \
2525 if (GEMV_GETN_COMPLEX(N) > iter) { \
2526 __builtin_mma_xxsetaccz(&e0##iter); \
2529#define GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter1, iter2) \
2530 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter2), GEMV_LOADPACKET_ROW_COMPLEX_DATA((iter2) + 1));
2532#define GEMV_WORK_ROW_COMPLEX_MMA(iter, N) \
2533 if (GEMV_GETN_COMPLEX(N) > iter) { \
2534 if (GEMV_IS_COMPLEX_FLOAT) { \
2535 PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2536 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
2537 ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2539 __vector_pair a##iter; \
2540 GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter, iter << 1) \
2541 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
2542 ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2546#define GEMV_PREDUX4_COMPLEX_MMA(iter1, iter2, iter3, N) \
2548 if (GEMV_IS_COMPLEX_FLOAT) { \
2549 cc##iter1 = predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>( \
2550 &e0##iter2, &e0##iter3); \
2553 predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter1); \
2556 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2559#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2560 GEMV_UNROLL_ROW(GEMV_INIT_ROW_COMPLEX_MMA, N) \
2561 GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX_MMA, N)
2563#define GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N) \
2564 for (; i < n##N; i += N) { \
2565 GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2566 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_MMA, (N >> 1)) \
2567 GEMV_PROCESS_END_ROW_COMPLEX(N); \
2571#define GEMV_WORK_ROW_COMPLEX(iter, N) \
2573 PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2574 gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
2575 ConjugateRhs, RowMajor>(a##iter, b, c0##iter, c1##iter); \
2578#define GEMV_PREDUX4_COMPLEX(iter1, iter2, iter3, N) \
2580 cc##iter1 = predux_complex<ResScalar, PResPacket, ResPacket, LhsPacket, RhsPacket>(c0##iter2, c0##iter3, \
2581 c1##iter2, c1##iter3); \
2583 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2586#define GEMV_MULT_COMPLEX(iter1, iter2, iter3, N) \
2588 cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), b0); \
2589 cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), b0); \
2592#define GEMV_STORE_ROW_COMPLEX(iter1, iter2, iter3, N) \
2594 storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2595 storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2598#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2599 GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX, N) \
2600 GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX, N)
2604#define GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2605 for (; i < n##N; i += N) { \
2606 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2607 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX, (N >> 1)) \
2608 GEMV_PROCESS_END_ROW_COMPLEX(N); \
2611#define GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2612 if (GEMV_IS_COMPLEX_COMPLEX) { \
2613 c0##iter = padd(c0##iter, c1##iter); \
2615 dd0 = predux(c0##iter);
2618#define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N)
2620#define GEMV_PROCESS_ROW_COMPLEX_ONE(N) GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N)
2622#define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter)
2627#define GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter) lhs.template load<LhsPacket, LhsAlignment>(i + (iter), j)
2629#define GEMV_INIT_COMPLEX_OLD(iter, N) \
2630 EIGEN_UNUSED_VARIABLE(c0##iter); \
2632 c1##iter = pset_zero<ResPacket>(); \
2634 EIGEN_UNUSED_VARIABLE(c1##iter); \
2637#define GEMV_WORK_ROW_COMPLEX_OLD(iter, N) \
2639 LhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter); \
2640 c1##iter = pcj.pmadd(a##iter, b0, c1##iter); \
2643#define GEMV_PREDUX4_COMPLEX_OLD(iter1, iter2, iter3, N) \
2645 cc##iter1.scalar[0] = predux(c1##iter2); \
2646 cc##iter1.scalar[1] = predux(c1##iter3); \
2648 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2651#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2652 GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX_OLD, N) \
2654 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2655 RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2656 GEMV_UNROLL_ROW(GEMV_WORK_ROW_COMPLEX_OLD, N) \
2659#define GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2660 for (; i < n##N; i += N) { \
2661 GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2662 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_OLD, (N >> 1)) \
2663 GEMV_PROCESS_END_ROW_COMPLEX(N) \
2666#define GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) dd0 = predux(c1##iter);
2669#define GEMV_PROCESS_ROW_COMPLEX_IS_NEW 1
2671#define GEMV_PROCESS_ROW_COMPLEX_IS_NEW (sizeof(Scalar) == sizeof(float)) || GEMV_IS_COMPLEX_COMPLEX
2674#define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) \
2675 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2676 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2678 GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2681#define GEMV_PROCESS_ROW_COMPLEX_ONE(N) \
2682 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2683 GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2685 GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2688#define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) \
2689 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2690 GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2692 GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) \
2697#define GEMV_PROCESS_ROW_COMPLEX(N) GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N)
2699#define GEMV_PROCESS_ROW_COMPLEX(N) GEMV_PROCESS_ROW_COMPLEX_ONE(N)
2702template <
typename Scalar,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
bool LhsIsReal,
2703 typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
bool RhsIsReal,
typename ResScalar>
2704EIGEN_STRONG_INLINE
void gemv_complex_row(Index rows, Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs,
2705 ResScalar* res, Index resIncr, ResScalar alpha) {
2706 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2708 typedef typename Traits::LhsPacket LhsPacket;
2709 typedef typename Traits::RhsPacket RhsPacket;
2710 typedef typename Traits::ResPacket ResPacket;
2712 typedef typename packet_traits<Scalar>::type ScalarPacket;
2713 typedef typename packet_traits<LhsScalar>::type PLhsPacket;
2714 typedef typename packet_traits<ResScalar>::type PResPacket;
2715 typedef gemv_traits<ResPacket, ResPacket> PTraits;
2719 LhsMapper lhs(alhs);
2720 typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2722 eigen_internal_assert(rhs.stride() == 1);
2723 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2725 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
2730#ifndef GCC_ONE_VECTORPAIR_BUG
2731 const Index n8 = lhs.stride() *
sizeof(LhsScalar) > 32000 ? (rows - 7) : (rows - 7);
2732 const Index n4 = rows - 3;
2733 const Index n2 = rows - 1;
2739 ResPacketSize = PTraits::ResPacketSize,
2740 LhsPacketSize = PTraits::LhsPacketSize,
2741 RhsPacketSize = PTraits::RhsPacketSize,
2745 PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2746 ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2748 __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2749 GEMV_UNUSED_ROW(8, e0)
2750 GEMV_UNUSED_EXTRA(1, c0)
2751 GEMV_UNUSED_EXTRA(1, c1)
2754#ifndef GCC_ONE_VECTORPAIR_BUG
2755 ScalarBlock<ResScalar, 2> cc0, cc1, cc2, cc3;
2757 if (!GEMV_IS_COMPLEX_COMPLEX)
2760 GEMV_PROCESS_ROW_COMPLEX(8)
2762 GEMV_PROCESS_ROW_COMPLEX(4)
2763 GEMV_PROCESS_ROW_COMPLEX(2)
2765 for (; i < rows; ++i) {
2766 GEMV_PROCESS_ROW_COMPLEX_SINGLE(1)
2767 GEMV_PROCESS_ROW_COMPLEX_PREDUX(0)
2768 for (; j < cols; ++j) {
2769 dd0 += cj.pmul(lhs(i, j), rhs2(j));
2771 res[i * resIncr] += alpha * dd0;
2775#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar) \
2776 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2777 struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, \
2778 ConjugateRhs, Version> { \
2779 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2781 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2782 const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2783 ResScalar alpha) { \
2784 gemv_complex_col<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, \
2785 RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, \
2786 res, resIncr, alpha); \
2790#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar) \
2791 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2792 struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, \
2793 ConjugateRhs, Version> { \
2794 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2796 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2797 const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2798 ResScalar alpha) { \
2799 gemv_complex_row<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, \
2800 RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, \
2801 res, resIncr, alpha); \
2805EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(
float,
float, std::complex<float>)
2806EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(
float, std::complex<float>,
float)
2807EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(
float, std::complex<float>, std::complex<float>)
2808EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(
double,
double, std::complex<double>)
2809EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(
double, std::complex<double>,
double)
2810EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(
double, std::complex<double>, std::complex<double>)
2811EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(
float,
float, std::complex<float>)
2812EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(
float, std::complex<float>,
float)
2813EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(
float, std::complex<float>, std::complex<float>)
2814EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(
double,
double, std::complex<double>)
2815EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(
double, std::complex<double>,
double)
2816EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(
double, std::complex<double>, std::complex<double>)
@ Unaligned
Definition Constants.h:235
@ ColMajor
Definition Constants.h:318
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_imag_op< typename Derived::Scalar >, const Derived > imag(const Eigen::ArrayBase< Derived > &x)