Multithreading in tensor flow / keras

I would like to train several different models using model.fit () in parallel in the same python application. The models used do not necessarily have something in common, they run in the same application at different times.

First, I run one model.fit () without problems in a separate thread, and then in the main thread. If now I want to run the second model.fit (), I get the following error message:

Exception in thread Thread-1: tensorflow.python.framework.errors_impl.InvalidArgumentError: Node 'hidden_1/BiasAdd': Unknown input node 'hidden_1/MatMul' 

Both of them begin with a method along the same lines of code:

 start_learn(self:) tf_session = K.get_session() # this creates a new session since one doesn't exist already. tf_graph = tf.get_default_graph() keras_learn_thread.Learn(learning_data, model, self.env_cont, tf_session, tf_graph) learning_results.start() 

Th called class / method is as follows:

 def run(self): tf_session = self.tf_session # take that from __init__() tf_graph = self.tf_graph # take that from __init__() with tf_session.as_default(): with tf_graph.as_default(): self.learn(self.learning_data, self.model, self.env_cont) # now my learn method where model.fit() is located is being started 

I think I somehow need to assign a new tf_session and a new tf_graph for each individual thread. But I'm not quite sure about that. I would be happy with every short idea, as I sit on it for too long.

thanks

+5
source share
1 answer

I don’t know if you fixed your problem, but this seems like a different question. I recently answered .

  • You need to complete the creation of the chart in the main thread before starting with others.
  • In the case of keras, the graph is initialized the first time the fit or pred function is called. You can force the creation of a graph to call some of the model's internal functions:

     model._make_predict_function() model._make_test_function() model._make_train_function() 

    If this does not work, try warming up the model by calling dummy data.

  • Once you have finished creating the graph, call finalize() on your main graph so that it can be safely shared with other threads (which will make it read-only).

  • Finalizing the schedule will also help you find other places where your schedule will be inadvertently changed.

Hope this helps.

0
source

Source: https://habr.com/ru/post/1271423/


All Articles