Tensorflow prints all placeholder variable names from a metagraph

I have a tensor flow model for which I have .meta and control point files. I am trying to print all the placeholders that the model requires without looking at the code that the model created, so that I can build the feed_dict input file without knowing how the model was created. For reference, here is the model building code (in another file)

def save(): import tensorflow as tf v1 = tf.placeholder(tf.float32, name="v1") v2 = tf.placeholder(tf.float32, name="v2") v3 = tf.multiply(v1, v2) vx = tf.Variable(10.0, name="vx") v4 = tf.add(v3, vx, name="v4") saver = tf.train.Saver() sess = tf.Session() sess.run(tf.initialize_all_variables()) sess.run(vx.assign(tf.add(vx, vx))) result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) print(result) saver.save(sess, "./model_ex1") 

Now in another file I have the following recovery code

 def restore(): import tensorflow as tf saver = tf.train.import_meta_graph("./model_ex1.meta") print(tf.get_default_graph().get_all_collection_keys()) for v in tf.get_default_graph().get_collection("variables"): print(v) for v in tf.get_default_graph().get_collection("trainable_variables"): print(v) sess = tf.Session() saver.restore(sess, "./model_ex1") result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 4.0}) print(result) 

However, when I print all the variables as above, I do not see "v1: 0" and "v2: 0" as variable names anywhere. How to determine what names of tensor placeholders had without looking at the code to create the model?

+5
source share
1 answer

The tensors v1:0 and v2:0 were created from tf.placeholder() ops, while tf.Variable are added to the "variables" (or "trainable_variables" ) collections. There is no general collection to which tf.placeholder() op operators are added, so your options are:

  • Add the tf.placeholder() option to the collection (using tf.add_to_collection() when building the original chart. You may need to add more metadata to indicate how placeholders should be used.

  • Use [x for x in tf.get_default_graph().get_operations() if x.type == "PlaceholderV2"] to get a list of placeholder statements after importing metadata.

+3
source

Source: https://habr.com/ru/post/1268341/


All Articles