"Fossies" - the Fresh Open Source Software Archive  

Source code changes of the file "include/armadillo_bits/glue_times_meat.hpp" between
armadillo-10.8.2.tar.xz and armadillo-11.0.0.tar.xz

About: Armadillo is a C++ linear algebra library (matrix maths) aiming towards a good balance between speed and ease of use.

glue_times_meat.hpp  (armadillo-10.8.2.tar.xz):glue_times_meat.hpp  (armadillo-11.0.0.tar.xz)
skipping to change at line 81 skipping to change at line 81
template<typename T1, typename T2> template<typename T1, typename T2>
arma_hot arma_hot
inline inline
void void
glue_times_redirect2_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X) glue_times_redirect2_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
{ {
arma_extra_debug_sigprint(); arma_extra_debug_sigprint();
typedef typename T1::elem_type eT; typedef typename T1::elem_type eT;
if(strip_inv<T1>::do_inv) if(arma_config::optimise_invexpr && (strip_inv<T1>::do_inv_gen || strip_inv<T1 >::do_inv_spd))
{ {
// replace inv(A)*B with solve(A,B) // replace inv(A)*B with solve(A,B)
arma_extra_debug_print("glue_times_redirect<2>::apply(): detected inv(A)*B") ; arma_extra_debug_print("glue_times_redirect<2>::apply(): detected inv(A)*B") ;
const strip_inv<T1> A_strip(X.A); const strip_inv<T1> A_strip(X.A);
Mat<eT> A = A_strip.M; Mat<eT> A = A_strip.M;
arma_debug_check( (A.is_square() == false), "inv(): given matrix must be squ are sized" ); arma_debug_check( (A.is_square() == false), "inv(): given matrix must be squ are sized" );
if(strip_inv<T1>::do_inv_sympd) if( (strip_inv<T1>::do_inv_spd) && (arma_config::debug) && (auxlib::rudiment ary_sym_check(A) == false) )
{ {
// if(auxlib::rudimentary_sym_check(A) == false) if(is_cx<eT>::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix
// { is not symmetric"); }
// if(is_cx<eT>::no ) { arma_debug_warn_level(1, "inv_sympd(): given ma if(is_cx<eT>::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix
trix is not symmetric"); } is not hermitian"); }
// if(is_cx<eT>::yes) { arma_debug_warn_level(1, "inv_sympd(): given ma
trix is not hermitian"); }
//
// out.soft_reset();
// arma_stop_runtime_error("matrix multiplication: problem with matrix i
nverse; suggest to use solve() instead");
//
// return;
// }
if( (arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false) )
{
if(is_cx<eT>::no ) { arma_debug_warn_level(1, "inv_sympd(): given matri
x is not symmetric"); }
if(is_cx<eT>::yes) { arma_debug_warn_level(1, "inv_sympd(): given matri
x is not hermitian"); }
}
} }
const unwrap_check<T2> B_tmp(X.B, out); const unwrap_check<T2> B_tmp(X.B, out);
const Mat<eT>& B = B_tmp.M; const Mat<eT>& B = B_tmp.M;
arma_debug_assert_mul_size(A, B, "matrix multiplication"); arma_debug_assert_mul_size(A, B, "matrix multiplication");
// TODO: detect sympd via sympd_helper::guess_sympd(A) ? const bool status = (strip_inv<T1>::do_inv_spd) ? auxlib::solve_sympd_fast(o
ut, A, B) : auxlib::solve_square_fast(out, A, B);
#if defined(ARMA_OPTIMISE_SYMPD)
const bool status = (strip_inv<T1>::do_inv_sympd) ? auxlib::solve_sympd_fa
st(out, A, B) : auxlib::solve_square_fast(out, A, B);
#else
const bool status = auxlib::solve_square_fast(out, A, B);
#endif
if(status == false) if(status == false)
{ {
out.soft_reset(); out.soft_reset();
arma_stop_runtime_error("matrix multiplication: problem with matrix invers e; suggest to use solve() instead"); arma_stop_runtime_error("matrix multiplication: problem with matrix invers e; suggest to use solve() instead");
} }
return; return;
} }
#if defined(ARMA_OPTIMISE_SYMPD) if(arma_config::optimise_invexpr && strip_inv<T2>::do_inv_spd)
{ {
if(strip_inv<T2>::do_inv_sympd) // replace A*inv_sympd(B) with trans( solve(trans(B),trans(A)) )
{ // transpose of B is avoided as B is explicitly marked as symmetric
// replace A*inv_sympd(B) with trans( solve(trans(B),trans(A)) )
// transpose of B is avoided as B is explicitly marked as symmetric
arma_extra_debug_print("glue_times_redirect<2>::apply(): detected A*inv_sy mpd(B)"); arma_extra_debug_print("glue_times_redirect<2>::apply(): detected A*inv_symp d(B)");
const Mat<eT> At = trans(X.A); const Mat<eT> At = trans(X.A);
const strip_inv<T2> B_strip(X.B); const strip_inv<T2> B_strip(X.B);
Mat<eT> B = B_strip.M; Mat<eT> B = B_strip.M;
arma_debug_check( (B.is_square() == false), "inv_sympd(): given matrix mus
t be square sized" );
// if(auxlib::rudimentary_sym_check(B) == false) arma_debug_check( (B.is_square() == false), "inv_sympd(): given matrix must
// { be square sized" );
// if(is_cx<eT>::no ) { arma_debug_warn_level(1, "inv_sympd(): given ma
trix is not symmetric"); }
// if(is_cx<eT>::yes) { arma_debug_warn_level(1, "inv_sympd(): given ma
trix is not hermitian"); }
//
// out.soft_reset();
// arma_stop_runtime_error("matrix multiplication: problem with matrix i
nverse; suggest to use solve() instead");
//
// return;
// }
if( (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) ) if( (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) )
{ {
if(is_cx<eT>::no ) { arma_debug_warn_level(1, "inv_sympd(): given matri if(is_cx<eT>::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix
x is not symmetric"); } is not symmetric"); }
if(is_cx<eT>::yes) { arma_debug_warn_level(1, "inv_sympd(): given matri if(is_cx<eT>::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix
x is not hermitian"); } is not hermitian"); }
} }
arma_debug_assert_mul_size(At.n_cols, At.n_rows, B.n_rows, B.n_cols, "matr ix multiplication"); arma_debug_assert_mul_size(At.n_cols, At.n_rows, B.n_rows, B.n_cols, "matrix multiplication");
const bool status = auxlib::solve_sympd_fast(out, B, At); const bool status = auxlib::solve_sympd_fast(out, B, At);
if(status == false) if(status == false)
{ {
out.soft_reset(); out.soft_reset();
arma_stop_runtime_error("matrix multiplication: problem with matrix inve arma_stop_runtime_error("matrix multiplication: problem with matrix invers
rse; suggest to use solve() instead"); e; suggest to use solve() instead");
} }
out = trans(out); out = trans(out);
return; return;
}
} }
#endif
glue_times_redirect2_helper<false>::apply(out, X); glue_times_redirect2_helper<false>::apply(out, X);
} }
template<bool do_inv_detect> template<bool do_inv_detect>
template<typename T1, typename T2, typename T3> template<typename T1, typename T2, typename T3>
arma_hot arma_hot
inline inline
void void
glue_times_redirect3_helper<do_inv_detect>::apply(Mat<typename T1::elem_type>& o ut, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X) glue_times_redirect3_helper<do_inv_detect>::apply(Mat<typename T1::elem_type>& o ut, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
skipping to change at line 256 skipping to change at line 221
template<typename T1, typename T2, typename T3> template<typename T1, typename T2, typename T3>
arma_hot arma_hot
inline inline
void void
glue_times_redirect3_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X) glue_times_redirect3_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
{ {
arma_extra_debug_sigprint(); arma_extra_debug_sigprint();
typedef typename T1::elem_type eT; typedef typename T1::elem_type eT;
if(strip_inv<T1>::do_inv) if(arma_config::optimise_invexpr && (strip_inv<T1>::do_inv_gen || strip_inv<T1 >::do_inv_spd))
{ {
// replace inv(A)*B*C with solve(A,B*C); // replace inv(A)*B*C with solve(A,B*C);
arma_extra_debug_print("glue_times_redirect<3>::apply(): detected inv(A)*B*C "); arma_extra_debug_print("glue_times_redirect<3>::apply(): detected inv(A)*B*C ");
const strip_inv<T1> A_strip(X.A.A); const strip_inv<T1> A_strip(X.A.A);
Mat<eT> A = A_strip.M; Mat<eT> A = A_strip.M;
arma_debug_check( (A.is_square() == false), "inv(): given matrix must be squ are sized" ); arma_debug_check( (A.is_square() == false), "inv(): given matrix must be squ are sized" );
skipping to change at line 290 skipping to change at line 255
< <
eT, eT,
partial_unwrap<T2>::do_trans, partial_unwrap<T2>::do_trans,
partial_unwrap<T3>::do_trans, partial_unwrap<T3>::do_trans,
(partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times) (partial_unwrap<T2>::do_times || partial_unwrap<T3>::do_times)
> >
(BC, B, C, alpha); (BC, B, C, alpha);
arma_debug_assert_mul_size(A, BC, "matrix multiplication"); arma_debug_assert_mul_size(A, BC, "matrix multiplication");
// TODO: detect sympd via sympd_helper::guess_sympd(A) ? if( (strip_inv<T1>::do_inv_spd) && (arma_config::debug) && (auxlib::rudiment
ary_sym_check(A) == false) )
{
if(is_cx<eT>::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix
is not symmetric"); }
if(is_cx<eT>::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix
is not hermitian"); }
}
#if defined(ARMA_OPTIMISE_SYMPD) const bool status = (strip_inv<T1>::do_inv_spd) ? auxlib::solve_sympd_fast(o
const bool status = (strip_inv<T1>::do_inv_sympd) ? auxlib::solve_sympd_fa ut, A, BC) : auxlib::solve_square_fast(out, A, BC);
st(out, A, BC) : auxlib::solve_square_fast(out, A, BC);
#else
const bool status = auxlib::solve_square_fast(out, A, BC);
#endif
if(status == false) if(status == false)
{ {
out.soft_reset(); out.soft_reset();
arma_stop_runtime_error("matrix multiplication: problem with matrix invers e; suggest to use solve() instead"); arma_stop_runtime_error("matrix multiplication: problem with matrix invers e; suggest to use solve() instead");
} }
return; return;
} }
if(strip_inv<T2>::do_inv) if(arma_config::optimise_invexpr && (strip_inv<T2>::do_inv_gen || strip_inv<T2 >::do_inv_spd))
{ {
// replace A*inv(B)*C with A*solve(B,C) // replace A*inv(B)*C with A*solve(B,C)
arma_extra_debug_print("glue_times_redirect<3>::apply(): detected A*inv(B)*C "); arma_extra_debug_print("glue_times_redirect<3>::apply(): detected A*inv(B)*C ");
const strip_inv<T2> B_strip(X.A.B); const strip_inv<T2> B_strip(X.A.B);
Mat<eT> B = B_strip.M; Mat<eT> B = B_strip.M;
arma_debug_check( (B.is_square() == false), "inv(): given matrix must be squ are sized" ); arma_debug_check( (B.is_square() == false), "inv(): given matrix must be squ are sized" );
const unwrap<T3> C_tmp(X.B); const unwrap<T3> C_tmp(X.B);
const Mat<eT>& C = C_tmp.M; const Mat<eT>& C = C_tmp.M;
arma_debug_assert_mul_size(B, C, "matrix multiplication"); arma_debug_assert_mul_size(B, C, "matrix multiplication");
if( (strip_inv<T2>::do_inv_spd) && (arma_config::debug) && (auxlib::rudiment
ary_sym_check(B) == false) )
{
if(is_cx<eT>::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix
is not symmetric"); }
if(is_cx<eT>::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix
is not hermitian"); }
}
Mat<eT> solve_result; Mat<eT> solve_result;
#if defined(ARMA_OPTIMISE_SYMPD) const bool status = (strip_inv<T2>::do_inv_spd) ? auxlib::solve_sympd_fast(s
const bool status = (strip_inv<T2>::do_inv_sympd) ? auxlib::solve_sympd_fa olve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C);
st(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C);
#else
const bool status = auxlib::solve_square_fast(solve_result, B, C);
#endif
if(status == false) if(status == false)
{ {
out.soft_reset(); out.soft_reset();
arma_stop_runtime_error("matrix multiplication: problem with matrix invers e; suggest to use solve() instead"); arma_stop_runtime_error("matrix multiplication: problem with matrix invers e; suggest to use solve() instead");
return; return;
} }
const partial_unwrap_check<T1> tmp1(X.A.A, out); const partial_unwrap_check<T1> tmp1(X.A.A, out);
skipping to change at line 534 skipping to change at line 501
arma_hot arma_hot
inline inline
void void
glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign) glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
{ {
arma_extra_debug_sigprint(); arma_extra_debug_sigprint();
typedef typename T1::elem_type eT; typedef typename T1::elem_type eT;
typedef typename get_pod_type<eT>::result T; typedef typename get_pod_type<eT>::result T;
if( (is_outer_product<T1>::value) || (has_op_inv<T1>::value) || (has_op_inv<T2 >::value) || (has_op_inv_sympd<T1>::value) || (has_op_inv_sympd<T2>::value) ) if( (is_outer_product<T1>::value) || (has_op_inv_any<T1>::value) || (has_op_in v_any<T2>::value) )
{ {
// partial workaround for corner cases // partial workaround for corner cases
const Mat<eT> tmp(X); const Mat<eT> tmp(X);
if(sign > sword(0)) { out += tmp; } else { out -= tmp; } if(sign > sword(0)) { out += tmp; } else { out -= tmp; }
return; return;
} }
skipping to change at line 568 skipping to change at line 535
const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0); const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0);
arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplicatio n"); arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplicatio n");
const uword result_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows ) : (TA::is_col ? 1 : A.n_cols); const uword result_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows ) : (TA::is_col ? 1 : A.n_cols);
const uword result_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols ) : (TB::is_row ? 1 : B.n_rows); const uword result_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols ) : (TB::is_row ? 1 : B.n_rows);
arma_debug_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_co ls, ( (sign > sword(0)) ? "addition" : "subtraction" ) ); arma_debug_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_co ls, ( (sign > sword(0)) ? "addition" : "subtraction" ) );
if(out.n_elem == 0) if(out.n_elem == 0) { return; }
{
return;
}
if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
{ {
if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true , false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true , false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); }
else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<fals e, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<fals e, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); }
else { gemm<fals e, false, false, true>::apply(out, A, B, alpha, eT(1)); } else { gemm<fals e, false, false, true>::apply(out, A, B, alpha, eT(1)); }
} }
else else
if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
{ {
skipping to change at line 677 skipping to change at line 641
arma_extra_debug_sigprint(); arma_extra_debug_sigprint();
//arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplicat ion"); //arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplicat ion");
arma_debug_assert_trans_mul_size<do_trans_A, do_trans_B>(A.n_rows, A.n_cols, B .n_rows, B.n_cols, "matrix multiplication"); arma_debug_assert_trans_mul_size<do_trans_A, do_trans_B>(A.n_rows, A.n_cols, B .n_rows, B.n_cols, "matrix multiplication");
const uword final_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols); const uword final_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
const uword final_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows); const uword final_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
out.set_size(final_n_rows, final_n_cols); out.set_size(final_n_rows, final_n_cols);
if( (A.n_elem == 0) || (B.n_elem == 0) ) if( (A.n_elem == 0) || (B.n_elem == 0) ) { out.zeros(); return; }
{
out.zeros();
return;
}
if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
{ {
if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true , false, false>::apply(out.memptr(), B, A.memptr()); } if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx<eT>::no) ) { gemv<true , false, false>::apply(out.memptr(), B, A.memptr()); }
else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<fals e, false, false>::apply(out.memptr(), A, B.memptr()); } else if( (B.n_cols == 1) || (TB::is_col) ) { gemv<fals e, false, false>::apply(out.memptr(), A, B.memptr()); }
else { gemm<fals e, false, false, false>::apply(out, A, B ); } else { gemm<fals e, false, false, false>::apply(out, A, B ); }
} }
else else
if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
{ {
 End of changes. 27 change blocks. 
100 lines changed or deleted 60 lines changed or added

Home  |  About  |  Features  |  All  |  Newest  |  Dox  |  Diffs  |  RSS Feeds  |  Screenshots  |  Comments  |  Imprint  |  Privacy  |  HTTP(S)