TensorFlow: Is there a way to measure FLOPS for a model?

The closest example I can get is found in this problem: https://github.com/tensorflow/tensorflow/issues/899

With this minimum reproducible code:

import tensorflow as tf import tensorflow.python.framework.ops as ops g = tf.Graph() with g.as_default(): A = tf.Variable(tf.random_normal( [25,16] )) B = tf.Variable(tf.random_normal( [16,9] )) C = tf.matmul(A,B) # shape=[25,9] for op in g.get_operations(): flops = ops.get_stats_for_node_def(g, op.node_def, 'flops').value if flops is not None: print 'Flops should be ~',2*25*16*9 print '25 x 25 x 9 would be',2*25*25*9 # ignores internal dim, repeats first print 'TF stats gives',flops 

However, the returned FLOPS are always None. Is there a way to specifically measure FLOPS, especially with a PB file?

+5
source share
1 answer

A bit late, but maybe this will help some visitors in the future. For your example, I successfully tested the following snippet:

 g = tf.Graph() run_meta = tf.RunMetadata() with g.as_default(): A = tf.Variable(tf.random_normal( [25,16] )) B = tf.Variable(tf.random_normal( [16,9] )) C = tf.matmul(A,B) # shape=[25,9] opts = tf.profiler.ProfileOptionBuilder.float_operation() flops = tf.profiler.profile(g, run_meta=run_meta, cmd='op', options=opts) if flops is not None: print('Flops should be ~',2*25*16*9) print('25 x 25 x 9 would be',2*25*25*9) # ignores internal dim, repeats first print('TF stats gives',flops.total_float_ops) 

It is also possible to use the profiler in combination with Keras , as well as the following snippet:

 import tensorflow as tf import keras.backend as K from keras.applications.mobilenet import MobileNet run_meta = tf.RunMetadata() with tf.Session(graph=tf.Graph()) as sess: K.set_session(sess) net = MobileNet(alpha=.75, input_tensor=tf.placeholder('float32', shape=(1,32,32,3))) opts = tf.profiler.ProfileOptionBuilder.float_operation() flops = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts) opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter() params = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts) print("{:,} --- {:,}".format(flops.total_float_ops, params.total_parameters)) 

I hope I can help!

+4
source

Source: https://habr.com/ru/post/1269805/


All Articles