Skip to content

Commit 5deeb19

Browse files
committed
bug #1517: fix triangular product with unit diagonal and nested scaling factor: (s*A).triangularView<UpperUnit>()*B
1 parent 12efc7d commit 5deeb19

File tree

3 files changed

+59
-15
lines changed

3 files changed

+59
-15
lines changed

Eigen/src/Core/products/TriangularMatrixMatrix.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
400400
{
401401
template<typename Dest> static void run(Dest& dst, const Lhs &a_lhs, const Rhs &a_rhs, const typename Dest::Scalar& alpha)
402402
{
403-
typedef typename Dest::Scalar Scalar;
403+
typedef typename Lhs::Scalar LhsScalar;
404+
typedef typename Rhs::Scalar RhsScalar;
405+
typedef typename Dest::Scalar Scalar;
404406

405407
typedef internal::blas_traits<Lhs> LhsBlasTraits;
406408
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
@@ -412,8 +414,9 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
412414
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
413415
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
414416

415-
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
416-
* RhsBlasTraits::extractScalarFactor(a_rhs);
417+
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(a_lhs);
418+
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(a_rhs);
419+
Scalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
417420

418421
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
419422
Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType;
@@ -438,6 +441,21 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
438441
&dst.coeffRef(0,0), dst.outerStride(), // result info
439442
actualAlpha, blocking
440443
);
444+
445+
// Apply correction if the diagonal is unit and a scalar factor was nested:
446+
if ((Mode&UnitDiag)==UnitDiag)
447+
{
448+
if (LhsIsTriangular && lhs_alpha!=LhsScalar(1))
449+
{
450+
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
451+
dst.topRows(diagSize) -= ((lhs_alpha-LhsScalar(1))*a_rhs).topRows(diagSize);
452+
}
453+
else if ((!LhsIsTriangular) && rhs_alpha!=RhsScalar(1))
454+
{
455+
Index diagSize = (std::min)(rhs.rows(),rhs.cols());
456+
dst.leftCols(diagSize) -= (rhs_alpha-RhsScalar(1))*a_lhs.leftCols(diagSize);
457+
}
458+
}
441459
}
442460
};
443461

Eigen/src/Core/products/TriangularMatrixVector.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,9 @@ template<int Mode> struct trmv_selector<Mode,ColMajor>
221221
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
222222
typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
223223

224-
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
225-
* RhsBlasTraits::extractScalarFactor(rhs);
224+
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
225+
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
226+
ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
226227

227228
enum {
228229
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
@@ -274,6 +275,12 @@ template<int Mode> struct trmv_selector<Mode,ColMajor>
274275
else
275276
dest = MappedDest(actualDestPtr, dest.size());
276277
}
278+
279+
if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
280+
{
281+
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
282+
dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
283+
}
277284
}
278285
};
279286

@@ -295,8 +302,9 @@ template<int Mode> struct trmv_selector<Mode,RowMajor>
295302
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
296303
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
297304

298-
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
299-
* RhsBlasTraits::extractScalarFactor(rhs);
305+
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
306+
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
307+
ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
300308

301309
enum {
302310
DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
@@ -326,6 +334,12 @@ template<int Mode> struct trmv_selector<Mode,RowMajor>
326334
actualRhsPtr,1,
327335
dest.data(),dest.innerStride(),
328336
actualAlpha);
337+
338+
if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
339+
{
340+
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
341+
dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
342+
}
329343
}
330344
};
331345

test/product_trmm.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void trmm(int rows=get_random_size<Scalar>(),
2929
typedef Matrix<Scalar,Dynamic,OtherCols,OtherCols==1?ColMajor:ResOrder> ResXS;
3030
typedef Matrix<Scalar,OtherCols,Dynamic,OtherCols==1?RowMajor:ResOrder> ResSX;
3131

32-
TriMatrix mat(rows,cols), tri(rows,cols), triTr(cols,rows);
32+
TriMatrix mat(rows,cols), tri(rows,cols), triTr(cols,rows), s1tri(rows,cols), s1triTr(cols,rows);
3333

3434
OnTheRight ge_right(cols,otherCols);
3535
OnTheLeft ge_left(otherCols,rows);
@@ -42,6 +42,8 @@ void trmm(int rows=get_random_size<Scalar>(),
4242
mat.setRandom();
4343
tri = mat.template triangularView<Mode>();
4444
triTr = mat.transpose().template triangularView<Mode>();
45+
s1tri = (s1*mat).template triangularView<Mode>();
46+
s1triTr = (s1*mat).transpose().template triangularView<Mode>();
4547
ge_right.setRandom();
4648
ge_left.setRandom();
4749

@@ -51,19 +53,29 @@ void trmm(int rows=get_random_size<Scalar>(),
5153
VERIFY_IS_APPROX( ge_xs.noalias() = mat.template triangularView<Mode>() * ge_right, tri * ge_right);
5254
VERIFY_IS_APPROX( ge_sx.noalias() = ge_left * mat.template triangularView<Mode>(), ge_left * tri);
5355

54-
VERIFY_IS_APPROX( ge_xs.noalias() = (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.transpose()), s1*triTr.conjugate() * (s2*ge_left.transpose()));
55-
VERIFY_IS_APPROX( ge_sx.noalias() = ge_right.transpose() * mat.adjoint().template triangularView<Mode>(), ge_right.transpose() * triTr.conjugate());
56+
if((Mode&UnitDiag)==0)
57+
VERIFY_IS_APPROX( ge_xs.noalias() = (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.transpose()), s1*triTr.conjugate() * (s2*ge_left.transpose()));
5658

57-
VERIFY_IS_APPROX( ge_xs.noalias() = (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.adjoint()), s1*triTr.conjugate() * (s2*ge_left.adjoint()));
58-
VERIFY_IS_APPROX( ge_sx.noalias() = ge_right.adjoint() * mat.adjoint().template triangularView<Mode>(), ge_right.adjoint() * triTr.conjugate());
59+
VERIFY_IS_APPROX( ge_xs.noalias() = (s1*mat.transpose()).template triangularView<Mode>() * (s2*ge_left.transpose()), s1triTr * (s2*ge_left.transpose()));
60+
VERIFY_IS_APPROX( ge_sx.noalias() = (s2*ge_left) * (s1*mat).template triangularView<Mode>(), (s2*ge_left)*s1tri);
5961

62+
VERIFY_IS_APPROX( ge_sx.noalias() = ge_right.transpose() * mat.adjoint().template triangularView<Mode>(), ge_right.transpose() * triTr.conjugate());
63+
VERIFY_IS_APPROX( ge_sx.noalias() = ge_right.adjoint() * mat.adjoint().template triangularView<Mode>(), ge_right.adjoint() * triTr.conjugate());
64+
65+
ge_xs_save = ge_xs;
66+
if((Mode&UnitDiag)==0)
67+
VERIFY_IS_APPROX( (ge_xs_save + s1*triTr.conjugate() * (s2*ge_left.adjoint())).eval(), ge_xs.noalias() += (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.adjoint()) );
6068
ge_xs_save = ge_xs;
61-
VERIFY_IS_APPROX( (ge_xs_save + s1*triTr.conjugate() * (s2*ge_left.adjoint())).eval(), ge_xs.noalias() += (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.adjoint()) );
69+
VERIFY_IS_APPROX( (ge_xs_save + s1triTr * (s2*ge_left.adjoint())).eval(), ge_xs.noalias() += (s1*mat.transpose()).template triangularView<Mode>() * (s2*ge_left.adjoint()) );
6270
ge_sx.setRandom();
6371
ge_sx_save = ge_sx;
64-
VERIFY_IS_APPROX( ge_sx_save - (ge_right.adjoint() * (-s1 * triTr).conjugate()).eval(), ge_sx.noalias() -= (ge_right.adjoint() * (-s1 * mat).adjoint().template triangularView<Mode>()).eval());
72+
if((Mode&UnitDiag)==0)
73+
VERIFY_IS_APPROX( ge_sx_save - (ge_right.adjoint() * (-s1 * triTr).conjugate()).eval(), ge_sx.noalias() -= (ge_right.adjoint() * (-s1 * mat).adjoint().template triangularView<Mode>()).eval());
6574

66-
VERIFY_IS_APPROX( ge_xs = (s1*mat).adjoint().template triangularView<Mode>() * ge_left.adjoint(), numext::conj(s1) * triTr.conjugate() * ge_left.adjoint());
75+
if((Mode&UnitDiag)==0)
76+
VERIFY_IS_APPROX( ge_xs = (s1*mat).adjoint().template triangularView<Mode>() * ge_left.adjoint(), numext::conj(s1) * triTr.conjugate() * ge_left.adjoint());
77+
VERIFY_IS_APPROX( ge_xs = (s1*mat).transpose().template triangularView<Mode>() * ge_left.adjoint(), s1triTr * ge_left.adjoint());
78+
6779

6880
// TODO check with sub-matrix expressions ?
6981
}

0 commit comments

Comments
 (0)