"Fossies" - the Fresh Open Source Software Archive

Member "dlib-19.18/docs/dnn_imagenet_train_ex.cpp.html" (22 Sep 2019, 52313 Bytes) of package /linux/misc/dlib-19.18.tar.bz2:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) HTML source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file.

    1 <html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - dnn_imagenet_train_ex.cpp</title></head><body bgcolor='white'><pre>
    2 <font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
    3 </font><font color='#009900'>/*
    4     This program was used to train the resnet34_1000_imagenet_classifier.dnn
    5     network used by the <a href="dnn_imagenet_ex.cpp.html">dnn_imagenet_ex.cpp</a> example program.  
    6 
    7     You should be familiar with dlib's DNN module before reading this example
    8     program.  So read <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a> and <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a> first.  
    9 */</font>
   10 
   11 
   12 
   13 <font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>dnn.h<font color='#5555FF'>&gt;</font>
   14 <font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
   15 <font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>&gt;</font>
   16 <font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>image_transforms.h<font color='#5555FF'>&gt;</font>
   17 <font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>dir_nav.h<font color='#5555FF'>&gt;</font>
   18 <font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iterator<font color='#5555FF'>&gt;</font>
   19 <font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>thread<font color='#5555FF'>&gt;</font>
   20 
   21 <font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
   22 <font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;
   23  
   24 <font color='#009900'>// ----------------------------------------------------------------------------------------
   25 </font>
   26 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font><font color='#0000FF'>class</font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font> <font color='#0000FF'>class</font> <b><a name='block'></a>block</b>, <font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font><font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font>
   27 <font color='#0000FF'>using</font> residual <font color='#5555FF'>=</font> add_prev1<font color='#5555FF'>&lt;</font>block<font color='#5555FF'>&lt;</font>N,BN,<font color='#979000'>1</font>,tag1<font color='#5555FF'>&lt;</font>SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   28 
   29 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font><font color='#0000FF'>class</font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font> <font color='#0000FF'>class</font> <b><a name='block'></a>block</b>, <font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font><font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font>
   30 <font color='#0000FF'>using</font> residual_down <font color='#5555FF'>=</font> add_prev2<font color='#5555FF'>&lt;</font>avg_pool<font color='#5555FF'>&lt;</font><font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,skip1<font color='#5555FF'>&lt;</font>tag2<font color='#5555FF'>&lt;</font>block<font color='#5555FF'>&lt;</font>N,BN,<font color='#979000'>2</font>,tag1<font color='#5555FF'>&lt;</font>SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   31 
   32 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font> <font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'><u>int</u></font> stride, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> 
   33 <font color='#0000FF'>using</font> block  <font color='#5555FF'>=</font> BN<font color='#5555FF'>&lt;</font>con<font color='#5555FF'>&lt;</font>N,<font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,relu<font color='#5555FF'>&lt;</font>BN<font color='#5555FF'>&lt;</font>con<font color='#5555FF'>&lt;</font>N,<font color='#979000'>3</font>,<font color='#979000'>3</font>,stride,stride,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   34 
   35 
   36 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> res       <font color='#5555FF'>=</font> relu<font color='#5555FF'>&lt;</font>residual<font color='#5555FF'>&lt;</font>block,N,bn_con,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   37 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> ares      <font color='#5555FF'>=</font> relu<font color='#5555FF'>&lt;</font>residual<font color='#5555FF'>&lt;</font>block,N,affine,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   38 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> res_down  <font color='#5555FF'>=</font> relu<font color='#5555FF'>&lt;</font>residual_down<font color='#5555FF'>&lt;</font>block,N,bn_con,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   39 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> ares_down <font color='#5555FF'>=</font> relu<font color='#5555FF'>&lt;</font>residual_down<font color='#5555FF'>&lt;</font>block,N,affine,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   40 
   41 
   42 <font color='#009900'>// ----------------------------------------------------------------------------------------
   43 </font>
   44 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> level1 <font color='#5555FF'>=</font> res<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,res_down<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   45 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> level2 <font color='#5555FF'>=</font> res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res_down<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   46 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> level3 <font color='#5555FF'>=</font> res<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,res_down<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   47 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> level4 <font color='#5555FF'>=</font> res<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   48 
   49 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> alevel1 <font color='#5555FF'>=</font> ares<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,ares_down<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   50 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> alevel2 <font color='#5555FF'>=</font> ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares_down<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   51 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> alevel3 <font color='#5555FF'>=</font> ares<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,ares_down<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   52 <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> alevel4 <font color='#5555FF'>=</font> ares<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   53 
   54 <font color='#009900'>// training network type
   55 </font><font color='#0000FF'>using</font> net_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'>&lt;</font>fc<font color='#5555FF'>&lt;</font><font color='#979000'>1000</font>,avg_pool_everything<font color='#5555FF'>&lt;</font>
   56                             level1<font color='#5555FF'>&lt;</font>
   57                             level2<font color='#5555FF'>&lt;</font>
   58                             level3<font color='#5555FF'>&lt;</font>
   59                             level4<font color='#5555FF'>&lt;</font>
   60                             max_pool<font color='#5555FF'>&lt;</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,relu<font color='#5555FF'>&lt;</font>bn_con<font color='#5555FF'>&lt;</font>con<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,<font color='#979000'>7</font>,<font color='#979000'>7</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,
   61                             input_rgb_image_sized<font color='#5555FF'>&lt;</font><font color='#979000'>227</font><font color='#5555FF'>&gt;</font>
   62                             <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   63 
   64 <font color='#009900'>// testing network type (replaced batch normalization with fixed affine transforms)
   65 </font><font color='#0000FF'>using</font> anet_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'>&lt;</font>fc<font color='#5555FF'>&lt;</font><font color='#979000'>1000</font>,avg_pool_everything<font color='#5555FF'>&lt;</font>
   66                             alevel1<font color='#5555FF'>&lt;</font>
   67                             alevel2<font color='#5555FF'>&lt;</font>
   68                             alevel3<font color='#5555FF'>&lt;</font>
   69                             alevel4<font color='#5555FF'>&lt;</font>
   70                             max_pool<font color='#5555FF'>&lt;</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,relu<font color='#5555FF'>&lt;</font>affine<font color='#5555FF'>&lt;</font>con<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,<font color='#979000'>7</font>,<font color='#979000'>7</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,
   71                             input_rgb_image_sized<font color='#5555FF'>&lt;</font><font color='#979000'>227</font><font color='#5555FF'>&gt;</font>
   72                             <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
   73 
   74 <font color='#009900'>// ----------------------------------------------------------------------------------------
   75 </font>
   76 rectangle <b><a name='make_random_cropping_rect_resnet'></a>make_random_cropping_rect_resnet</b><font face='Lucida Console'>(</font>
   77     <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> img,
   78     dlib::rand<font color='#5555FF'>&amp;</font> rnd
   79 <font face='Lucida Console'>)</font>
   80 <b>{</b>
   81     <font color='#009900'>// figure out what rectangle we want to crop from the image
   82 </font>    <font color='#0000FF'><u>double</u></font> mins <font color='#5555FF'>=</font> <font color='#979000'>0.466666666</font>, maxs <font color='#5555FF'>=</font> <font color='#979000'>0.875</font>;
   83     <font color='#0000FF'>auto</font> scale <font color='#5555FF'>=</font> mins <font color='#5555FF'>+</font> rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font><font face='Lucida Console'>(</font>maxs<font color='#5555FF'>-</font>mins<font face='Lucida Console'>)</font>;
   84     <font color='#0000FF'>auto</font> size <font color='#5555FF'>=</font> scale<font color='#5555FF'>*</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
   85     rectangle <font color='#BB00BB'>rect</font><font face='Lucida Console'>(</font>size, size<font face='Lucida Console'>)</font>;
   86     <font color='#009900'>// randomly shift the box around
   87 </font>    point <font color='#BB00BB'>offset</font><font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,
   88                  rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
   89     <font color='#0000FF'>return</font> <font color='#BB00BB'>move_rect</font><font face='Lucida Console'>(</font>rect, offset<font face='Lucida Console'>)</font>;
   90 <b>}</b>
   91 
   92 <font color='#009900'>// ----------------------------------------------------------------------------------------
   93 </font>
   94 <font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_image'></a>randomly_crop_image</b> <font face='Lucida Console'>(</font>
   95     <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> img,
   96     matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> crop,
   97     dlib::rand<font color='#5555FF'>&amp;</font> rnd
   98 <font face='Lucida Console'>)</font>
   99 <b>{</b>
  100     <font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect_resnet</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;
  101 
  102     <font color='#009900'>// now crop it out as a 227x227 image.
  103 </font>    <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>img, <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>,<font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, crop<font face='Lucida Console'>)</font>;
  104 
  105     <font color='#009900'>// Also randomly flip the image
  106 </font>    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font>
  107         crop <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>crop<font face='Lucida Console'>)</font>;
  108 
  109     <font color='#009900'>// And then randomly adjust the colors.
  110 </font>    <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>crop, rnd<font face='Lucida Console'>)</font>;
  111 <b>}</b>
  112 
  113 <font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_images'></a>randomly_crop_images</b> <font face='Lucida Console'>(</font>
  114     <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> img,
  115     dlib::array<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> crops,
  116     dlib::rand<font color='#5555FF'>&amp;</font> rnd,
  117     <font color='#0000FF'><u>long</u></font> num_crops
  118 <font face='Lucida Console'>)</font>
  119 <b>{</b>
  120     std::vector<font color='#5555FF'>&lt;</font>chip_details<font color='#5555FF'>&gt;</font> dets;
  121     <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> num_crops; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
  122     <b>{</b>
  123         <font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect_resnet</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;
  124         dets.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>,<font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
  125     <b>}</b>
  126 
  127     <font color='#BB00BB'>extract_image_chips</font><font face='Lucida Console'>(</font>img, dets, crops<font face='Lucida Console'>)</font>;
  128 
  129     <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> img : crops<font face='Lucida Console'>)</font>
  130     <b>{</b>
  131         <font color='#009900'>// Also randomly flip the image
  132 </font>        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font>
  133             img <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>img<font face='Lucida Console'>)</font>;
  134 
  135         <font color='#009900'>// And then randomly adjust the colors.
  136 </font>        <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;
  137     <b>}</b>
  138 <b>}</b>
  139 
  140 <font color='#009900'>// ----------------------------------------------------------------------------------------
  141 </font>
  142 <font color='#0000FF'>struct</font> <b><a name='image_info'></a>image_info</b>
  143 <b>{</b>
  144     string filename;
  145     string label;
  146     <font color='#0000FF'><u>long</u></font> numeric_label;
  147 <b>}</b>;
  148 
  149 std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font> <b><a name='get_imagenet_train_listing'></a>get_imagenet_train_listing</b><font face='Lucida Console'>(</font>
  150     <font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> images_folder
  151 <font face='Lucida Console'>)</font>
  152 <b>{</b>
  153     std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font> results;
  154     image_info temp;
  155     temp.numeric_label <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
  156     <font color='#009900'>// We will loop over all the label types in the dataset, each is contained in a subfolder.
  157 </font>    <font color='#0000FF'>auto</font> subdirs <font color='#5555FF'>=</font> <font color='#BB00BB'>directory</font><font face='Lucida Console'>(</font>images_folder<font face='Lucida Console'>)</font>.<font color='#BB00BB'>get_dirs</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  158     <font color='#009900'>// But first, sort the sub directories so the numeric labels will be assigned in sorted order.
  159 </font>    std::<font color='#BB00BB'>sort</font><font face='Lucida Console'>(</font>subdirs.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, subdirs.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
  160     <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> subdir : subdirs<font face='Lucida Console'>)</font>
  161     <b>{</b>
  162         <font color='#009900'>// Now get all the images in this label type
  163 </font>        temp.label <font color='#5555FF'>=</font> subdir.<font color='#BB00BB'>name</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  164         <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> image_file : subdir.<font color='#BB00BB'>get_files</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
  165         <b>{</b>
  166             temp.filename <font color='#5555FF'>=</font> image_file;
  167             results.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
  168         <b>}</b>
  169         <font color='#5555FF'>+</font><font color='#5555FF'>+</font>temp.numeric_label;
  170     <b>}</b>
  171     <font color='#0000FF'>return</font> results;
  172 <b>}</b>
  173 
  174 std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font> <b><a name='get_imagenet_val_listing'></a>get_imagenet_val_listing</b><font face='Lucida Console'>(</font>
  175     <font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> imagenet_root_dir,
  176     <font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> validation_images_file 
  177 <font face='Lucida Console'>)</font>
  178 <b>{</b>
  179     ifstream <font color='#BB00BB'>fin</font><font face='Lucida Console'>(</font>validation_images_file<font face='Lucida Console'>)</font>;
  180     string label, filename;
  181     std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font> results;
  182     image_info temp;
  183     temp.numeric_label <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>1</font>;
  184     <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>fin <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> label <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> filename<font face='Lucida Console'>)</font>
  185     <b>{</b>
  186         temp.filename <font color='#5555FF'>=</font> imagenet_root_dir<font color='#5555FF'>+</font>"<font color='#CC0000'>/</font>"<font color='#5555FF'>+</font>filename;
  187         <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font><font color='#BB00BB'>file_exists</font><font face='Lucida Console'>(</font>temp.filename<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
  188         <b>{</b>
  189             cerr <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>file doesn't exist! </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> temp.filename <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  190             <font color='#BB00BB'>exit</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
  191         <b>}</b>
  192         <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>label <font color='#5555FF'>!</font><font color='#5555FF'>=</font> temp.label<font face='Lucida Console'>)</font>
  193             <font color='#5555FF'>+</font><font color='#5555FF'>+</font>temp.numeric_label;
  194 
  195         temp.label <font color='#5555FF'>=</font> label;
  196         results.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
  197     <b>}</b>
  198 
  199     <font color='#0000FF'>return</font> results;
  200 <b>}</b>
  201 
  202 <font color='#009900'>// ----------------------------------------------------------------------------------------
  203 </font>
  204 <font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font>
  205 <b>{</b>
  206     <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>3</font><font face='Lucida Console'>)</font>
  207     <b>{</b>
  208         cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>To run this program you need a copy of the imagenet ILSVRC2015 dataset and</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  209         cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>also the file http://dlib.net/files/imagenet2015_validation_images.txt.bz2</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  210         cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  211         cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>With those things, you call this program like this: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  212         cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>./dnn_imagenet_train_ex /path/to/ILSVRC2015 imagenet2015_validation_images.txt</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  213         <font color='#0000FF'>return</font> <font color='#979000'>1</font>;
  214     <b>}</b>
  215 
  216     cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\nSCANNING IMAGENET DATASET\n</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  217 
  218     <font color='#0000FF'>auto</font> listing <font color='#5555FF'>=</font> <font color='#BB00BB'>get_imagenet_train_listing</font><font face='Lucida Console'>(</font><font color='#BB00BB'>string</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font><font color='#5555FF'>+</font>"<font color='#CC0000'>/Data/CLS-LOC/train/</font>"<font face='Lucida Console'>)</font>;
  219     cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>images in dataset: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  220     <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> number_of_classes <font color='#5555FF'>=</font> listing.<font color='#BB00BB'>back</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.numeric_label<font color='#5555FF'>+</font><font color='#979000'>1</font>;
  221     <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> number_of_classes <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>1000</font><font face='Lucida Console'>)</font>
  222     <b>{</b>
  223         cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Didn't find the imagenet dataset. </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  224         <font color='#0000FF'>return</font> <font color='#979000'>1</font>;
  225     <b>}</b>
  226         
  227     <font color='#BB00BB'>set_dnn_prefer_smallest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  228 
  229 
  230     <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>;
  231     <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>;
  232     <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>;
  233 
  234     net_type net;
  235     dnn_trainer<font color='#5555FF'>&lt;</font>net_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>trainer</font><font face='Lucida Console'>(</font>net,<font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
  236     trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  237     trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>;
  238     trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>imagenet_trainer_state_file.dat</font>", std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
  239     <font color='#009900'>// This threshold is probably excessively large.  You could likely get good results
  240 </font>    <font color='#009900'>// with a smaller value but if you aren't in a hurry this value will surely work well.
  241 </font>    trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>20000</font><font face='Lucida Console'>)</font>;
  242     <font color='#009900'>// Since the progress threshold is so large might as well set the batch normalization
  243 </font>    <font color='#009900'>// stats window to something big too.
  244 </font>    <font color='#BB00BB'>set_all_bn_running_stats_window_sizes</font><font face='Lucida Console'>(</font>net, <font color='#979000'>1000</font><font face='Lucida Console'>)</font>;
  245 
  246     std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> samples;
  247     std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> labels;
  248 
  249     <font color='#009900'>// Start a bunch of threads that read images from disk and pull out random crops.  It's
  250 </font>    <font color='#009900'>// important to be sure to feed the GPU fast enough to keep it busy.  Using multiple
  251 </font>    <font color='#009900'>// thread for this kind of data preparation helps us do that.  Each thread puts the
  252 </font>    <font color='#009900'>// crops into the data queue.
  253 </font>    dlib::pipe<font color='#5555FF'>&lt;</font>std::pair<font color='#5555FF'>&lt;</font>image_info,matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>;
  254     <font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>data, <font color='#5555FF'>&amp;</font>listing]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font>
  255     <b>{</b>
  256         dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>seed<font face='Lucida Console'>)</font>;
  257         matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> img;
  258         std::pair<font color='#5555FF'>&lt;</font>image_info, matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> temp;
  259         <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
  260         <b>{</b>
  261             temp.first <font color='#5555FF'>=</font> listing[rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>];
  262             <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>img, temp.first.filename<font face='Lucida Console'>)</font>;
  263             <font color='#BB00BB'>randomly_crop_image</font><font face='Lucida Console'>(</font>img, temp.second, rnd<font face='Lucida Console'>)</font>;
  264             data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
  265         <b>}</b>
  266     <b>}</b>;
  267     std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
  268     std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
  269     std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
  270     std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
  271 
  272     <font color='#009900'>// The main training loop.  Keep making mini-batches and giving them to the trainer.
  273 </font>    <font color='#009900'>// We will run until the learning rate has dropped by a factor of 1e-3.
  274 </font>    <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> initial_learning_rate<font color='#5555FF'>*</font><font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>
  275     <b>{</b>
  276         samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  277         labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  278 
  279         <font color='#009900'>// make a 160 image mini-batch
  280 </font>        std::pair<font color='#5555FF'>&lt;</font>image_info, matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> img;
  281         <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> <font color='#979000'>160</font><font face='Lucida Console'>)</font>
  282         <b>{</b>
  283             data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>img<font face='Lucida Console'>)</font>;
  284 
  285             samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>img.second<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
  286             labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>img.first.numeric_label<font face='Lucida Console'>)</font>;
  287         <b>}</b>
  288 
  289         trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
  290     <b>}</b>
  291 
  292     <font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before
  293 </font>    <font color='#009900'>// moving on.
  294 </font>    data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  295     data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  296     data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  297     data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  298     data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  299 
  300     <font color='#009900'>// also wait for threaded processing to stop in the trainer.
  301 </font>    trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  302 
  303     net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  304     cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>saving network</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  305     <font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>resnet34.dnn</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> net;
  306 
  307 
  308 
  309 
  310 
  311 
  312     <font color='#009900'>// Now test the network on the imagenet validation dataset.  First, make a testing
  313 </font>    <font color='#009900'>// network with softmax as the final layer.  We don't have to do this if we just wanted
  314 </font>    <font color='#009900'>// to test the "top1 accuracy" since the normal network outputs the class prediction.
  315 </font>    <font color='#009900'>// But this snet object will make getting the top5 predictions easy as it directly
  316 </font>    <font color='#009900'>// outputs the probability of each class as its final output.
  317 </font>    softmax<font color='#5555FF'>&lt;</font>anet_type::subnet_type<font color='#5555FF'>&gt;</font> snet; snet.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
  318 
  319     cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Testing network on imagenet validation dataset...</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  320     <font color='#0000FF'><u>int</u></font> num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
  321     <font color='#0000FF'><u>int</u></font> num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
  322     <font color='#0000FF'><u>int</u></font> num_right_top1 <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
  323     <font color='#0000FF'><u>int</u></font> num_wrong_top1 <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
  324     dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
  325     <font color='#009900'>// loop over all the imagenet validation images
  326 </font>    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> l : <font color='#BB00BB'>get_imagenet_val_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>], argv[<font color='#979000'>2</font>]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
  327     <b>{</b>
  328         dlib::array<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> images;
  329         matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> img;
  330         <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>img, l.filename<font face='Lucida Console'>)</font>;
  331         <font color='#009900'>// Grab 16 random crops from the image.  We will run all of them through the
  332 </font>        <font color='#009900'>// network and average the results.
  333 </font>        <font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> num_crops <font color='#5555FF'>=</font> <font color='#979000'>16</font>;
  334         <font color='#BB00BB'>randomly_crop_images</font><font face='Lucida Console'>(</font>img, images, rnd, num_crops<font face='Lucida Console'>)</font>;
  335         <font color='#009900'>// p(i) == the probability the image contains object of class i.
  336 </font>        matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font>,<font color='#979000'>1</font>,<font color='#979000'>1000</font><font color='#5555FF'>&gt;</font> p <font color='#5555FF'>=</font> <font color='#BB00BB'>sum_rows</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font><font color='#BB00BB'>snet</font><font face='Lucida Console'>(</font>images.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, images.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>num_crops;
  337 
  338         <font color='#009900'>// check top 1 accuracy
  339 </font>        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>index_of_max</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> l.numeric_label<font face='Lucida Console'>)</font>
  340             <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right_top1;
  341         <font color='#0000FF'>else</font>
  342             <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong_top1;
  343 
  344         <font color='#009900'>// check top 5 accuracy
  345 </font>        <font color='#0000FF'><u>bool</u></font> found_match <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
  346         <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; k <font color='#5555FF'>&lt;</font> <font color='#979000'>5</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>k<font face='Lucida Console'>)</font>
  347         <b>{</b>
  348             <font color='#0000FF'><u>long</u></font> predicted_label <font color='#5555FF'>=</font> <font color='#BB00BB'>index_of_max</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>;
  349             <font color='#BB00BB'>p</font><font face='Lucida Console'>(</font>predicted_label<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
  350             <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>predicted_label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> l.numeric_label<font face='Lucida Console'>)</font>
  351             <b>{</b>
  352                 found_match <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
  353                 <font color='#0000FF'>break</font>;
  354             <b>}</b>
  355 
  356         <b>}</b>
  357         <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>found_match<font face='Lucida Console'>)</font>
  358             <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right;
  359         <font color='#0000FF'>else</font>
  360             <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong;
  361     <b>}</b>
  362     cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>val top5 accuracy:  </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> num_right<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right<font color='#5555FF'>+</font>num_wrong<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  363     cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>val top1 accuracy:  </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> num_right_top1<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right_top1<font color='#5555FF'>+</font>num_wrong_top1<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  364 <b>}</b>
  365 <font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&amp;</font> e<font face='Lucida Console'>)</font>
  366 <b>{</b>
  367     cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
  368 <b>}</b>
  369 
  370 
  371 </pre></body></html>