"Fossies" - the Fresh Open Source Software Archive

Member "armadillo-9.800.3/include/armadillo_bits/gmm_diag_meat.hpp" (16 Jun 2016, 65539 Bytes) of package /linux/misc/armadillo-9.800.3.tar.xz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "gmm_diag_meat.hpp" see the Fossies "Dox" file reference documentation and the last Fossies "Diffs" side-by-side code changes report: 9.600.6_vs_9.700.2.

    1 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
    2 // Copyright 2008-2016 National ICT Australia (NICTA)
    3 // 
    4 // Licensed under the Apache License, Version 2.0 (the "License");
    5 // you may not use this file except in compliance with the License.
    6 // You may obtain a copy of the License at
    7 // http://www.apache.org/licenses/LICENSE-2.0
    8 // 
    9 // Unless required by applicable law or agreed to in writing, software
   10 // distributed under the License is distributed on an "AS IS" BASIS,
   11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   12 // See the License for the specific language governing permissions and
   13 // limitations under the License.
   14 // ------------------------------------------------------------------------
   15 
   16 
   17 //! \addtogroup gmm_diag
   18 //! @{
   19 
   20 
   21 namespace gmm_priv
   22 {
   23 
   24 
   25 template<typename eT>
   26 inline
   27 gmm_diag<eT>::~gmm_diag()
   28   {
   29   arma_extra_debug_sigprint_this(this);
   30   
   31   arma_type_check(( (is_same_type<eT,float>::value == false) && (is_same_type<eT,double>::value == false) ));
   32   }
   33 
   34 
   35 
   36 template<typename eT>
   37 inline
   38 gmm_diag<eT>::gmm_diag()
   39   {
   40   arma_extra_debug_sigprint_this(this);
   41   }
   42 
   43 
   44 
   45 template<typename eT>
   46 inline
   47 gmm_diag<eT>::gmm_diag(const gmm_diag<eT>& x)
   48   {
   49   arma_extra_debug_sigprint_this(this);
   50   
   51   init(x);
   52   }
   53 
   54 
   55 
   56 template<typename eT>
   57 inline
   58 gmm_diag<eT>&
   59 gmm_diag<eT>::operator=(const gmm_diag<eT>& x)
   60   {
   61   arma_extra_debug_sigprint();
   62   
   63   init(x);
   64   
   65   return *this;
   66   }
   67 
   68 
   69 
   70 template<typename eT>
   71 inline
   72 gmm_diag<eT>::gmm_diag(const gmm_full<eT>& x)
   73   {
   74   arma_extra_debug_sigprint_this(this);
   75   
   76   init(x);
   77   }
   78 
   79 
   80 
   81 template<typename eT>
   82 inline
   83 gmm_diag<eT>&
   84 gmm_diag<eT>::operator=(const gmm_full<eT>& x)
   85   {
   86   arma_extra_debug_sigprint();
   87   
   88   init(x);
   89   
   90   return *this;
   91   }
   92 
   93 
   94 
   95 template<typename eT>
   96 inline
   97 gmm_diag<eT>::gmm_diag(const uword in_n_dims, const uword in_n_gaus)
   98   {
   99   arma_extra_debug_sigprint_this(this);
  100   
  101   init(in_n_dims, in_n_gaus);
  102   }
  103 
  104 
  105 
  106 template<typename eT>
  107 inline
  108 void
  109 gmm_diag<eT>::reset()
  110   {
  111   arma_extra_debug_sigprint();
  112   
  113   init(0, 0);
  114   }
  115 
  116 
  117 
  118 template<typename eT>
  119 inline
  120 void
  121 gmm_diag<eT>::reset(const uword in_n_dims, const uword in_n_gaus)
  122   {
  123   arma_extra_debug_sigprint();
  124   
  125   init(in_n_dims, in_n_gaus);
  126   }
  127 
  128 
  129 
  130 template<typename eT>
  131 template<typename T1, typename T2, typename T3>
  132 inline
  133 void
  134 gmm_diag<eT>::set_params(const Base<eT,T1>& in_means_expr, const Base<eT,T2>& in_dcovs_expr, const Base<eT,T3>& in_hefts_expr)
  135   {
  136   arma_extra_debug_sigprint();
  137   
  138   const unwrap<T1> tmp1(in_means_expr.get_ref());
  139   const unwrap<T2> tmp2(in_dcovs_expr.get_ref());
  140   const unwrap<T3> tmp3(in_hefts_expr.get_ref());
  141   
  142   const Mat<eT>& in_means = tmp1.M;
  143   const Mat<eT>& in_dcovs = tmp2.M;
  144   const Mat<eT>& in_hefts = tmp3.M;
  145   
  146   arma_debug_check
  147     (
  148     (arma::size(in_means) != arma::size(in_dcovs)) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1),
  149     "gmm_diag::set_params(): given parameters have inconsistent and/or wrong sizes"
  150     );
  151   
  152   arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_params(): given means have non-finite values" );
  153   arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_params(): given dcovs have non-finite values" );
  154   arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_params(): given hefts have non-finite values" );
  155   
  156   arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_params(): given dcovs have negative or zero values" );
  157   arma_debug_check( (any(vectorise(in_hefts) <  eT(0))), "gmm_diag::set_params(): given hefts have negative values"         );
  158   
  159   const eT s = accu(in_hefts);
  160   
  161   arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_params(): sum of given hefts is not 1" );
  162   
  163   access::rw(means) = in_means;
  164   access::rw(dcovs) = in_dcovs;
  165   access::rw(hefts) = in_hefts;
  166   
  167   init_constants();
  168   }
  169 
  170 
  171 
  172 template<typename eT>
  173 template<typename T1>
  174 inline
  175 void
  176 gmm_diag<eT>::set_means(const Base<eT,T1>& in_means_expr)
  177   {
  178   arma_extra_debug_sigprint();
  179   
  180   const unwrap<T1> tmp(in_means_expr.get_ref());
  181   
  182   const Mat<eT>& in_means = tmp.M;
  183   
  184   arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_diag::set_means(): given means have incompatible size" );
  185   arma_debug_check( (in_means.is_finite() == false),             "gmm_diag::set_means(): given means have non-finite values" );
  186   
  187   access::rw(means) = in_means;
  188   }
  189 
  190 
  191 
  192 template<typename eT>
  193 template<typename T1>
  194 inline
  195 void
  196 gmm_diag<eT>::set_dcovs(const Base<eT,T1>& in_dcovs_expr)
  197   {
  198   arma_extra_debug_sigprint();
  199   
  200   const unwrap<T1> tmp(in_dcovs_expr.get_ref());
  201   
  202   const Mat<eT>& in_dcovs = tmp.M;
  203   
  204   arma_debug_check( (arma::size(in_dcovs) != arma::size(dcovs)), "gmm_diag::set_dcovs(): given dcovs have incompatible size"       );
  205   arma_debug_check( (in_dcovs.is_finite() == false),             "gmm_diag::set_dcovs(): given dcovs have non-finite values"       );
  206   arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))),         "gmm_diag::set_dcovs(): given dcovs have negative or zero values" );
  207   
  208   access::rw(dcovs) = in_dcovs;
  209   
  210   init_constants();
  211   }
  212 
  213 
  214 
  215 template<typename eT>
  216 template<typename T1>
  217 inline
  218 void
  219 gmm_diag<eT>::set_hefts(const Base<eT,T1>& in_hefts_expr)
  220   {
  221   arma_extra_debug_sigprint();
  222   
  223   const unwrap<T1> tmp(in_hefts_expr.get_ref());
  224   
  225   const Mat<eT>& in_hefts = tmp.M;
  226   
  227   arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_diag::set_hefts(): given hefts have incompatible size" );
  228   arma_debug_check( (in_hefts.is_finite() == false),             "gmm_diag::set_hefts(): given hefts have non-finite values" );
  229   arma_debug_check( (any(vectorise(in_hefts) <  eT(0))),         "gmm_diag::set_hefts(): given hefts have negative values"   );
  230   
  231   const eT s = accu(in_hefts);
  232   
  233   arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_hefts(): sum of given hefts is not 1" );
  234   
  235   // make sure all hefts are positive and non-zero
  236   
  237   const eT* in_hefts_mem = in_hefts.memptr();
  238         eT*    hefts_mem = access::rw(hefts).memptr();
  239   
  240   for(uword i=0; i < hefts.n_elem; ++i)
  241     {
  242     hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits<eT>::min() );
  243     }
  244   
  245   access::rw(hefts) /= accu(hefts);
  246   
  247   log_hefts = log(hefts);
  248   }
  249 
  250 
  251 
  252 template<typename eT>
  253 inline
  254 uword
  255 gmm_diag<eT>::n_dims() const
  256   {
  257   return means.n_rows;
  258   }
  259 
  260 
  261 
  262 template<typename eT>
  263 inline
  264 uword
  265 gmm_diag<eT>::n_gaus() const
  266   {
  267   return means.n_cols;
  268   }
  269 
  270 
  271 
  272 template<typename eT>
  273 inline
  274 bool
  275 gmm_diag<eT>::load(const std::string name)
  276   {
  277   arma_extra_debug_sigprint();
  278   
  279   Cube<eT> Q;
  280   
  281   bool status = Q.load(name, arma_binary);
  282   
  283   if( (status == false) || (Q.n_slices != 2) )
  284     {
  285     reset();
  286     arma_debug_warn("gmm_diag::load(): problem with loading or incompatible format");
  287     return false;
  288     }
  289   
  290   if( (Q.n_rows < 2) || (Q.n_cols < 1) )
  291     {
  292     reset();
  293     return true;
  294     }
  295   
  296   access::rw(hefts) = Q.slice(0).row(0);
  297   access::rw(means) = Q.slice(0).submat(1, 0, Q.n_rows-1, Q.n_cols-1);
  298   access::rw(dcovs) = Q.slice(1).submat(1, 0, Q.n_rows-1, Q.n_cols-1);
  299   
  300   init_constants();
  301   
  302   return true;
  303   }
  304 
  305 
  306 
  307 template<typename eT>
  308 inline
  309 bool
  310 gmm_diag<eT>::save(const std::string name) const
  311   {
  312   arma_extra_debug_sigprint();
  313   
  314   Cube<eT> Q(means.n_rows + 1, means.n_cols, 2);
  315   
  316   if(Q.n_elem > 0)
  317     {
  318     Q.slice(0).row(0) = hefts;
  319     Q.slice(1).row(0).zeros();  // reserved for future use
  320     
  321     Q.slice(0).submat(1, 0, arma::size(means)) = means;
  322     Q.slice(1).submat(1, 0, arma::size(dcovs)) = dcovs;
  323     }
  324   
  325   const bool status = Q.save(name, arma_binary);
  326   
  327   return status;
  328   }
  329 
  330 
  331 
  332 template<typename eT>
  333 inline
  334 Col<eT>
  335 gmm_diag<eT>::generate() const
  336   {
  337   arma_extra_debug_sigprint();
  338   
  339   const uword N_dims = means.n_rows;
  340   const uword N_gaus = means.n_cols;
  341   
  342   Col<eT> out( ((N_gaus > 0) ? N_dims : uword(0)), fill::randn );
  343   
  344   if(N_gaus > 0)
  345     {
  346     const double val = randu<double>();
  347     
  348     double csum    = double(0);
  349     uword  gaus_id = 0;
  350     
  351     for(uword j=0; j < N_gaus; ++j)
  352       {
  353       csum += hefts[j];
  354       
  355       if(val <= csum)  { gaus_id = j; break; }
  356       }
  357     
  358     out %= sqrt(dcovs.col(gaus_id));
  359     out += means.col(gaus_id);
  360     }
  361   
  362   return out;
  363   }
  364 
  365 
  366 
  367 template<typename eT>
  368 inline
  369 Mat<eT>
  370 gmm_diag<eT>::generate(const uword N_vec) const
  371   {
  372   arma_extra_debug_sigprint();
  373   
  374   const uword N_dims = means.n_rows;
  375   const uword N_gaus = means.n_cols;
  376   
  377   Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn );
  378   
  379   if(N_gaus > 0)
  380     {
  381     const eT* hefts_mem = hefts.memptr();
  382     
  383     const Mat<eT> sqrt_dcovs = sqrt(dcovs);
  384     
  385     for(uword i=0; i < N_vec; ++i)
  386       {
  387       const double val = randu<double>();
  388       
  389       double csum    = double(0);
  390       uword  gaus_id = 0;
  391       
  392       for(uword j=0; j < N_gaus; ++j)
  393         {
  394         csum += hefts_mem[j];
  395         
  396         if(val <= csum)  { gaus_id = j; break; }
  397         }
  398       
  399       subview_col<eT> out_col = out.col(i);
  400       
  401       out_col %= sqrt_dcovs.col(gaus_id);
  402       out_col += means.col(gaus_id);
  403       }
  404     }
  405   
  406   return out;
  407   }
  408 
  409 
  410 
  411 template<typename eT>
  412 template<typename T1>
  413 inline
  414 eT
  415 gmm_diag<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
  416   {
  417   arma_extra_debug_sigprint();
  418   arma_ignore(junk1);
  419   arma_ignore(junk2);
  420   
  421   const quasi_unwrap<T1> tmp(expr);
  422   
  423   arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
  424   
  425   return internal_scalar_log_p( tmp.M.memptr() );
  426   }
  427 
  428 
  429 
  430 template<typename eT>
  431 template<typename T1>
  432 inline
  433 eT
  434 gmm_diag<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
  435   {
  436   arma_extra_debug_sigprint();
  437   arma_ignore(junk2);
  438   
  439   const quasi_unwrap<T1> tmp(expr);
  440   
  441   arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
  442   
  443   arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" );
  444   
  445   return internal_scalar_log_p( tmp.M.memptr(), gaus_id );
  446   }
  447 
  448 
  449 
  450 template<typename eT>
  451 template<typename T1>
  452 inline
  453 Row<eT>
  454 gmm_diag<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
  455   {
  456   arma_extra_debug_sigprint();
  457   arma_ignore(junk1);
  458   arma_ignore(junk2);
  459   
  460   const quasi_unwrap<T1> tmp(expr);
  461   
  462   const Mat<eT>& X = tmp.M;
  463   
  464   return internal_vec_log_p(X);
  465   }
  466 
  467 
  468 
  469 template<typename eT>
  470 template<typename T1>
  471 inline
  472 Row<eT>
  473 gmm_diag<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
  474   {
  475   arma_extra_debug_sigprint();
  476   arma_ignore(junk2);
  477   
  478   const quasi_unwrap<T1> tmp(expr);
  479   
  480   const Mat<eT>& X = tmp.M;
  481   
  482   return internal_vec_log_p(X, gaus_id);
  483   }
  484 
  485 
  486 
  487 template<typename eT>
  488 template<typename T1>
  489 inline
  490 eT
  491 gmm_diag<eT>::sum_log_p(const Base<eT,T1>& expr) const
  492   {
  493   arma_extra_debug_sigprint();
  494   
  495   const quasi_unwrap<T1> tmp(expr.get_ref());
  496   
  497   const Mat<eT>& X = tmp.M;
  498   
  499   return internal_sum_log_p(X);
  500   }
  501 
  502 
  503 
  504 template<typename eT>
  505 template<typename T1>
  506 inline
  507 eT
  508 gmm_diag<eT>::sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
  509   {
  510   arma_extra_debug_sigprint();
  511   
  512   const quasi_unwrap<T1> tmp(expr.get_ref());
  513   
  514   const Mat<eT>& X = tmp.M;
  515   
  516   return internal_sum_log_p(X, gaus_id);
  517   }
  518 
  519 
  520 
  521 template<typename eT>
  522 template<typename T1>
  523 inline
  524 eT
  525 gmm_diag<eT>::avg_log_p(const Base<eT,T1>& expr) const
  526   {
  527   arma_extra_debug_sigprint();
  528   
  529   const quasi_unwrap<T1> tmp(expr.get_ref());
  530   
  531   const Mat<eT>& X = tmp.M;
  532   
  533   return internal_avg_log_p(X);
  534   }
  535 
  536 
  537 
  538 template<typename eT>
  539 template<typename T1>
  540 inline
  541 eT
  542 gmm_diag<eT>::avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
  543   {
  544   arma_extra_debug_sigprint();
  545   
  546   const quasi_unwrap<T1> tmp(expr.get_ref());
  547   
  548   const Mat<eT>& X = tmp.M;
  549   
  550   return internal_avg_log_p(X, gaus_id);
  551   }
  552 
  553 
  554 
  555 template<typename eT>
  556 template<typename T1>
  557 inline
  558 uword
  559 gmm_diag<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk) const
  560   {
  561   arma_extra_debug_sigprint();
  562   arma_ignore(junk);
  563   
  564   const quasi_unwrap<T1> tmp(expr);
  565   
  566   const Mat<eT>& X = tmp.M;
  567   
  568   return internal_scalar_assign(X, dist);
  569   }
  570 
  571 
  572 
  573 template<typename eT>
  574 template<typename T1>
  575 inline
  576 urowvec
  577 gmm_diag<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk) const
  578   {
  579   arma_extra_debug_sigprint();
  580   arma_ignore(junk);
  581   
  582   urowvec out;
  583   
  584   const quasi_unwrap<T1> tmp(expr);
  585   
  586   const Mat<eT>& X = tmp.M;
  587   
  588   internal_vec_assign(out, X, dist);
  589   
  590   return out;
  591   }
  592 
  593 
  594 
  595 template<typename eT>
  596 template<typename T1>
  597 inline
  598 urowvec
  599 gmm_diag<eT>::raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
  600   {
  601   arma_extra_debug_sigprint();
  602   
  603   const unwrap<T1>   tmp(expr.get_ref());
  604   const Mat<eT>& X = tmp.M;
  605   
  606   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::raw_hist(): incompatible dimensions" );
  607   
  608   arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::raw_hist(): unsupported distance mode" );
  609   
  610   urowvec hist;
  611   
  612   internal_raw_hist(hist, X, dist_mode);
  613   
  614   return hist;
  615   }
  616 
  617 
  618 
  619 template<typename eT>
  620 template<typename T1>
  621 inline
  622 Row<eT>
  623 gmm_diag<eT>::norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
  624   {
  625   arma_extra_debug_sigprint();
  626   
  627   const unwrap<T1>   tmp(expr.get_ref());
  628   const Mat<eT>& X = tmp.M;
  629   
  630   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::norm_hist(): incompatible dimensions" );
  631   
  632   arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::norm_hist(): unsupported distance mode" );
  633   
  634   urowvec hist;
  635   
  636   internal_raw_hist(hist, X, dist_mode);
  637   
  638   const uword  hist_n_elem = hist.n_elem;
  639   const uword* hist_mem    = hist.memptr();
  640   
  641   eT acc = eT(0);
  642   for(uword i=0; i<hist_n_elem; ++i)  { acc += eT(hist_mem[i]); }
  643   
  644   if(acc == eT(0))  { acc = eT(1); }
  645   
  646   Row<eT> out(hist_n_elem);
  647   
  648   eT* out_mem = out.memptr();
  649   
  650   for(uword i=0; i<hist_n_elem; ++i)  { out_mem[i] = eT(hist_mem[i]) / acc; }
  651   
  652   return out;
  653   }
  654 
  655 
  656 
  657 template<typename eT>
  658 template<typename T1>
  659 inline
  660 bool
  661 gmm_diag<eT>::learn
  662   (
  663   const Base<eT,T1>&   data,
  664   const uword          N_gaus,
  665   const gmm_dist_mode& dist_mode,
  666   const gmm_seed_mode& seed_mode,
  667   const uword          km_iter,
  668   const uword          em_iter,
  669   const eT             var_floor,
  670   const bool           print_mode
  671   )
  672   {
  673   arma_extra_debug_sigprint();
  674   
  675   const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist);
  676   
  677   const bool seed_mode_ok = \
  678        (seed_mode == keep_existing)
  679     || (seed_mode == static_subset)
  680     || (seed_mode == static_spread)
  681     || (seed_mode == random_subset)
  682     || (seed_mode == random_spread);
  683   
  684   arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" );
  685   arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown seed_mode"                        );
  686   arma_debug_check( (var_floor < eT(0)    ), "gmm_diag::learn(): variance floor is negative"               );
  687   
  688   const unwrap<T1>   tmp_X(data.get_ref());
  689   const Mat<eT>& X = tmp_X.M;
  690   
  691   if(X.is_empty()          )  { arma_debug_warn("gmm_diag::learn(): given matrix is empty"             ); return false; }
  692   if(X.is_finite() == false)  { arma_debug_warn("gmm_diag::learn(): given matrix has non-finite values"); return false; }
  693   
  694   if(N_gaus == 0)  { reset(); return true; }
  695   
  696   if(dist_mode == maha_dist)
  697     {
  698     mah_aux = var(X,1,1);
  699     
  700     const uword mah_aux_n_elem = mah_aux.n_elem;
  701           eT*   mah_aux_mem    = mah_aux.memptr();
  702     
  703     for(uword i=0; i < mah_aux_n_elem; ++i)
  704       {
  705       const eT val = mah_aux_mem[i];
  706       
  707       mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1);
  708       }
  709     }
  710   
  711   
  712   // copy current model, in case of failure by k-means and/or EM
  713   
  714   const gmm_diag<eT> orig = (*this);
  715   
  716   
  717   // initial means
  718   
  719   if(seed_mode == keep_existing)
  720     {
  721     if(means.is_empty()        )  { arma_debug_warn("gmm_diag::learn(): no existing means"      ); return false; }
  722     if(X.n_rows != means.n_rows)  { arma_debug_warn("gmm_diag::learn(): dimensionality mismatch"); return false; }
  723     
  724     // TODO: also check for number of vectors?
  725     }
  726   else
  727     {
  728     if(X.n_cols < N_gaus)  { arma_debug_warn("gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; }
  729     
  730     reset(X.n_rows, N_gaus);
  731     
  732     if(print_mode)  { get_cout_stream() << "gmm_diag::learn(): generating initial means\n"; get_cout_stream().flush(); }
  733     
  734          if(dist_mode == eucl_dist)  { generate_initial_means<1>(X, seed_mode); }
  735     else if(dist_mode == maha_dist)  { generate_initial_means<2>(X, seed_mode); }
  736     }
  737   
  738   
  739   // k-means
  740   
  741   if(km_iter > 0)
  742     {
  743     const arma_ostream_state stream_state(get_cout_stream());
  744     
  745     bool status = false;
  746     
  747          if(dist_mode == eucl_dist)  { status = km_iterate<1>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); }
  748     else if(dist_mode == maha_dist)  { status = km_iterate<2>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); }
  749     
  750     stream_state.restore(get_cout_stream());
  751     
  752     if(status == false)  { arma_debug_warn("gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; }
  753     }
  754   
  755   
  756   // initial dcovs
  757   
  758   const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits<eT>::min();
  759   
  760   if(seed_mode != keep_existing)
  761     {
  762     if(print_mode)  { get_cout_stream() << "gmm_diag::learn(): generating initial covariances\n"; get_cout_stream().flush(); }
  763     
  764          if(dist_mode == eucl_dist)  { generate_initial_params<1>(X, var_floor_actual); }
  765     else if(dist_mode == maha_dist)  { generate_initial_params<2>(X, var_floor_actual); }
  766     }
  767   
  768   
  769   // EM algorithm
  770   
  771   if(em_iter > 0)
  772     {
  773     const arma_ostream_state stream_state(get_cout_stream());
  774     
  775     const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode);
  776     
  777     stream_state.restore(get_cout_stream());
  778     
  779     if(status == false)  { arma_debug_warn("gmm_diag::learn(): EM algorithm failed"); init(orig); return false; }
  780     }
  781   
  782   mah_aux.reset();
  783   
  784   init_constants();
  785   
  786   return true;
  787   }
  788 
  789 
  790 
  791 template<typename eT>
  792 template<typename T1>
  793 inline
  794 bool
  795 gmm_diag<eT>::kmeans_wrapper
  796   (
  797         Mat<eT>&       user_means,
  798   const Base<eT,T1>&   data,
  799   const uword          N_gaus,
  800   const gmm_seed_mode& seed_mode,
  801   const uword          km_iter,
  802   const bool           print_mode
  803   )
  804   {
  805   arma_extra_debug_sigprint();
  806   
  807   const bool seed_mode_ok = \
  808        (seed_mode == keep_existing)
  809     || (seed_mode == static_subset)
  810     || (seed_mode == static_spread)
  811     || (seed_mode == random_subset)
  812     || (seed_mode == random_spread);
  813   
  814   arma_debug_check( (seed_mode_ok == false), "kmeans(): unknown seed_mode" );
  815   
  816   const unwrap<T1>   tmp_X(data.get_ref());
  817   const Mat<eT>& X = tmp_X.M;
  818   
  819   if(X.is_empty()          )  { arma_debug_warn("kmeans(): given matrix is empty"             ); return false; }
  820   if(X.is_finite() == false)  { arma_debug_warn("kmeans(): given matrix has non-finite values"); return false; }
  821   
  822   if(N_gaus == 0)  { reset(); return true; }
  823   
  824   
  825   // initial means
  826   
  827   if(seed_mode == keep_existing)
  828     {
  829     access::rw(means) = user_means;
  830     
  831     if(means.is_empty()        )  { arma_debug_warn("kmeans(): no existing means"      ); return false; }
  832     if(X.n_rows != means.n_rows)  { arma_debug_warn("kmeans(): dimensionality mismatch"); return false; }
  833     
  834     // TODO: also check for number of vectors?
  835     }
  836   else
  837     {
  838     if(X.n_cols < N_gaus)  { arma_debug_warn("kmeans(): number of vectors is less than number of means"); return false; }
  839     
  840     access::rw(means).zeros(X.n_rows, N_gaus);
  841     
  842     if(print_mode)  { get_cout_stream() << "kmeans(): generating initial means\n"; }
  843     
  844     generate_initial_means<1>(X, seed_mode);
  845     }
  846   
  847   
  848   // k-means
  849   
  850   if(km_iter > 0)
  851     {
  852     const arma_ostream_state stream_state(get_cout_stream());
  853     
  854     bool status = false;
  855     
  856     status = km_iterate<1>(X, km_iter, print_mode, "kmeans()");
  857     
  858     stream_state.restore(get_cout_stream());
  859     
  860     if(status == false)  { arma_debug_warn("kmeans(): clustering failed; not enough data, or too many means requested"); return false; }
  861     }
  862   
  863   return true;
  864   }
  865 
  866 
  867 
  868 //
  869 //
  870 //
  871 
  872 
  873 
  874 template<typename eT>
  875 inline
  876 void
  877 gmm_diag<eT>::init(const gmm_diag<eT>& x)
  878   {
  879   arma_extra_debug_sigprint();
  880   
  881   gmm_diag<eT>& t = *this;
  882   
  883   if(&t != &x)
  884     {
  885     access::rw(t.means) = x.means;
  886     access::rw(t.dcovs) = x.dcovs;
  887     access::rw(t.hefts) = x.hefts;
  888     
  889     init_constants();
  890     }
  891   }
  892 
  893 
  894 
  895 template<typename eT>
  896 inline
  897 void
  898 gmm_diag<eT>::init(const gmm_full<eT>& x)
  899   {
  900   arma_extra_debug_sigprint();
  901   
  902   access::rw(hefts) = x.hefts;
  903   access::rw(means) = x.means;
  904   
  905   const uword N_dims = x.means.n_rows;
  906   const uword N_gaus = x.means.n_cols;
  907   
  908   access::rw(dcovs).zeros(N_dims,N_gaus);
  909   
  910   for(uword g=0; g < N_gaus; ++g)
  911     {
  912     const Mat<eT>& fcov = x.fcovs.slice(g);
  913     
  914     eT* dcov_mem = access::rw(dcovs).colptr(g);
  915     
  916     for(uword d=0; d < N_dims; ++d)
  917       {
  918       dcov_mem[d] = fcov.at(d,d);
  919       }
  920     }
  921   
  922   init_constants();
  923   }
  924 
  925 
  926 
  927 template<typename eT>
  928 inline
  929 void
  930 gmm_diag<eT>::init(const uword in_n_dims, const uword in_n_gaus)
  931   {
  932   arma_extra_debug_sigprint();
  933   
  934   access::rw(means).zeros(in_n_dims, in_n_gaus);
  935   
  936   access::rw(dcovs).ones(in_n_dims, in_n_gaus);
  937   
  938   access::rw(hefts).set_size(in_n_gaus);
  939   
  940   access::rw(hefts).fill(eT(1) / eT(in_n_gaus));
  941   
  942   init_constants();
  943   }
  944 
  945 
  946 
  947 template<typename eT>
  948 inline
  949 void
  950 gmm_diag<eT>::init_constants()
  951   {
  952   arma_extra_debug_sigprint();
  953   
  954   const uword N_dims = means.n_rows;
  955   const uword N_gaus = means.n_cols;
  956   
  957   // 
  958   
  959   inv_dcovs.copy_size(dcovs);
  960   
  961   const eT*     dcovs_mem =     dcovs.memptr();
  962         eT* inv_dcovs_mem = inv_dcovs.memptr();
  963   
  964   const uword dcovs_n_elem = dcovs.n_elem;
  965   
  966   for(uword i=0; i < dcovs_n_elem; ++i)
  967     {
  968     inv_dcovs_mem[i] = eT(1) / (std::max)( dcovs_mem[i], std::numeric_limits<eT>::min() );
  969     }
  970   
  971   //
  972   
  973   const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi);
  974   
  975   log_det_etc.set_size(N_gaus);
  976   
  977   for(uword g=0; g < N_gaus; ++g)
  978     {
  979     const eT* dcovs_colmem = dcovs.colptr(g);
  980     
  981     eT log_det_val = eT(0);
  982     
  983     for(uword d=0; d < N_dims; ++d)
  984       {
  985       log_det_val += std::log( (std::max)( dcovs_colmem[d], std::numeric_limits<eT>::min() ) );
  986       }
  987     
  988     log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val );
  989     }
  990   
  991   //
  992   
  993   eT* hefts_mem = access::rw(hefts).memptr();
  994   
  995   for(uword g=0; g < N_gaus; ++g)
  996     {
  997     hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits<eT>::min() );
  998     }
  999   
 1000   log_hefts = log(hefts);
 1001   }
 1002 
 1003 
 1004 
 1005 template<typename eT>
 1006 inline
 1007 umat
 1008 gmm_diag<eT>::internal_gen_boundaries(const uword N) const
 1009   {
 1010   arma_extra_debug_sigprint();
 1011   
 1012   #if defined(ARMA_USE_OPENMP)
 1013     const uword n_threads_avail = (omp_in_parallel()) ? uword(1) : uword(omp_get_max_threads());
 1014     const uword n_threads       = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1;
 1015   #else
 1016     static const uword n_threads = 1;
 1017   #endif
 1018   
 1019   // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): n_threads: " << n_threads << '\n';
 1020   
 1021   umat boundaries(2, n_threads);
 1022   
 1023   if(N > 0)
 1024     {
 1025     const uword chunk_size = N / n_threads;
 1026     
 1027     uword count = 0;
 1028     
 1029     for(uword t=0; t<n_threads; t++)
 1030       {
 1031       boundaries.at(0,t) = count;
 1032       
 1033       count += chunk_size;
 1034       
 1035       boundaries.at(1,t) = count-1;
 1036       }
 1037     
 1038     boundaries.at(1,n_threads-1) = N - 1;
 1039     }
 1040   else
 1041     {
 1042     boundaries.zeros();
 1043     }
 1044   
 1045   // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): boundaries: " << '\n' << boundaries << '\n';
 1046   
 1047   return boundaries;
 1048   }
 1049 
 1050 
 1051 
 1052 template<typename eT>
 1053 arma_hot
 1054 inline
 1055 eT
 1056 gmm_diag<eT>::internal_scalar_log_p(const eT* x) const
 1057   {
 1058   arma_extra_debug_sigprint();
 1059   
 1060   const eT* log_hefts_mem = log_hefts.mem;
 1061   
 1062   const uword N_gaus = means.n_cols;
 1063   
 1064   if(N_gaus > 0)
 1065     {
 1066     eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0];
 1067     
 1068     for(uword g=1; g < N_gaus; ++g)
 1069       {
 1070       const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g];
 1071       
 1072       log_sum = log_add_exp(log_sum, tmp);
 1073       }
 1074     
 1075     return log_sum;
 1076     }
 1077   else
 1078     {
 1079     return -Datum<eT>::inf;
 1080     }
 1081   }
 1082 
 1083 
 1084 
 1085 template<typename eT>
 1086 arma_hot
 1087 inline
 1088 eT
 1089 gmm_diag<eT>::internal_scalar_log_p(const eT* x, const uword g) const
 1090   {
 1091   arma_extra_debug_sigprint();
 1092   
 1093   const eT*     mean =     means.colptr(g);
 1094   const eT* inv_dcov = inv_dcovs.colptr(g);
 1095   
 1096   const uword N_dims = means.n_rows;
 1097   
 1098   eT val_i = eT(0);
 1099   eT val_j = eT(0);
 1100   
 1101   uword i,j;
 1102   
 1103   for(i=0, j=1; j<N_dims; i+=2, j+=2)
 1104     {
 1105     eT tmp_i = x[i];
 1106     eT tmp_j = x[j];
 1107     
 1108     tmp_i -= mean[i];
 1109     tmp_j -= mean[j];
 1110     
 1111     val_i += (tmp_i*tmp_i) * inv_dcov[i];
 1112     val_j += (tmp_j*tmp_j) * inv_dcov[j];
 1113     }
 1114   
 1115   if(i < N_dims)
 1116     {
 1117     const eT tmp = x[i] - mean[i];
 1118     
 1119     val_i += (tmp*tmp) * inv_dcov[i];
 1120     }
 1121   
 1122   return eT(-0.5)*(val_i + val_j) + log_det_etc.mem[g];
 1123   }
 1124 
 1125 
 1126 
 1127 template<typename eT>
 1128 inline
 1129 Row<eT>
 1130 gmm_diag<eT>::internal_vec_log_p(const Mat<eT>& X) const
 1131   {
 1132   arma_extra_debug_sigprint();
 1133   
 1134   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
 1135   
 1136   const uword N = X.n_cols;
 1137   
 1138   Row<eT> out(N);
 1139   
 1140   if(N > 0)
 1141     {
 1142     #if defined(ARMA_USE_OPENMP)
 1143       {
 1144       const umat boundaries = internal_gen_boundaries(N);
 1145       
 1146       const uword n_threads = boundaries.n_cols;
 1147       
 1148       #pragma omp parallel for schedule(static)
 1149       for(uword t=0; t < n_threads; ++t)
 1150         {
 1151         const uword start_index = boundaries.at(0,t);
 1152         const uword   end_index = boundaries.at(1,t);
 1153         
 1154         eT* out_mem = out.memptr();
 1155         
 1156         for(uword i=start_index; i <= end_index; ++i)
 1157           {
 1158           out_mem[i] = internal_scalar_log_p( X.colptr(i) );
 1159           }
 1160         }
 1161       }
 1162     #else
 1163       {
 1164       eT* out_mem = out.memptr();
 1165       
 1166       for(uword i=0; i < N; ++i)
 1167         {
 1168         out_mem[i] = internal_scalar_log_p( X.colptr(i) );
 1169         }
 1170       }
 1171     #endif
 1172     }
 1173   
 1174   return out;
 1175   }
 1176 
 1177 
 1178 
 1179 template<typename eT>
 1180 inline
 1181 Row<eT>
 1182 gmm_diag<eT>::internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const
 1183   {
 1184   arma_extra_debug_sigprint();
 1185   
 1186   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
 1187   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" );
 1188   
 1189   const uword N = X.n_cols;
 1190   
 1191   Row<eT> out(N);
 1192   
 1193   if(N > 0)
 1194     {
 1195     #if defined(ARMA_USE_OPENMP)
 1196       {
 1197       const umat boundaries = internal_gen_boundaries(N);
 1198       
 1199       const uword n_threads = boundaries.n_cols;
 1200       
 1201       #pragma omp parallel for schedule(static)
 1202       for(uword t=0; t < n_threads; ++t)
 1203         {
 1204         const uword start_index = boundaries.at(0,t);
 1205         const uword   end_index = boundaries.at(1,t);
 1206         
 1207         eT* out_mem = out.memptr();
 1208         
 1209         for(uword i=start_index; i <= end_index; ++i)
 1210           {
 1211           out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
 1212           }
 1213         }
 1214       }
 1215     #else
 1216       {
 1217       eT* out_mem = out.memptr();
 1218       
 1219       for(uword i=0; i < N; ++i)
 1220         {
 1221         out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
 1222         }
 1223       }
 1224     #endif
 1225     }
 1226   
 1227   return out;
 1228   }
 1229 
 1230 
 1231 
 1232 template<typename eT>
 1233 inline
 1234 eT
 1235 gmm_diag<eT>::internal_sum_log_p(const Mat<eT>& X) const
 1236   {
 1237   arma_extra_debug_sigprint();
 1238   
 1239   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" );
 1240     
 1241   const uword N = X.n_cols;
 1242   
 1243   if(N == 0)  { return (-Datum<eT>::inf); }
 1244   
 1245   
 1246   #if defined(ARMA_USE_OPENMP)
 1247     {
 1248     const umat boundaries = internal_gen_boundaries(N);
 1249     
 1250     const uword n_threads = boundaries.n_cols;
 1251     
 1252     Col<eT> t_accs(n_threads, fill::zeros);
 1253     
 1254     #pragma omp parallel for schedule(static)
 1255     for(uword t=0; t < n_threads; ++t)
 1256       {
 1257       const uword start_index = boundaries.at(0,t);
 1258       const uword   end_index = boundaries.at(1,t);
 1259       
 1260       eT t_acc = eT(0);
 1261       
 1262       for(uword i=start_index; i <= end_index; ++i)
 1263         {
 1264         t_acc += internal_scalar_log_p( X.colptr(i) );
 1265         }
 1266       
 1267       t_accs[t] = t_acc;
 1268       }
 1269     
 1270     return eT(accu(t_accs));
 1271     }
 1272   #else
 1273     {
 1274     eT acc = eT(0);
 1275     
 1276     for(uword i=0; i<N; ++i)
 1277       {
 1278       acc += internal_scalar_log_p( X.colptr(i) );
 1279       }
 1280     
 1281     return acc;
 1282     }
 1283   #endif
 1284   }
 1285 
 1286 
 1287 
 1288 template<typename eT>
 1289 inline
 1290 eT
 1291 gmm_diag<eT>::internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const
 1292   {
 1293   arma_extra_debug_sigprint();
 1294   
 1295   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions"            );
 1296   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_diag::sum_log_p(): specified gaussian is out of range" );
 1297     
 1298   const uword N = X.n_cols;
 1299   
 1300   if(N == 0)  { return (-Datum<eT>::inf); }
 1301   
 1302   
 1303   #if defined(ARMA_USE_OPENMP)
 1304     {
 1305     const umat boundaries = internal_gen_boundaries(N);
 1306     
 1307     const uword n_threads = boundaries.n_cols;
 1308     
 1309     Col<eT> t_accs(n_threads, fill::zeros);
 1310     
 1311     #pragma omp parallel for schedule(static)
 1312     for(uword t=0; t < n_threads; ++t)
 1313       {
 1314       const uword start_index = boundaries.at(0,t);
 1315       const uword   end_index = boundaries.at(1,t);
 1316       
 1317       eT t_acc = eT(0);
 1318       
 1319       for(uword i=start_index; i <= end_index; ++i)
 1320         {
 1321         t_acc += internal_scalar_log_p( X.colptr(i), gaus_id );
 1322         }
 1323       
 1324       t_accs[t] = t_acc;
 1325       }
 1326     
 1327     return eT(accu(t_accs));
 1328     }
 1329   #else
 1330     {
 1331     eT acc = eT(0);
 1332     
 1333     for(uword i=0; i<N; ++i)
 1334       {
 1335       acc += internal_scalar_log_p( X.colptr(i), gaus_id );
 1336       }
 1337     
 1338     return acc;
 1339     }
 1340   #endif
 1341   }
 1342 
 1343 
 1344 
 1345 template<typename eT>
 1346 inline
 1347 eT
 1348 gmm_diag<eT>::internal_avg_log_p(const Mat<eT>& X) const
 1349   {
 1350   arma_extra_debug_sigprint();
 1351   
 1352   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" );
 1353     
 1354   const uword N = X.n_cols;
 1355   
 1356   if(N == 0)  { return (-Datum<eT>::inf); }
 1357   
 1358   
 1359   #if defined(ARMA_USE_OPENMP)
 1360     {
 1361     const umat boundaries = internal_gen_boundaries(N);
 1362     
 1363     const uword n_threads = boundaries.n_cols;
 1364     
 1365     field< running_mean_scalar<eT> > t_running_means(n_threads);
 1366     
 1367     
 1368     #pragma omp parallel for schedule(static)
 1369     for(uword t=0; t < n_threads; ++t)
 1370       {
 1371       const uword start_index = boundaries.at(0,t);
 1372       const uword   end_index = boundaries.at(1,t);
 1373       
 1374       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
 1375       
 1376       for(uword i=start_index; i <= end_index; ++i)
 1377         {
 1378         current_running_mean( internal_scalar_log_p( X.colptr(i) ) );
 1379         }
 1380       }
 1381     
 1382     
 1383     eT avg = eT(0);
 1384     
 1385     for(uword t=0; t < n_threads; ++t)
 1386       {
 1387       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
 1388       
 1389       const eT w = eT(current_running_mean.count()) / eT(N);
 1390       
 1391       avg += w * current_running_mean.mean();
 1392       }
 1393     
 1394     return avg;
 1395     }
 1396   #else
 1397     {
 1398     running_mean_scalar<eT> running_mean;
 1399     
 1400     for(uword i=0; i<N; ++i)
 1401       {
 1402       running_mean( internal_scalar_log_p( X.colptr(i) ) );
 1403       }
 1404     
 1405     return running_mean.mean();
 1406     }
 1407   #endif
 1408   }
 1409 
 1410 
 1411 
 1412 template<typename eT>
 1413 inline
 1414 eT
 1415 gmm_diag<eT>::internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const
 1416   {
 1417   arma_extra_debug_sigprint();
 1418   
 1419   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions"            );
 1420   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_diag::avg_log_p(): specified gaussian is out of range" );
 1421   
 1422   const uword N = X.n_cols;
 1423   
 1424   if(N == 0)  { return (-Datum<eT>::inf); }
 1425   
 1426   
 1427   #if defined(ARMA_USE_OPENMP)
 1428     {
 1429     const umat boundaries = internal_gen_boundaries(N);
 1430     
 1431     const uword n_threads = boundaries.n_cols;
 1432     
 1433     field< running_mean_scalar<eT> > t_running_means(n_threads);
 1434     
 1435     
 1436     #pragma omp parallel for schedule(static)
 1437     for(uword t=0; t < n_threads; ++t)
 1438       {
 1439       const uword start_index = boundaries.at(0,t);
 1440       const uword   end_index = boundaries.at(1,t);
 1441       
 1442       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
 1443       
 1444       for(uword i=start_index; i <= end_index; ++i)
 1445         {
 1446         current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) );
 1447         }
 1448       }
 1449     
 1450     
 1451     eT avg = eT(0);
 1452     
 1453     for(uword t=0; t < n_threads; ++t)
 1454       {
 1455       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
 1456       
 1457       const eT w = eT(current_running_mean.count()) / eT(N);
 1458       
 1459       avg += w * current_running_mean.mean();
 1460       }
 1461     
 1462     return avg;
 1463     }
 1464   #else
 1465     {
 1466     running_mean_scalar<eT> running_mean;
 1467     
 1468     for(uword i=0; i<N; ++i)
 1469       {
 1470       running_mean( internal_scalar_log_p( X.colptr(i), gaus_id ) );
 1471       }
 1472     
 1473     return running_mean.mean();
 1474     }
 1475   #endif
 1476   }
 1477 
 1478 
 1479 
 1480 template<typename eT>
 1481 inline
 1482 uword
 1483 gmm_diag<eT>::internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
 1484   {
 1485   arma_extra_debug_sigprint();
 1486   
 1487   const uword N_dims = means.n_rows;
 1488   const uword N_gaus = means.n_cols;
 1489   
 1490   arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" );
 1491   arma_debug_check( (N_gaus == 0),        "gmm_diag::assign(): model has no means"      );
 1492   
 1493   const eT* X_mem = X.colptr(0);
 1494   
 1495   if(dist_mode == eucl_dist)
 1496     {
 1497     eT    best_dist = Datum<eT>::inf;
 1498     uword best_g    = 0;
 1499     
 1500     for(uword g=0; g < N_gaus; ++g)
 1501       {
 1502       const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(g), X_mem);
 1503       
 1504       if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
 1505       }
 1506     
 1507     return best_g;
 1508     }
 1509   else
 1510   if(dist_mode == prob_dist)
 1511     {
 1512     const eT* log_hefts_mem = log_hefts.memptr();
 1513     
 1514     eT    best_p = -Datum<eT>::inf;
 1515     uword best_g = 0;
 1516     
 1517     for(uword g=0; g < N_gaus; ++g)
 1518       {
 1519       const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g];
 1520       
 1521       if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
 1522       }
 1523     
 1524     return best_g;
 1525     }
 1526   else
 1527     {
 1528     arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode");
 1529     }
 1530   
 1531   return uword(0);
 1532   }
 1533 
 1534 
 1535 
 1536 template<typename eT>
 1537 inline
 1538 void
 1539 gmm_diag<eT>::internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
 1540   {
 1541   arma_extra_debug_sigprint();
 1542   
 1543   const uword N_dims = means.n_rows;
 1544   const uword N_gaus = means.n_cols;
 1545   
 1546   arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" );
 1547   
 1548   const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0;
 1549   
 1550   out.set_size(1,X_n_cols);
 1551   
 1552   uword* out_mem = out.memptr();
 1553   
 1554   if(dist_mode == eucl_dist)
 1555     {
 1556     #if defined(ARMA_USE_OPENMP)
 1557       {
 1558       #pragma omp parallel for schedule(static)
 1559       for(uword i=0; i<X_n_cols; ++i)
 1560         {
 1561         const eT* X_colptr = X.colptr(i);
 1562         
 1563         eT    best_dist = Datum<eT>::inf;
 1564         uword best_g    = 0;
 1565         
 1566         for(uword g=0; g<N_gaus; ++g)
 1567           {
 1568           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
 1569           
 1570           if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
 1571           }
 1572         
 1573         out_mem[i] = best_g;
 1574         }
 1575       }
 1576     #else
 1577       {
 1578       for(uword i=0; i<X_n_cols; ++i)
 1579         {
 1580         const eT* X_colptr = X.colptr(i);
 1581         
 1582         eT    best_dist = Datum<eT>::inf;
 1583         uword best_g    = 0;
 1584         
 1585         for(uword g=0; g<N_gaus; ++g)
 1586           {
 1587           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
 1588           
 1589           if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
 1590           }
 1591         
 1592         out_mem[i] = best_g;
 1593         }
 1594       }
 1595     #endif
 1596     }
 1597   else
 1598   if(dist_mode == prob_dist)
 1599     {
 1600     #if defined(ARMA_USE_OPENMP)
 1601       {
 1602       const eT* log_hefts_mem = log_hefts.memptr();
 1603       
 1604       #pragma omp parallel for schedule(static)
 1605       for(uword i=0; i<X_n_cols; ++i)
 1606         {
 1607         const eT* X_colptr = X.colptr(i);
 1608         
 1609         eT    best_p = -Datum<eT>::inf;
 1610         uword best_g = 0;
 1611         
 1612         for(uword g=0; g<N_gaus; ++g)
 1613           {
 1614           const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
 1615           
 1616           if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
 1617           }
 1618         
 1619         out_mem[i] = best_g;
 1620         }
 1621       }
 1622     #else
 1623       {
 1624       const eT* log_hefts_mem = log_hefts.memptr();
 1625       
 1626       for(uword i=0; i<X_n_cols; ++i)
 1627         {
 1628         const eT* X_colptr = X.colptr(i);
 1629          
 1630         eT    best_p = -Datum<eT>::inf;
 1631         uword best_g = 0;
 1632         
 1633         for(uword g=0; g<N_gaus; ++g)
 1634           {
 1635           const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
 1636           
 1637           if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
 1638           }
 1639         
 1640         out_mem[i] = best_g;
 1641         }
 1642       }
 1643     #endif
 1644     }
 1645   else
 1646     {
 1647     arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode");
 1648     }
 1649   }
 1650 
 1651 
 1652 
 1653 
 1654 template<typename eT>
 1655 inline
 1656 void
 1657 gmm_diag<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
 1658   {
 1659   arma_extra_debug_sigprint();
 1660   
 1661   const uword N_dims = means.n_rows;
 1662   const uword N_gaus = means.n_cols;
 1663   
 1664   const uword X_n_cols = X.n_cols;
 1665   
 1666   hist.zeros(N_gaus);
 1667   
 1668   if(N_gaus == 0)  { return; }
 1669   
 1670   #if defined(ARMA_USE_OPENMP)
 1671     {
 1672     const umat boundaries = internal_gen_boundaries(X_n_cols);
 1673     
 1674     const uword n_threads = boundaries.n_cols;
 1675     
 1676     field<urowvec> thread_hist(n_threads);
 1677     
 1678     for(uword t=0; t < n_threads; ++t)  { thread_hist(t).zeros(N_gaus); }
 1679     
 1680     
 1681     if(dist_mode == eucl_dist)
 1682       {
 1683       #pragma omp parallel for schedule(static)
 1684       for(uword t=0; t < n_threads; ++t)
 1685         {
 1686         uword* thread_hist_mem = thread_hist(t).memptr();
 1687         
 1688         const uword start_index = boundaries.at(0,t);
 1689         const uword   end_index = boundaries.at(1,t);
 1690         
 1691         for(uword i=start_index; i <= end_index; ++i)
 1692           {
 1693           const eT* X_colptr = X.colptr(i);
 1694           
 1695           eT    best_dist = Datum<eT>::inf;
 1696           uword best_g    = 0;
 1697           
 1698           for(uword g=0; g < N_gaus; ++g)
 1699             {
 1700             const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
 1701             
 1702             if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
 1703             }
 1704           
 1705           thread_hist_mem[best_g]++;
 1706           }
 1707         }
 1708       }
 1709     else
 1710     if(dist_mode == prob_dist)
 1711       {
 1712       const eT* log_hefts_mem = log_hefts.memptr();
 1713       
 1714       #pragma omp parallel for schedule(static)
 1715       for(uword t=0; t < n_threads; ++t)
 1716         {
 1717         uword* thread_hist_mem = thread_hist(t).memptr();
 1718         
 1719         const uword start_index = boundaries.at(0,t);
 1720         const uword   end_index = boundaries.at(1,t);
 1721         
 1722         for(uword i=start_index; i <= end_index; ++i)
 1723           {
 1724           const eT* X_colptr = X.colptr(i);
 1725             
 1726           eT    best_p = -Datum<eT>::inf;
 1727           uword best_g = 0;
 1728           
 1729           for(uword g=0; g < N_gaus; ++g)
 1730             {
 1731             const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
 1732             
 1733             if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
 1734             }
 1735           
 1736           thread_hist_mem[best_g]++;
 1737           }
 1738         }
 1739       }
 1740     
 1741     // reduction
 1742     hist = thread_hist(0);
 1743     
 1744     for(uword t=1; t < n_threads; ++t)
 1745       {
 1746       hist += thread_hist(t);
 1747       }
 1748     }
 1749   #else
 1750     {
 1751     uword* hist_mem = hist.memptr();
 1752     
 1753     if(dist_mode == eucl_dist)
 1754       {
 1755       for(uword i=0; i<X_n_cols; ++i)
 1756         {
 1757         const eT* X_colptr = X.colptr(i);
 1758          
 1759         eT    best_dist = Datum<eT>::inf;
 1760         uword best_g    = 0;
 1761         
 1762         for(uword g=0; g < N_gaus; ++g)
 1763           {
 1764           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
 1765           
 1766           if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
 1767           }
 1768         
 1769         hist_mem[best_g]++;
 1770         }
 1771       }
 1772     else
 1773     if(dist_mode == prob_dist)
 1774       {
 1775       const eT* log_hefts_mem = log_hefts.memptr();
 1776       
 1777       for(uword i=0; i<X_n_cols; ++i)
 1778         {
 1779         const eT* X_colptr = X.colptr(i);
 1780         
 1781         eT    best_p = -Datum<eT>::inf;
 1782         uword best_g = 0;
 1783         
 1784         for(uword g=0; g < N_gaus; ++g)
 1785           {
 1786           const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
 1787           
 1788           if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
 1789           }
 1790         
 1791         hist_mem[best_g]++;
 1792         }
 1793       }
 1794     }
 1795   #endif
 1796   }
 1797 
 1798 
 1799 
 1800 template<typename eT>
 1801 template<uword dist_id>
 1802 inline
 1803 void
 1804 gmm_diag<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed_mode)
 1805   {
 1806   arma_extra_debug_sigprint();
 1807   
 1808   const uword N_dims = means.n_rows;
 1809   const uword N_gaus = means.n_cols;
 1810   
 1811   if( (seed_mode == static_subset) || (seed_mode == random_subset) )
 1812     {
 1813     uvec initial_indices;
 1814     
 1815          if(seed_mode == static_subset)  { initial_indices = linspace<uvec>(0, X.n_cols-1, N_gaus); }
 1816     else if(seed_mode == random_subset)  { initial_indices = randperm<uvec>(X.n_cols, N_gaus);      }
 1817     
 1818     // initial_indices.print("initial_indices:");
 1819     
 1820     access::rw(means) = X.cols(initial_indices);
 1821     }
 1822   else
 1823   if( (seed_mode == static_spread) || (seed_mode == random_spread) )
 1824     {
 1825     // going through all of the samples can be extremely time consuming;
 1826     // instead, if there are enough samples, randomly choose samples with probability 0.1
 1827     
 1828     const bool  use_sampling = ((X.n_cols/uword(100)) > N_gaus);
 1829     const uword step         = (use_sampling) ? uword(10) : uword(1);
 1830     
 1831     uword start_index = 0;
 1832     
 1833          if(seed_mode == static_spread)  { start_index = X.n_cols / 2;                                         }
 1834     else if(seed_mode == random_spread)  { start_index = as_scalar(randi<uvec>(1, distr_param(0,X.n_cols-1))); }
 1835     
 1836     access::rw(means).col(0) = X.unsafe_col(start_index);
 1837     
 1838     const eT* mah_aux_mem = mah_aux.memptr();
 1839     
 1840     running_stat<double> rs;
 1841     
 1842     for(uword g=1; g < N_gaus; ++g)
 1843       {
 1844       eT    max_dist = eT(0);
 1845       uword best_i   = uword(0);
 1846       uword start_i  = uword(0);
 1847       
 1848       if(use_sampling)
 1849         {
 1850         uword start_i_proposed = uword(0);
 1851         
 1852         if(seed_mode == static_spread)  { start_i_proposed = g % uword(10);                               }
 1853         if(seed_mode == random_spread)  { start_i_proposed = as_scalar(randi<uvec>(1, distr_param(0,9))); }
 1854         
 1855         if(start_i_proposed < X.n_cols)  { start_i = start_i_proposed; }
 1856         }
 1857       
 1858       
 1859       for(uword i=start_i; i < X.n_cols; i += step)
 1860         {
 1861         rs.reset();
 1862         
 1863         const eT* X_colptr = X.colptr(i);
 1864         
 1865         bool ignore_i = false;
 1866         
 1867         // find the average distance between sample i and the means so far
 1868         for(uword h = 0; h < g; ++h)
 1869           {
 1870           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem);
 1871           
 1872           // ignore sample already selected as a mean
 1873           if(dist == eT(0))  { ignore_i = true; break; }
 1874           else               { rs(dist);               }
 1875           }
 1876         
 1877         if( (rs.mean() >= max_dist) && (ignore_i == false))
 1878           {
 1879           max_dist = eT(rs.mean()); best_i = i;
 1880           }
 1881         }
 1882       
 1883       // set the mean to the sample that is the furthest away from the means so far
 1884       access::rw(means).col(g) = X.unsafe_col(best_i);
 1885       }
 1886     }
 1887   
 1888   // get_cout_stream() << "generate_initial_means():" << '\n';
 1889   // means.print();
 1890   }
 1891 
 1892 
 1893 
 1894 template<typename eT>
 1895 template<uword dist_id>
 1896 inline
 1897 void
 1898 gmm_diag<eT>::generate_initial_params(const Mat<eT>& X, const eT var_floor)
 1899   {
 1900   arma_extra_debug_sigprint();
 1901   
 1902   const uword N_dims = means.n_rows;
 1903   const uword N_gaus = means.n_cols;
 1904   
 1905   const eT* mah_aux_mem = mah_aux.memptr();
 1906   
 1907   const uword X_n_cols = X.n_cols;
 1908   
 1909   if(X_n_cols == 0)  { return; }
 1910   
 1911   // as the covariances are calculated via accumulators,
 1912   // the means also need to be calculated via accumulators to ensure numerical consistency
 1913   
 1914   Mat<eT> acc_means(N_dims, N_gaus, fill::zeros);
 1915   Mat<eT> acc_dcovs(N_dims, N_gaus, fill::zeros);
 1916   
 1917   Row<uword> acc_hefts(N_gaus, fill::zeros);
 1918   
 1919   uword* acc_hefts_mem = acc_hefts.memptr();
 1920   
 1921   #if defined(ARMA_USE_OPENMP)
 1922     {
 1923     const umat boundaries = internal_gen_boundaries(X_n_cols);
 1924     
 1925     const uword n_threads = boundaries.n_cols;
 1926     
 1927     field< Mat<eT>    > t_acc_means(n_threads);
 1928     field< Mat<eT>    > t_acc_dcovs(n_threads);
 1929     field< Row<uword> > t_acc_hefts(n_threads);
 1930     
 1931     for(uword t=0; t < n_threads; ++t)
 1932       {
 1933       t_acc_means(t).zeros(N_dims, N_gaus);
 1934       t_acc_dcovs(t).zeros(N_dims, N_gaus);
 1935       t_acc_hefts(t).zeros(N_gaus);
 1936       }
 1937     
 1938     #pragma omp parallel for schedule(static)
 1939     for(uword t=0; t < n_threads; ++t)
 1940       {
 1941       uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
 1942       
 1943       const uword start_index = boundaries.at(0,t);
 1944       const uword   end_index = boundaries.at(1,t);
 1945       
 1946       for(uword i=start_index; i <= end_index; ++i)
 1947         {
 1948         const eT* X_colptr = X.colptr(i);
 1949         
 1950         eT     min_dist = Datum<eT>::inf;
 1951         uword  best_g   = 0;
 1952         
 1953         for(uword g=0; g<N_gaus; ++g)
 1954           {
 1955           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
 1956           
 1957           if(dist < min_dist)  { min_dist = dist;  best_g = g; }
 1958           }
 1959         
 1960         eT* t_acc_mean = t_acc_means(t).colptr(best_g);
 1961         eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g);
 1962         
 1963         for(uword d=0; d<N_dims; ++d)
 1964           {
 1965           const eT x_d = X_colptr[d];
 1966           
 1967           t_acc_mean[d] += x_d;
 1968           t_acc_dcov[d] += x_d*x_d;
 1969           }
 1970         
 1971         t_acc_hefts_mem[best_g]++;
 1972         }
 1973       }
 1974     
 1975     // reduction
 1976     acc_means = t_acc_means(0);
 1977     acc_dcovs = t_acc_dcovs(0);
 1978     acc_hefts = t_acc_hefts(0);
 1979     
 1980     for(uword t=1; t < n_threads; ++t)
 1981       {
 1982       acc_means += t_acc_means(t);
 1983       acc_dcovs += t_acc_dcovs(t);
 1984       acc_hefts += t_acc_hefts(t);
 1985       }
 1986     }
 1987   #else
 1988     {
 1989     for(uword i=0; i<X_n_cols; ++i)
 1990       {
 1991       const eT* X_colptr = X.colptr(i);
 1992       
 1993       eT     min_dist = Datum<eT>::inf;
 1994       uword  best_g   = 0;
 1995       
 1996       for(uword g=0; g<N_gaus; ++g)
 1997         {
 1998         const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
 1999         
 2000         if(dist < min_dist)  { min_dist = dist;  best_g = g; }
 2001         }
 2002       
 2003       eT* acc_mean = acc_means.colptr(best_g);
 2004       eT* acc_dcov = acc_dcovs.colptr(best_g);
 2005       
 2006       for(uword d=0; d<N_dims; ++d)
 2007         {
 2008         const eT x_d = X_colptr[d];
 2009         
 2010         acc_mean[d] += x_d;
 2011         acc_dcov[d] += x_d*x_d;
 2012         }
 2013       
 2014       acc_hefts_mem[best_g]++;
 2015       }
 2016     }
 2017   #endif
 2018   
 2019   eT* hefts_mem = access::rw(hefts).memptr();
 2020   
 2021   for(uword g=0; g<N_gaus; ++g)
 2022     {
 2023     const eT*   acc_mean = acc_means.colptr(g);
 2024     const eT*   acc_dcov = acc_dcovs.colptr(g);
 2025     const uword acc_heft = acc_hefts_mem[g];
 2026     
 2027     eT* mean = access::rw(means).colptr(g);
 2028     eT* dcov = access::rw(dcovs).colptr(g);
 2029     
 2030     for(uword d=0; d<N_dims; ++d)
 2031       {
 2032       const eT tmp = acc_mean[d] / eT(acc_heft);
 2033       
 2034       mean[d] = (acc_heft >= 1) ? tmp : eT(0);
 2035       dcov[d] = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor);
 2036       }
 2037     
 2038     hefts_mem[g] = eT(acc_heft) / eT(X_n_cols);
 2039     }
 2040   
 2041   em_fix_params(var_floor);
 2042   }
 2043 
 2044 
 2045 
 2046 //! multi-threaded implementation of k-means, inspired by MapReduce
 2047 template<typename eT>
 2048 template<uword dist_id>
 2049 inline
 2050 bool
 2051 gmm_diag<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose, const char* signature)
 2052   {
 2053   arma_extra_debug_sigprint();
 2054   
 2055   if(verbose)
 2056     {
 2057     get_cout_stream().unsetf(ios::showbase);
 2058     get_cout_stream().unsetf(ios::uppercase);
 2059     get_cout_stream().unsetf(ios::showpos);
 2060     get_cout_stream().unsetf(ios::scientific);
 2061     
 2062     get_cout_stream().setf(ios::right);
 2063     get_cout_stream().setf(ios::fixed);
 2064     }
 2065   
 2066   const uword X_n_cols = X.n_cols;
 2067   
 2068   if(X_n_cols == 0)  { return true; }
 2069   
 2070   const uword N_dims = means.n_rows;
 2071   const uword N_gaus = means.n_cols;
 2072   
 2073   const eT* mah_aux_mem = mah_aux.memptr();
 2074   
 2075   Mat<eT>    acc_means(N_dims, N_gaus, fill::zeros);
 2076   Row<uword> acc_hefts(N_gaus, fill::zeros);
 2077   Row<uword> last_indx(N_gaus, fill::zeros);
 2078   
 2079   Mat<eT> new_means = means;
 2080   Mat<eT> old_means = means;
 2081   
 2082   running_mean_scalar<eT> rs_delta;
 2083   
 2084   #if defined(ARMA_USE_OPENMP)
 2085     const umat boundaries = internal_gen_boundaries(X_n_cols);
 2086     const uword n_threads = boundaries.n_cols;
 2087     
 2088     field< Mat<eT>    > t_acc_means(n_threads);
 2089     field< Row<uword> > t_acc_hefts(n_threads);
 2090     field< Row<uword> > t_last_indx(n_threads);
 2091   #else
 2092     const uword n_threads = 1;
 2093   #endif
 2094   
 2095   if(verbose)  { get_cout_stream() << signature << ": n_threads: " << n_threads  << '\n'; get_cout_stream().flush(); }
 2096   
 2097   for(uword iter=1; iter <= max_iter; ++iter)
 2098     {
 2099     #if defined(ARMA_USE_OPENMP)
 2100       {
 2101       for(uword t=0; t < n_threads; ++t)
 2102         {
 2103         t_acc_means(t).zeros(N_dims, N_gaus);
 2104         t_acc_hefts(t).zeros(N_gaus);
 2105         t_last_indx(t).zeros(N_gaus);
 2106         }
 2107       
 2108       #pragma omp parallel for schedule(static)
 2109       for(uword t=0; t < n_threads; ++t)
 2110         {
 2111         Mat<eT>& t_acc_means_t   = t_acc_means(t);
 2112         uword*   t_acc_hefts_mem = t_acc_hefts(t).memptr();
 2113         uword*   t_last_indx_mem = t_last_indx(t).memptr();
 2114         
 2115         const uword start_index = boundaries.at(0,t);
 2116         const uword   end_index = boundaries.at(1,t);
 2117         
 2118         for(uword i=start_index; i <= end_index; ++i)
 2119           {
 2120           const eT* X_colptr = X.colptr(i);
 2121           
 2122           eT     min_dist = Datum<eT>::inf;
 2123           uword  best_g   = 0;
 2124           
 2125           for(uword g=0; g<N_gaus; ++g)
 2126             {
 2127             const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
 2128             
 2129             if(dist < min_dist)  { min_dist = dist;  best_g = g; }
 2130             }
 2131           
 2132           eT* t_acc_mean = t_acc_means_t.colptr(best_g);
 2133           
 2134           for(uword d=0; d<N_dims; ++d)  { t_acc_mean[d] += X_colptr[d]; }
 2135           
 2136           t_acc_hefts_mem[best_g]++;
 2137           t_last_indx_mem[best_g] = i;
 2138           }
 2139         }
 2140       
 2141       // reduction
 2142       
 2143       acc_means = t_acc_means(0);
 2144       acc_hefts = t_acc_hefts(0);
 2145       
 2146       for(uword t=1; t < n_threads; ++t)
 2147         {
 2148         acc_means += t_acc_means(t);
 2149         acc_hefts += t_acc_hefts(t);
 2150         }
 2151       
 2152       for(uword g=0; g < N_gaus;    ++g)
 2153       for(uword t=0; t < n_threads; ++t)
 2154         {
 2155         if( t_acc_hefts(t)(g) >= 1 )  { last_indx(g) = t_last_indx(t)(g); }
 2156         }
 2157       }
 2158     #else
 2159       {
 2160       uword* acc_hefts_mem = acc_hefts.memptr();
 2161       uword* last_indx_mem = last_indx.memptr();
 2162       
 2163       for(uword i=0; i < X_n_cols; ++i)
 2164         {
 2165         const eT* X_colptr = X.colptr(i);
 2166         
 2167         eT     min_dist = Datum<eT>::inf;
 2168         uword  best_g   = 0;
 2169         
 2170         for(uword g=0; g<N_gaus; ++g)
 2171           {
 2172           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
 2173           
 2174           if(dist < min_dist)  { min_dist = dist;  best_g = g; }
 2175           }
 2176         
 2177         eT* acc_mean = acc_means.colptr(best_g);
 2178         
 2179         for(uword d=0; d<N_dims; ++d)  { acc_mean[d] += X_colptr[d]; }
 2180         
 2181         acc_hefts_mem[best_g]++;
 2182         last_indx_mem[best_g] = i;
 2183         }
 2184       }
 2185     #endif
 2186     
 2187     // generate new means
 2188     
 2189     uword* acc_hefts_mem = acc_hefts.memptr();
 2190     
 2191     for(uword g=0; g < N_gaus; ++g)
 2192       {
 2193       const eT*   acc_mean = acc_means.colptr(g);
 2194       const uword acc_heft = acc_hefts_mem[g];
 2195       
 2196       eT* new_mean = access::rw(new_means).colptr(g);
 2197   
 2198       for(uword d=0; d<N_dims; ++d)
 2199         {
 2200         new_mean[d] = (acc_heft >= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0);
 2201         }
 2202       }
 2203     
 2204     
 2205     // heuristics to resurrect dead means
 2206     
 2207     const uvec dead_gs = find(acc_hefts == uword(0));
 2208     
 2209     if(dead_gs.n_elem > 0)
 2210       {
 2211       if(verbose)  { get_cout_stream() << signature << ": recovering from dead means\n"; get_cout_stream().flush(); }
 2212       
 2213       uword* last_indx_mem = last_indx.memptr();
 2214     
 2215       const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" );
 2216       
 2217       if(live_gs.n_elem == 0)  { return false; }
 2218       
 2219       uword live_gs_count  = 0;
 2220       
 2221       for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count)
 2222         {
 2223         const uword dead_g_id = dead_gs(dead_gs_count);
 2224         
 2225         uword proposed_i = 0;
 2226         
 2227         if(live_gs_count < live_gs.n_elem)
 2228           {
 2229           const uword live_g_id = live_gs(live_gs_count);  ++live_gs_count;
 2230           
 2231           if(live_g_id == dead_g_id)  { return false; }
 2232           
 2233           // recover by using a sample from a known good mean
 2234           proposed_i = last_indx_mem[live_g_id];
 2235           }
 2236         else
 2237           {
 2238           // recover by using a randomly seleced sample (last resort)
 2239           proposed_i = as_scalar(randi<uvec>(1, distr_param(0,X_n_cols-1)));
 2240           }
 2241         
 2242         if(proposed_i >= X_n_cols)  { return false; }
 2243         
 2244         new_means.col(dead_g_id) = X.col(proposed_i);
 2245         }
 2246       }
 2247 
 2248     rs_delta.reset();
 2249     
 2250     for(uword g=0; g < N_gaus; ++g)
 2251       {
 2252       rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) );
 2253       }
 2254     
 2255     if(verbose)
 2256       {
 2257       get_cout_stream() << signature << ": iteration: ";
 2258       get_cout_stream().unsetf(ios::scientific);
 2259       get_cout_stream().setf(ios::fixed);
 2260       get_cout_stream().width(std::streamsize(4));
 2261       get_cout_stream() << iter;
 2262       get_cout_stream() << "   delta: ";
 2263       get_cout_stream().unsetf(ios::fixed);
 2264       //get_cout_stream().setf(ios::scientific);
 2265       get_cout_stream() << rs_delta.mean() << '\n';
 2266       get_cout_stream().flush();
 2267       }
 2268     
 2269     arma::swap(old_means, new_means);
 2270     
 2271     if(rs_delta.mean() <= Datum<eT>::eps)  { break; }
 2272     }
 2273   
 2274   access::rw(means) = old_means;
 2275   
 2276   if(means.is_finite() == false)  { return false; }
 2277   
 2278   return true;
 2279   }
 2280 
 2281 
 2282 
 2283 //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce
 2284 template<typename eT>
 2285 inline
 2286 bool
 2287 gmm_diag<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose)
 2288   {
 2289   arma_extra_debug_sigprint();
 2290   
 2291   if(X.n_cols == 0)  { return true; }
 2292   
 2293   const uword N_dims = means.n_rows;
 2294   const uword N_gaus = means.n_cols;
 2295   
 2296   if(verbose)
 2297     {
 2298     get_cout_stream().unsetf(ios::showbase);
 2299     get_cout_stream().unsetf(ios::uppercase);
 2300     get_cout_stream().unsetf(ios::showpos);
 2301     get_cout_stream().unsetf(ios::scientific);
 2302     
 2303     get_cout_stream().setf(ios::right);
 2304     get_cout_stream().setf(ios::fixed);
 2305     }
 2306   
 2307   const umat boundaries = internal_gen_boundaries(X.n_cols);
 2308   
 2309   const uword n_threads = boundaries.n_cols;
 2310   
 2311   field< Mat<eT> > t_acc_means(n_threads); 
 2312   field< Mat<eT> > t_acc_dcovs(n_threads);
 2313   
 2314   field< Col<eT> > t_acc_norm_lhoods(n_threads);
 2315   field< Col<eT> > t_gaus_log_lhoods(n_threads);
 2316   
 2317   Col<eT>          t_progress_log_lhood(n_threads);
 2318   
 2319   for(uword t=0; t<n_threads; t++)
 2320     {
 2321     t_acc_means[t].set_size(N_dims, N_gaus);
 2322     t_acc_dcovs[t].set_size(N_dims, N_gaus);
 2323     
 2324     t_acc_norm_lhoods[t].set_size(N_gaus);
 2325     t_gaus_log_lhoods[t].set_size(N_gaus);
 2326     }
 2327   
 2328   
 2329   if(verbose)
 2330     {
 2331     get_cout_stream() << "gmm_diag::learn(): EM: n_threads: " << n_threads  << '\n';
 2332     }
 2333   
 2334   eT old_avg_log_p = -Datum<eT>::inf;
 2335   
 2336   for(uword iter=1; iter <= max_iter; ++iter)
 2337     {
 2338     init_constants();
 2339     
 2340     em_update_params(X, boundaries, t_acc_means, t_acc_dcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood);
 2341     
 2342     em_fix_params(var_floor);
 2343     
 2344     const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem);
 2345     
 2346     if(verbose)
 2347       {
 2348       get_cout_stream() << "gmm_diag::learn(): EM: iteration: ";
 2349       get_cout_stream().unsetf(ios::scientific);
 2350       get_cout_stream().setf(ios::fixed);
 2351       get_cout_stream().width(std::streamsize(4));
 2352       get_cout_stream() << iter;
 2353       get_cout_stream() << "   avg_log_p: ";
 2354       get_cout_stream().unsetf(ios::fixed);
 2355       //get_cout_stream().setf(ios::scientific);
 2356       get_cout_stream() << new_avg_log_p << '\n';
 2357       get_cout_stream().flush();
 2358       }
 2359     
 2360     if(arma_isfinite(new_avg_log_p) == false)  { return false; }
 2361     
 2362     if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum<eT>::eps)  { break; }
 2363     
 2364     
 2365     old_avg_log_p = new_avg_log_p;
 2366     }
 2367   
 2368   
 2369   if(any(vectorise(dcovs) <= eT(0)))  { return false; }
 2370   if(means.is_finite() == false    )  { return false; }
 2371   if(dcovs.is_finite() == false    )  { return false; }
 2372   if(hefts.is_finite() == false    )  { return false; }
 2373   
 2374   return true;
 2375   }
 2376 
 2377 
 2378 
 2379 
 2380 template<typename eT>
 2381 inline
 2382 void
 2383 gmm_diag<eT>::em_update_params
 2384   (
 2385   const Mat<eT>&          X,
 2386   const umat&             boundaries,
 2387         field< Mat<eT> >& t_acc_means,
 2388         field< Mat<eT> >& t_acc_dcovs,
 2389         field< Col<eT> >& t_acc_norm_lhoods,
 2390         field< Col<eT> >& t_gaus_log_lhoods,
 2391         Col<eT>&          t_progress_log_lhood
 2392   )
 2393   {
 2394   arma_extra_debug_sigprint();
 2395   
 2396   const uword n_threads = boundaries.n_cols;
 2397   
 2398   
 2399   // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts
 2400     
 2401   #if defined(ARMA_USE_OPENMP)
 2402     {
 2403     #pragma omp parallel for schedule(static)
 2404     for(uword t=0; t<n_threads; t++)
 2405       {
 2406       Mat<eT>& acc_means          = t_acc_means[t];
 2407       Mat<eT>& acc_dcovs          = t_acc_dcovs[t];
 2408       Col<eT>& acc_norm_lhoods    = t_acc_norm_lhoods[t];
 2409       Col<eT>& gaus_log_lhoods    = t_gaus_log_lhoods[t];
 2410       eT&      progress_log_lhood = t_progress_log_lhood[t];
 2411       
 2412       em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_dcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood);
 2413       }
 2414     }
 2415   #else
 2416     {
 2417     em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_dcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]);
 2418     }
 2419   #endif
 2420   
 2421   const uword N_dims = means.n_rows;
 2422   const uword N_gaus = means.n_cols;
 2423   
 2424   Mat<eT>& final_acc_means = t_acc_means[0];
 2425   Mat<eT>& final_acc_dcovs = t_acc_dcovs[0];
 2426   
 2427   Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0];
 2428   
 2429   
 2430   // the "reduce" operation, which combines the partial accumulators produced by the separate threads
 2431   
 2432   for(uword t=1; t<n_threads; t++)
 2433     {
 2434     final_acc_means += t_acc_means[t];
 2435     final_acc_dcovs += t_acc_dcovs[t];
 2436     
 2437     final_acc_norm_lhoods += t_acc_norm_lhoods[t];
 2438     }
 2439   
 2440   
 2441   eT* hefts_mem = access::rw(hefts).memptr();
 2442   
 2443   
 2444   //// update each component without sanity checking
 2445   //for(uword g=0; g < N_gaus; ++g)
 2446   //  {
 2447   //  const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
 2448   //  
 2449   //  eT* mean_mem = access::rw(means).colptr(g);
 2450   //  eT* dcov_mem = access::rw(dcovs).colptr(g);
 2451   //  
 2452   //  eT* acc_mean_mem = final_acc_means.colptr(g);
 2453   //  eT* acc_dcov_mem = final_acc_dcovs.colptr(g);
 2454   //  
 2455   //  hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
 2456   //  
 2457   //  for(uword d=0; d < N_dims; ++d)
 2458   //    {
 2459   //    const eT tmp = acc_mean_mem[d] / acc_norm_lhood;
 2460   //    
 2461   //    mean_mem[d] = tmp;
 2462   //    dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp;
 2463   //    }
 2464   //  }
 2465   
 2466   
 2467   // conditionally update each component;  if only a subset of the hefts was updated, em_fix_params() will sanitise them
 2468   for(uword g=0; g < N_gaus; ++g)
 2469     {
 2470     const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
 2471     
 2472     if(arma_isfinite(acc_norm_lhood) == false)  { continue; }
 2473     
 2474     eT* acc_mean_mem = final_acc_means.colptr(g);
 2475     eT* acc_dcov_mem = final_acc_dcovs.colptr(g);
 2476     
 2477     bool ok = true;
 2478     
 2479     for(uword d=0; d < N_dims; ++d)
 2480       {
 2481       const eT tmp1 = acc_mean_mem[d] / acc_norm_lhood;
 2482       const eT tmp2 = acc_dcov_mem[d] / acc_norm_lhood - tmp1*tmp1;
 2483       
 2484       acc_mean_mem[d] = tmp1;
 2485       acc_dcov_mem[d] = tmp2;
 2486       
 2487       if(arma_isfinite(tmp2) == false)  { ok = false; }
 2488       }
 2489     
 2490     
 2491     if(ok)
 2492       {
 2493       hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
 2494       
 2495       eT* mean_mem = access::rw(means).colptr(g);
 2496       eT* dcov_mem = access::rw(dcovs).colptr(g);
 2497       
 2498       for(uword d=0; d < N_dims; ++d)
 2499         {
 2500         mean_mem[d] = acc_mean_mem[d];
 2501         dcov_mem[d] = acc_dcov_mem[d];
 2502         }
 2503       }
 2504     }
 2505   }
 2506 
 2507 
 2508 
 2509 template<typename eT>
 2510 inline
 2511 void
 2512 gmm_diag<eT>::em_generate_acc
 2513   (
 2514   const Mat<eT>& X,
 2515   const uword    start_index,
 2516   const uword      end_index,
 2517         Mat<eT>& acc_means,
 2518         Mat<eT>& acc_dcovs,
 2519         Col<eT>& acc_norm_lhoods,
 2520         Col<eT>& gaus_log_lhoods,
 2521         eT&      progress_log_lhood
 2522   )
 2523   const
 2524   {
 2525   arma_extra_debug_sigprint();
 2526   
 2527   progress_log_lhood = eT(0);
 2528   
 2529   acc_means.zeros();
 2530   acc_dcovs.zeros();
 2531   
 2532   acc_norm_lhoods.zeros();
 2533   gaus_log_lhoods.zeros();
 2534   
 2535   const uword N_dims = means.n_rows;
 2536   const uword N_gaus = means.n_cols;
 2537   
 2538   const eT* log_hefts_mem       = log_hefts.memptr();
 2539         eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr();
 2540   
 2541   
 2542   for(uword i=start_index; i <= end_index; i++)
 2543     {
 2544     const eT* x = X.colptr(i);
 2545     
 2546     for(uword g=0; g < N_gaus; ++g)
 2547       {
 2548       gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g];
 2549       }
 2550     
 2551     eT log_lhood_sum = gaus_log_lhoods_mem[0];
 2552     
 2553     for(uword g=1; g < N_gaus; ++g)
 2554       {
 2555       log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]);
 2556       }
 2557     
 2558     progress_log_lhood += log_lhood_sum;
 2559     
 2560     for(uword g=0; g < N_gaus; ++g)
 2561       {
 2562       const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum);
 2563       
 2564       acc_norm_lhoods[g] += norm_lhood;
 2565       
 2566       eT* acc_mean_mem = acc_means.colptr(g);
 2567       eT* acc_dcov_mem = acc_dcovs.colptr(g);
 2568       
 2569       for(uword d=0; d < N_dims; ++d)
 2570         {
 2571         const eT x_d = x[d];
 2572         const eT y_d = x_d * norm_lhood;
 2573         
 2574         acc_mean_mem[d] += y_d;
 2575         acc_dcov_mem[d] += y_d * x_d;  // equivalent to x_d * x_d * norm_lhood
 2576         }
 2577       }
 2578     }
 2579   
 2580   progress_log_lhood /= eT((end_index - start_index) + 1);
 2581   }
 2582 
 2583 
 2584 
 2585 template<typename eT>
 2586 inline
 2587 void
 2588 gmm_diag<eT>::em_fix_params(const eT var_floor)
 2589   {
 2590   arma_extra_debug_sigprint();
 2591   
 2592   const uword N_dims = means.n_rows;
 2593   const uword N_gaus = means.n_cols;
 2594   
 2595   const eT var_ceiling = std::numeric_limits<eT>::max();
 2596   
 2597   const uword dcovs_n_elem = dcovs.n_elem;
 2598         eT*   dcovs_mem    = access::rw(dcovs).memptr();
 2599   
 2600   for(uword i=0; i < dcovs_n_elem; ++i)
 2601     {
 2602     eT& var_val = dcovs_mem[i];
 2603     
 2604          if(var_val < var_floor  )  { var_val = var_floor;   }
 2605     else if(var_val > var_ceiling)  { var_val = var_ceiling; }
 2606     else if(arma_isnan(var_val)  )  { var_val = eT(1);       }
 2607     }
 2608   
 2609   
 2610   eT* hefts_mem = access::rw(hefts).memptr();
 2611   
 2612   for(uword g1=0; g1 < N_gaus; ++g1)
 2613     {
 2614     if(hefts_mem[g1] > eT(0))
 2615       {
 2616       const eT* means_colptr_g1 = means.colptr(g1);
 2617       
 2618       for(uword g2=(g1+1); g2 < N_gaus; ++g2)
 2619         {
 2620         if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits<eT>::epsilon()) )
 2621           {
 2622           const eT dist = distance<eT,1>::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1);
 2623           
 2624           if(dist == eT(0)) { hefts_mem[g2] = eT(0); }
 2625           }
 2626         }
 2627       }
 2628     }
 2629   
 2630   const eT heft_floor   = std::numeric_limits<eT>::min();
 2631   const eT heft_initial = eT(1) / eT(N_gaus);
 2632   
 2633   for(uword i=0; i < N_gaus; ++i)
 2634     {
 2635     eT& heft_val = hefts_mem[i];
 2636     
 2637          if(heft_val < heft_floor)  { heft_val = heft_floor;   }
 2638     else if(heft_val > eT(1)     )  { heft_val = eT(1);        }
 2639     else if(arma_isnan(heft_val) )  { heft_val = heft_initial; }
 2640     }
 2641   
 2642   const eT heft_sum = accu(hefts);
 2643   
 2644   if((heft_sum < (eT(1) - Datum<eT>::eps)) || (heft_sum > (eT(1) + Datum<eT>::eps)))  { access::rw(hefts) /= heft_sum; }
 2645   }
 2646 
 2647 
 2648 } // namespace gmm_priv
 2649 
 2650 
 2651 //! @}