Here is an example of using a lambda expression to transfer the function to which we want to pass an argument:
import tensorflow as tf def fun(x, arg): return x * arg my_arg = tf.constant(2, dtype=tf.int64) ds = tf.data.Dataset.range(5) ds = ds.map(lambda x: fun(x, my_arg))
In the above signature, the function provided by map must match the contents of our dataset. Therefore, we must write our lambda expression so that it matches this. This is simple here, because there is only one element in the dataset, x , which contains elements in the range 0 to 4.
If necessary, you can pass an arbitrary number of external arguments from outside the data set: ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3) , etc.
To make sure that the above works, we can notice that the mapping really multiplies each element of the data set by two:
iterator = ds.make_initializable_iterator() next_x = iterator.get_next() with tf.Session() as sess: sess.run(iterator.initializer) while True: try: print(sess.run(next_x)) except tf.errors.OutOfRangeError: break
Output:
0 2 4 6 8
source share