The link provided in Roi's own comment was really helpful. Since I struggled with the same question for some time, I would like to generalize the answer given in the link above as a reference:
def batched_input_fn(dataset_x, dataset_y, batch_size): def _input_fn(): all_x = tf.constant(dataset_x, shape=dataset_x.shape, dtype=tf.float32) all_y = tf.constant(dataset_y, shape=dataset_y.shape, dtype=tf.float32) sliced_input = tf.train.slice_input_producer([all_x, all_y]) return tf.train.batch(sliced_input, batch_size=batch_size) return _input_fn
Then it can be used as this example (using TensorFlow v1.1):
model = CustomModel(FLAGS.learning_rate) estimator= tf.estimator.Estimator(model_fn=model.build(), params=model.params()) estimator.train(input_fn=batched_input_fn( train.features, train.labels, FLAGS.batch_size), steps=FLAGS.train_steps)
Unfortunately, this approach is about 10x slower compared to manual feed (using the low-level TensorFlows API) or compared to using the entire dataset with train.shape[0] == batch_size
rather than using train.sliced_input_producer()
and train.batch()
in general. At least on my machine (processor only). I'm really curious why this approach is so slow. Any ideas?
Edited by:
I could speed it up a bit by using num_threads
> 1 as the parameter for train.batch()
. On a virtual machine with two processors, I can double the performance using this batch mechanism compared to the default num_threads=1
. But, nevertheless, it is 5x slower than manual feed. But the results can be different in your own system or a system that uses all the CPU cores for the input pipeline and the GPU to calculate the model. It would be great if someone could post their results in the comments.
source share