Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
BlasUtil.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009-2010 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_BLASUTIL_H
11#define EIGEN_BLASUTIL_H
12
13// This file contains many lightweight helper classes used to
14// implement and control fast level 2 and level 3 BLAS-like routines.
15
16// IWYU pragma: private
17#include "../InternalHeaderCheck.h"
18
19namespace Eigen {
20
21namespace internal {
22
23// forward declarations
24template <typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr,
25 bool ConjugateLhs = false, bool ConjugateRhs = false>
26struct gebp_kernel;
27
28template <typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false,
29 bool PanelMode = false>
30struct gemm_pack_rhs;
31
32template <typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, int StorageOrder,
33 bool Conjugate = false, bool PanelMode = false>
34struct gemm_pack_lhs;
35
36template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar,
37 int RhsStorageOrder, bool ConjugateRhs, int ResStorageOrder, int ResInnerStride>
38struct general_matrix_matrix_product;
39
40template <typename Index, typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
41 typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version = Specialized>
42struct general_matrix_vector_product;
43
44template <typename From, typename To>
45struct get_factor {
46 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); }
47};
48
49template <typename Scalar>
50struct get_factor<Scalar, typename NumTraits<Scalar>::Real> {
51 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) {
52 return numext::real(x);
53 }
54};
55
56template <typename Scalar, typename Index>
57class BlasVectorMapper {
58 public:
59 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar* data) : m_data(data) {}
60
61 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { return m_data[i]; }
62 template <typename Packet, int AlignmentType>
63 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const {
64 return ploadt<Packet, AlignmentType>(m_data + i);
65 }
66
67 template <typename Packet>
68 EIGEN_DEVICE_FUNC bool aligned(Index i) const {
69 return (std::uintptr_t(m_data + i) % sizeof(Packet)) == 0;
70 }
71
72 protected:
73 Scalar* m_data;
74};
75
76template <typename Scalar, typename Index, int AlignmentType, int Incr = 1>
77class BlasLinearMapper;
78
79template <typename Scalar, typename Index, int AlignmentType>
80class BlasLinearMapper<Scalar, Index, AlignmentType> {
81 public:
82 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar* data, Index incr = 1) : m_data(data) {
83 EIGEN_ONLY_USED_FOR_DEBUG(incr);
84 eigen_assert(incr == 1);
85 }
86
87 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(Index i) const { internal::prefetch(&operator()(i)); }
88
89 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { return m_data[i]; }
90
91 template <typename PacketType>
92 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const {
93 return ploadt<PacketType, AlignmentType>(m_data + i);
94 }
95
96 template <typename PacketType>
97 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacketPartial(Index i, Index n, Index offset = 0) const {
98 return ploadt_partial<PacketType, AlignmentType>(m_data + i, n, offset);
99 }
100
101 template <typename PacketType, int AlignmentT>
102 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType load(Index i) const {
103 return ploadt<PacketType, AlignmentT>(m_data + i);
104 }
105
106 template <typename PacketType>
107 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType& p) const {
108 pstoret<Scalar, PacketType, AlignmentType>(m_data + i, p);
109 }
110
111 template <typename PacketType>
112 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketPartial(Index i, const PacketType& p, Index n,
113 Index offset = 0) const {
114 pstoret_partial<Scalar, PacketType, AlignmentType>(m_data + i, p, n, offset);
115 }
116
117 protected:
118 Scalar* m_data;
119};
120
121// Lightweight helper class to access matrix coefficients.
122template <typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned, int Incr = 1>
123class blas_data_mapper;
124
125// TMP to help PacketBlock store implementation.
126// There's currently no known use case for PacketBlock load.
127// The default implementation assumes ColMajor order.
128// It always store each packet sequentially one `stride` apart.
129template <typename Index, typename Scalar, typename Packet, int n, int idx, int StorageOrder>
130struct PacketBlockManagement {
131 PacketBlockManagement<Index, Scalar, Packet, n, idx - 1, StorageOrder> pbm;
132 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar* to, const Index stride, Index i, Index j,
133 const PacketBlock<Packet, n>& block) const {
134 pbm.store(to, stride, i, j, block);
135 pstoreu<Scalar>(to + i + (j + idx) * stride, block.packet[idx]);
136 }
137};
138
139// PacketBlockManagement specialization to take care of RowMajor order without ifs.
140template <typename Index, typename Scalar, typename Packet, int n, int idx>
141struct PacketBlockManagement<Index, Scalar, Packet, n, idx, RowMajor> {
142 PacketBlockManagement<Index, Scalar, Packet, n, idx - 1, RowMajor> pbm;
143 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar* to, const Index stride, Index i, Index j,
144 const PacketBlock<Packet, n>& block) const {
145 pbm.store(to, stride, i, j, block);
146 pstoreu<Scalar>(to + j + (i + idx) * stride, block.packet[idx]);
147 }
148};
149
150template <typename Index, typename Scalar, typename Packet, int n, int StorageOrder>
151struct PacketBlockManagement<Index, Scalar, Packet, n, -1, StorageOrder> {
152 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar* to, const Index stride, Index i, Index j,
153 const PacketBlock<Packet, n>& block) const {
154 EIGEN_UNUSED_VARIABLE(to);
155 EIGEN_UNUSED_VARIABLE(stride);
156 EIGEN_UNUSED_VARIABLE(i);
157 EIGEN_UNUSED_VARIABLE(j);
158 EIGEN_UNUSED_VARIABLE(block);
159 }
160};
161
162template <typename Index, typename Scalar, typename Packet, int n>
163struct PacketBlockManagement<Index, Scalar, Packet, n, -1, RowMajor> {
164 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(Scalar* to, const Index stride, Index i, Index j,
165 const PacketBlock<Packet, n>& block) const {
166 EIGEN_UNUSED_VARIABLE(to);
167 EIGEN_UNUSED_VARIABLE(stride);
168 EIGEN_UNUSED_VARIABLE(i);
169 EIGEN_UNUSED_VARIABLE(j);
170 EIGEN_UNUSED_VARIABLE(block);
171 }
172};
173
174template <typename Scalar, typename Index, int StorageOrder, int AlignmentType>
175class blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, 1> {
176 public:
177 typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
178 typedef blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType> SubMapper;
179 typedef BlasVectorMapper<Scalar, Index> VectorMapper;
180
181 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr = 1)
182 : m_data(data), m_stride(stride) {
183 EIGEN_ONLY_USED_FOR_DEBUG(incr);
184 eigen_assert(incr == 1);
185 }
186
187 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
188 return SubMapper(&operator()(i, j), m_stride);
189 }
190
191 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
192 return LinearMapper(&operator()(i, j));
193 }
194
195 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
196 return VectorMapper(&operator()(i, j));
197 }
198
199 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(Index i, Index j) const { internal::prefetch(&operator()(i, j)); }
200
201 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
202 return m_data[StorageOrder == RowMajor ? j + i * m_stride : i + j * m_stride];
203 }
204
205 template <typename PacketType>
206 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const {
207 return ploadt<PacketType, AlignmentType>(&operator()(i, j));
208 }
209
210 template <typename PacketType>
211 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacketPartial(Index i, Index j, Index n,
212 Index offset = 0) const {
213 return ploadt_partial<PacketType, AlignmentType>(&operator()(i, j), n, offset);
214 }
215
216 template <typename PacketT, int AlignmentT>
217 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
218 return ploadt<PacketT, AlignmentT>(&operator()(i, j));
219 }
220
221 template <typename PacketType>
222 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Index j, const PacketType& p) const {
223 pstoret<Scalar, PacketType, AlignmentType>(&operator()(i, j), p);
224 }
225
226 template <typename PacketType>
227 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketPartial(Index i, Index j, const PacketType& p, Index n,
228 Index offset = 0) const {
229 pstoret_partial<Scalar, PacketType, AlignmentType>(&operator()(i, j), p, n, offset);
230 }
231
232 template <typename SubPacket>
233 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket& p) const {
234 pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
235 }
236
237 template <typename SubPacket>
238 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
239 return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
240 }
241
242 EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
243 EIGEN_DEVICE_FUNC const Index incr() const { return 1; }
244 EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }
245
246 EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
247 if (std::uintptr_t(m_data) % sizeof(Scalar)) {
248 return -1;
249 }
250 return internal::first_default_aligned(m_data, size);
251 }
252
253 template <typename SubPacket, int n>
254 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j,
255 const PacketBlock<SubPacket, n>& block) const {
256 PacketBlockManagement<Index, Scalar, SubPacket, n, n - 1, StorageOrder> pbm;
257 pbm.store(m_data, m_stride, i, j, block);
258 }
259
260 protected:
261 Scalar* EIGEN_RESTRICT m_data;
262 const Index m_stride;
263};
264
265// Implementation of non-natural increment (i.e. inner-stride != 1)
266// The exposed API is not complete yet compared to the Incr==1 case
267// because some features makes less sense in this case.
268template <typename Scalar, typename Index, int AlignmentType, int Incr>
269class BlasLinearMapper {
270 public:
271 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar* data, Index incr) : m_data(data), m_incr(incr) {}
272
273 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const { internal::prefetch(&operator()(i)); }
274
275 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { return m_data[i * m_incr.value()]; }
276
277 template <typename PacketType>
278 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const {
279 return pgather<Scalar, PacketType>(m_data + i * m_incr.value(), m_incr.value());
280 }
281
282 template <typename PacketType>
283 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacketPartial(Index i, Index n, Index /*offset*/ = 0) const {
284 return pgather_partial<Scalar, PacketType>(m_data + i * m_incr.value(), m_incr.value(), n);
285 }
286
287 template <typename PacketType>
288 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType& p) const {
289 pscatter<Scalar, PacketType>(m_data + i * m_incr.value(), p, m_incr.value());
290 }
291
292 template <typename PacketType>
293 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketPartial(Index i, const PacketType& p, Index n,
294 Index /*offset*/ = 0) const {
295 pscatter_partial<Scalar, PacketType>(m_data + i * m_incr.value(), p, m_incr.value(), n);
296 }
297
298 protected:
299 Scalar* m_data;
300 const internal::variable_if_dynamic<Index, Incr> m_incr;
301};
302
303template <typename Scalar, typename Index, int StorageOrder, int AlignmentType, int Incr>
304class blas_data_mapper {
305 public:
306 typedef BlasLinearMapper<Scalar, Index, AlignmentType, Incr> LinearMapper;
307 typedef blas_data_mapper SubMapper;
308
309 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr)
310 : m_data(data), m_stride(stride), m_incr(incr) {}
311
312 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
313 return SubMapper(&operator()(i, j), m_stride, m_incr.value());
314 }
315
316 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
317 return LinearMapper(&operator()(i, j), m_incr.value());
318 }
319
320 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(Index i, Index j) const { internal::prefetch(&operator()(i, j)); }
321
322 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
323 return m_data[StorageOrder == RowMajor ? j * m_incr.value() + i * m_stride : i * m_incr.value() + j * m_stride];
324 }
325
326 template <typename PacketType>
327 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const {
328 return pgather<Scalar, PacketType>(&operator()(i, j), m_incr.value());
329 }
330
331 template <typename PacketType>
332 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacketPartial(Index i, Index j, Index n,
333 Index /*offset*/ = 0) const {
334 return pgather_partial<Scalar, PacketType>(&operator()(i, j), m_incr.value(), n);
335 }
336
337 template <typename PacketT, int AlignmentT>
338 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
339 return pgather<Scalar, PacketT>(&operator()(i, j), m_incr.value());
340 }
341
342 template <typename PacketType>
343 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Index j, const PacketType& p) const {
344 pscatter<Scalar, PacketType>(&operator()(i, j), p, m_incr.value());
345 }
346
347 template <typename PacketType>
348 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketPartial(Index i, Index j, const PacketType& p, Index n,
349 Index /*offset*/ = 0) const {
350 pscatter_partial<Scalar, PacketType>(&operator()(i, j), p, m_incr.value(), n);
351 }
352
353 template <typename SubPacket>
354 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket& p) const {
355 pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
356 }
357
358 template <typename SubPacket>
359 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
360 return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
361 }
362
363 // storePacketBlock_helper defines a way to access values inside the PacketBlock, this is essentially required by the
364 // Complex types.
365 template <typename SubPacket, typename Scalar_, int n, int idx>
366 struct storePacketBlock_helper {
367 storePacketBlock_helper<SubPacket, Scalar_, n, idx - 1> spbh;
368 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
369 const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j,
370 const PacketBlock<SubPacket, n>& block) const {
371 spbh.store(sup, i, j, block);
372 sup->template storePacket<SubPacket>(i, j + idx, block.packet[idx]);
373 }
374 };
375
376 template <typename SubPacket, int n, int idx>
377 struct storePacketBlock_helper<SubPacket, std::complex<float>, n, idx> {
378 storePacketBlock_helper<SubPacket, std::complex<float>, n, idx - 1> spbh;
379 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
380 const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j,
381 const PacketBlock<SubPacket, n>& block) const {
382 spbh.store(sup, i, j, block);
383 sup->template storePacket<SubPacket>(i, j + idx, block.packet[idx]);
384 }
385 };
386
387 template <typename SubPacket, int n, int idx>
388 struct storePacketBlock_helper<SubPacket, std::complex<double>, n, idx> {
389 storePacketBlock_helper<SubPacket, std::complex<double>, n, idx - 1> spbh;
390 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
391 const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j,
392 const PacketBlock<SubPacket, n>& block) const {
393 spbh.store(sup, i, j, block);
394 for (int l = 0; l < unpacket_traits<SubPacket>::size; l++) {
395 std::complex<double>* v = &sup->operator()(i + l, j + idx);
396 v->real(block.packet[idx].v[2 * l + 0]);
397 v->imag(block.packet[idx].v[2 * l + 1]);
398 }
399 }
400 };
401
402 template <typename SubPacket, typename Scalar_, int n>
403 struct storePacketBlock_helper<SubPacket, Scalar_, n, -1> {
404 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
405 const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index,
406 const PacketBlock<SubPacket, n>&) const {}
407 };
408
409 template <typename SubPacket, int n>
410 struct storePacketBlock_helper<SubPacket, std::complex<float>, n, -1> {
411 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
412 const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index,
413 const PacketBlock<SubPacket, n>&) const {}
414 };
415
416 template <typename SubPacket, int n>
417 struct storePacketBlock_helper<SubPacket, std::complex<double>, n, -1> {
418 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(
419 const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>*, Index, Index,
420 const PacketBlock<SubPacket, n>&) const {}
421 };
422 // This function stores a PacketBlock on m_data, this approach is really quite slow compare to Incr=1 and should be
423 // avoided when possible.
424 template <typename SubPacket, int n>
425 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacketBlock(Index i, Index j,
426 const PacketBlock<SubPacket, n>& block) const {
427 storePacketBlock_helper<SubPacket, Scalar, n, n - 1> spb;
428 spb.store(this, i, j, block);
429 }
430
431 EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
432 EIGEN_DEVICE_FUNC const Index incr() const { return m_incr.value(); }
433 EIGEN_DEVICE_FUNC Scalar* data() const { return m_data; }
434
435 protected:
436 Scalar* EIGEN_RESTRICT m_data;
437 const Index m_stride;
438 const internal::variable_if_dynamic<Index, Incr> m_incr;
439};
440
441// lightweight helper class to access matrix coefficients (const version)
442template <typename Scalar, typename Index, int StorageOrder>
443class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
444 public:
445 typedef const_blas_data_mapper<Scalar, Index, StorageOrder> SubMapper;
446
447 EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar* data, Index stride)
448 : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
449
450 EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
451 return SubMapper(&(this->operator()(i, j)), this->m_stride);
452 }
453};
454
455/* Helper class to analyze the factors of a Product expression.
456 * In particular it allows to pop out operator-, scalar multiples,
457 * and conjugate */
458template <typename XprType>
459struct blas_traits {
460 typedef typename traits<XprType>::Scalar Scalar;
461 typedef const XprType& ExtractType;
462 typedef XprType ExtractType_;
463 enum {
464 IsComplex = NumTraits<Scalar>::IsComplex,
465 IsTransposed = false,
466 NeedToConjugate = false,
467 HasUsableDirectAccess =
468 ((int(XprType::Flags) & DirectAccessBit) &&
469 (bool(XprType::IsVectorAtCompileTime) || int(inner_stride_at_compile_time<XprType>::ret) == 1))
470 ? 1
471 : 0,
472 HasScalarFactor = false
473 };
474 typedef std::conditional_t<bool(HasUsableDirectAccess), ExtractType, typename ExtractType_::PlainObject>
475 DirectLinearAccessType;
476 EIGEN_DEVICE_FUNC static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) { return x; }
477 EIGEN_DEVICE_FUNC static inline EIGEN_DEVICE_FUNC const Scalar extractScalarFactor(const XprType&) {
478 return Scalar(1);
479 }
480};
481
482// pop conjugate
483template <typename Scalar, typename NestedXpr>
484struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> > : blas_traits<NestedXpr> {
485 typedef blas_traits<NestedXpr> Base;
486 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
487 typedef typename Base::ExtractType ExtractType;
488
489 enum { IsComplex = NumTraits<Scalar>::IsComplex, NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex };
490 EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
491 EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) {
492 return conj(Base::extractScalarFactor(x.nestedExpression()));
493 }
494};
495
496// pop scalar multiple
497template <typename Scalar, typename NestedXpr, typename Plain>
498struct blas_traits<
499 CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain>, NestedXpr> >
500 : blas_traits<NestedXpr> {
501 enum { HasScalarFactor = true };
502 typedef blas_traits<NestedXpr> Base;
503 typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain>, NestedXpr>
504 XprType;
505 typedef typename Base::ExtractType ExtractType;
506 EIGEN_DEVICE_FUNC static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) {
507 return Base::extract(x.rhs());
508 }
509 EIGEN_DEVICE_FUNC static inline EIGEN_DEVICE_FUNC Scalar extractScalarFactor(const XprType& x) {
510 return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs());
511 }
512};
513template <typename Scalar, typename NestedXpr, typename Plain>
514struct blas_traits<
515 CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain> > >
516 : blas_traits<NestedXpr> {
517 enum { HasScalarFactor = true };
518 typedef blas_traits<NestedXpr> Base;
519 typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain> >
520 XprType;
521 typedef typename Base::ExtractType ExtractType;
522 EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); }
523 EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) {
524 return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other;
525 }
526};
527template <typename Scalar, typename Plain1, typename Plain2>
528struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain1>,
529 const CwiseNullaryOp<scalar_constant_op<Scalar>, Plain2> > >
530 : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>, Plain1> > {};
531
532// pop opposite
533template <typename Scalar, typename NestedXpr>
534struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> > : blas_traits<NestedXpr> {
535 enum { HasScalarFactor = true };
536 typedef blas_traits<NestedXpr> Base;
537 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
538 typedef typename Base::ExtractType ExtractType;
539 EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
540 EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) {
541 return -Base::extractScalarFactor(x.nestedExpression());
542 }
543};
544
545// pop/push transpose
546template <typename NestedXpr>
547struct blas_traits<Transpose<NestedXpr> > : blas_traits<NestedXpr> {
548 typedef typename NestedXpr::Scalar Scalar;
549 typedef blas_traits<NestedXpr> Base;
550 typedef Transpose<NestedXpr> XprType;
551 typedef Transpose<const typename Base::ExtractType_>
552 ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
553 typedef Transpose<const typename Base::ExtractType_> ExtractType_;
554 typedef std::conditional_t<bool(Base::HasUsableDirectAccess), ExtractType, typename ExtractType::PlainObject>
555 DirectLinearAccessType;
556 enum { IsTransposed = Base::IsTransposed ? 0 : 1 };
557 EIGEN_DEVICE_FUNC static inline ExtractType extract(const XprType& x) {
558 return ExtractType(Base::extract(x.nestedExpression()));
559 }
560 EIGEN_DEVICE_FUNC static inline Scalar extractScalarFactor(const XprType& x) {
561 return Base::extractScalarFactor(x.nestedExpression());
562 }
563};
564
565template <typename T>
566struct blas_traits<const T> : blas_traits<T> {};
567
568template <typename T, bool HasUsableDirectAccess = blas_traits<T>::HasUsableDirectAccess>
569struct extract_data_selector {
570 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename T::Scalar* run(const T& m) {
571 return blas_traits<T>::extract(m).data();
572 }
573};
574
575template <typename T>
576struct extract_data_selector<T, false> {
577 EIGEN_DEVICE_FUNC static typename T::Scalar* run(const T&) { return 0; }
578};
579
580template <typename T>
581EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename T::Scalar* extract_data(const T& m) {
582 return extract_data_selector<T>::run(m);
583}
584
589template <typename ResScalar, typename Lhs, typename Rhs>
590struct combine_scalar_factors_impl {
591 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs) {
592 return blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
593 }
594 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs) {
595 return alpha * blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
596 }
597};
598template <typename Lhs, typename Rhs>
599struct combine_scalar_factors_impl<bool, Lhs, Rhs> {
600 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs) {
601 return blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
602 }
603 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs) {
604 return alpha && blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
605 }
606};
607
608template <typename ResScalar, typename Lhs, typename Rhs>
609EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs,
610 const Rhs& rhs) {
611 return combine_scalar_factors_impl<ResScalar, Lhs, Rhs>::run(alpha, lhs, rhs);
612}
613template <typename ResScalar, typename Lhs, typename Rhs>
614EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs) {
615 return combine_scalar_factors_impl<ResScalar, Lhs, Rhs>::run(lhs, rhs);
616}
617
618} // end namespace internal
619
620} // end namespace Eigen
621
622#endif // EIGEN_BLASUTIL_H
AlignmentType
Definition Constants.h:234
@ RowMajor
Definition Constants.h:320
const unsigned int DirectAccessBit
Definition Constants.h:159
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83