How to map a function to an additional parameter using the new api dataset in TF1.3?

I play with the Dataset API in Tensorflow v1.3 . It's great. You can map the dataset to the function described here . I'm interested in learning how to pass a function that has an additional argument, like arg1 :

 def _parse_function(example_proto, arg1): features = {"image": tf.FixedLenFeature((), tf.string, default_value=""), "label": tf.FixedLenFeature((), tf.int32, default_value=0)} parsed_features = tf.parse_single_example(example_proto, features) return parsed_features["image"], parsed_features["label"] 

Sure,

 dataset = dataset.map(_parse_function) 

will not work as it is not possible to go to arg1 .

+5
source share
1 answer

Here is an example of using a lambda expression to transfer the function to which we want to pass an argument:

 import tensorflow as tf def fun(x, arg): return x * arg my_arg = tf.constant(2, dtype=tf.int64) ds = tf.data.Dataset.range(5) ds = ds.map(lambda x: fun(x, my_arg)) 

In the above signature, the function provided by map must match the contents of our dataset. Therefore, we must write our lambda expression so that it matches this. This is simple here, because there is only one element in the dataset, x , which contains elements in the range 0 to 4.

If necessary, you can pass an arbitrary number of external arguments from outside the data set: ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3) , etc.

To make sure that the above works, we can notice that the mapping really multiplies each element of the data set by two:

 iterator = ds.make_initializable_iterator() next_x = iterator.get_next() with tf.Session() as sess: sess.run(iterator.initializer) while True: try: print(sess.run(next_x)) except tf.errors.OutOfRangeError: break 

Output:

 0 2 4 6 8 
+5
source

Source: https://habr.com/ru/post/1271883/


All Articles