Epoch Counter with TensorFlow Dataset API

I am changing my TensorFlow code from the old queue interface to the new Dataset API . In my old codec, I tracked the era score, while increasing tf.Variable every time a new input tensor gets access and is processed in the queue. I would like this epoch counter with the new Dataset API, but I am having problems with its operation.

Since I create a variable number of data elements at the pre-processing stage, this is not just a question of increasing the counter (Python) in the training cycle - I need to calculate the epoch counter with respect to entering queues or a data set.

I emulated what I had before with the old queue system, and here is what I got for the Dataset API (simplified example):

 with tf.Graph().as_default(): data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data") input_tensors = (data,) epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32, trainable=False) def pre_processing_func(data_): data_size = tf.constant(0.1, dtype=tf.float32) epoch_counter_op = tf.assign_add(epoch_counter, data_size) with tf.control_dependencies([epoch_counter_op]): # normally I would do data-augmentation here results = (tf.expand_dims(data_, axis=0),) return tf.data.Dataset.from_tensor_slices(results) dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors) dataset = dataset_source.flat_map(pre_processing_func) dataset = dataset.repeat() # ... do something with 'dataset' and print # the value of 'epoch_counter' every once a while 

However, this does not work. It crashes with a cryptic error message:

  TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32]) are not compatible with expected types ([tf.float32_ref, tf.float32]) 

A closer check reveals that the epoch_counter variable epoch_counter not be available in pre_processing_func at all. Can he live on a different schedule?

Any idea how to fix the above example? Or how to get an era counter (with decimal points, for example 0.4 or 2.9), using some other means?

+5
source share
1 answer

TL DR : replace the definition of epoch_counter with the following:

 epoch_counter = tf.get_variable("epoch_counter", initializer=0.0, trainable=False, use_resource=True) 

There are some limitations associated with using TensorFlow variables inside tf.data.Dataset transformations. The main limitation is that all variables must be "resource variables" and not older "reference variables"; Unfortunately, tf.Variable still creates "reference variables" for backward compatibility reasons.

Generally speaking, I would not recommend using variables in the tf.data pipeline if you can avoid it. For example, you can use Dataset.range() to define an epoch counter, and then do something like:

 epoch_counter = tf.data.Dataset.range(NUM_EPOCHS) dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip( (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat())) 

The above snippet binds an epoch counter to each value as a second component.

+3
source

Source: https://habr.com/ru/post/1273543/


All Articles