Recover a subset of variables in Tensorflow

I train the Generative Adversarial Network (GAN) in a tensor flow, where basically we have two different networks, each of which has its own optimizer.

self.G, self.layer = self.generator(self.inputCT,batch_size_tf) self.D, self.D_logits = self.discriminator(self.GT_1hot) ... self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step) self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \ .minimize(self.d_loss, var_list=self.d_vars) 

The problem is that I first train one of the networks (g), and then I want to train g and d together. However, when I call the load function:

 self.sess.run(tf.initialize_all_variables()) self.sess.graph.finalize() self.load(self.checkpoint_dir) def load(self, checkpoint_dir): print(" [*] Reading checkpoints...") ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, ckpt.model_checkpoint_path) return True else: return False 

I have an error similar to this (with lots of tracing):

 Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000 

I can restore the network g and continue learning using this function, but when I want to show d from scratch and g from a saved model, I have this error.

+6
source share
3 answers

To restore a subset of variables, you must create a new tf.train.Saver and pass it a specific list of variables to restore in optional var_list .

By default, tf.train.Saver will create operating systems that (i) save each variable in your graph when you call saver.save() and (ii) search (by name) for each variable at that breakpoint when you call saver.restore() . Although this works for most common scenarios, you should provide more information to work with specific subsets of variables:

  • If you want to restore only a subset of variables, you can get a list of these variables by calling tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX) , assuming you put the network "g" in common with tf.name_scope(G_NETWORK_PREFIX): or tf.variable_scope(G_NETWORK_PREFIX): You can then pass this list to the tf.train.Saver constructor.

  • If you want to restore a subset of a variable and / or the variables at the breakpoint have different names , you can pass the dictionary as a var_list argument. By default, each variable at the breakpoint is associated with a key, which is the value of its tf.Variable.name property. If the name on the other target graph is different (for example, since you added the area prefix), you can specify a dictionary that maps the string keys (in the checkpoint file) to tf.Variable objects (in the target graph).

+18
source

You can create a separate instance of tf.train.Saver() with the var_list argument set to the variables you want to restore. And create a separate instance to save the variables

0
source

Inspired by @mrry, I offer a solution to this problem. To make it clear, I formulate the problem as restoring a subset of a variable from a control point when the model is built on a pre-prepared model. First, we must use the print_tensors_in_checkpoint_file function from the inspect_checkpoint library , or simply extract this function:

 from tensorflow.python import pywrap_tensorflow def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): varlist=[] reader = pywrap_tensorflow.NewCheckpointReader(file_name) if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): varlist.append(key) return varlist varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None) 

Then we use tf.get_collection () in the same way as @mrry saied:

 variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 

Finally, we can initialize the keeper:

 saver = tf.train.Saver(variable[:len(varlist)]) 

The full version can be found in my github: https://github.com/pobingwanghai/tensorflow_trick/blob/master/restore_from_checkpoint.py

In my situation, new variables are added at the end of the model, so I can just use [: length ()] to identify the necessary variables, for a more complex situation you may have to perform some operations, alignment or write a simple string matching function to determine the required variables .

0
source

Source: https://habr.com/ru/post/1014032/


All Articles