TensorFlow 0.12 Model Files

I train the model and save it using:

saver = tf.train.Saver() saver.save(session, './my_model_name') 

Besides the checkpoint file, which simply contains pointers to the most recent model checkpoints, it creates the following 3 files in the current path:

  • my_model_name.meta li>
  • my_model_name.index
  • my_model_name.data-00000-of-00001

It is interesting what each of these files contains.

I want to load this model in C ++ and run the output. The label_image example loads a model from a single .bp file using ReadBinaryProto() . I wonder how I can download it from these 3 files. What is the C ++ equivalent for the following?

 new_saver = tf.train.import_meta_graph('./my_model_name.meta') new_saver.restore(session, './my_model_name') 
+3
source share
2 answers

I am currently struggling with this; I have found that it is not so easy to do now. The two most commonly mentioned related tutorials are: https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f#.goxwm1e5j as well as https://medium.com/ @ hamedmp / exporting-trained-tensorflow-models-to-c-the-right-way-cf24b609d183 # .g1gak956i

Equivalent

 new_saver = tf.train.import_meta_graph('./my_model_name.meta') new_saver.restore(session, './my_model_name') 

Just

 Status load_graph_status = LoadGraph(graph_path, &session); 

Assuming you "frozen the chart" (using a script to combine the chart file with the control point values). Also see the discussion here: Tensorflow Various ways to export and run a graph in C ++

+3
source

What your splash screen creates is called β€œV2 Breakpoint” and was introduced in TF 0.12.

I did pretty well (although the C ++ docs are terrible, so it took me a whole day to solve them). Some people suggest converting all variables to constants or freezing the chart , but none of them are really needed.

Python part (save)

 with tf.Session() as sess: tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model') 

If you create a Saver using tf.trainable_variables() , you can save some headaches and space on tf.trainable_variables() . But maybe some more complex models need to save all the data, then delete this argument in Saver , just make sure you create Saver after creating the chart. It is also very reasonable to give all variables / layers unique names, otherwise you may run into various problems.

C ++ part (output)

Please note that checkpointPath is not a path to any of the existing files, but simply their common prefix. If you mistakenly specified the path to the .index file, TF will not tell you that this is wrong, but it will die during output due to uninitialized variables.

 #include <tensorflow/core/public/session.h> #include <tensorflow/core/protobuf/meta_graph.pb.h> using namespace std; using namespace tensorflow; ... // set up your input paths const string pathToGraph = "models/my-model.meta" const string checkpointPath = "models/my-model"; ... auto session = NewSession(SessionOptions()); if (session == nullptr) { throw runtime_error("Could not create Tensorflow session."); } Status status; // Read in the protobuf graph we exported MetaGraphDef graph_def; status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def); if (!status.ok()) { throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString()); } // Add the graph to the session status = session->Create(graph_def.graph_def()); if (!status.ok()) { throw runtime_error("Error creating graph: " + status.ToString()); } // Read weights from the saved checkpoint Tensor checkpointPathTensor(DT_STRING, TensorShape()); checkpointPathTensor.scalar<std::string>()() = checkpointPath; status = session->Run( {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },}, {}, {graph_def.saver_def().restore_op_name()}, nullptr); if (!status.ok()) { throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString()); } // and run the inference to your liking auto feedDict = ... auto outputOps = ... std::vector<tensorflow::Tensor> outputTensors; status = session->Run(feedDict, outputOps, {}, &outputTensors); 

For completeness, here is the Python equivalent:

Python output

 with tf.Session() as sess: saver = tf.train.import_meta_graph('models/my-model.meta') saver.restore(sess, tf.train.latest_checkpoint('models/')) outputTensors = sess.run(outputOps, feed_dict=feedDict) 
+6
source

Source: https://habr.com/ru/post/1268793/


All Articles