The get_new_centers() procedure below uses tagged attachments and updates to the common variables center/sums and center/cts . These variables are then used to calculate and return deployment centers using updated values.
The loop simply executes get_new_centers() and shows that over time it converges to the expected average investment for all classes.
Please note that the term alpha used in the original article is not included here, but should be easily added if necessary.
ndims = 2 nclass = 4 nbatch = 100 with tf.variable_scope('center'): center_sums = tf.get_variable("sums", [nclass, ndims], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) center_cts = tf.get_variable("cts", [nclass], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) def get_new_centers(embeddings, indices): ''' Update embedding for selected class indices and return the new average embeddings. Only the newly-updated average embeddings are returned corresponding to the indices (including duplicates). ''' with tf.variable_scope('center', reuse=True): center_sums = tf.get_variable("sums") center_cts = tf.get_variable("cts") # update embedding sums, cts if embeddings is not None: ones = tf.ones_like(indices, tf.float32) center_sums = tf.scatter_add(center_sums, indices, embeddings, name='sa1') center_cts = tf.scatter_add(center_cts, indices, ones, name='sa2') # return updated centers num = tf.gather(center_sums, indices) denom = tf.reshape(tf.gather(center_cts, indices), [-1, 1]) return tf.div(num, denom) with tf.Session() as sess: labels_ph = tf.placeholder(tf.int32) embeddings_ph = tf.placeholder(tf.float32) unq_labels, ul_idxs = tf.unique(labels_ph) indices = tf.gather(unq_labels, ul_idxs) new_centers_with_update = get_new_centers(embeddings_ph, indices) new_centers = get_new_centers(None, indices) sess.run(tf.initialize_all_variables()) tf.get_default_graph().finalize() for i in range(100001): embeddings = 100*np.random.randn(nbatch, ndims) labels = np.random.randint(0, nclass, nbatch) feed_dict = {embeddings_ph:embeddings, labels_ph:labels} rval = sess.run([new_centers_with_update], feed_dict) if i % 1000 == 0: feed_dict = {labels_ph:range(nclass)} rval = sess.run(new_centers, feed_dict) print('\nFor step ', i) for iclass in range(nclass): print('Class %d, center: %s' % (iclass, str(rval[iclass])))
Typical result in step 0:
For step 0 Class 0, center: [-1.7618252 -0.30574229] Class 1, center: [ -4.50493908 10.12403965] Class 2, center: [ 3.6156714 -9.94263649] Class 3, center: [-4.20281982 -8.28845882]
and the output at step 10000 shows convergence:
For step 10000 Class 0, center: [ 0.00313433 -0.00757505] Class 1, center: [-0.03476512 0.04682625] Class 2, center: [-0.03865958 0.06585111] Class 3, center: [-0.02502561 -0.03370816]