Zheng Chu's Blog

让希望永驻


  • 主页

  • 所有专栏

  • 历史文章

  • 标签

  • 关于我

Tensorflow2-CallBack

Posted on 2020-10-01 Edited on 2021-03-17 In Tensorflow Views:

Keras callbacks overview

所有的callbacks子类都继承自 keras.callbacks.Callback 类;

可以把callbacks的lists传递给如下的接口: (参数名为 callbacks)

  • keras.Model.fit()
  • keras.Model.evaluate()
  • keras.Model.predict()

An overview of callback methods

Global methods

on_(train|test|predict)_begin(self, logs=None)

Called at the beginning of fit/evaluate/predict.

on_(train|test|predict)_end(self, logs=None)

Called at the end of fit/evaluate/predict.

Batch-level methods for training/testing/predicting

on_(train|test|predict)_batch_begin(self, batch, logs=None)

Called right before processing a batch during training/testing/predicting.

on_(train|test|predict)_batch_end(self, batch, logs=None)

Called at the end of training/testing/predicting a batch. Within this method, logs is a dict containing the metrics results.

Epoch-level methods (training only)

on_epoch_begin(self, epoch, logs=None)

Called at the beginning of an epoch during training.

on_epoch_end(self, epoch, logs=None)

Called at the end of an epoch during training.

A basic example

Let’s take a look at a concrete example. To get started, let’s import tensorflow and define a simple Sequential Keras model:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1, input_dim=784))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)

# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0

# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]

定义一个做如下工作的callback :

  • When fit/evaluate/predict starts & ends
  • When each epoch starts & ends
  • When each training batch starts & ends
  • When each evaluation (test) batch starts & ends
  • When each inference (prediction) batch starts & ends
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))

def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))

def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))

def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))

def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))

def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))

def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))

def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=1,
verbose=0,
validation_split=0.5,
callbacks=[CustomCallback()],
)

res = model.evaluate(
x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)

res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])

Usage of logs dict

logs 字典包含了损失值,全部的metrics,在每一个batch或者epoch之后。例子如下:

1
2
3
4
5
6
if default_keys:
assert min(default_keys) <= len(
args), "Not enough arguments (%s, %s, %s)" % (args, default_keys,
self.arg_names)

AssertionError: Not enough arguments ((), [2], ['true', 'pred', 'num_classes'])
# Tensorflow
Tensorflow2-ParameterServer
Tensorflow2-DistributedTraining
  • Table of Contents
  • Overview
Zheng Chu

Zheng Chu

90 posts
20 categories
25 tags
GitHub 简书 CSDN E-Mail
  1. 1. Keras callbacks overview
  2. 2. An overview of callback methods
    1. 2.1. Global methods
      1. 2.1.1. on_(train|test|predict)_begin(self, logs=None)
      2. 2.1.2. on_(train|test|predict)_end(self, logs=None)
    2. 2.2. Batch-level methods for training/testing/predicting
      1. 2.2.1. on_(train|test|predict)_batch_begin(self, batch, logs=None)
      2. 2.2.2. on_(train|test|predict)_batch_end(self, batch, logs=None)
    3. 2.3. Epoch-level methods (training only)
      1. 2.3.1. on_epoch_begin(self, epoch, logs=None)
      2. 2.3.2. on_epoch_end(self, epoch, logs=None)
  3. 3. A basic example
    1. 3.1. Usage of logs dict
© 2021 Zheng Chu
Powered by Hexo v4.2.1
|
Theme – NexT.Pisces v7.3.0
|