Multidimensional arsmax tensor

Say I have a tensor of size BxWxHxD. I want to process the tensor in such a way that I have a new tensor BxWxHxD, where only the maximum element in each fragment of WxH is stored, and all other values ​​are zero.

In other words, I believe that the best way to achieve this is to somehow take 2D argmax through WxH slices, as a result of which the tensors of the BxD index for rows and columns can then be converted into a one-time tensor BxWxHxD to use as a mask. How to do it?
+5
source share
1 answer

As a starting point, you can use the following function. It calculates the indices of the maximum element for each batch and for each channel. The resulting array is in the format (batch size, 2, number of channels).

def argmax_2d(tensor): # input format: BxHxWxD assert rank(tensor) == 4 # flatten the Tensor along the height and width axes flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3])) # argmax of the flat tensor argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32) # convert indexes into 2D coordinates argmax_x = argmax // tf.shape(tensor)[2] argmax_y = argmax % tf.shape(tensor)[2] # stack and return 2D coordinates return tf.stack((argmax_x, argmax_y), axis=1) def rank(tensor): # return the rank of a Tensor return len(tensor.get_shape()) 
+1
source

Source: https://habr.com/ru/post/1246347/


All Articles