10#ifndef EIGEN_GENERAL_MATRIX_MATRIX_H
11#define EIGEN_GENERAL_MATRIX_MATRIX_H
17template<
typename _LhsScalar,
typename _RhsScalar>
class level3_blocking;
22 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
23 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
32 const LhsScalar* lhs,
Index lhsStride,
33 const RhsScalar* rhs,
Index rhsStride,
44 ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking,info);
52 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
53 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
62 const LhsScalar* _lhs,
Index lhsStride,
63 const RhsScalar* _rhs,
Index rhsStride,
72 LhsMapper lhs(_lhs, lhsStride);
73 RhsMapper rhs(_rhs, rhsStride);
74 ResMapper res(_res, resStride, resIncr);
84#ifdef EIGEN_HAS_OPENMP
88 int tid = omp_get_thread_num();
89 int threads = omp_get_num_threads();
91 LhsScalar* blockA = blocking.
blockA();
94 std::size_t sizeB = kc*nc;
98 for(
Index k=0; k<depth; k+=kc)
104 pack_rhs(blockB, rhs.getSubMapper(k,0), actual_kc, nc);
112 while(info[tid].users!=0) {}
113 info[tid].users = threads;
115 pack_lhs(blockA+info[tid].lhs_start*actual_kc, lhs.getSubMapper(info[tid].lhs_start,k), actual_kc, info[tid].lhs_length);
121 for(
int shift=0; shift<threads; ++shift)
123 int i = (tid+shift)%threads;
129 while(info[i].sync!=k) {
133 gebp(res.getSubMapper(info[i].lhs_start, 0), blockA+info[i].lhs_start*actual_kc, blockB, info[i].lhs_length, actual_kc, nc, alpha);
137 for(
Index j=nc; j<cols; j+=nc)
142 pack_rhs(blockB, rhs.getSubMapper(k,j), actual_kc, actual_nc);
145 gebp(res.getSubMapper(0, j), blockA, blockB, rows, actual_kc, actual_nc, alpha);
150 for(
Index i=0; i<threads; ++i)
163 std::size_t sizeA = kc*mc;
164 std::size_t sizeB = kc*nc;
169 const bool pack_rhs_once = mc!=rows && kc==depth && nc==cols;
172 for(
Index i2=0; i2<rows; i2+=mc)
184 pack_lhs(blockA, lhs.getSubMapper(i2,
k2), actual_kc, actual_mc);
187 for(
Index j2=0; j2<cols; j2+=nc)
194 if((!pack_rhs_once) || i2==0)
195 pack_rhs(blockB, rhs.getSubMapper(
k2,j2), actual_kc, actual_nc);
198 gebp(res.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, alpha);
212template<
typename Scalar,
typename Index,
typename Gemm,
typename Lhs,
typename Rhs,
typename Dest,
typename BlockingType>
215 gemm_functor(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const Scalar& actualAlpha, BlockingType& blocking)
230 Gemm::run(rows, cols,
m_lhs.cols(),
247template<
int StorageOrder,
typename LhsScalar,
typename RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor=1,
250template<
typename _LhsScalar,
typename _RhsScalar>
253 typedef _LhsScalar LhsScalar;
254 typedef _RhsScalar RhsScalar;
278template<
int StorageOrder,
typename _LhsScalar,
typename _RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor>
281 typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
282 typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
286 ActualRows =
Transpose ? MaxCols : MaxRows,
287 ActualCols =
Transpose ? MaxRows : MaxCols
293 SizeA = ActualRows * MaxDepth,
294 SizeB = ActualCols * MaxDepth
297#if EIGEN_MAX_STATIC_ALIGN_BYTES >= EIGEN_DEFAULT_ALIGN_BYTES
309 this->m_mc = ActualRows;
310 this->m_nc = ActualCols;
311 this->m_kc = MaxDepth;
312#if EIGEN_MAX_STATIC_ALIGN_BYTES >= EIGEN_DEFAULT_ALIGN_BYTES
313 this->m_blockA = m_staticA;
314 this->m_blockB = m_staticB;
329template<
int StorageOrder,
typename _LhsScalar,
typename _RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor>
332 typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
333 typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
355 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, this->m_nc, num_threads);
359 Index n = this->m_nc;
360 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, n, num_threads);
363 m_sizeA = this->m_mc * this->m_kc;
364 m_sizeB = this->m_kc * this->m_nc;
374 Index m = this->m_mc;
375 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, m, this->m_nc, num_threads);
376 m_sizeA = this->m_mc * this->m_kc;
377 m_sizeB = this->m_kc * this->m_nc;
382 if(this->m_blockA==0)
383 this->m_blockA = aligned_new<LhsScalar>(m_sizeA);
388 if(this->m_blockB==0)
389 this->m_blockB = aligned_new<RhsScalar>(m_sizeB);
409template<
typename Lhs,
typename Rhs>
431 template<
typename Dst>
432 static void evalTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
445 scaleAndAddTo(dst, lhs, rhs,
Scalar(1));
449 template<
typename Dst>
450 static void addTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
455 scaleAndAddTo(dst,lhs, rhs,
Scalar(1));
458 template<
typename Dst>
459 static void subTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
464 scaleAndAddTo(dst, lhs, rhs,
Scalar(-1));
467 template<
typename Dest>
470 eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
471 if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0)
481 else if (dst.rows() == 1)
495 Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
504 Dest::InnerStrideAtCompileTime>,
507 BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1,
true);
509 (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(), Dest::Flags&
RowMajorBit);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ColXpr col(Index i)
This is the const version of col().
Definition: BlockMethods.h:1097
Block< Derived, 1, internal::traits< Derived >::ColsAtCompileTime, IsRowMajor > RowXpr
Definition: BlockMethods.h:17
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RowXpr row(Index i)
This is the const version of row(). */.
Definition: BlockMethods.h:1118
Block< Derived, internal::traits< Derived >::RowsAtCompileTime, 1, !IsRowMajor > ColXpr
Definition: BlockMethods.h:14
#define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD
Definition: GeneralProduct.h:28
#define eigen_internal_assert(x)
Definition: Macros.h:1053
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:1086
#define EIGEN_HAS_CXX11_ATOMIC
Definition: Macros.h:841
#define eigen_assert(x)
Definition: Macros.h:1047
#define EIGEN_STRONG_INLINE
Definition: Macros.h:927
#define EIGEN_SIZE_MIN_PREFER_FIXED(a, b)
Definition: Macros.h:1312
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:768
Expression of the product of two arbitrary matrices or vectors.
Definition: Product.h:75
Expression of the transpose of a matrix.
Definition: Transpose.h:54
Definition: BlasUtil.h:270
Definition: BlasUtil.h:389
Definition: GeneralBlockPanelKernel.h:419
gemm_blocking_space(Index rows, Index cols, Index depth, Index num_threads, bool l3_blocking)
Definition: GeneralMatrixMatrix.h:347
void allocateB()
Definition: GeneralMatrixMatrix.h:386
void allocateAll()
Definition: GeneralMatrixMatrix.h:392
~gemm_blocking_space()
Definition: GeneralMatrixMatrix.h:398
void allocateA()
Definition: GeneralMatrixMatrix.h:380
void initParallel(Index rows, Index cols, Index depth, Index num_threads)
Definition: GeneralMatrixMatrix.h:367
gemm_blocking_space(Index, Index, Index, Index, bool)
Definition: GeneralMatrixMatrix.h:307
void allocateA()
Definition: GeneralMatrixMatrix.h:324
void allocateAll()
Definition: GeneralMatrixMatrix.h:326
void initParallel(Index, Index, Index, Index)
Definition: GeneralMatrixMatrix.h:321
void allocateB()
Definition: GeneralMatrixMatrix.h:325
Definition: GeneralMatrixMatrix.h:248
Definition: GeneralMatrixMatrix.h:252
RhsScalar * blockB()
Definition: GeneralMatrixMatrix.h:275
Index kc() const
Definition: GeneralMatrixMatrix.h:272
Index mc() const
Definition: GeneralMatrixMatrix.h:270
Index m_mc
Definition: GeneralMatrixMatrix.h:260
level3_blocking()
Definition: GeneralMatrixMatrix.h:266
RhsScalar * m_blockB
Definition: GeneralMatrixMatrix.h:258
Index nc() const
Definition: GeneralMatrixMatrix.h:271
Index m_kc
Definition: GeneralMatrixMatrix.h:262
LhsScalar * blockA()
Definition: GeneralMatrixMatrix.h:274
Index m_nc
Definition: GeneralMatrixMatrix.h:261
LhsScalar * m_blockA
Definition: GeneralMatrixMatrix.h:257
@ ColMajor
Storage order is column major (see TopicStorageOrders).
Definition: Constants.h:319
@ RowMajor
Storage order is row major (see TopicStorageOrders).
Definition: Constants.h:321
const unsigned int RowMajorBit
for a matrix, this means that the storage order is row-major.
Definition: Constants.h:66
constexpr common_t< T1, T2 > min(const T1 x, const T2 y) noexcept
Compile-time pairwise minimum function.
Definition: min.hpp:35
std::size_t UIntPtr
Definition: Meta.h:92
EIGEN_DEVICE_FUNC void aligned_delete(T *ptr, std::size_t size)
Definition: Memory.h:361
void parallelize_gemm(const Functor &func, Index rows, Index cols, Index depth, bool transpose)
Definition: Parallelizer.h:100
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar &alpha, const Lhs &lhs, const Rhs &rhs)
Definition: BlasUtil.h:568
Namespace containing all symbols from the Eigen library.
Definition: Core:141
@ GemmProduct
Definition: Constants.h:500
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
const int Dynamic
This value means that a positive quantity (e.g., a size) is not known at compile-time,...
Definition: Constants.h:22
Definition: Eigen_Colamd.h:50
Definition: BFloat16.h:88
static constexpr uint64_t k2
Definition: Hashing.h:172
Definition: Constants.h:528
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:806
Definition: Parallelizer.h:80
Definition: AssignmentFunctors.h:46
const T type
Definition: Meta.h:214
Definition: AssignmentFunctors.h:21
Definition: BlasUtil.h:403
Then type
Definition: Meta.h:109
Definition: GeneralBlockPanelKernel.h:1058
Definition: GeneralMatrixMatrix.h:214
Dest & m_dest
Definition: GeneralMatrixMatrix.h:242
void initParallelSession(Index num_threads) const
Definition: GeneralMatrixMatrix.h:219
const Rhs & m_rhs
Definition: GeneralMatrixMatrix.h:241
Gemm::Traits Traits
Definition: GeneralMatrixMatrix.h:237
Scalar m_actualAlpha
Definition: GeneralMatrixMatrix.h:243
BlockingType & m_blocking
Definition: GeneralMatrixMatrix.h:244
gemm_functor(const Lhs &lhs, const Rhs &rhs, Dest &dest, const Scalar &actualAlpha, BlockingType &blocking)
Definition: GeneralMatrixMatrix.h:215
const Lhs & m_lhs
Definition: GeneralMatrixMatrix.h:240
void operator()(Index row, Index rows, Index col=0, Index cols=-1, GemmParallelInfo< Index > *info=0) const
Definition: GeneralMatrixMatrix.h:225
Definition: BlasUtil.h:28
Definition: BlasUtil.h:25
gebp_traits< LhsScalar, RhsScalar > Traits
Definition: GeneralMatrixMatrix.h:58
static void run(Index rows, Index cols, Index depth, const LhsScalar *_lhs, Index lhsStride, const RhsScalar *_rhs, Index rhsStride, ResScalar *_res, Index resIncr, Index resStride, ResScalar alpha, level3_blocking< LhsScalar, RhsScalar > &blocking, GemmParallelInfo< Index > *info=0)
Definition: GeneralMatrixMatrix.h:61
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixMatrix.h:60
static EIGEN_STRONG_INLINE void run(Index rows, Index cols, Index depth, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsStride, ResScalar *res, Index resIncr, Index resStride, ResScalar alpha, level3_blocking< RhsScalar, LhsScalar > &blocking, GemmParallelInfo< Index > *info=0)
Definition: GeneralMatrixMatrix.h:30
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixMatrix.h:29
gebp_traits< RhsScalar, LhsScalar > Traits
Definition: GeneralMatrixMatrix.h:27
Definition: BlasUtil.h:35
Definition: ProductEvaluators.h:394
Lhs::Scalar LhsScalar
Definition: GeneralMatrixMatrix.h:414
internal::blas_traits< Lhs > LhsBlasTraits
Definition: GeneralMatrixMatrix.h:417
RhsBlasTraits::DirectLinearAccessType ActualRhsType
Definition: GeneralMatrixMatrix.h:422
LhsBlasTraits::DirectLinearAccessType ActualLhsType
Definition: GeneralMatrixMatrix.h:418
internal::remove_all< ActualRhsType >::type ActualRhsTypeCleaned
Definition: GeneralMatrixMatrix.h:423
static void subTo(Dst &dst, const Lhs &lhs, const Rhs &rhs)
Definition: GeneralMatrixMatrix.h:459
Product< Lhs, Rhs >::Scalar Scalar
Definition: GeneralMatrixMatrix.h:413
Rhs::Scalar RhsScalar
Definition: GeneralMatrixMatrix.h:415
static void evalTo(Dst &dst, const Lhs &lhs, const Rhs &rhs)
Definition: GeneralMatrixMatrix.h:432
internal::blas_traits< Rhs > RhsBlasTraits
Definition: GeneralMatrixMatrix.h:421
static void addTo(Dst &dst, const Lhs &lhs, const Rhs &rhs)
Definition: GeneralMatrixMatrix.h:450
generic_product_impl< Lhs, Rhs, DenseShape, DenseShape, CoeffBasedProductMode > lazyproduct
Definition: GeneralMatrixMatrix.h:429
internal::remove_all< ActualLhsType >::type ActualLhsTypeCleaned
Definition: GeneralMatrixMatrix.h:419
static void scaleAndAddTo(Dest &dst, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar &alpha)
Definition: GeneralMatrixMatrix.h:468
Definition: ProductEvaluators.h:344
Definition: ProductEvaluators.h:86
T type
Definition: Meta.h:126
Definition: AssignmentFunctors.h:67