Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
MathFunctionsImpl.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Pedro Gonnet ([email protected])
5// Copyright (C) 2016 Gael Guennebaud <[email protected]>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_MATHFUNCTIONSIMPL_H
12#define EIGEN_MATHFUNCTIONSIMPL_H
13
14// IWYU pragma: private
15#include "./InternalHeaderCheck.h"
16
17namespace Eigen {
18
19namespace internal {
20
35template <typename Packet, int Steps>
36struct generic_reciprocal_newton_step {
37 static_assert(Steps > 0, "Steps must be at least 1.");
38 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_a_recip) {
39 using Scalar = typename unpacket_traits<Packet>::type;
40 const Packet two = pset1<Packet>(Scalar(2));
41 // Refine the approximation using one Newton-Raphson step:
42 // x_{i} = x_{i-1} * (2 - a * x_{i-1})
43 const Packet x = generic_reciprocal_newton_step<Packet, Steps - 1>::run(a, approx_a_recip);
44 const Packet tmp = pnmadd(a, x, two);
45 // If tmp is NaN, it means that a is either +/-0 or +/-Inf.
46 // In this case return the approximation directly.
47 const Packet is_not_nan = pcmp_eq(tmp, tmp);
48 return pselect(is_not_nan, pmul(x, tmp), x);
49 }
50};
51
52template <typename Packet>
53struct generic_reciprocal_newton_step<Packet, 0> {
54 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& /*unused*/, const Packet& approx_rsqrt) {
55 return approx_rsqrt;
56 }
57};
58
74template <typename Packet, int Steps>
75struct generic_rsqrt_newton_step {
76 static_assert(Steps > 0, "Steps must be at least 1.");
77 using Scalar = typename unpacket_traits<Packet>::type;
78 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_rsqrt) {
79 constexpr Scalar kMinusHalf = Scalar(-1) / Scalar(2);
80 const Packet cst_minus_half = pset1<Packet>(kMinusHalf);
81 const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
82
83 Packet inv_sqrt = approx_rsqrt;
84 for (int step = 0; step < Steps; ++step) {
85 // Refine the approximation using one Newton-Raphson step:
86 // h_n = (x * inv_sqrt) * inv_sqrt - 1 (so that h_n is nearly 0).
87 // inv_sqrt = inv_sqrt - 0.5 * inv_sqrt * h_n
88 Packet r2 = pmul(a, inv_sqrt);
89 Packet half_r = pmul(inv_sqrt, cst_minus_half);
90 Packet h_n = pmadd(r2, inv_sqrt, cst_minus_one);
91 inv_sqrt = pmadd(half_r, h_n, inv_sqrt);
92 }
93
94 // If x is NaN, then either:
95 // 1) the input is NaN
96 // 2) zero and infinity were multiplied
97 // In either of these cases, return approx_rsqrt
98 return pselect(pisnan(inv_sqrt), approx_rsqrt, inv_sqrt);
99 }
100};
101
102template <typename Packet>
103struct generic_rsqrt_newton_step<Packet, 0> {
104 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& /*unused*/, const Packet& approx_rsqrt) {
105 return approx_rsqrt;
106 }
107};
108
124template <typename Packet, int Steps = 1>
125struct generic_sqrt_newton_step {
126 static_assert(Steps > 0, "Steps must be at least 1.");
127
128 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_rsqrt) {
129 using Scalar = typename unpacket_traits<Packet>::type;
130 const Packet one_point_five = pset1<Packet>(Scalar(1.5));
131 const Packet minus_half = pset1<Packet>(Scalar(-0.5));
132 // If a is inf or zero, return a directly.
133 const Packet inf_mask = pcmp_eq(a, pset1<Packet>(NumTraits<Scalar>::infinity()));
134 const Packet return_a = por(pcmp_eq(a, pzero(a)), inf_mask);
135 // Do a single step of Newton's iteration for reciprocal square root:
136 // x_{n+1} = x_n * (1.5 + (-0.5 * x_n) * (a * x_n))).
137 // The Newton's step is computed this way to avoid over/under-flows.
138 Packet rsqrt = pmul(approx_rsqrt, pmadd(pmul(minus_half, approx_rsqrt), pmul(a, approx_rsqrt), one_point_five));
139 for (int step = 1; step < Steps; ++step) {
140 rsqrt = pmul(rsqrt, pmadd(pmul(minus_half, rsqrt), pmul(a, rsqrt), one_point_five));
141 }
142
143 // Return sqrt(x) = x * rsqrt(x) for non-zero finite positive arguments.
144 // Return a itself for 0 or +inf, NaN for negative arguments.
145 return pselect(return_a, a, pmul(a, rsqrt));
146 }
147};
148
149template <typename RealScalar>
150EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RealScalar positive_real_hypot(const RealScalar& x, const RealScalar& y) {
151 // IEEE IEC 6059 special cases.
152 if ((numext::isinf)(x) || (numext::isinf)(y)) return NumTraits<RealScalar>::infinity();
153 if ((numext::isnan)(x) || (numext::isnan)(y)) return NumTraits<RealScalar>::quiet_NaN();
154
155 EIGEN_USING_STD(sqrt);
156 RealScalar p, qp;
157 p = numext::maxi(x, y);
158 if (numext::is_exactly_zero(p)) return RealScalar(0);
159 qp = numext::mini(y, x) / p;
160 return p * sqrt(RealScalar(1) + qp * qp);
161}
162
163template <typename Scalar>
164struct hypot_impl {
165 typedef typename NumTraits<Scalar>::Real RealScalar;
166 static EIGEN_DEVICE_FUNC inline RealScalar run(const Scalar& x, const Scalar& y) {
167 EIGEN_USING_STD(abs);
168 return positive_real_hypot<RealScalar>(abs(x), abs(y));
169 }
170};
171
172// Generic complex sqrt implementation that correctly handles corner cases
173// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt
174template <typename T>
175EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
176 // Computes the principal sqrt of the input.
177 //
178 // For a complex square root of the number x + i*y. We want to find real
179 // numbers u and v such that
180 // (u + i*v)^2 = x + i*y <=>
181 // u^2 - v^2 + i*2*u*v = x + i*v.
182 // By equating the real and imaginary parts we get:
183 // u^2 - v^2 = x
184 // 2*u*v = y.
185 //
186 // For x >= 0, this has the numerically stable solution
187 // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
188 // v = y / (2 * u)
189 // and for x < 0,
190 // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
191 // u = y / (2 * v)
192 //
193 // Letting w = sqrt(0.5 * (|x| + |z|)),
194 // if x == 0: u = w, v = sign(y) * w
195 // if x > 0: u = w, v = y / (2 * w)
196 // if x < 0: u = |y| / (2 * w), v = sign(y) * w
197
198 const T x = numext::real(z);
199 const T y = numext::imag(z);
200 const T zero = T(0);
201 const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y)));
202
203 return (numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y)
204 : numext::is_exactly_zero(x) ? std::complex<T>(w, y < zero ? -w : w)
205 : x > zero ? std::complex<T>(w, y / (2 * w))
206 : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w);
207}
208
209// Generic complex rsqrt implementation.
210template <typename T>
211EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
212 // Computes the principal reciprocal sqrt of the input.
213 //
214 // For a complex reciprocal square root of the number z = x + i*y. We want to
215 // find real numbers u and v such that
216 // (u + i*v)^2 = 1 / (x + i*y) <=>
217 // u^2 - v^2 + i*2*u*v = x/|z|^2 - i*v/|z|^2.
218 // By equating the real and imaginary parts we get:
219 // u^2 - v^2 = x/|z|^2
220 // 2*u*v = y/|z|^2.
221 //
222 // For x >= 0, this has the numerically stable solution
223 // u = sqrt(0.5 * (x + |z|)) / |z|
224 // v = -y / (2 * u * |z|)
225 // and for x < 0,
226 // v = -sign(y) * sqrt(0.5 * (-x + |z|)) / |z|
227 // u = -y / (2 * v * |z|)
228 //
229 // Letting w = sqrt(0.5 * (|x| + |z|)),
230 // if x == 0: u = w / |z|, v = -sign(y) * w / |z|
231 // if x > 0: u = w / |z|, v = -y / (2 * w * |z|)
232 // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z|
233
234 const T x = numext::real(z);
235 const T y = numext::imag(z);
236 const T zero = T(0);
237
238 const T abs_z = numext::hypot(x, y);
239 const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z));
240 const T woz = w / abs_z;
241 // Corner cases consistent with 1/sqrt(z) on gcc/clang.
242 return numext::is_exactly_zero(abs_z) ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
243 : ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero)
244 : numext::is_exactly_zero(x) ? std::complex<T>(woz, y < zero ? woz : -woz)
245 : x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z))
246 : std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz);
247}
248
249template <typename T>
250EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) {
251 // Computes complex log.
252 T a = numext::abs(z);
253 EIGEN_USING_STD(atan2);
254 T b = atan2(z.imag(), z.real());
255 return std::complex<T>(numext::log(a), b);
256}
257
258} // end namespace internal
259
260} // end namespace Eigen
261
262#endif // EIGEN_MATHFUNCTIONSIMPL_H
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sqrt_op< typename Derived::Scalar >, const Derived > sqrt(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_rsqrt_op< typename Derived::Scalar >, const Derived > rsqrt(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_abs_op< typename Derived::Scalar >, const Derived > abs(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_isinf_op< typename Derived::Scalar >, const Derived > isinf(const Eigen::ArrayBase< Derived > &x)