How to find the index of the first matching element in TensorFlow

I am looking for a TensorFlow method to implement something like the Python list.index () function.

Given the matrix and the value to search, I want to know the first occurrence of the value in each row of the matrix.

For instance,

m is a <batch_size, 100> matrix of integers val = 23 result = [0] * batch_size for i, row_elems in enumerate(m): result[i] = row_elems.index(val) 

I cannot assume that "val" appears only once on each line, otherwise I would execute it with tf.argmax (m == val). In my case, it is important to get the index of the first occurrence of "val", not any.

+5
source share
2 answers

It seems that tf.argmax works like np.argmax (according to the test ), which will return the first index if there are several occurrences of the maximum value. You can use tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1) to get what you want. However, at present, the behavior of tf.argmax is undefined in the case of multiple occurrences of the maximum value.

If you are concerned about undefined behavior, you can apply tf.argmin to the tf.argmin return value, as @Igor Tsvetkov suggested. For instance,

 # test with tensorflow r1.0 import tensorflow as tf val = 3 m = tf.placeholder(tf.int32) m_feed = [[0 , 0, val, 0, val], [val, 0, val, val, 0], [0 , val, 0, 0, 0]] tmp_indices = tf.where(tf.equal(m, val)) result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0]) with tf.Session() as sess: print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1] 

Note that tf.segment_min will raise an InvalidArgumentError when there is some string containing val . An exception will also be thrown in your row_elems.index(val) code if row_elems does not contain val .

+7
source

It looks a little ugly, but it works (provided that m and val are both tensors):

 idx = list() for t in tf.unpack(m, axis=0): idx.append(tf.reduce_min(tf.where(tf.equal(t, val)))) idx = tf.pack(idx, axis=0) 

EDIT: As Yaroslav Bulatov said, you can achieve the same result with tf.map_fn :

 def index1d(t): return tf.reduce_min(tf.where(tf.equal(t, val))) idx = tf.map_fn(index1d, m, dtype=tf.int64) 
+2
source

Source: https://habr.com/ru/post/1264118/


All Articles