WPILibC++ 2023.4.3-108-ge5452e3
TriangularMatrixVector.h
Go to the documentation of this file.
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11#define EIGEN_TRIANGULARMATRIXVECTOR_H
12
13namespace Eigen {
14
15namespace internal {
16
17template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
19
20template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
21struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
22{
24 enum {
25 IsLower = ((Mode&Lower)==Lower),
26 HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
27 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
28 };
29 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha);
31};
32
33template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
35 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
36 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha)
37 {
38 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
39 Index size = (std::min)(_rows,_cols);
40 Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
41 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
42
44 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
45 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
46
48 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
49 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
50
51 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
52 ResMap res(_res,rows);
53
56
57 for (Index pi=0; pi<size; pi+=PanelWidth)
58 {
59 Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
60 for (Index k=0; k<actualPanelWidth; ++k)
61 {
62 Index i = pi + k;
63 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
64 Index r = IsLower ? actualPanelWidth-k : k+1;
65 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
66 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
67 if (HasUnitDiag)
68 res.coeffRef(i) += alpha * cjRhs.coeff(i);
69 }
70 Index r = IsLower ? rows - pi - actualPanelWidth : pi;
71 if (r>0)
72 {
73 Index s = IsLower ? pi+actualPanelWidth : 0;
75 r, actualPanelWidth,
76 LhsMapper(&lhs.coeffRef(s,pi), lhsStride),
77 RhsMapper(&rhs.coeffRef(pi), rhsIncr),
78 &res.coeffRef(s), resIncr, alpha);
79 }
80 }
81 if((!IsLower) && cols>size)
82 {
84 rows, cols-size,
85 LhsMapper(&lhs.coeffRef(0,size), lhsStride),
86 RhsMapper(&rhs.coeffRef(size), rhsIncr),
87 _res, resIncr, alpha);
88 }
89 }
90
91template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
92struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
93{
95 enum {
96 IsLower = ((Mode&Lower)==Lower),
97 HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
98 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
99 };
100 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
101 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
102};
103
104template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
106 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
107 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
108 {
109 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
110 Index diagSize = (std::min)(_rows,_cols);
111 Index rows = IsLower ? _rows : diagSize;
112 Index cols = IsLower ? diagSize : _cols;
113
115 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
116 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
117
119 const RhsMap rhs(_rhs,cols);
120 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
121
123 ResMap res(_res,rows,InnerStride<>(resIncr));
124
127
128 for (Index pi=0; pi<diagSize; pi+=PanelWidth)
129 {
130 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
131 for (Index k=0; k<actualPanelWidth; ++k)
132 {
133 Index i = pi + k;
134 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
135 Index r = IsLower ? k+1 : actualPanelWidth-k;
136 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
137 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
138 if (HasUnitDiag)
139 res.coeffRef(i) += alpha * cjRhs.coeff(i);
140 }
141 Index r = IsLower ? pi : cols - pi - actualPanelWidth;
142 if (r>0)
143 {
144 Index s = IsLower ? 0 : pi + actualPanelWidth;
146 actualPanelWidth, r,
147 LhsMapper(&lhs.coeffRef(pi,s), lhsStride),
148 RhsMapper(&rhs.coeffRef(s), rhsIncr),
149 &res.coeffRef(pi), resIncr, alpha);
150 }
151 }
152 if(IsLower && rows>diagSize)
153 {
155 rows-diagSize, cols,
156 LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride),
157 RhsMapper(&rhs.coeffRef(0), rhsIncr),
158 &res.coeffRef(diagSize), resIncr, alpha);
159 }
160 }
161
162/***************************************************************************
163* Wrapper to product_triangular_vector
164***************************************************************************/
165
166template<int Mode,int StorageOrder>
168
169} // end namespace internal
170
171namespace internal {
172
173template<int Mode, typename Lhs, typename Rhs>
174struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true>
175{
176 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
177 {
178 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
179
180 internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha);
181 }
182};
183
184template<int Mode, typename Lhs, typename Rhs>
185struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false>
186{
187 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
188 {
189 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
190
191 Transpose<Dest> dstT(dst);
192 internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
194 ::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
195 }
196};
197
198} // end namespace internal
199
200namespace internal {
201
202// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
203
204template<int Mode> struct trmv_selector<Mode,ColMajor>
205{
206 template<typename Lhs, typename Rhs, typename Dest>
207 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
208 {
209 typedef typename Lhs::Scalar LhsScalar;
210 typedef typename Rhs::Scalar RhsScalar;
211 typedef typename Dest::Scalar ResScalar;
212 typedef typename Dest::RealScalar RealScalar;
213
214 typedef internal::blas_traits<Lhs> LhsBlasTraits;
215 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
216 typedef internal::blas_traits<Rhs> RhsBlasTraits;
217 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
218
220
221 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
222 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
223
224 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
225 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
226 ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
227
228 enum {
229 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
230 // on, the other hand it is good for the cache to pack the vector anyways...
231 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
233 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
234 };
235
237
238 bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
239 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
240
241 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
242
243 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
244 evalToDest ? dest.data() : static_dest.data());
245
246 if(!evalToDest)
247 {
248 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
249 Index size = dest.size();
250 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
251 #endif
252 if(!alphaIsCompatible)
253 {
254 MappedDest(actualDestPtr, dest.size()).setZero();
255 compatibleAlpha = RhsScalar(1);
256 }
257 else
258 MappedDest(actualDestPtr, dest.size()) = dest;
259 }
260
262 <Index,Mode,
263 LhsScalar, LhsBlasTraits::NeedToConjugate,
264 RhsScalar, RhsBlasTraits::NeedToConjugate,
265 ColMajor>
266 ::run(actualLhs.rows(),actualLhs.cols(),
267 actualLhs.data(),actualLhs.outerStride(),
268 actualRhs.data(),actualRhs.innerStride(),
269 actualDestPtr,1,compatibleAlpha);
270
271 if (!evalToDest)
272 {
273 if(!alphaIsCompatible)
274 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
275 else
276 dest = MappedDest(actualDestPtr, dest.size());
277 }
278
279 if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
280 {
281 Index diagSize = (std::min)(lhs.rows(),lhs.cols());
282 dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
283 }
284 }
285};
286
287template<int Mode> struct trmv_selector<Mode,RowMajor>
288{
289 template<typename Lhs, typename Rhs, typename Dest>
290 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
291 {
292 typedef typename Lhs::Scalar LhsScalar;
293 typedef typename Rhs::Scalar RhsScalar;
294 typedef typename Dest::Scalar ResScalar;
295
296 typedef internal::blas_traits<Lhs> LhsBlasTraits;
297 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
298 typedef internal::blas_traits<Rhs> RhsBlasTraits;
299 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
300 typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
301
302 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
303 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
304
305 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
306 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
307 ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
308
309 enum {
310 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
311 };
312
314
315 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
316 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
317
318 if(!DirectlyUseRhs)
319 {
320 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
321 Index size = actualRhs.size();
322 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
323 #endif
324 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
325 }
326
328 <Index,Mode,
329 LhsScalar, LhsBlasTraits::NeedToConjugate,
330 RhsScalar, RhsBlasTraits::NeedToConjugate,
331 RowMajor>
332 ::run(actualLhs.rows(),actualLhs.cols(),
333 actualLhs.data(),actualLhs.outerStride(),
334 actualRhsPtr,1,
335 dest.data(),dest.innerStride(),
336 actualAlpha);
337
338 if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
339 {
340 Index diagSize = (std::min)(lhs.rows(),lhs.cols());
341 dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
342 }
343 }
344};
345
346} // end namespace internal
347
348} // end namespace Eigen
349
350#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
EIGEN_DEVICE_FUNC const ImagReturnType imag() const
Definition: CommonCwiseUnaryOps.h:109
#define EIGEN_PLAIN_ENUM_MIN(a, b)
Definition: Macros.h:1298
#define EIGEN_DONT_INLINE
Definition: Macros.h:950
#define eigen_assert(x)
Definition: Macros.h:1047
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:768
#define EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH
Defines the maximal width of the blocks used in the triangular product and solver for vectors (level ...
Definition: Settings.h:38
Convenience specialization of Stride to specify only an inner stride See class Map for some examples.
Definition: Stride.h:96
A matrix or vector expression mapping an existing array of data.
Definition: Map.h:96
Convenience specialization of Stride to specify only an outer stride See class Map for some examples.
Definition: Stride.h:107
Expression of the transpose of a matrix.
Definition: Transpose.h:54
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT
Definition: Transpose.h:69
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
Definition: Transpose.h:71
Definition: BlasUtil.h:389
@ UnitDiag
Matrix has ones on the diagonal; to be used in combination with Lower or Upper.
Definition: Constants.h:213
@ ZeroDiag
Matrix has zeros on the diagonal; to be used in combination with Lower or Upper.
Definition: Constants.h:215
@ Lower
View matrix as a lower triangular matrix.
Definition: Constants.h:209
@ Upper
View matrix as an upper triangular matrix.
Definition: Constants.h:211
@ AlignedMax
Definition: Constants.h:252
@ ColMajor
Storage order is column major (see TopicStorageOrders).
Definition: Constants.h:319
@ RowMajor
Storage order is row major (see TopicStorageOrders).
Definition: Constants.h:321
const unsigned int RowMajorBit
for a matrix, this means that the storage order is row-major.
Definition: Constants.h:66
constexpr common_t< T1, T2 > min(const T1 x, const T2 y) noexcept
Compile-time pairwise minimum function.
Definition: min.hpp:35
EIGEN_CONSTEXPR Index size(const T &x)
Definition: Meta.h:479
Namespace containing all symbols from the Eigen library.
Definition: Core:141
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
Definition: Eigen_Colamd.h:50
static constexpr const unit_t< PI > pi(1)
Ratio of a circle's circumference to its diameter.
Holds information about the various numeric (i.e.
Definition: NumTraits.h:233
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:806
const T type
Definition: Meta.h:214
Definition: BlasUtil.h:403
Definition: GeneralProduct.h:161
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To run(const From &x)
Definition: BlasUtil.h:43
Definition: GenericPacketMath.h:107
T type
Definition: Meta.h:126
Definition: ForwardDeclarations.h:17
static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar *_lhs, Index lhsStride, const RhsScalar *_rhs, Index rhsIncr, ResScalar *_res, Index resIncr, const ResScalar &alpha)
Definition: TriangularMatrixVector.h:106
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: TriangularMatrixVector.h:94
static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar *_lhs, Index lhsStride, const RhsScalar *_rhs, Index rhsIncr, ResScalar *_res, Index resIncr, const RhsScalar &alpha)
Definition: TriangularMatrixVector.h:35
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: TriangularMatrixVector.h:23
Definition: TriangularMatrixVector.h:18
static void run(Dest &dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar &alpha)
Definition: TriangularMatrixVector.h:187
static void run(Dest &dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar &alpha)
Definition: TriangularMatrixVector.h:176
Definition: ProductEvaluators.h:758
static void run(const Lhs &lhs, const Rhs &rhs, Dest &dest, const typename Dest::Scalar &alpha)
Definition: TriangularMatrixVector.h:207
static void run(const Lhs &lhs, const Rhs &rhs, Dest &dest, const typename Dest::Scalar &alpha)
Definition: TriangularMatrixVector.h:290
Definition: TriangularMatrixVector.h:167