10#ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11#define EIGEN_GENERAL_MATRIX_VECTOR_H
23template <
int N,
typename T1,
typename T2,
typename T3>
26template <
typename T1,
typename T2,
typename T3>
29template <
typename T1,
typename T2,
typename T3>
32template<
typename LhsScalar,
typename RhsScalar,
int _PacketSize=GEMVPacketFull>
37#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
38 typedef typename gemv_packet_cond<packet_size, \
39 typename packet_traits<name ## Scalar>::type, \
40 typename packet_traits<name ## Scalar>::half, \
41 typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
42 prefix ## name ## Packet
44 PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
45 PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
46 PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
47#undef PACKET_DECL_COND_PREFIX
78template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
101 const LhsMapper& lhs,
102 const RhsMapper& rhs,
107template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
110 const LhsMapper& alhs,
111 const RhsMapper& rhs,
127 const Index lhsStride = lhs.stride();
130 ResPacketSize = Traits::ResPacketSize,
131 ResPacketSizeHalf = HalfTraits::ResPacketSize,
132 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
133 LhsPacketSize = Traits::LhsPacketSize,
134 HasHalf = (int)ResPacketSizeHalf < (
int)ResPacketSize,
135 HasQuarter = (int)ResPacketSizeQuarter < (
int)ResPacketSizeHalf
138 const Index n8 = rows-8*ResPacketSize+1;
139 const Index n4 = rows-4*ResPacketSize+1;
140 const Index n3 = rows-3*ResPacketSize+1;
141 const Index n2 = rows-2*ResPacketSize+1;
142 const Index n1 = rows-1*ResPacketSize+1;
143 const Index n_half = rows-1*ResPacketSizeHalf+1;
144 const Index n_quarter = rows-1*ResPacketSizeQuarter+1;
147 const Index block_cols = cols<128 ? cols : (lhsStride*
sizeof(LhsScalar)<32000?16:4);
148 ResPacket palpha = pset1<ResPacket>(alpha);
152 for(
Index j2=0; j2<cols; j2+=block_cols)
156 for(; i<n8; i+=ResPacketSize*8)
167 for(
Index j=j2; j<jend; j+=1)
169 RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
170 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
171 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
172 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
173 c3 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*3,j),b0,c3);
174 c4 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*4,j),b0,c4);
175 c5 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*5,j),b0,c5);
176 c6 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*6,j),b0,c6);
177 c7 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*7,j),b0,c7);
179 pstoreu(res+i+ResPacketSize*0,
pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
180 pstoreu(res+i+ResPacketSize*1,
pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
181 pstoreu(res+i+ResPacketSize*2,
pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
182 pstoreu(res+i+ResPacketSize*3,
pmadd(c3,palpha,ploadu<ResPacket>(res+i+ResPacketSize*3)));
183 pstoreu(res+i+ResPacketSize*4,
pmadd(c4,palpha,ploadu<ResPacket>(res+i+ResPacketSize*4)));
184 pstoreu(res+i+ResPacketSize*5,
pmadd(c5,palpha,ploadu<ResPacket>(res+i+ResPacketSize*5)));
185 pstoreu(res+i+ResPacketSize*6,
pmadd(c6,palpha,ploadu<ResPacket>(res+i+ResPacketSize*6)));
186 pstoreu(res+i+ResPacketSize*7,
pmadd(c7,palpha,ploadu<ResPacket>(res+i+ResPacketSize*7)));
195 for(
Index j=j2; j<jend; j+=1)
197 RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
198 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
199 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
200 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
201 c3 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*3,j),b0,c3);
203 pstoreu(res+i+ResPacketSize*0,
pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
204 pstoreu(res+i+ResPacketSize*1,
pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
205 pstoreu(res+i+ResPacketSize*2,
pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
206 pstoreu(res+i+ResPacketSize*3,
pmadd(c3,palpha,ploadu<ResPacket>(res+i+ResPacketSize*3)));
216 for(
Index j=j2; j<jend; j+=1)
218 RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
219 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
220 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
221 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
223 pstoreu(res+i+ResPacketSize*0,
pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
224 pstoreu(res+i+ResPacketSize*1,
pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
225 pstoreu(res+i+ResPacketSize*2,
pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
234 for(
Index j=j2; j<jend; j+=1)
236 RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
237 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
238 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
240 pstoreu(res+i+ResPacketSize*0,
pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
241 pstoreu(res+i+ResPacketSize*1,
pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
247 for(
Index j=j2; j<jend; j+=1)
249 RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
250 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
252 pstoreu(res+i+ResPacketSize*0,
pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
255 if(HasHalf && i<n_half)
258 for(
Index j=j2; j<jend; j+=1)
261 c0 = pcj_half.
pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i+0,j),b0,c0);
263 pstoreu(res+i+ResPacketSizeHalf*0,
pmadd(c0,palpha_half,ploadu<ResPacketHalf>(res+i+ResPacketSizeHalf*0)));
264 i+=ResPacketSizeHalf;
266 if(HasQuarter && i<n_quarter)
269 for(
Index j=j2; j<jend; j+=1)
272 c0 = pcj_quarter.
pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i+0,j),b0,c0);
274 pstoreu(res+i+ResPacketSizeQuarter*0,
pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(res+i+ResPacketSizeQuarter*0)));
275 i+=ResPacketSizeQuarter;
280 for(
Index j=j2; j<jend; j+=1)
281 c0 += cj.
pmul(lhs(i,j), rhs(j,0));
297template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
320 const LhsMapper& lhs,
321 const RhsMapper& rhs,
326template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
329 const LhsMapper& alhs,
330 const RhsMapper& rhs,
346 const Index n8 = lhs.stride()*
sizeof(LhsScalar)>32000 ? 0 : rows-7;
347 const Index n4 = rows-3;
348 const Index n2 = rows-1;
352 ResPacketSize = Traits::ResPacketSize,
353 ResPacketSizeHalf = HalfTraits::ResPacketSize,
354 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
355 LhsPacketSize = Traits::LhsPacketSize,
356 LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
357 LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
358 HasHalf = (int)ResPacketSizeHalf < (
int)ResPacketSize,
359 HasQuarter = (int)ResPacketSizeQuarter < (
int)ResPacketSizeHalf
375 for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
377 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
379 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
380 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
381 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
382 c3 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
383 c4 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+4,j),b0,c4);
384 c5 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+5,j),b0,c5);
385 c6 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+6,j),b0,c6);
386 c7 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+7,j),b0,c7);
398 RhsScalar b0 = rhs(j,0);
400 cc0 += cj.
pmul(lhs(i+0,j), b0);
401 cc1 += cj.
pmul(lhs(i+1,j), b0);
402 cc2 += cj.
pmul(lhs(i+2,j), b0);
403 cc3 += cj.
pmul(lhs(i+3,j), b0);
404 cc4 += cj.
pmul(lhs(i+4,j), b0);
405 cc5 += cj.
pmul(lhs(i+5,j), b0);
406 cc6 += cj.
pmul(lhs(i+6,j), b0);
407 cc7 += cj.
pmul(lhs(i+7,j), b0);
409 res[(i+0)*resIncr] += alpha*cc0;
410 res[(i+1)*resIncr] += alpha*cc1;
411 res[(i+2)*resIncr] += alpha*cc2;
412 res[(i+3)*resIncr] += alpha*cc3;
413 res[(i+4)*resIncr] += alpha*cc4;
414 res[(i+5)*resIncr] += alpha*cc5;
415 res[(i+6)*resIncr] += alpha*cc6;
416 res[(i+7)*resIncr] += alpha*cc7;
426 for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
428 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
430 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
431 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
432 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
433 c3 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
441 RhsScalar b0 = rhs(j,0);
443 cc0 += cj.
pmul(lhs(i+0,j), b0);
444 cc1 += cj.
pmul(lhs(i+1,j), b0);
445 cc2 += cj.
pmul(lhs(i+2,j), b0);
446 cc3 += cj.
pmul(lhs(i+3,j), b0);
448 res[(i+0)*resIncr] += alpha*cc0;
449 res[(i+1)*resIncr] += alpha*cc1;
450 res[(i+2)*resIncr] += alpha*cc2;
451 res[(i+3)*resIncr] += alpha*cc3;
459 for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
461 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
463 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
464 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
470 RhsScalar b0 = rhs(j,0);
472 cc0 += cj.
pmul(lhs(i+0,j), b0);
473 cc1 += cj.
pmul(lhs(i+1,j), b0);
475 res[(i+0)*resIncr] += alpha*cc0;
476 res[(i+1)*resIncr] += alpha*cc1;
484 for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
486 RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(j,0);
487 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
491 for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf)
493 RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(j,0);
494 c0_h = pcj_half.
pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i,j),b0,c0_h);
499 for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter)
502 c0_q = pcj_quarter.
pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i,j),b0,c0_q);
508 cc0 += cj.
pmul(lhs(i,j), rhs(j,0));
510 res[i*resIncr] += alpha*cc0;
#define eigen_internal_assert(x)
Definition: Macros.h:1053
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:1086
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:986
#define EIGEN_DONT_INLINE
Definition: Macros.h:950
Definition: GeneralMatrixVector.h:34
conditional< Vectorizable, _LhsPacket, LhsScalar >::type LhsPacket
Definition: GeneralMatrixVector.h:59
conditional< Vectorizable, _ResPacket, ResScalar >::type ResPacket
Definition: GeneralMatrixVector.h:61
conditional< Vectorizable, _RhsPacket, RhsScalar >::type RhsPacket
Definition: GeneralMatrixVector.h:60
@ Vectorizable
Definition: GeneralMatrixVector.h:51
@ ResPacketSize
Definition: GeneralMatrixVector.h:56
@ RhsPacketSize
Definition: GeneralMatrixVector.h:55
@ LhsPacketSize
Definition: GeneralMatrixVector.h:54
@ Unaligned
Data pointer has no specific alignment.
Definition: Constants.h:233
@ ColMajor
Storage order is column major (see TopicStorageOrders).
Definition: Constants.h:319
@ RowMajor
Storage order is row major (see TopicStorageOrders).
Definition: Constants.h:321
EIGEN_DEVICE_FUNC unpacket_traits< Packet >::type predux(const Packet &a)
Definition: GenericPacketMath.h:875
EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i &a, const Packet4i &b, const Packet4i &c)
Definition: PacketMath.h:370
GEMVPacketSizeType
Definition: GeneralMatrixVector.h:17
@ GEMVPacketFull
Definition: GeneralMatrixVector.h:18
@ GEMVPacketHalf
Definition: GeneralMatrixVector.h:19
@ GEMVPacketQuarter
Definition: GeneralMatrixVector.h:20
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:700
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:1083
Namespace containing all symbols from the Eigen library.
Definition: Core:141
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
Definition: Eigen_Colamd.h:50
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:806
Definition: ConjHelper.h:63
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType &x, const RhsType &y, const ResultType &c) const
Definition: ConjHelper.h:67
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType &x, const RhsType &y) const
Definition: ConjHelper.h:71
T1 type
Definition: GeneralMatrixVector.h:27
T2 type
Definition: GeneralMatrixVector.h:30
Definition: GeneralMatrixVector.h:24
T3 type
Definition: GeneralMatrixVector.h:24
QuarterTraits::RhsPacket RhsPacketQuarter
Definition: GeneralMatrixVector.h:315
gemv_traits< LhsScalar, RhsScalar > Traits
Definition: GeneralMatrixVector.h:300
Traits::RhsPacket RhsPacket
Definition: GeneralMatrixVector.h:307
HalfTraits::RhsPacket RhsPacketHalf
Definition: GeneralMatrixVector.h:311
Traits::ResPacket ResPacket
Definition: GeneralMatrixVector.h:308
gemv_traits< LhsScalar, RhsScalar, GEMVPacketHalf > HalfTraits
Definition: GeneralMatrixVector.h:301
Traits::LhsPacket LhsPacket
Definition: GeneralMatrixVector.h:306
HalfTraits::LhsPacket LhsPacketHalf
Definition: GeneralMatrixVector.h:310
HalfTraits::ResPacket ResPacketHalf
Definition: GeneralMatrixVector.h:312
gemv_traits< LhsScalar, RhsScalar, GEMVPacketQuarter > QuarterTraits
Definition: GeneralMatrixVector.h:302
QuarterTraits::ResPacket ResPacketQuarter
Definition: GeneralMatrixVector.h:316
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixVector.h:304
QuarterTraits::LhsPacket LhsPacketQuarter
Definition: GeneralMatrixVector.h:314
gemv_traits< LhsScalar, RhsScalar > Traits
Definition: GeneralMatrixVector.h:81
Traits::LhsPacket LhsPacket
Definition: GeneralMatrixVector.h:87
Traits::RhsPacket RhsPacket
Definition: GeneralMatrixVector.h:88
QuarterTraits::LhsPacket LhsPacketQuarter
Definition: GeneralMatrixVector.h:95
Traits::ResPacket ResPacket
Definition: GeneralMatrixVector.h:89
HalfTraits::ResPacket ResPacketHalf
Definition: GeneralMatrixVector.h:93
QuarterTraits::RhsPacket RhsPacketQuarter
Definition: GeneralMatrixVector.h:96
QuarterTraits::ResPacket ResPacketQuarter
Definition: GeneralMatrixVector.h:97
HalfTraits::LhsPacket LhsPacketHalf
Definition: GeneralMatrixVector.h:91
gemv_traits< LhsScalar, RhsScalar, GEMVPacketHalf > HalfTraits
Definition: GeneralMatrixVector.h:82
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixVector.h:85
HalfTraits::RhsPacket RhsPacketHalf
Definition: GeneralMatrixVector.h:92
gemv_traits< LhsScalar, RhsScalar, GEMVPacketQuarter > QuarterTraits
Definition: GeneralMatrixVector.h:83
Definition: BlasUtil.h:40
Definition: GenericPacketMath.h:133
@ vectorizable
Definition: GenericPacketMath.h:140