op_pinv_meat.hpp (armadillo-10.8.2.tar.xz) | : | op_pinv_meat.hpp (armadillo-11.0.0.tar.xz) | ||
---|---|---|---|---|
skipping to change at line 60 | skipping to change at line 60 | |||
typedef typename T1::pod_type T; | typedef typename T1::pod_type T; | |||
arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0"); | arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0"); | |||
// method_id = 0 -> default setting | // method_id = 0 -> default setting | |||
// method_id = 1 -> use standard algorithm | // method_id = 1 -> use standard algorithm | |||
// method_id = 2 -> use divide and conquer algorithm | // method_id = 2 -> use divide and conquer algorithm | |||
Mat<eT> A(expr.get_ref()); | Mat<eT> A(expr.get_ref()); | |||
const uword n_rows = A.n_rows; | if(A.is_empty()) { out.set_size(A.n_cols,A.n_rows); return true; } | |||
const uword n_cols = A.n_cols; | ||||
if(A.is_empty()) { out.set_size(n_cols,n_rows); return true; } | ||||
if(is_op_diagmat<T1>::value || A.is_diagmat()) | if(is_op_diagmat<T1>::value || A.is_diagmat()) | |||
{ | { | |||
arma_extra_debug_print("op_pinv: detected diagonal matrix"); | arma_extra_debug_print("op_pinv: detected diagonal matrix"); | |||
return op_pinv::apply_diag(out, A, tol); | return op_pinv::apply_diag(out, A, tol); | |||
} | } | |||
#if defined(ARMA_OPTIMISE_SYMPD) | bool do_sym = false; | |||
bool do_sym = false; | bool do_sympd = false; | |||
bool do_sympd = false; | ||||
const bool is_sym_size_ok = (n_rows > (is_cx<eT>::yes ? uword(20) : uword(40 | const bool is_sym_size_ok = (A.n_rows > (is_cx<eT>::yes ? uword(20) : uword(40 | |||
))); | ))); | |||
const bool is_arg_default = ((tol == T(0)) && (method_id == uword(0))); | const bool is_arg_default = ((tol == T(0)) && (method_id == uword(0))); | |||
if( (auxlib::crippled_lapack(A) == false) && (is_arg_default || is_sym_size_ | if( (arma_config::optimise_sympd) && (auxlib::crippled_lapack(A) == false) && | |||
ok) ) | (is_arg_default || is_sym_size_ok) ) | |||
{ | { | |||
bool is_approx_sym = false; | bool is_approx_sym = false; | |||
bool is_approx_sympd = false; | bool is_approx_sympd = false; | |||
sympd_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); | sympd_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); | |||
do_sym = is_sym_size_ok && ((is_cx<eT>::no) ? (is_approx_sym) : (is_appr | do_sym = is_sym_size_ok && ((is_cx<eT>::no) ? (is_approx_sym) : (is_approx | |||
ox_sym && is_approx_sympd)); | _sym && is_approx_sympd)); | |||
do_sympd = is_arg_default && is_approx_sympd; | do_sympd = is_arg_default && is_approx_sympd; | |||
} | } | |||
#else | ||||
const bool do_sym = false; | ||||
const bool do_sympd = false; | ||||
#endif | ||||
if(do_sympd) | if(do_sympd) | |||
{ | { | |||
arma_extra_debug_print("op_pinv: attempting sympd optimisation"); | arma_extra_debug_print("op_pinv: attempting sympd optimisation"); | |||
out = A; | out = A; | |||
const T rcond_threshold = T((std::max)(uword(100), uword(A.n_rows))) * std:: | bool is_sympd_junk = false; | |||
numeric_limits<T>::epsilon(); | T rcond_calc = T(0); | |||
const T rcond_threshold = T((std::max)(uword(100), uword(A.n_rows))) * st | ||||
d::numeric_limits<T>::epsilon(); | ||||
const bool status = auxlib::inv_sympd_rcond(out, rcond_threshold); | const bool status = auxlib::inv_sympd_rcond(out, is_sympd_junk, rcond_calc, rcond_threshold); | |||
if(status) { return true; } | if(status && arma_isfinite(rcond_calc)) { return true; } | |||
arma_extra_debug_print("op_pinv: sympd optimisation failed"); | arma_extra_debug_print("op_pinv: sympd optimisation failed"); | |||
// auxlib::inv_sympd_rcond() will fail if A isn't really positive definite o r its rcond is below rcond_threshold | // auxlib::inv_sympd_rcond() will fail if A isn't really positive definite o r its rcond is below rcond_threshold | |||
} | } | |||
if(do_sym) | if(do_sym) | |||
{ | { | |||
arma_extra_debug_print("op_pinv: symmetric/hermitian optimisation"); | arma_extra_debug_print("op_pinv: symmetric/hermitian optimisation"); | |||
return op_pinv::apply_sym(out, A, tol, method_id); | return op_pinv::apply_sym(out, A, tol, method_id); | |||
} | } | |||
// economical SVD decomposition | return op_pinv::apply_gen(out, A, tol, method_id); | |||
Mat<eT> U; | ||||
Col< T> s; | ||||
Mat<eT> V; | ||||
if(n_cols > n_rows) { A = trans(A); } | ||||
const bool status = ((method_id == uword(0)) || (method_id == uword(2))) ? aux | ||||
lib::svd_dc_econ(U, s, V, A) : auxlib::svd_econ(U, s, V, A, 'b'); | ||||
if(status == false) { return false; } | ||||
// set tolerance to default if it hasn't been specified | ||||
if( (tol == T(0)) && (s.n_elem > 0) ) { tol = (std::max)(n_rows, n_cols) * s[ | ||||
0] * std::numeric_limits<T>::epsilon(); } | ||||
uword count = 0; | ||||
for(uword i=0; i < s.n_elem; ++i) { count += (s[i] >= tol) ? uword(1) : uword | ||||
(0); } | ||||
if(count == 0) { out.zeros(n_cols, n_rows); return true; } | ||||
Col<T> s2(count, arma_nozeros_indicator()); | ||||
uword count2 = 0; | ||||
for(uword i=0; i < s.n_elem; ++i) | ||||
{ | ||||
const T val = s[i]; | ||||
if(val >= tol) { s2[count2] = (val > T(0)) ? T(T(1) / val) : T(0); ++count2 | ||||
; } | ||||
} | ||||
const Mat<eT> U_use(U.memptr(), U.n_rows, count, false); | ||||
const Mat<eT> V_use(V.memptr(), V.n_rows, count, false); | ||||
Mat<eT> tmp; | ||||
if(n_rows >= n_cols) | ||||
{ | ||||
// out = ( (V.n_cols > count) ? V.cols(0,count-1) : V ) * diagmat(s2) * tran | ||||
s( (U.n_cols > count) ? U.cols(0,count-1) : U ); | ||||
tmp = V_use * diagmat(s2); | ||||
out = tmp * trans(U_use); | ||||
} | ||||
else | ||||
{ | ||||
// out = ( (U.n_cols > count) ? U.cols(0,count-1) : U ) * diagmat(s2) * tran | ||||
s( (V.n_cols > count) ? V.cols(0,count-1) : V ); | ||||
tmp = U_use * diagmat(s2); | ||||
out = tmp * trans(V_use); | ||||
} | ||||
return true; | ||||
} | } | |||
template<typename eT> | template<typename eT> | |||
inline | inline | |||
bool | bool | |||
op_pinv::apply_diag(Mat<eT>& out, const Mat<eT>& A, typename get_pod_type<eT>::r esult tol) | op_pinv::apply_diag(Mat<eT>& out, const Mat<eT>& A, typename get_pod_type<eT>::r esult tol) | |||
{ | { | |||
arma_extra_debug_sigprint(); | arma_extra_debug_sigprint(); | |||
typedef typename get_pod_type<eT>::result T; | typedef typename get_pod_type<eT>::result T; | |||
skipping to change at line 271 | skipping to change at line 212 | |||
if(abs_val >= tol) { eigval2[count2] = (val != T(0)) ? T(T(1) / val) : T(0) ; ++count2; } | if(abs_val >= tol) { eigval2[count2] = (val != T(0)) ? T(T(1) / val) : T(0) ; ++count2; } | |||
} | } | |||
const Mat<eT> eigvec_use(eigvec.memptr(), eigvec.n_rows, count, false); | const Mat<eT> eigvec_use(eigvec.memptr(), eigvec.n_rows, count, false); | |||
out = (eigvec_use * diagmat(eigval2)).eval() * eigvec_use.t(); | out = (eigvec_use * diagmat(eigval2)).eval() * eigvec_use.t(); | |||
return true; | return true; | |||
} | } | |||
template<typename eT> | ||||
inline | ||||
bool | ||||
op_pinv::apply_gen(Mat<eT>& out, Mat<eT>& A, typename get_pod_type<eT>::result t | ||||
ol, const uword method_id) | ||||
{ | ||||
arma_extra_debug_sigprint(); | ||||
typedef typename get_pod_type<eT>::result T; | ||||
const uword n_rows = A.n_rows; | ||||
const uword n_cols = A.n_cols; | ||||
// economical SVD decomposition | ||||
Mat<eT> U; | ||||
Col< T> s; | ||||
Mat<eT> V; | ||||
if(n_cols > n_rows) { A = trans(A); } | ||||
const bool status = ((method_id == uword(0)) || (method_id == uword(2))) ? aux | ||||
lib::svd_dc_econ(U, s, V, A) : auxlib::svd_econ(U, s, V, A, 'b'); | ||||
if(status == false) { return false; } | ||||
// set tolerance to default if it hasn't been specified | ||||
if( (tol == T(0)) && (s.n_elem > 0) ) { tol = (std::max)(n_rows, n_cols) * s[ | ||||
0] * std::numeric_limits<T>::epsilon(); } | ||||
uword count = 0; | ||||
for(uword i=0; i < s.n_elem; ++i) { count += (s[i] >= tol) ? uword(1) : uword | ||||
(0); } | ||||
if(count == 0) { out.zeros(n_cols, n_rows); return true; } | ||||
Col<T> s2(count, arma_nozeros_indicator()); | ||||
uword count2 = 0; | ||||
for(uword i=0; i < s.n_elem; ++i) | ||||
{ | ||||
const T val = s[i]; | ||||
if(val >= tol) { s2[count2] = (val > T(0)) ? T(T(1) / val) : T(0); ++count2 | ||||
; } | ||||
} | ||||
const Mat<eT> U_use(U.memptr(), U.n_rows, count, false); | ||||
const Mat<eT> V_use(V.memptr(), V.n_rows, count, false); | ||||
Mat<eT> tmp; | ||||
if(n_rows >= n_cols) | ||||
{ | ||||
// out = ( (V.n_cols > count) ? V.cols(0,count-1) : V ) * diagmat(s2) * tran | ||||
s( (U.n_cols > count) ? U.cols(0,count-1) : U ); | ||||
tmp = V_use * diagmat(s2); | ||||
out = tmp * trans(U_use); | ||||
} | ||||
else | ||||
{ | ||||
// out = ( (U.n_cols > count) ? U.cols(0,count-1) : U ) * diagmat(s2) * tran | ||||
s( (V.n_cols > count) ? V.cols(0,count-1) : V ); | ||||
tmp = U_use * diagmat(s2); | ||||
out = tmp * trans(V_use); | ||||
} | ||||
return true; | ||||
} | ||||
//! @} | //! @} | |||
End of changes. 11 change blocks. | ||||
88 lines changed or deleted | 98 lines changed or added |