What is the use of tf.select

( edited answer wrt @quirk)

I read several tensor code on the Internet and saw these statements:

threshold = tf.select(input > RLSA_THRESHOLD, positive, negative)

source: https://github.com/Raverss/tensorflow-RLSA-NMS/blob/master/source.py#L31

positiveIt is a tensor with only 1's, negativealso of the same size c 0, and the input is a heat map (/ tensor) of the same size (the whole type tf.float32).

The code snippet seems to me a reasonable assumption that the authors would simply use tf.cast(input > RLSA_THRESHOLD, tf.float32)if there were no specific reason for the expression tf.select(...). Moreover, this would eliminate the need for variables positiveand negativesave memory, since they are simply expensive ways of storing 0and 1.

Is the above expression tf.select(...)equivalent tf.cast(input > RLSA_THRESHOLD, tf.float32)? If not, why not?

Note: I usually use Keras, and I'm sorry if I touch on something very trivial here.

+4
source share
2 answers

Umm, RTD (read the docs)!

tf.select positive negative boolness condition.

tf.select(condition, t, e, name=None)
t e, .
t e , ​​.

( .)

, :

threshold = tf.select(input > RLSA_THRESHOLD, positive, negative)

input > RLSA_THRESHOLD bool (0 1 ), positive negative.

, , RLSA_THRESHOLD 0,5, input - 4- 0 1. positive negative [1, 1, 1, 1] [0, 0, 0, 0], . input - [0.8, 0.2, 0.5, 0.6].

threshold [1, 0, 0, 1].

. positive negative , condition. positive negative , , [2, 4, 6, 8] [1, 3, 5, 7] , threshold [2, 3, 5, 8].


, input > RLSA_THRESHOLD, tf.select.

. input > RLSA_THRESHOLD () . . - . positive / negative , , , threshold , .


tf.select input > RLSA_THRESHOLD? , ?

, . - , - .

, :

threshold input > RLSA_THRESHOLD? , ?

, . , input > RLSA_THRESHOLD bool. threshold, , , positive negative.

.. ( ) , casting, .

+2

- :

In [86]: s = tf.InteractiveSession()

In [87]: inputs = tf.random_uniform([10], 0., 1.)

In [88]: positives = tf.ones([10])

In [89]: negatives = tf.zeros([10])    

In [90]: s.run([inputs, tf.select(inputs > .5, positives, negatives)])
Out[90]: 
[array([ 0.13187623,  0.77344072,  0.29853749,  0.29245567,  0.53489852,
         0.34861541,  0.15090156,  0.40595055,  0.34910154,  0.24349082], dtype=float32),
 array([ 0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.], dtype=float32)]

> 0.5 inputs 1. , 0..

inputs > .5 (True , , False ).

In [92]: s.run(inputs > .5)
Out[92]: array([ True, False,  True,  True,  True,  True,  True,  True, False,  True], dtype=bool)
+2

Source: https://habr.com/ru/post/1665994/


All Articles