Keras callbacks overview
所有的callbacks子类都继承自 keras.callbacks.Callback
类;
可以把callbacks的lists传递给如下的接口: (参数名为 callbacks
)
keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()
An overview of callback methods
Global methods
on_(train|test|predict)_begin(self, logs=None)
Called at the beginning of fit
/evaluate
/predict
.
on_(train|test|predict)_end(self, logs=None)
Called at the end of fit
/evaluate
/predict
.
Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None)
Called right before processing a batch during training/testing/predicting.
on_(train|test|predict)_batch_end(self, batch, logs=None)
Called at the end of training/testing/predicting a batch. Within this method, logs
is a dict containing the metrics results.
Epoch-level methods (training only)
on_epoch_begin(self, epoch, logs=None)
Called at the beginning of an epoch during training.
on_epoch_end(self, epoch, logs=None)
Called at the end of an epoch during training.
A basic example
Let’s take a look at a concrete example. To get started, let’s import tensorflow and define a simple Sequential Keras model:
1 | def get_model(): |
定义一个做如下工作的callback
:
- When
fit
/evaluate
/predict
starts & ends - When each epoch starts & ends
- When each training batch starts & ends
- When each evaluation (test) batch starts & ends
- When each inference (prediction) batch starts & ends
1 | class CustomCallback(keras.callbacks.Callback): |
1 | model = get_model() |
Usage of logs
dict
logs
字典包含了损失值,全部的metrics,在每一个batch或者epoch之后。例子如下:
1 | if default_keys: |