Resuming a Training Process With Keras

Resuming a Training Process With KerasHandling nontrivial cases where custom callbacks are usedEyal ZakkayBlockedUnblockFollowFollowingJan 2In this post I will present a use case of the Keras API in which resuming a training process from a loaded checkpoint needs to be handled differently than usual.TL;DR — If you are using custom callbacks which have internal variables that change during a training process, you need to address this when resuming by initializing these callbacks differently.When training a Deep Learning model using Keras, we usually save checkpoints of that model’s state so we could recover an interrupted training process and restart it from where we left off..Usually this is done with the ModelCheckpoint Callback..According to the documentation of Keras, a saved model (saved with model.save(filepath)) contains the following:The architecture of the model, allowing to re-create the modelThe weights of the modelThe training configuration (loss, optimizer)The state of the optimizer, allowing to resume training exactly where you left off.In certain use cases, this last part isn’t exactly true.Example:Let’s say you are training a model with a custom learning rate scheduler callback, which updates the LR after each batch:The counter variable is initialized to zero when the callback is created and keeps up with the global batch index (the batch argument in on_batch_end holds the batch index within the current epoch).Let’s say we want to resume a training process from a checkpoint..The usual way would be:The wrong way to do itNotice that the LearningRateSchedulerPerBatch callback is initialized with counter=0 even when resuming..When training resumes this will not recreate the same conditions that took place when the checkpoint was saved..The learning rate would restart from its initial value and that is probably not what we would want.The correct way to do itWe saw an example of how a wrong initialization of callbacks can lead to unwanted results when resuming..There are several ways to approach this, here I will describe two:Solution 1: Updating the variables with correct valuesWhen dealing with simple cases, in which the callback has only a handful of updating variables, it is rather simple to overwrite the values of these variables before resuming..In our example, if we would want to resume with the correct value of count we would do the following:Solution 2: Saving and loading callbacks with PickleWhen our custom callbacks have many updating variables or include complex behaviors, safely overwriting each variable might be difficult..An alternative solution is to pickle the callback instance every time we save a checkpoint, then we can load this pickle when resuming and reconstruct the original callback with all its correct values.Note: in order to pickle your callbacks they must not hold any non-serializable elements..Also, in Keras versions <2.2.3 the model itself was not serializable..This prevented the pickling of any callback, since each callback also holds a reference to the model.In this case, resuming will look something like this:You might have noticed that I have used a modified checkpoint callback called ModelCheckpointEnhanced..This is because using the pickling method means that we also need the ModelCheckpoint callback to save pickles of the relevant callbacks.. More details

Leave a Reply