Tesseract  3.02
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
errorcounter.cpp
Go to the documentation of this file.
1 // Copyright 2011 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
15 #include <ctime>
16 
17 #include "errorcounter.h"
18 
19 #include "fontinfo.h"
20 #include "ndminx.h"
21 #include "sampleiterator.h"
22 #include "shapeclassifier.h"
23 #include "shapetable.h"
24 #include "trainingsample.h"
25 #include "trainingsampleset.h"
26 #include "unicity_table.h"
27 
28 namespace tesseract {
29 
30 // Tests a classifier, computing its error rate.
31 // See errorcounter.h for description of arguments.
32 // Iterates over the samples, calling the classifier in normal/silent mode.
33 // If the classifier makes a CT_UNICHAR_TOPN_ERR error, and the appropriate
34 // report_level is set (4 or greater), it will then call the classifier again
35 // with a debug flag and a keep_this argument to find out what is going on.
37  int report_level, CountTypes boosting_mode,
38  const UnicityTable<FontInfo>& fontinfo_table,
39  const GenericVector<Pix*>& page_images, SampleIterator* it,
40  double* unichar_error, double* scaled_error, STRING* fonts_report) {
41  int charsetsize = it->shape_table()->unicharset().size();
42  int shapesize = it->CompactCharsetSize();
43  int fontsize = it->sample_set()->NumFonts();
44  ErrorCounter counter(charsetsize, shapesize, fontsize);
46 
47  clock_t start = clock();
48  int total_samples = 0;
49  double unscaled_error = 0.0;
50  // Set a number of samples on which to run the classify debug mode.
51  int error_samples = report_level > 3 ? report_level * report_level : 0;
52  // Iterate over all the samples, accumulating errors.
53  for (it->Begin(); !it->AtEnd(); it->Next()) {
54  TrainingSample* mutable_sample = it->MutableSample();
55  int page_index = mutable_sample->page_num();
56  Pix* page_pix = 0 <= page_index && page_index < page_images.size()
57  ? page_images[page_index] : NULL;
58  // No debug, no keep this.
59  classifier->ClassifySample(*mutable_sample, page_pix, 0, INVALID_UNICHAR_ID,
60  &results);
61  if (mutable_sample->class_id() == 0) {
62  // This is junk so use the special counter.
63  counter.AccumulateJunk(*it->shape_table(), results, mutable_sample);
64  } else if (counter.AccumulateErrors(report_level > 3, boosting_mode,
65  fontinfo_table, *it->shape_table(),
66  results, mutable_sample) &&
67  error_samples > 0) {
68  // Running debug, keep the correct answer, and debug the classifier.
69  tprintf("Error on sample %d: Classifier debug output:\n",
70  it->GlobalSampleIndex());
71  int keep_this = it->GetSparseClassID();
72  classifier->ClassifySample(*mutable_sample, page_pix, 1, keep_this,
73  &results);
74  --error_samples;
75  }
76  ++total_samples;
77  }
78  double total_time = 1.0 * (clock() - start) / CLOCKS_PER_SEC;
79  // Create the appropriate error report.
80  unscaled_error = counter.ReportErrors(report_level, boosting_mode,
81  fontinfo_table,
82  *it, unichar_error, fonts_report);
83  if (scaled_error != NULL) *scaled_error = counter.scaled_error_;
84  if (report_level > 1) {
85  // It is useful to know the time in microseconds/char.
86  tprintf("Errors computed in %.2fs at %.1f μs/char\n",
87  total_time, 1000000.0 * total_time / total_samples);
88  }
89  return unscaled_error;
90 }
91 
92 // Constructor is private. Only anticipated use of ErrorCounter is via
93 // the static ComputeErrorRate.
94 ErrorCounter::ErrorCounter(int charsetsize, int shapesize, int fontsize)
95  : scaled_error_(0.0), unichar_counts_(charsetsize, shapesize, 0) {
96  Counts empty_counts;
97  font_counts_.init_to_size(fontsize, empty_counts);
98 }
99 ErrorCounter::~ErrorCounter() {
100 }
101 
102 // Accumulates the errors from the classifier results on a single sample.
103 // Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred.
104 // boosting_mode selects the type of error to be used for boosting and the
105 // is_error_ member of sample is set according to whether the required type
106 // of error occurred. The font_table provides access to font properties
107 // for error counting and shape_table is used to understand the relationship
108 // between unichar_ids and shape_ids in the results
109 bool ErrorCounter::AccumulateErrors(bool debug, CountTypes boosting_mode,
110  const UnicityTable<FontInfo>& font_table,
111  const ShapeTable& shape_table,
112  const GenericVector<ShapeRating>& results,
113  TrainingSample* sample) {
114  int num_results = results.size();
115  int res_index = 0;
116  bool debug_it = false;
117  int font_id = sample->font_id();
118  int unichar_id = sample->class_id();
119  sample->set_is_error(false);
120  if (num_results == 0) {
121  // Reject. We count rejects as a separate category, but still mark the
122  // sample as an error in case any training module wants to use that to
123  // improve the classifier.
124  sample->set_is_error(true);
125  ++font_counts_[font_id].n[CT_REJECT];
126  } else if (shape_table.GetShape(results[0].shape_id).
127  ContainsUnicharAndFont(unichar_id, font_id)) {
128  ++font_counts_[font_id].n[CT_SHAPE_TOP_CORRECT];
129  // Unichar and font OK, but count if multiple unichars.
130  if (shape_table.GetShape(results[0].shape_id).size() > 1)
131  ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR];
132  } else {
133  // This is a top shape error.
134  ++font_counts_[font_id].n[CT_SHAPE_TOP_ERR];
135  // Check to see if any font in the top choice has attributes that match.
136  bool attributes_match = false;
137  uinT32 font_props = font_table.get(font_id).properties;
138  const Shape& shape = shape_table.GetShape(results[0].shape_id);
139  for (int c = 0; c < shape.size() && !attributes_match; ++c) {
140  for (int f = 0; f < shape[c].font_ids.size(); ++f) {
141  if (font_table.get(shape[c].font_ids[f]).properties == font_props) {
142  attributes_match = true;
143  break;
144  }
145  }
146  }
147  // TODO(rays) It is easy to add counters for individual font attributes
148  // here if we want them.
149  if (!attributes_match)
150  ++font_counts_[font_id].n[CT_FONT_ATTR_ERR];
151  if (boosting_mode == CT_SHAPE_TOP_ERR) sample->set_is_error(true);
152  // Find rank of correct unichar answer. (Ignoring the font.)
153  while (res_index < num_results &&
154  !shape_table.GetShape(results[res_index].shape_id).
155  ContainsUnichar(unichar_id)) {
156  ++res_index;
157  }
158  if (res_index == 0) {
159  // Unichar OK, but count if multiple unichars.
160  if (shape_table.GetShape(results[res_index].shape_id).size() > 1) {
161  ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR];
162  }
163  } else {
164  // Count maps from unichar id to shape id.
165  if (num_results > 0)
166  ++unichar_counts_(unichar_id, results[0].shape_id);
167  // This is a unichar error.
168  ++font_counts_[font_id].n[CT_UNICHAR_TOP1_ERR];
169  if (boosting_mode == CT_UNICHAR_TOP1_ERR) sample->set_is_error(true);
170  if (res_index >= MIN(2, num_results)) {
171  // It is also a 2nd choice unichar error.
172  ++font_counts_[font_id].n[CT_UNICHAR_TOP2_ERR];
173  if (boosting_mode == CT_UNICHAR_TOP2_ERR) sample->set_is_error(true);
174  }
175  if (res_index >= num_results) {
176  // It is also a top-n choice unichar error.
177  ++font_counts_[font_id].n[CT_UNICHAR_TOPN_ERR];
178  if (boosting_mode == CT_UNICHAR_TOPN_ERR) sample->set_is_error(true);
179  debug_it = debug;
180  }
181  }
182  }
183  // Compute mean number of return values and mean rank of correct answer.
184  font_counts_[font_id].n[CT_NUM_RESULTS] += num_results;
185  font_counts_[font_id].n[CT_RANK] += res_index;
186  // If it was an error for boosting then sum the weight.
187  if (sample->is_error()) {
188  scaled_error_ += sample->weight();
189  }
190  if (debug_it) {
191  tprintf("%d results for char %s font %d :",
192  num_results, shape_table.unicharset().id_to_unichar(unichar_id),
193  font_id);
194  for (int i = 0; i < num_results; ++i) {
195  tprintf(" %.3f/%.3f:%s",
196  results[i].rating, results[i].font,
197  shape_table.DebugStr(results[i].shape_id).string());
198  }
199  tprintf("\n");
200  return true;
201  }
202  return false;
203 }
204 
205 // Accumulates counts for junk. Counts only whether the junk was correctly
206 // rejected or not.
207 void ErrorCounter::AccumulateJunk(const ShapeTable& shape_table,
208  const GenericVector<ShapeRating>& results,
209  TrainingSample* sample) {
210  // For junk we accept no answer, or an explicit shape answer matching the
211  // class id of the sample.
212  int num_results = results.size();
213  int font_id = sample->font_id();
214  int unichar_id = sample->class_id();
215  if (num_results > 0 &&
216  !shape_table.GetShape(results[0].shape_id).ContainsUnichar(unichar_id)) {
217  // This is a junk error.
218  ++font_counts_[font_id].n[CT_ACCEPTED_JUNK];
219  sample->set_is_error(true);
220  // It counts as an error for boosting too so sum the weight.
221  scaled_error_ += sample->weight();
222  } else {
223  // Correctly rejected.
224  ++font_counts_[font_id].n[CT_REJECTED_JUNK];
225  sample->set_is_error(false);
226  }
227 }
228 
229 // Creates a report of the error rate. The report_level controls the detail
230 // that is reported to stderr via tprintf:
231 // 0 -> no output.
232 // >=1 -> bottom-line error rate.
233 // >=3 -> font-level error rate.
234 // boosting_mode determines the return value. It selects which (un-weighted)
235 // error rate to return.
236 // The fontinfo_table from MasterTrainer provides the names of fonts.
237 // The it determines the current subset of the training samples.
238 // If not NULL, the top-choice unichar error rate is saved in unichar_error.
239 // If not NULL, the report string is saved in fonts_report.
240 // (Ignoring report_level).
241 double ErrorCounter::ReportErrors(int report_level, CountTypes boosting_mode,
242  const UnicityTable<FontInfo>& fontinfo_table,
243  const SampleIterator& it,
244  double* unichar_error,
245  STRING* fonts_report) {
246  // Compute totals over all the fonts and report individual font results
247  // when required.
248  Counts totals;
249  int fontsize = font_counts_.size();
250  for (int f = 0; f < fontsize; ++f) {
251  // Accumulate counts over fonts.
252  totals += font_counts_[f];
253  STRING font_report;
254  if (ReportString(font_counts_[f], &font_report)) {
255  if (fonts_report != NULL) {
256  *fonts_report += fontinfo_table.get(f).name;
257  *fonts_report += ": ";
258  *fonts_report += font_report;
259  *fonts_report += "\n";
260  }
261  if (report_level > 2) {
262  // Report individual font error rates.
263  tprintf("%s: %s\n", fontinfo_table.get(f).name, font_report.string());
264  }
265  }
266  }
267  if (report_level > 0) {
268  // Report the totals.
269  STRING total_report;
270  if (ReportString(totals, &total_report)) {
271  tprintf("TOTAL Scaled Err=%.4g%%, %s\n",
272  scaled_error_ * 100.0, total_report.string());
273  }
274  // Report the worst substitution error only for now.
275  if (totals.n[CT_UNICHAR_TOP1_ERR] > 0) {
276  const UNICHARSET& unicharset = it.shape_table()->unicharset();
277  int charsetsize = unicharset.size();
278  int shapesize = it.CompactCharsetSize();
279  int worst_uni_id = 0;
280  int worst_shape_id = 0;
281  int worst_err = 0;
282  for (int u = 0; u < charsetsize; ++u) {
283  for (int s = 0; s < shapesize; ++s) {
284  if (unichar_counts_(u, s) > worst_err) {
285  worst_err = unichar_counts_(u, s);
286  worst_uni_id = u;
287  worst_shape_id = s;
288  }
289  }
290  }
291  if (worst_err > 0) {
292  tprintf("Worst error = %d:%s -> %s with %d/%d=%.2f%% errors\n",
293  worst_uni_id, unicharset.id_to_unichar(worst_uni_id),
294  it.shape_table()->DebugStr(worst_shape_id).string(),
295  worst_err, totals.n[CT_UNICHAR_TOP1_ERR],
296  100.0 * worst_err / totals.n[CT_UNICHAR_TOP1_ERR]);
297  }
298  }
299  }
300  double rates[CT_SIZE];
301  if (!ComputeRates(totals, rates))
302  return 0.0;
303  // Set output values if asked for.
304  if (unichar_error != NULL)
305  *unichar_error = rates[CT_UNICHAR_TOP1_ERR];
306  return rates[boosting_mode];
307 }
308 
309 // Sets the report string to a combined human and machine-readable report
310 // string of the error rates.
311 // Returns false if there is no data, leaving report unchanged.
312 bool ErrorCounter::ReportString(const Counts& counts, STRING* report) {
313  // Compute the error rates.
314  double rates[CT_SIZE];
315  if (!ComputeRates(counts, rates))
316  return false;
317  // Using %.4g%%, the length of the output string should exactly match the
318  // length of the format string, but in case of overflow, allow for +eddd
319  // on each number.
320  const int kMaxExtraLength = 5; // Length of +eddd.
321  // Keep this format string and the snprintf in sync with the CountTypes enum.
322  const char* format_str = "ShapeErr=%.4g%%, FontAttr=%.4g%%, "
323  "Unichar=%.4g%%[1], %.4g%%[2], %.4g%%[n], "
324  "Multi=%.4g%%, Rej=%.4g%%, "
325  "Answers=%.3g, Rank=%.3g, "
326  "OKjunk=%.4g%%, Badjunk=%.4g%%";
327  int max_str_len = strlen(format_str) + kMaxExtraLength * (CT_SIZE - 1) + 1;
328  char* formatted_str = new char[max_str_len];
329  snprintf(formatted_str, max_str_len, format_str,
330  rates[CT_SHAPE_TOP_ERR] * 100.0,
331  rates[CT_FONT_ATTR_ERR] * 100.0,
332  rates[CT_UNICHAR_TOP1_ERR] * 100.0,
333  rates[CT_UNICHAR_TOP2_ERR] * 100.0,
334  rates[CT_UNICHAR_TOPN_ERR] * 100.0,
335  rates[CT_OK_MULTI_UNICHAR] * 100.0,
336  rates[CT_REJECT] * 100.0,
337  rates[CT_NUM_RESULTS],
338  rates[CT_RANK],
339  100.0 * rates[CT_REJECTED_JUNK],
340  100.0 * rates[CT_ACCEPTED_JUNK]);
341  *report = formatted_str;
342  delete [] formatted_str;
343  // Now append each field of counts with a tab in front so the result can
344  // be loaded into a spreadsheet.
345  for (int ct = 0; ct < CT_SIZE; ++ct)
346  report->add_str_int("\t", counts.n[ct]);
347  return true;
348 }
349 
350 // Computes the error rates and returns in rates which is an array of size
351 // CT_SIZE. Returns false if there is no data, leaving rates unchanged.
352 bool ErrorCounter::ComputeRates(const Counts& counts, double rates[CT_SIZE]) {
353  int ok_samples = counts.n[CT_SHAPE_TOP_CORRECT] + counts.n[CT_SHAPE_TOP_ERR] +
354  counts.n[CT_REJECT];
355  int junk_samples = counts.n[CT_REJECTED_JUNK] + counts.n[CT_ACCEPTED_JUNK];
356  if (ok_samples == 0 && junk_samples == 0) {
357  // There is no data.
358  return false;
359  }
360  // Compute rates for normal chars.
361  double denominator = static_cast<double>(MAX(ok_samples, 1));
362  for (int ct = 0; ct <= CT_RANK; ++ct)
363  rates[ct] = counts.n[ct] / denominator;
364  // Compute rates for junk.
365  denominator = static_cast<double>(MAX(junk_samples, 1));
366  for (int ct = CT_REJECTED_JUNK; ct <= CT_ACCEPTED_JUNK; ++ct)
367  rates[ct] = counts.n[ct] / denominator;
368  return true;
369 }
370 
371 ErrorCounter::Counts::Counts() {
372  memset(n, 0, sizeof(n[0]) * CT_SIZE);
373 }
374 // Adds other into this for computing totals.
375 void ErrorCounter::Counts::operator+=(const Counts& other) {
376  for (int ct = 0; ct < CT_SIZE; ++ct)
377  n[ct] += other.n[ct];
378 }
379 
380 
381 } // namespace tesseract.
382 
383 
384 
385 
386