Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
MatrixProductCommon.h
1// #define EIGEN_POWER_USE_PREFETCH // Use prefetching in gemm routines
2#ifdef EIGEN_POWER_USE_PREFETCH
3#define EIGEN_POWER_PREFETCH(p) prefetch(p)
4#else
5#define EIGEN_POWER_PREFETCH(p)
6#endif
7
8#if defined(_ARCH_PWR9) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
9#define USE_PARTIAL_PACKETS
10#endif
11
12// IWYU pragma: private
13#include "../../InternalHeaderCheck.h"
14
15namespace Eigen {
16
17namespace internal {
18
19template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
20EIGEN_ALWAYS_INLINE void gemm_extra_row(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
21 Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows,
22 Index remaining_rows, const Packet& pAlpha, const Packet& pMask);
23
24template <typename Scalar, typename Packet, typename DataMapper, const Index accCols>
25EIGEN_ALWAYS_INLINE void gemm_extra_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
26 Index strideA, Index offsetA, Index strideB, Index offsetB, Index col,
27 Index rows, Index cols, Index remaining_rows, const Packet& pAlpha,
28 const Packet& pMask);
29
30template <typename Packet>
31EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows);
32
33template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
34 const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
35EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
36 Index depth, Index strideA, Index offsetA, Index strideB, Index row,
37 Index rows, Index remaining_rows, const Packet& pAlphaReal,
38 const Packet& pAlphaImag, const Packet& pMask);
39
40template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols,
41 bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
42EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
43 Index depth, Index strideA, Index offsetA, Index strideB,
44 Index offsetB, Index col, Index rows, Index cols, Index remaining_rows,
45 const Packet& pAlphaReal, const Packet& pAlphaImag,
46 const Packet& pMask);
47
48template <typename DataMapper>
49EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float* result, Index cols, Index rows, const DataMapper& src);
50
51template <const Index size, bool non_unit_stride, Index delta>
52EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra = 0);
53
54template <bool non_unit_stride = false>
55EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float* result, Index cols, Index rows, bfloat16* src,
56 Index resInc = 1);
57
58template <bool rhsExtraCols, bool lhsExtraRows>
59EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result,
60 Index extra_cols, Index extra_rows);
61
62template <Index num_acc, bool extraRows, Index size = 4>
63EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], float* result, Packet4f pAlpha,
64 Index extra_rows);
65
66template <Index num_acc, Index size = 4>
67EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float* result, Packet4f pAlpha);
68
69template <typename RhsMapper, bool linear>
70EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j);
71
72template <typename Packet>
73EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet) * lhs);
74
75template <typename DataMapper, typename Packet, const Index accCols, int StorageOrder, bool Complex, int N,
76 bool full = true>
77EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
78 Index col);
79
80template <typename DataMapper, typename Packet, int N>
81EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet, N>& acc, const DataMapper& res, Index row);
82
83#ifdef USE_PARTIAL_PACKETS
84template <typename DataMapper, typename Packet, const Index accCols, bool Complex, Index N, bool full = true>
85EIGEN_ALWAYS_INLINE void bload_partial(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
86 Index elements);
87
88template <typename DataMapper, typename Packet, Index N>
89EIGEN_ALWAYS_INLINE void bstore_partial(PacketBlock<Packet, N>& acc, const DataMapper& res, Index row, Index elements);
90#endif
91
92template <typename Packet, int N>
93EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ, const Packet& pAlpha);
94
95template <typename Packet, int N, bool mask>
96EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ, const Packet& pAlpha,
97 const Packet& pMask);
98
99template <typename Packet, int N, bool mask>
100EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet, N>& aReal, PacketBlock<Packet, N>& aImag, const Packet& bReal,
101 const Packet& bImag, PacketBlock<Packet, N>& cReal, PacketBlock<Packet, N>& cImag,
102 const Packet& pMask);
103
104template <typename Packet, typename Packetc, int N, bool full>
105EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet, N>& taccReal, PacketBlock<Packet, N>& taccImag,
106 PacketBlock<Packetc, N * 2>& tRes, PacketBlock<Packetc, N>& acc1,
107 PacketBlock<Packetc, N>& acc2);
108
109#define MICRO_NORMAL(iter) (accCols == accCols2) || (unroll_factor != (iter + 1))
110
111#define MICRO_UNROLL_ITER1(func, N) \
112 switch (remaining_rows) { \
113 default: \
114 func(N, 0) break; \
115 case 1: \
116 func(N, 1) break; \
117 case 2: \
118 if (sizeof(Scalar) == sizeof(float)) { \
119 func(N, 2) \
120 } \
121 break; \
122 case 3: \
123 if (sizeof(Scalar) == sizeof(float)) { \
124 func(N, 3) \
125 } \
126 break; \
127 }
128
129#ifdef USE_PARTIAL_PACKETS
130#define MICRO_UNROLL_ITER(func, N) \
131 if (remaining_rows) { \
132 func(N, true); \
133 } else { \
134 func(N, false); \
135 }
136
137#define MICRO_NORMAL_PARTIAL(iter) full || (unroll_factor != (iter + 1))
138#else
139#define MICRO_UNROLL_ITER(func, N) MICRO_UNROLL_ITER1(func, N)
140#endif
141
142#define MICRO_COMPLEX_UNROLL_ITER(func, N) MICRO_UNROLL_ITER1(func, N)
143
144#define MICRO_NORMAL_COLS(iter, a, b) ((MICRO_NORMAL(iter)) ? a : b)
145
146#define MICRO_LOAD1(lhs_ptr, iter) \
147 if (unroll_factor > iter) { \
148 lhsV##iter = ploadLhs<Packet>(lhs_ptr##iter); \
149 lhs_ptr##iter += MICRO_NORMAL_COLS(iter, accCols, accCols2); \
150 } else { \
151 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
152 }
153
154#define MICRO_LOAD_ONE(iter) MICRO_LOAD1(lhs_ptr, iter)
155
156#define MICRO_COMPLEX_LOAD_ONE(iter) \
157 if (!LhsIsReal && (unroll_factor > iter)) { \
158 lhsVi##iter = ploadLhs<Packet>(lhs_ptr_real##iter + MICRO_NORMAL_COLS(iter, imag_delta, imag_delta2)); \
159 } else { \
160 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
161 } \
162 MICRO_LOAD1(lhs_ptr_real, iter)
163
164#define MICRO_SRC_PTR1(lhs_ptr, advRows, iter) \
165 if (unroll_factor > iter) { \
166 lhs_ptr##iter = lhs_base + (row + (iter * accCols)) * strideA * advRows - \
167 MICRO_NORMAL_COLS(iter, 0, (accCols - accCols2) * offsetA); \
168 } else { \
169 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
170 }
171
172#define MICRO_SRC_PTR_ONE(iter) MICRO_SRC_PTR1(lhs_ptr, 1, iter)
173
174#define MICRO_COMPLEX_SRC_PTR_ONE(iter) MICRO_SRC_PTR1(lhs_ptr_real, advanceRows, iter)
175
176#define MICRO_PREFETCH1(lhs_ptr, iter) \
177 if (unroll_factor > iter) { \
178 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
179 }
180
181#define MICRO_PREFETCH_ONE(iter) MICRO_PREFETCH1(lhs_ptr, iter)
182
183#define MICRO_COMPLEX_PREFETCH_ONE(iter) MICRO_PREFETCH1(lhs_ptr_real, iter)
184
185#ifdef USE_PARTIAL_PACKETS
186#define MICRO_UPDATE_MASK
187#else
188#define MICRO_UPDATE_MASK EIGEN_UNUSED_VARIABLE(pMask);
189#endif
190
191#define MICRO_UPDATE \
192 if (accCols == accCols2) { \
193 MICRO_UPDATE_MASK \
194 EIGEN_UNUSED_VARIABLE(offsetA); \
195 row += unroll_factor * accCols; \
196 }
197
198#define MICRO_COMPLEX_UPDATE \
199 MICRO_UPDATE \
200 if (LhsIsReal || (accCols == accCols2)) { \
201 EIGEN_UNUSED_VARIABLE(imag_delta2); \
202 }
203
204} // end namespace internal
205} // end namespace Eigen
Namespace containing all symbols from the Eigen library.
Definition Core:137