Get_variable () does not work after session recovery

I am trying to restore a session and call get_variable()to get an object like tf.Variable (according to this answer ). And he cannot find the variable. The minimum example to reproduce is as follows.

First create a variable and save the session.

import tensorflow as tf

var = tf.Variable(101)

with tf.Session() as sess:
    with tf.variable_scope(''):
        scoped_var = tf.get_variable('scoped_var', [])

    with tf.variable_scope('', reuse=True):
        new_scoped_var = tf.get_variable('scoped_var', [])

    assert scoped_var is new_scoped_var
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print(sess.run(scoped_var))
    saver.save(sess, 'data/sess')

Here get_variablesinside the area with reuse=Trueworks fine. Then restore the session from the file and try to get the variable.

import tensorflow as tf

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('data/sess.meta')
    saver.restore(sess, 'data/sess')

    for v in tf.get_collection('variables'):
        print(v.name)

    print(tf.get_collection(("__variable_store",)))
    # Oops, it empty!

    with tf.variable_scope('', reuse=True):
        # the next line fails
        new_scoped_var = tf.get_variable('scoped_var', [])

    print("new_scoped_var: ", new_scoped_var)

Conclusion:

Variable:0
scoped_var:0
[]
Traceback (most recent call last):
...
ValueError: Variable scoped_var does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

As we see, get_variable()cannot find a variable. And also ("__variable_store",), which is used internally get_variable(), is empty.

Why get_variabledoesn’t work?

+4
source share
1 answer

, ( , ..), .

import tensorflow as tf

with tf.Session() as sess:
  with tf.variable_scope(''):
    scoped_var = tf.get_variable('scoped_var', [])

  with tf.variable_scope('', reuse=True):
    new_scoped_var = tf.get_variable('scoped_var', [])

  assert scoped_var is new_scoped_var
  saver = tf.train.Saver()
  path = tf.train.get_checkpoint_state('data/sess')
  if path is not None:
    saver.restore(sess, path.model_checkpoint_path)
  else:
    sess.run(tf.global_variables_initializer())

  print(sess.run(scoped_var))
  saver.save(sess, 'data/sess')

  #now continue to use as you normally would with a restored model

, saver.restore

+1

Source: https://habr.com/ru/post/1672183/


All Articles