How to invalidate all entries except argmax?

Assuming I have a matrix / array / list, such as a=[1,2,3,4,5], and I want to collapse all entries except max, so this will be a=[0,0,0,0,5].

I use b = [val if idx == np.argmax(a) else 0 for idx,val in enumerate(a)], but there is a better (and faster) way (especially for more than 1-dimensional arrays ...)

+4
source share
2 answers

You can use numpyto solve on the spot. Note that the method below will make all matches for the maximum value equal to 0.

import numpy as np

a = np.array([1,2,3,4,5])

a[np.where(a != a.max())] = 0

# array([0, 0, 0, 0, 5])

For unique highs see @ cᴏʟᴅsᴘᴇᴇᴅ solution .

+4
source

Instead of masking, can you create an array of zeros and set the right index correctly?

1-D (optimized) solution

() a 1D-: a = np.array([1,2,3,4,5]).

  • max

    b = np.zeros_like(a)
    i = np.argmax(a)
    b[i] = a[i]
    
  • max

    b = np.zeros_like(a)
    m = a == a.max()
    b[m] = a[m]
    

N-D-

np.random.seed(0)
a = np.random.randn(5, 5)

b = np.zeros_like(a)
m = a == a.max(1, keepdims=True)
b[m] = a[m]

b
array([[0.        , 0.        , 0.        , 2.2408932 , 0.        ],
       [0.        , 0.95008842, 0.        , 0.        , 0.        ],
       [0.        , 1.45427351, 0.        , 0.        , 0.        ],
       [0.        , 1.49407907, 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 2.26975462]])

max .

+3

Source: https://habr.com/ru/post/1695167/


All Articles