I am trying to update a two-dimensional tensor in a nested while_loop(). However, passing the variable to the second loop, I cannot update it with tf.assign(), because it throws this error:
ValueError: Sliced assignment is only supported for variables
Somehow it works fine if I create a variable outside of while_loop and use it only in the first loop.
How can I change my two-dimensional variable tf in the second while loop?
(I am using python 2.7 and TensorFlow 1.2)
My code is:
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
BATCH_SIZE = 10
LENGTH_MAX_OUTPUT = 31
it_batch_nr = tf.constant(0)
it_row_nr = tf.Variable(0, dtype=tf.int32)
it_col_nr = tf.constant(0)
cost = tf.constant(0)
it_batch_end = lambda it_batch_nr, cost: tf.less(it_batch_nr, BATCH_SIZE)
it_row_end = lambda it_row_nr, cost_matrix: tf.less(it_row_nr, LENGTH_MAX_OUTPUT+1)
def iterate_batch(it_batch_nr, cost):
cost_matrix = tf.Variable(np.ones((LENGTH_MAX_OUTPUT+1, LENGTH_MAX_OUTPUT+1)), dtype=tf.float32)
it_rows, cost_matrix = tf.while_loop(it_row_end, iterate_row, [it_row_nr, cost_matrix])
cost = cost_matrix[0,0]
return tf.add(it_batch_nr,1), cost
def iterate_row(it_row_nr, cost_matrix):
cost_matrix[0,0].assign(100.0)
return tf.add(it_row_nr,1), cost_matrix
it_batch = tf.while_loop(it_batch_end, iterate_batch, [it_batch_nr, cost])
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
out = sess.run(it_batch)
print(out)
source
share