Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
DeviceWrapper.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2023 Charlie Schlosser <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_DEVICEWRAPPER_H
11#define EIGEN_DEVICEWRAPPER_H
12
13namespace Eigen {
14template <typename Derived, typename Device>
15struct DeviceWrapper {
16 using Base = EigenBase<internal::remove_all_t<Derived>>;
17 using Scalar = typename Derived::Scalar;
18
19 EIGEN_DEVICE_FUNC DeviceWrapper(Base& xpr, Device& device) : m_xpr(xpr.derived()), m_device(device) {}
20 EIGEN_DEVICE_FUNC DeviceWrapper(const Base& xpr, Device& device) : m_xpr(xpr.derived()), m_device(device) {}
21
22 template <typename OtherDerived>
23 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator=(const EigenBase<OtherDerived>& other) {
24 using AssignOp = internal::assign_op<Scalar, typename OtherDerived::Scalar>;
25 internal::call_assignment(*this, other.derived(), AssignOp());
26 return m_xpr;
27 }
28 template <typename OtherDerived>
29 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator+=(const EigenBase<OtherDerived>& other) {
30 using AddAssignOp = internal::add_assign_op<Scalar, typename OtherDerived::Scalar>;
31 internal::call_assignment(*this, other.derived(), AddAssignOp());
32 return m_xpr;
33 }
34 template <typename OtherDerived>
35 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator-=(const EigenBase<OtherDerived>& other) {
36 using SubAssignOp = internal::sub_assign_op<Scalar, typename OtherDerived::Scalar>;
37 internal::call_assignment(*this, other.derived(), SubAssignOp());
38 return m_xpr;
39 }
40
41 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& derived() { return m_xpr; }
42 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Device& device() { return m_device; }
43 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE NoAlias<DeviceWrapper, EigenBase> noalias() {
44 return NoAlias<DeviceWrapper, EigenBase>(*this);
45 }
46
47 Derived& m_xpr;
48 Device& m_device;
49};
50
51namespace internal {
52
53// this is where we differentiate between lazy assignment and specialized kernels (e.g. matrix products)
54template <typename DstXprType, typename SrcXprType, typename Functor, typename Device,
55 typename Kind = typename AssignmentKind<typename evaluator_traits<DstXprType>::Shape,
56 typename evaluator_traits<SrcXprType>::Shape>::Kind,
57 typename EnableIf = void>
58struct AssignmentWithDevice;
59
60// unless otherwise specified, use the default product implementation
61template <typename DstXprType, typename Lhs, typename Rhs, int Options, typename Functor, typename Device,
62 typename Weak>
63struct AssignmentWithDevice<DstXprType, Product<Lhs, Rhs, Options>, Functor, Device, Dense2Dense, Weak> {
64 using SrcXprType = Product<Lhs, Rhs, Options>;
65 using Base = Assignment<DstXprType, SrcXprType, Functor>;
66 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src, const Functor& func,
67 Device&) {
68 Base::run(dst, src, func);
69 };
70};
71
72// specialization for coeffcient-wise assignment
73template <typename DstXprType, typename SrcXprType, typename Functor, typename Device, typename Weak>
74struct AssignmentWithDevice<DstXprType, SrcXprType, Functor, Device, Dense2Dense, Weak> {
75 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src, const Functor& func,
76 Device& device) {
77#ifndef EIGEN_NO_DEBUG
78 internal::check_for_aliasing(dst, src);
79#endif
80
81 call_dense_assignment_loop(dst, src, func, device);
82 }
83};
84
85// this allows us to use the default evaulation scheme if it is not specialized for the device
86template <typename Kernel, typename Device, int Traversal = Kernel::AssignmentTraits::Traversal,
87 int Unrolling = Kernel::AssignmentTraits::Unrolling>
88struct dense_assignment_loop_with_device {
89 using Base = dense_assignment_loop<Kernel, Traversal, Unrolling>;
90 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR void run(Kernel& kernel, Device&) { Base::run(kernel); }
91};
92
93// entry point for a generic expression with device
94template <typename Dst, typename Src, typename Func, typename Device>
95EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR void call_assignment_no_alias(DeviceWrapper<Dst, Device> dst,
96 const Src& src, const Func& func) {
97 enum {
98 NeedToTranspose = ((int(Dst::RowsAtCompileTime) == 1 && int(Src::ColsAtCompileTime) == 1) ||
99 (int(Dst::ColsAtCompileTime) == 1 && int(Src::RowsAtCompileTime) == 1)) &&
100 int(Dst::SizeAtCompileTime) != 1
101 };
102
103 using ActualDstTypeCleaned = std::conditional_t<NeedToTranspose, Transpose<Dst>, Dst>;
104 using ActualDstType = std::conditional_t<NeedToTranspose, Transpose<Dst>, Dst&>;
105 ActualDstType actualDst(dst.derived());
106
107 // TODO check whether this is the right place to perform these checks:
108 EIGEN_STATIC_ASSERT_LVALUE(Dst)
109 EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(ActualDstTypeCleaned, Src)
110 EIGEN_CHECK_BINARY_COMPATIBILIY(Func, typename ActualDstTypeCleaned::Scalar, typename Src::Scalar);
111
112 // this provides a mechanism for specializing simple assignments, matrix products, etc
113 AssignmentWithDevice<ActualDstTypeCleaned, Src, Func, Device>::run(actualDst, src, func, dst.device());
114}
115
116// copy and pasted from AssignEvaluator except forward device to kernel
117template <typename DstXprType, typename SrcXprType, typename Functor, typename Device>
118EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR void call_dense_assignment_loop(DstXprType& dst,
119 const SrcXprType& src,
120 const Functor& func,
121 Device& device) {
122 using DstEvaluatorType = evaluator<DstXprType>;
123 using SrcEvaluatorType = evaluator<SrcXprType>;
124
125 SrcEvaluatorType srcEvaluator(src);
126
127 // NOTE To properly handle A = (A*A.transpose())/s with A rectangular,
128 // we need to resize the destination after the source evaluator has been created.
129 resize_if_allowed(dst, src, func);
130
131 DstEvaluatorType dstEvaluator(dst);
132
133 using Kernel = generic_dense_assignment_kernel<DstEvaluatorType, SrcEvaluatorType, Functor>;
134
135 Kernel kernel(dstEvaluator, srcEvaluator, func, dst.const_cast_derived());
136
137 dense_assignment_loop_with_device<Kernel, Device>::run(kernel, device);
138}
139
140} // namespace internal
141
142template <typename Derived>
143template <typename Device>
144EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceWrapper<Derived, Device> EigenBase<Derived>::device(Device& device) {
145 return DeviceWrapper<Derived, Device>(derived(), device);
146}
147
148template <typename Derived>
149template <typename Device>
150EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceWrapper<const Derived, Device> EigenBase<Derived>::device(
151 Device& device) const {
152 return DeviceWrapper<const Derived, Device>(derived(), device);
153}
154} // namespace Eigen
155#endif
Namespace containing all symbols from the Eigen library.
Definition Core:137