How to find the names of variables that are stored in the control point of the tensor flow?

I want to see variables that are stored in the tensorflow breakpoint along with their values. How to find the names of variables that are stored in the control point of the tensor flow?

EDIT:

I used tf.train.NewCheckpointReader , which is explained here . But this is not given in the tensor flow documentation. Is there another way?

`

  import tensorflow as tf v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0") v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32, name="v1") init_all_op = tf.initialize_all_variables() save = tf.train.Saver({"v0": v0, "v1": v1}) checkpoint_path = os.path.join(model_dir, "model.ckpt") with tf.Session() as sess: sess.run(init_all_op) # Saves a checkpoint. save.save(sess, checkpoint_path) # Creates a reader. reader = tf.train.NewCheckpointReader(checkpoint_path) print('reder:\n', reader) # Verifies that the tensors exist. print('is exist v0?', reader.has_tensor("v0")) print('is exist v1?', reader.has_tensor("v1")) # Verifies that debug string contains the right strings. debug_string = reader.debug_string() print('\n All Variables: \n', debug_string) # Verifies get_variable_to_shape_map() returns the correct information. var_map = reader.get_variable_to_shape_map() print('\n All Variables information :\n', var_map) # Verifies get_tensor() returns the tensor value. v0_tensor = reader.get_tensor("v0") v1_tensor = reader.get_tensor("v1") print('\n returns the v0 tensor value:\n', v0_tensor) print('\n returns the v1 tensor value:\n', v1_tensor) 

`

+11
source share
3 answers

You can use the inspect_checkpoint.py tool.

+4
source

Using an example:

 from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file checkpoint_path = os.path.join(model_dir, "model.ckpt") # List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80] print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='') # List contents of v0 tensor. # Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02 print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0') # List contents of v1 tensor. print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1') 

Update: the all_tensors argument has been added to print_tensors_in_checkpoint_file since Tensorflow is 0.12.0-rc0 so you may need to add all_tensors=False or all_tensors=True if necessary.

Alternative method:

 from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join(model_dir, "model.ckpt") reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key)) # Remove this is you want to print only variable names 

Hope this helps.

+19
source

Adding to the previous answer:

If the model is saved in V2 format

 model-10000.data-00000-of-00001 model-10000.index model-10000.meta 

The entered checkpoint name should only be a prefix

 print_tensors_in_checkpoint_file(file_name='/home/RNN/models/model_10000', tensor_name='',all_tensors=True) 

source: by @LingjiaDeng at https://github.com/tensorflow/tensorflow/issues/7696

+1
source

Source: https://habr.com/ru/post/1014441/


All Articles