"Fossies" - the Fresh Open Source Software Archive

Member "tesseract-5.2.0/src/lstm/tfnetwork.h" (6 Jul 2022, 3737 Bytes) of package /linux/misc/tesseract-5.2.0.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "tfnetwork.h" see the Fossies "Dox" file reference documentation and the last Fossies "Diffs" side-by-side code changes report: 4.1.3_vs_5.0.0.

    1 ///////////////////////////////////////////////////////////////////////
    2 // File:        tfnetwork.h
    3 // Description: Encapsulation of an entire tensorflow graph as a
    4 //              Tesseract Network.
    5 // Author:      Ray Smith
    6 //
    7 // (C) Copyright 2016, Google Inc.
    8 // Licensed under the Apache License, Version 2.0 (the "License");
    9 // you may not use this file except in compliance with the License.
   10 // You may obtain a copy of the License at
   11 // http://www.apache.org/licenses/LICENSE-2.0
   12 // Unless required by applicable law or agreed to in writing, software
   13 // distributed under the License is distributed on an "AS IS" BASIS,
   14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   15 // See the License for the specific language governing permissions and
   16 // limitations under the License.
   17 ///////////////////////////////////////////////////////////////////////
   18 
   19 #ifndef TESSERACT_LSTM_TFNETWORK_H_
   20 #define TESSERACT_LSTM_TFNETWORK_H_
   21 
   22 #ifdef INCLUDE_TENSORFLOW
   23 
   24 #  include <memory>
   25 #  include <string>
   26 
   27 #  include "network.h"
   28 #  include "static_shape.h"
   29 #  include "tensorflow/core/framework/graph.pb.h"
   30 #  include "tensorflow/core/public/session.h"
   31 #  include "tfnetwork.pb.h"
   32 
   33 namespace tesseract {
   34 
   35 class TFNetwork : public Network {
   36 public:
   37   explicit TFNetwork(const char *name);
   38   virtual ~TFNetwork() = default;
   39 
   40   // Returns the required shape input to the network.
   41   StaticShape InputShape() const override {
   42     return input_shape_;
   43   }
   44   // Returns the shape output from the network given an input shape (which may
   45   // be partially unknown ie zero).
   46   StaticShape OutputShape(const StaticShape &input_shape) const override {
   47     return output_shape_;
   48   }
   49 
   50   std::string spec() const override {
   51     return spec_;
   52   }
   53 
   54   // Deserializes *this from a serialized TFNetwork proto. Returns 0 if failed,
   55   // otherwise the global step of the serialized graph.
   56   int InitFromProtoStr(const std::string &proto_str);
   57   // The number of classes in this network should be equal to those in the
   58   // recoder_ in LSTMRecognizer.
   59   int num_classes() const {
   60     return output_shape_.depth();
   61   }
   62 
   63   // Writes to the given file. Returns false in case of error.
   64   // Should be overridden by subclasses, but called by their Serialize.
   65   bool Serialize(TFile *fp) const override;
   66   // Reads from the given file. Returns false in case of error.
   67   // Should be overridden by subclasses, but NOT called by their DeSerialize.
   68   bool DeSerialize(TFile *fp) override;
   69 
   70   // Runs forward propagation of activations on the input line.
   71   // See Network for a detailed discussion of the arguments.
   72   void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
   73                NetworkScratch *scratch, NetworkIO *output) override;
   74 
   75 private:
   76   // Runs backward propagation of errors on the deltas line.
   77   // See Network for a detailed discussion of the arguments.
   78   bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
   79                 NetworkIO *back_deltas) override {
   80     tprintf("Must override Network::Backward for type %d\n", type_);
   81     return false;
   82   }
   83 
   84   void DebugWeights() override {
   85     tprintf("Must override Network::DebugWeights for type %d\n", type_);
   86   }
   87 
   88   int InitFromProto();
   89 
   90   // The original network definition for reference.
   91   std::string spec_;
   92   // Input tensor parameters.
   93   StaticShape input_shape_;
   94   // Output tensor parameters.
   95   StaticShape output_shape_;
   96   // The tensor flow graph is contained in here.
   97   std::unique_ptr<tensorflow::Session> session_;
   98   // The serialized graph is also contained in here.
   99   TFNetworkModel model_proto_;
  100 };
  101 
  102 } // namespace tesseract.
  103 
  104 #endif // ifdef INCLUDE_TENSORFLOW
  105 
  106 #endif // TESSERACT_TENSORFLOW_TFNETWORK_H_