make_template
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/kernel_tests/template_test.py; , :
training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
tf.set_random_seed(1234)
def test_line(x):
m = tf.get_variable("w", shape=[],
initializer=tf.truncated_normal_initializer())
b = tf.get_variable("b", shape=[],
initializer=tf.truncated_normal_initializer())
return x * m + b
line_template = template.make_template("line", test_line)
train_prediction = line_template(training_input)
test_prediction = line_template(test_input)
train_loss = tf.reduce_mean(tf.square(train_prediction - training_output))
test_loss = tf.reduce_mean(tf.square(test_prediction - test_output))
optimizer = tf.train.GradientDescentOptimizer(0.1)
train_op = optimizer.minimize(train_loss)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
initial_test_loss = sess.run(test_loss)
sess.run(train_op)
final_test_loss = sess.run(test_loss)
self.assertLess(final_test_loss, initial_test_loss)