Alternative alternative.
Many of the other solutions use cropping to avoid an undefined gradient. Depending on your problem, clipping introduces bias and may be unacceptable in all cases. As the following code shows, we only need to handle the break point - not the area next to it.
Specific answer
def cross_entropy(x, y, axis=-1): safe_y = tf.where(tf.equal(x, 0.), tf.ones_like(y), y) return -tf.reduce_sum(x * tf.log(safe_y), axis) def entropy(x, axis=-1): return cross_entropy(x, x, axis)
But did it work?
x = tf.constant([0.1, 0.2, 0., 0.7]) e = entropy(x) # ==> 0.80181855 g = tf.gradients(e, x)[0] # ==> array([1.30258512, 0.60943794, 0., -0.64332503], dtype=float32) Yay! No NaN.
(Note: dup cross-post has been removed.)
General recipe
Use internal tf.where so that the function does not have asymptotes. That is, change the input to the inf generation function so that no inf can be created. Then use the second tf.where to always select a valid path code. That is, implement the mathematical condition in the way you are "usually", that is, a "naive" implementation.
In Python code, the recipe is:
Instead of this:
tf.where(x_ok, f(x), safe_f(x))
Do it:
safe_x = tf.where(x_ok, x, safe_x) tf.where(x_ok, f(safe_x), safe_f(x))
Example
Suppose you want to calculate:
f(x) = { 1/x, x!=0 { 0, x=0
A naive implementation results in NaNs in the gradient, i.e.
def f(x): x_ok = tf.not_equal(x, 0.) f = lambda x: 1. / x safe_f = tf.zeros_like return tf.where(x_ok, f(x), safe_f(x))
Does he work?
x = tf.constant([-1., 0, 1]) tf.gradients(f(x), x)[0].eval()
The basic pattern for avoiding NaN gradients when using tf.where is to call tf.where twice. The innermost tf.where ensures that the result of f(x) always finite. The most external tf.where provides a choice of the correct result. For an example implementation, the trick is as follows:
def safe_f(x): x_ok = tf.not_equal(x, 0.) f = lambda x: 1. / x safe_f = tf.zeros_like safe_x = tf.where(x_ok, x, tf.ones_like(x)) return tf.where(x_ok, f(safe_x), safe_f(x))
But did it work?
x = tf.constant([-1., 0, 1]) tf.gradients(safe_f(x), x)[0].eval()