It seems that tf.argmax works like np.argmax (according to the test ), which will return the first index if there are several occurrences of the maximum value. You can use tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1) to get what you want. However, at present, the behavior of tf.argmax is undefined in the case of multiple occurrences of the maximum value.
If you are concerned about undefined behavior, you can apply tf.argmin to the tf.argmin return value, as @Igor Tsvetkov suggested. For instance,
# test with tensorflow r1.0 import tensorflow as tf val = 3 m = tf.placeholder(tf.int32) m_feed = [[0 , 0, val, 0, val], [val, 0, val, val, 0], [0 , val, 0, 0, 0]] tmp_indices = tf.where(tf.equal(m, val)) result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0]) with tf.Session() as sess: print(sess.run(result, feed_dict={m: m_feed}))
Note that tf.segment_min will raise an InvalidArgumentError when there is some string containing val . An exception will also be thrown in your row_elems.index(val) code if row_elems does not contain val .
Jenny source share