This issue has been resolved in release # 14451 . Just post anwser here to make it more visible to other developers.
The code example is an oversample of low frequent classes and a low sample of frequent ones, where class_target_prob
- this is just an even distribution in my case. I would like to check some conclusions from a recent manuscript . A systematic study of the problem of class imbalance in convolutional neural networks
Oversampling of certain classes is done by calling:
dataset = dataset.flat_map(
lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)
Here is the complete snippet that does everything:
oversampling_coef = 0.9
undersampling_coef = 0.5
def oversample_classes(example):
"""
Returns the number of copies of given example
"""
class_prob = example['class_prob']
class_target_prob = example['class_target_prob']
prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
prob_ratio = prob_ratio ** oversampling_coef
prob_ratio = tf.maximum(prob_ratio, 1)
repeat_count = tf.floor(prob_ratio)
repeat_residual = prob_ratio - repeat_count
residual_acceptance = tf.less_equal(
tf.random_uniform([], dtype=tf.float32), repeat_residual
)
residual_acceptance = tf.cast(residual_acceptance, tf.int64)
repeat_count = tf.cast(repeat_count, dtype=tf.int64)
return repeat_count + residual_acceptance
def undersampling_filter(example):
"""
Computes if given example is rejected or not.
"""
class_prob = example['class_prob']
class_target_prob = example['class_target_prob']
prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
prob_ratio = prob_ratio ** undersampling_coef
prob_ratio = tf.minimum(prob_ratio, 1.0)
acceptance = tf.less_equal(tf.random_uniform([], dtype=tf.float32), prob_ratio)
return acceptance
dataset = dataset.flat_map(
lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)
dataset = dataset.filter(undersampling_filter)
dataset = dataset.repeat(-1)
dataset = dataset.shuffle(2048)
dataset = dataset.batch(32)
sess.run(tf.global_variables_initializer())
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
Update # 1
Here is a simple jupyter laptop that implements the above oversampling / de-sampling on a toy model.
source
share