Pretty accurate how you do it using numpy, for example:
tensor[tensor!=0] = 0
To replace zeros and not zeros, you can just chain them together. Just remember to use a copy of the tensor, as they change:
def custom_replace(tensor, on_zero, on_non_zero):
res = tensor.clone()
res[tensor==0] = on_zero
res[tensor!=0] = on_non_zero
return res
And use it like this:
>>>z
(0 ,.,.) =
0 1
1 3
(1 ,.,.) =
0 1
1 0
[torch.LongTensor of size 2x2x2]
>>>out = custom_replace(z, on_zero=5, on_non_zero=0)
>>>out
(0 ,.,.) =
5 0
0 0
(1 ,.,.) =
5 0
0 5
[torch.LongTensor of size 2x2x2]
source
share