Basics of determining tensor flow

I want to train my Tensorflow model, freeze a snapshot, and then run it in live transfer mode (without further preparation) with new input data. Questions:

  • Are tf.train.export_meta_graph and tf.train.import_meta_graph right tools for this?
  • Do I need to include in the collection_list names of all the variables that I want to include in the snapshot? (The easiest way would be to include everything.)
  • Tensorflow docs say, "If collection_list not specified, all collections in the model will be exported." Does this mean that if I do not specify any variables in collection_list , then all the variables in the model will be exported, because they are in the collection by default?
  • Tensorflow docs say, "For a Python object to be serialized to and from MetaGraphDef, the Python class must implement the to_proto () and from_proto () methods and register them with the system using register_proto_function." Does this mean that to_proto() and from_proto() should only be added to the classes that I have defined and want to export? If I use only standard Python data types (int, float, list, dict), then that doesn't matter?

Thanks in advance.

+5
source share
1 answer

A little late, but I still try to answer.

  • Are tf.train.export_meta_graph and tf.train.import_meta_graph right tools for this?

I would say so. Note that tf.train.export_meta_graph is called implicitly when saving the model using tf.train.Saver . The bottom line is this:

 # create the model ... saver = tf.train.Saver() with tf.Session() as sess: ... # save graph and variables # if you are using global_step, the saver will automatically keep the n=5 latest checkpoints saver.save(sess, save_path, global_step) 

Then to recover:

 save_path = ... latest_checkpoint = tf.train.latest_checkpoint(save_path) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta') with tf.Session() as sess: saver.restore(sess, latest_checkpoint) 

Note that instead of calling tf.train.import_meta_graph you can also call the source code snippet that you used to create the model in the first place. However, I think it’s more elegant to use import_meta_graph , since you can also restore your model, even if you do not have access to the code that created it.


  1. Do I need to include collection_list names of all the variables that I want to include in the snapshot? (The easiest way would be to include everything.)

No. However, the question is a bit confused: collection_list in export_meta_graph should not be a list of variables, but collections (i.e. a list of string keys).

Collections are very convenient, for example. all learning variables are automatically included in the tf.GraphKeys.TRAINABLE_VARIABLES collection, which you can get by calling:

 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 

or

 tf.trainable_variables() # defaults to the default graph 

If after recovery you need access to intermediate results other than your learning variables, I find it convenient to add them to your user collection, for example:

 ... input_ = tf.placeholder(tf.float32, shape=[64, 64]) .... tf.add_to_collection('my_custom_collection', input_) 

This collection is automatically saved (unless you specifically specify without dropping the name of this collection in the collection_list argument to export_meta_graph ). That way, you can simply restore the input_ after recovery as follows:

 ... with tf.Session() as sess: saver.restore(sess, latest_checkpoint) input_ = tf.get_collection_ref('my_custom_collection')[0] 

  1. Tensorflow docs say, "If collection_list not specified, all collections in the model will be exported." Does this mean that if I do not specify the variables in collection_list , then all the variables in the model will be exported, because they are in the collection by default?

Yes. Again, notice the subtle details that collection_list is a list of collections, not variables. In fact, if you want certain variables to be saved, you can specify them when creating the tf.train.Saver object. From the tf.train.Saver.__init__ :

  """Creates a `Saver`. The constructor adds ops to save and restore variables. `var_list` specifies the variables that will be saved and restored. It can be passed as a `dict` or a list: * A `dict` of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files. * A list of variables: The variables will be keyed with their op name in the checkpoint files. 

  1. Tensorflow docs say: "For a Python object to be serialized to and from MetaGraphDef, the Python class must implement to_proto() and from_proto() and register them in the system using register_proto_function." Does this mean that to_proto() and from_proto() should only be added to the classes that I defined and want to export? If I use only standard Python data types (int, float, list, dict), is this inappropriate?

I never used this function, but I would say that your interpretation is correct.

+2
source

Source: https://habr.com/ru/post/1257555/


All Articles