Tensor Slow Prediction Estimator

I trained tf.estimator.LinearClassifier. Although training and evaluating the model requires a reasonable amount of time for my data size (~ 60 seconds), forecasting takes many orders of magnitude more (~ 1 hour).

The forecast code is as follows:

predictionResult = estimator.predict(input_fn=lambda: my_input_fn2(predictionValidationFile, False, 1)) predictionList = [prediction for prediction in predictionResult] 

with:

 def my_input_fn2(file_path, perform_shuffle=False, repeat_count=1): def _parse_function(example_proto): keys_to_features = {"xslm": tf.FixedLenFeature([10000], tf.float32), "xrnn": tf.FixedLenFeature([10000], tf.float32), "target": tf.FixedLenFeature([10000], tf.float32)} parsed_features = tf.parse_single_example(example_proto, keys_to_features) myfeatures = {'xrnn':parsed_features['xrnn'], 'xslm':parsed_features['xslm']} return myfeatures, parsed_features['target'] dataset = (tf.data.TFRecordDataset(file_path) .map(_parse_function)) dataset = dataset.repeat(repeat_count) dataset = dataset.batch(1) iterator = dataset.make_one_shot_iterator() batch_feature, batch_labels = iterator.get_next() xs= tf.reshape(batch_feature['xslm'],[-1,1]) xr= tf.reshape(batch_feature['xrnn'],[-1,1]) x = {'xrnn':xr, 'xslm':xs} y = tf.reshape(batch_labels, [-1,1]) return x, y 

The second line takes 0.8 seconds to win over 10,000 samples (which corresponds to one batch). With 50,000,000 samples, forecasting takes more than one hour.

My assumption at this point is that this slow performance is simply caused by the fact that the predator () function of the calculate function returns the python generator instead of returning the actual prediction results. For each batch, the generator ultimately calls 10,000 function calls to get 10,000 forecast results. This seems ineffective.

Are there any options to speed things up?

+5
source share
1 answer

You are right that the reason is that it is slow. This makes function calls for each element, since your default bach size in functions is 1.

You must pass the packet size of the function as a parameter and replace

 dataset = dataset.batch(1) 

with

 dataset = dataset.batch(batch_size) 
0
source

Source: https://habr.com/ru/post/1274043/


All Articles