How to sort a multidimensional tensor using tf.nn.top_k return indices?

I have two multidimensional tensors a and b . And I want to sort them by a values.

I found tf.nn.top_k is able to sort the tensor and return the indices that are used to sort the input. How to use return indices from tf.nn.top_k(a, k=2) to sort b ?

For instance,

 import tensorflow as tf a = tf.reshape(tf.range(30), (2, 5, 3)) b = tf.reshape(tf.range(210), (2, 5, 3, 7)) k = 2 sorted_a, indices = tf.nn.top_k(a, k) # How to sort b into # sorted_b[0, 0, 0, :] = b[0, 0, indices[0, 0, 0], :] # sorted_b[0, 0, 1, :] = b[0, 0, indices[0, 0, 1], :] # sorted_b[0, 1, 0, :] = b[0, 1, indices[0, 1, 0], :] # ... 

Update

Combining tf.gather_nd with tf.meshgrid might be one solution. For example, the following code is tested on python 3.5 using the tensorflow 1.0.0-rc0 :

 a = tf.reshape(tf.range(30), (2, 5, 3)) b = tf.reshape(tf.range(210), (2, 5, 3, 7)) k = 2 sorted_a, indices = tf.nn.top_k(a, k) shape_a = tf.shape(a) auxiliary_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(shape_a[:(a.get_shape().ndims - 1)]) + [k])], indexing='ij') sorted_b = tf.gather_nd(b, tf.stack(auxiliary_indices[:-1] + [indices], axis=-1)) 

However, I am wondering if there is a more readable solution and do not need to create auxiliary_indices above.

+5
source share
1 answer

Your code has problems.

 b = tf.reshape(tf.range(60), (2, 5, 3, 7)) 

Since TensorFlow cannot change the tensor with 60 elements to form [2,5,3,7] (210 elements). And you cannot sort rank 4 (b) tensor using rank 3 tensor indices.

0
source

Source: https://habr.com/ru/post/1263455/


All Articles