I quickly wrote my own callback that tracks the best workout values so I can reload them. It looks like this:
class ModelState(callbacks.Callback):
def __init__(self, state_path):
self.state_path = state_path
if os.path.isfile(state_path):
print('Loading existing .json state')
with open(state_path, 'r') as f:
self.state = json.load(f)
else:
self.state = { 'epoch_count': 0,
'best_values': {},
'best_epoch': {}
}
def on_train_begin(self, logs={}):
print('Training commences...')
def on_epoch_end(self, batch, logs={}):
for k in logs:
if k not in self.state['best_values'] or logs[k] < self.state['best_values'][k]:
self.state['best_values'][k] = float(logs[k])
self.state['best_epoch'][k] = self.state['epoch_count']
with open(self.state_path, 'w') as f:
json.dump(self.state, f, indent=4)
print('Completed epoch', self.state['epoch_count'])
self.state['epoch_count'] += 1
Then, in the fit () function, something like this:
model_state = ModelState(path_to_state_file)
model_checkpoint = callbacks.ModelCheckpoint(path_to_model_file,
monitor='val_loss',
save_best_only=True,
verbose=1,
mode='min',
save_weights_only=False)
if 'best_values' in model_state.state:
model_checkpoint.best = model_state.state['best_values']['val_loss']
callback_list = [model_checkpoint,
model_state]
initial_epoch = model_state.state['epoch_count']
epochs += initial_epoch
source
share