Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
MatrixProductMMAbfloat16.h
1#ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
2#define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
3
4#if EIGEN_COMP_LLVM
5#define BFLOAT16_UNROLL _Pragma("unroll 8")
6#else
7#define BFLOAT16_UNROLL _Pragma("GCC unroll(8)")
8#endif
9
10namespace Eigen {
11
12namespace internal {
13
14template <bool zero>
15EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16* indexA) {
16 Packet8bf lhs1 = ploadu<Packet8bf>(indexA);
17 if (zero) {
18 Packet8bf lhs2 = pset1<Packet8bf>(Eigen::bfloat16(0));
19 return vec_mergeh(lhs1.m_val, lhs2.m_val);
20 } else {
21 return lhs1;
22 }
23}
24
25template <bool zero>
26EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* blockB, Index strideB, Index i) {
27 return loadBfloat16<zero>(blockB + strideB * i);
28}
29
30template <Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs,
31 Index num_lhs>
32EIGEN_ALWAYS_INLINE void KLoop(const bfloat16* indexA, const bfloat16* indexB, __vector_quad (&quad_acc)[num_acc],
33 Index strideB, Index k, Index offsetB, Index extra_cols, Index extra_rows) {
34 Packet8bf lhs[num_lhs], rhs[num_rhs];
35
36 BFLOAT16_UNROLL
37 for (Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++) {
38 rhs[i] = loadRhsBfloat16<zero>(indexB + k * 4, strideB, i);
39 }
40 if (rhsExtraCols) {
41 rhs[num_rhs - 1] = loadRhsBfloat16<zero>(indexB + k * extra_cols - offsetB, strideB, num_rhs - 1);
42 }
43
44 indexA += k * (lhsExtraRows ? extra_rows : num_packets);
45 if (num_lhs == 1) {
46 lhs[0] = loadBfloat16<zero>(indexA);
47 } else {
48 BFLOAT16_UNROLL
49 for (Index j = 0; j < num_lhs; j += 2) {
50 Packet8bf lhs1 = ploadu<Packet8bf>(indexA + (j + 0) * (zero ? 4 : 8));
51 if (zero) {
52 Packet8bf lhs2 = pset1<Packet8bf>(Eigen::bfloat16(0));
53 lhs[j + 0] = vec_mergeh(lhs1.m_val, lhs2.m_val);
54 lhs[j + 1] = vec_mergel(lhs1.m_val, lhs2.m_val);
55 } else {
56 lhs[j + 0] = lhs1;
57 lhs[j + 1] = ploadu<Packet8bf>(indexA + (j + 1) * 8);
58 }
59 }
60 }
61
62 BFLOAT16_UNROLL
63 for (Index i = 0, x = 0; i < num_rhs; i++) {
64 BFLOAT16_UNROLL
65 for (Index j = 0; j < num_lhs; j++, x++) {
66 __builtin_mma_xvbf16ger2pp(&(quad_acc[x]), reinterpret_cast<Packet16uc>(rhs[i].m_val),
67 reinterpret_cast<Packet16uc>(lhs[j].m_val));
68 }
69 }
70}
71
72template <Index num_acc>
73EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc]) {
74 BFLOAT16_UNROLL
75 for (Index k = 0; k < num_acc; k++) __builtin_mma_xxsetaccz(&(quad_acc[k]));
76}
77
78template <Index num_acc>
79EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_acc], Packet4f (&acc)[num_acc][4]) {
80 BFLOAT16_UNROLL
81 for (Index k = 0; k < num_acc; k++) __builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
82}
83
84template <Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
85EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result,
86 const Index extra_cols, Index extra_rows) {
87 BFLOAT16_UNROLL
88 for (Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4 * rows) {
89 BFLOAT16_UNROLL
90 for (Index j = 0; j < num_lhs; j++, k++) {
91 storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + j * 4, extra_cols, extra_rows);
92 }
93 }
94 if (rhsExtraCols) {
95 storeResults<rhsExtraCols, lhsExtraRows>(acc[num_acc - 1], rows, pAlpha, result, extra_cols, extra_rows);
96 }
97}
98
99template <const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows, bool multiIter = false>
100EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
101 const bfloat16* indexB, Index strideB, Index offsetB, float* result,
102 const Index extra_cols, const Index extra_rows) {
103 constexpr Index num_lhs = multiIter ? (num_packets / 4) : 1;
104 constexpr Index num_rhs = (num_acc + num_lhs - 1) / num_lhs;
105
106 for (Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8),
107 indexB += (multiIter ? (num_rhs * strideB) : 0), result += (multiIter ? (4 * rows * num_rhs) : 4)) {
108 Packet4f acc[num_acc][4];
109 __vector_quad quad_acc[num_acc];
110
111 zeroAccumulators<num_acc>(quad_acc);
112
113 Index k;
114 for (k = 0; k + 2 <= depth; k += 2) {
115 KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(
116 indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
117 }
118 if (depth & 1) {
119 KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(
120 indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
121 }
122
123 disassembleAccumulators<num_acc>(quad_acc, acc);
124
125 outputResults<num_acc, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(acc, rows, pAlpha, result, extra_cols,
126 extra_rows);
127 }
128}
129
130#define MAX_BFLOAT16_ACC 8
131
132template <const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
133void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
134 const bfloat16* indexB, Index strideB, Index offsetB, float* result) {
135 constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
136 const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
137 const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
138 constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC);
139 constexpr bool normIters = multiIters && ((num_acc % (num_packets / 4)) == 0);
140
141 do {
142 colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, normIters>(
143 depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
144
145 indexB += strideB * num_acc;
146 result += rows * step;
147 } while (multiIters && (step <= cols - (col += step)));
148}
149
150template <const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
151EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha,
152 const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB,
153 float* result) {
154 if (MAX_BFLOAT16_ACC > num_acc) {
155 colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(
156 col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
157 }
158}
159
160template <const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
161void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
162 const bfloat16* blockB, Index strideB, Index offsetB, float* result) {
163 switch ((cols - col) >> 2) {
164 case 7:
165 colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
166 strideB, offsetB, result);
167 break;
168 case 6:
169 colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
170 strideB, offsetB, result);
171 break;
172 case 5:
173 colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
174 strideB, offsetB, result);
175 break;
176 case 4:
177 colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
178 strideB, offsetB, result);
179 break;
180 case 3:
181 colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
182 strideB, offsetB, result);
183 break;
184 case 2:
185 colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
186 strideB, offsetB, result);
187 break;
188 case 1:
189 colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
190 strideB, offsetB, result);
191 break;
192 default:
193 if (rhsExtraCols) {
194 colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
195 offsetB, result);
196 }
197 break;
198 }
199}
200
201template <const Index num_packets, bool lhsExtraRows = false>
202EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
203 const bfloat16* blockB, Index strideB, Index offsetB, float* result) {
204 Index col = 0;
205 if (cols >= (MAX_BFLOAT16_ACC * 4)) {
206 colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
207 strideB, 0, result);
208 blockB += (strideB >> 2) * col;
209 result += rows * col;
210 }
211 if (cols & 3) {
212 colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB,
213 result);
214 } else {
215 colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0,
216 result);
217 }
218}
219
220EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float* res) {
221 Packet16uc fp16[2];
222 __vector_pair fp16_vp = *reinterpret_cast<__vector_pair*>(const_cast<float*>(res));
223 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp);
224 fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
225 fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
226 return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1]));
227}
228
229template <typename DataMapper, const Index size>
230EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float* result, Index col, Index rows, const DataMapper& res) {
231 const DataMapper res2 = res.getSubMapper(0, col);
232 Index row;
233 float* result2 = result + col * rows;
234 for (row = 0; row + 8 <= rows; row += 8, result2 += 8) {
235 // get and save block
236 PacketBlock<Packet8bf, size> block;
237 BFLOAT16_UNROLL
238 for (Index j = 0; j < size; j++) {
239 block.packet[j] = convertF32toBF16(result2 + j * rows);
240 }
241 res2.template storePacketBlock<Packet8bf, size>(row, 0, block);
242 }
243 // extra rows
244 if (row < rows) {
245 BFLOAT16_UNROLL
246 for (Index j = 0; j < size; j++) {
247 Packet8bf fp16 = convertF32toBF16(result2 + j * rows);
248 res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
249 }
250 }
251}
252
253template <const Index size, bool non_unit_stride = false>
254EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst,
255 Index resInc = 1) {
256 constexpr Index extra = ((size < 8) ? 8 : size);
257 while (i + size <= rows) {
258 PacketBlock<Packet8bf, (size + 7) / 8> r32;
259 r32.packet[0] = convertF32toBF16(result + i + 0);
260 if (size >= 16) {
261 r32.packet[1] = convertF32toBF16(result + i + 8);
262 }
263 if (size >= 32) {
264 r32.packet[2] = convertF32toBF16(result + i + 16);
265 r32.packet[3] = convertF32toBF16(result + i + 24);
266 }
267 storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc, rows & 7);
268 if (size >= 16) {
269 storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
270 }
271 if (size >= 32) {
272 storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
273 storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
274 }
275 i += extra;
276 dst += extra * resInc;
277 if (size != 32) break;
278 }
279}
280
281template <bool non_unit_stride = false>
282EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float* result, Index rows, bfloat16* dst, Index resInc = 1) {
283 Index i = 0;
284 convertPointerF32toBF16<32, non_unit_stride>(i, result, rows, dst, resInc);
285 convertPointerF32toBF16<16, non_unit_stride>(i, result, rows, dst, resInc);
286 convertPointerF32toBF16<8, non_unit_stride>(i, result, rows, dst, resInc);
287 convertPointerF32toBF16<1, non_unit_stride>(i, result, rows, dst, resInc);
288}
289
290template <typename DataMapper>
291EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float* result, Index cols, Index rows, const DataMapper& res) {
292 Index col;
293 for (col = 0; col + 4 <= cols; col += 4) {
294 convertArrayF32toBF16Col<DataMapper, 4>(result, col, rows, res);
295 }
296 // extra cols
297 switch (cols - col) {
298 case 1:
299 convertArrayF32toBF16Col<DataMapper, 1>(result, col, rows, res);
300 break;
301 case 2:
302 convertArrayF32toBF16Col<DataMapper, 2>(result, col, rows, res);
303 break;
304 case 3:
305 convertArrayF32toBF16Col<DataMapper, 3>(result, col, rows, res);
306 break;
307 }
308}
309
310template <Index size>
311EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16*& indexA, Index& row, Index depth, Index cols, Index rows,
312 const Packet4f pAlpha, const bfloat16* indexB, Index strideB, Index offsetA,
313 Index offsetB, Index bigSuffix, float* result) {
314 if ((size == 16) || (rows & size)) {
315 indexA += size * offsetA;
316 colLoops<size>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
317 row += size;
318 indexA += bigSuffix * size / 16;
319 }
320}
321
322template <typename DataMapper>
323void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth,
324 Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
325 float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
326 const Packet4f pAlpha = pset1<Packet4f>(falpha);
327 ei_declare_aligned_stack_constructed_variable(float, result, cols* rows, 0);
328
329 convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
330
331 if (strideA == -1) strideA = depth;
332 if (strideB == -1) strideB = depth;
333 // Packing is done in blocks.
334 // There's 4 possible sizes of blocks
335 // Blocks of 8 columns with 16 elements (8x16)
336 // Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
337 // Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
338 // Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
339
340 // Loop for LHS standard block (8x16)
341 Index bigSuffix = (2 * 8) * (strideA - offsetA);
342 indexB += 4 * offsetB;
343 strideB *= 4;
344 offsetB *= 3;
345
346 Index row = 0;
347 while (row + 16 <= rows) {
348 calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
349 }
350 // LHS (8x8) block
351 calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
352 // LHS (8x4) block
353 calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
354 // extra rows
355 if (rows & 3) {
356 // This index is the beginning of remaining block.
357 colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
358 }
359
360 // Convert back to bfloat16
361 convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
362}
363
364#undef MAX_BFLOAT16_ACC
365
366#if !EIGEN_ALTIVEC_DISABLE_MMA
367template <Index num_acc, typename LhsMapper, bool zero>
368EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper& lhs, Packet8bf (&a0)[num_acc], Packet8bf b1) {
369 a0[k + 0] = lhs.template loadPacket<Packet8bf>(k * 4, 0);
370 if (!zero) {
371 b1 = lhs.template loadPacket<Packet8bf>(k * 4, 1);
372 }
373 if (num_acc > (k + 1)) {
374 a0[k + 1] = vec_mergel(a0[k + 0].m_val, b1.m_val);
375 }
376 a0[k + 0] = vec_mergeh(a0[k + 0].m_val, b1.m_val);
377}
378
379template <Index num_acc>
380EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf (&a0)[num_acc], Packet8bf b0) {
381 BFLOAT16_UNROLL
382 for (Index k = 0; k < num_acc; k++) {
383 __builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast<Packet16uc>(b0.m_val),
384 reinterpret_cast<Packet16uc>(a0[k].m_val));
385 }
386}
387
388template <Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>
389EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc]) {
390 Packet8bf a0[num_acc];
391 Packet8bf b1 = pset1<Packet8bf>(Eigen::bfloat16(0));
392 Packet8bf b0 = loadColData<RhsMapper, linear>(rhs, j);
393
394 if (zero) {
395 b0 = vec_mergeh(b0.m_val, b1.m_val);
396 }
397
398 using LhsSubMapper = typename LhsMapper::SubMapper;
399
400 LhsSubMapper lhs2 = lhs.getSubMapper(0, j);
401 BFLOAT16_UNROLL
402 for (Index k = 0; k < num_acc; k += 2) {
403 loadVecLoop<num_acc, LhsSubMapper, zero>(k, lhs2, a0, b1);
404 }
405
406 multVec<num_acc>(quad_acc, a0, b0);
407}
408
409#define MAX_BFLOAT16_VEC_ACC 8
410
411template <const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
412void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
413 float* result) {
414 constexpr Index step = (num_acc * 4);
415 const Index extra_rows = (extraRows) ? (rows & 3) : 0;
416 constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC);
417
418 do {
419 Packet4f acc[num_acc][4];
420 __vector_quad quad_acc[num_acc];
421
422 zeroAccumulators<num_acc>(quad_acc);
423
424 using LhsSubMapper = typename LhsMapper::SubMapper;
425
426 LhsSubMapper lhs2 = lhs.getSubMapper(row, 0);
427 for (Index j = 0; j + 2 <= cend; j += 2) {
428 vecColLoop<num_acc, LhsSubMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
429 }
430 if (cend & 1) {
431 vecColLoop<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
432 }
433
434 disassembleAccumulators<num_acc>(quad_acc, acc);
435
436 outputVecColResults<num_acc, extraRows>(acc, result, pAlpha, extra_rows);
437
438 result += step;
439 } while (multiIters && (step <= rows - (row += step)));
440}
441
442template <const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
443EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
444 const Packet4f pAlpha, float* result) {
445 if (MAX_BFLOAT16_VEC_ACC > num_acc) {
446 colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs,
447 pAlpha, result);
448 }
449}
450
451template <typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
452EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
453 const Packet4f pAlpha, float* result) {
454 switch ((rows - row) >> 2) {
455 case 7:
456 colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
457 break;
458 case 6:
459 colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
460 break;
461 case 5:
462 colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
463 break;
464 case 4:
465 colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
466 break;
467 case 3:
468 colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
469 break;
470 case 2:
471 colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
472 break;
473 case 1:
474 colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
475 break;
476 default:
477 if (extraRows) {
478 colVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
479 }
480 break;
481 }
482}
483
484template <typename LhsMapper, typename RhsMapper, bool linear>
485EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
486 float* result) {
487 Index row = 0;
488 if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) {
489 colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha,
490 result);
491 result += row;
492 }
493 if (rows & 3) {
494 colVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
495 } else {
496 colVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
497 }
498}
499
500template <typename RhsMapper, typename LhsMapper, typename = void>
501struct UseMMAStride : std::false_type {
502 static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha,
503 float* result) {
504 using RhsSubMapper = typename RhsMapper::SubMapper;
505
506 RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
507 calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
508 }
509};
510
511template <typename RhsMapper, typename LhsMapper>
512struct UseMMAStride<RhsMapper, LhsMapper,
513 std::enable_if_t<std::is_member_function_pointer<decltype(&RhsMapper::stride)>::value>>
514 : std::true_type {
515 static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha,
516 float* result) {
517 using RhsSubMapper = typename RhsMapper::SubMapper;
518
519 RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
520 if (rhs.stride() == 1) {
521 calcVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2, rows, lhs, rhs2, pAlpha, result);
522 } else {
523 calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
524 }
525 }
526};
527
528template <typename LhsMapper, typename RhsMapper>
529void gemvMMA_bfloat16_col(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, bfloat16* res,
530 Index resIncr, bfloat16 alpha) {
531 EIGEN_UNUSED_VARIABLE(resIncr);
532 eigen_internal_assert(resIncr == 1);
533
534 // The following copy tells the compiler that lhs's attributes are not modified outside this function
535 // This helps GCC to generate proper code.
536 LhsMapper lhs(alhs);
537 RhsMapper rhs2(rhs);
538
539 const Index lhsStride = lhs.stride();
540
541 // TODO: improve the following heuristic:
542 const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8);
543 float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
544 Packet4f pAlpha = pset1<Packet4f>(falpha);
545
546 ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
547
548 convertArrayPointerBF16toF32(result, 1, rows, res);
549
550 for (Index j2 = 0; j2 < cols; j2 += block_cols) {
551 Index jend = numext::mini(j2 + block_cols, cols);
552
553 using LhsSubMapper = typename LhsMapper::SubMapper;
554
555 LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
556 UseMMAStride<RhsMapper, LhsSubMapper>::run(j2, jend, rows, lhs2, rhs2, pAlpha, result);
557 }
558
559 convertArrayPointerF32toBF16(result, rows, res);
560}
561
562static Packet16uc p16uc_ELEMENT_VEC3 = {0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f,
563 0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f};
564
565template <Index num_acc>
566EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f (&acc)[num_acc][4], Index k) {
567 if (num_acc > (k + 1)) {
568 acc[k][0] = vec_mergeh(acc[k][0], acc[k + 1][0]);
569 acc[k][1] = vec_mergeo(acc[k][1], acc[k + 1][1]);
570 acc[k][2] = vec_mergel(acc[k][2], acc[k + 1][2]);
571 acc[k][3] = vec_perm(acc[k][3], acc[k + 1][3], p16uc_ELEMENT_VEC3);
572
573 acc[k][0] = (acc[k][0] + acc[k][2]) + (acc[k][1] + acc[k][3]);
574 } else {
575 acc[k][0] = vec_mergeh(acc[k][0], acc[k][1]);
576 acc[k][0] += vec_mergel(acc[k][2], acc[k][3]);
577#ifdef _BIG_ENDIAN
578 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
579#else
580 acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
581#endif
582 }
583}
584
585template <Index num_acc>
586EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f (&acc)[num_acc][4]) {
587 BFLOAT16_UNROLL
588 for (Index k = 0; k < num_acc; k += 4) {
589 preduxVecResults2<num_acc>(acc, k + 0);
590 if (num_acc > (k + 2)) {
591 preduxVecResults2<num_acc>(acc, k + 2);
592 acc[k + 0][0] = reinterpret_cast<Packet4f>(
593 vec_mergeh(reinterpret_cast<Packet2ul>(acc[k + 0][0]), reinterpret_cast<Packet2ul>(acc[k + 2][0])));
594 }
595 }
596}
597
598template <Index num_acc, typename LhsMapper, typename RhsMapper, bool extra>
599EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const LhsMapper& lhs, RhsMapper& rhs, Index j,
600 Index extra_cols) {
601 Packet8bf a0[num_acc], b0;
602
603 if (extra) {
604 b0 = rhs.template loadPacketPartial<Packet8bf>(j, extra_cols);
605 } else {
606 b0 = rhs.template loadPacket<Packet8bf>(j);
607 }
608
609 const LhsMapper lhs2 = lhs.getSubMapper(0, j);
610 BFLOAT16_UNROLL
611 for (Index k = 0; k < num_acc; k++) {
612 if (extra) {
613 a0[k] = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
614 } else {
615 a0[k] = lhs2.template loadPacket<Packet8bf>(k, 0);
616 }
617 }
618
619 multVec<num_acc>(quad_acc, a0, b0);
620}
621
622template <Index num_acc, typename LhsMapper, typename RhsMapper>
623EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc],
624 Index extra_cols) {
625 Index j = 0;
626 for (; j + 8 <= cols; j += 8) {
627 multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs, j, extra_cols);
628 }
629
630 if (extra_cols) {
631 multVecLoop<num_acc, LhsMapper, RhsMapper, true>(quad_acc, lhs, rhs, j, extra_cols);
632 }
633}
634
635template <const Index num_acc, typename LhsMapper, typename RhsMapper>
636void colVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
637 float* result) {
638 constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC);
639 const Index extra_cols = (cols & 7);
640
641 do {
642 Packet4f acc[num_acc][4];
643 __vector_quad quad_acc[num_acc];
644
645 zeroAccumulators<num_acc>(quad_acc);
646
647 const LhsMapper lhs2 = lhs.getSubMapper(row, 0);
648 vecLoop<num_acc, LhsMapper, RhsMapper>(cols, lhs2, rhs, quad_acc, extra_cols);
649
650 disassembleAccumulators<num_acc>(quad_acc, acc);
651
652 preduxVecResults<num_acc>(acc);
653
654 outputVecResults<num_acc>(acc, result, pAlpha);
655
656 result += num_acc;
657 } while (multiIters && (num_acc <= rows - (row += num_acc)));
658}
659
660template <const Index num_acc, typename LhsMapper, typename RhsMapper>
661EIGEN_ALWAYS_INLINE void colVecLoopBodyExtraN(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
662 const Packet4f pAlpha, float* result) {
663 if (MAX_BFLOAT16_VEC_ACC > num_acc) {
664 colVecLoopBody<num_acc, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
665 }
666}
667
668template <typename LhsMapper, typename RhsMapper>
669EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
670 const Packet4f pAlpha, float* result) {
671 switch (rows - row) {
672 case 7:
673 colVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
674 break;
675 case 6:
676 colVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
677 break;
678 case 5:
679 colVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
680 break;
681 case 4:
682 colVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
683 break;
684 case 3:
685 colVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
686 break;
687 case 2:
688 colVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
689 break;
690 case 1:
691 colVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
692 break;
693 }
694}
695
696template <typename LhsMapper, typename RhsMapper>
697EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
698 float* result) {
699 Index row = 0;
700 if (rows >= MAX_BFLOAT16_VEC_ACC) {
701 colVecLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
702 result += row;
703 }
704 colVecLoopBodyExtra<LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
705}
706
707template <typename LhsMapper, typename RhsMapper>
708EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
709 bfloat16* res, Index resIncr, bfloat16 alpha) {
710 typedef typename RhsMapper::LinearMapper LinearMapper;
711
712 // The following copy tells the compiler that lhs's attributes are not modified outside this function
713 // This helps GCC to generate proper code.
714 LhsMapper lhs(alhs);
715 LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
716
717 eigen_internal_assert(rhs.stride() == 1);
718
719 float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
720 const Packet4f pAlpha = pset1<Packet4f>(falpha);
721
722 ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
723 if (resIncr == 1) {
724 convertArrayPointerBF16toF32(result, 1, rows, res);
725 } else {
726 convertArrayPointerBF16toF32<true>(result, 1, rows, res, resIncr);
727 }
728 calcVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
729 if (resIncr == 1) {
730 convertArrayPointerF32toBF16(result, rows, res);
731 } else {
732 convertArrayPointerF32toBF16<true>(result, rows, res, resIncr);
733 }
734}
735#endif
736
737#undef MAX_BFLOAT16_VEC_ACC
738#undef BFLOAT16_UNROLL
739
740} // namespace internal
741} // namespace Eigen
742#endif // EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
Namespace containing all symbols from the Eigen library.
Definition Core:137