CheckpointManager

简单的记录下,如何怎么结合CheckpointManager和Callback ,实现按一定周期保存最近N个模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
N = 5
# 构建模型
class Model(tf.keras.Model):
def __init__(self, **kwargs):
super(Model, self).__init__(**kwargs)
self.d = tf.keras.layers.Dense(1, kernel_initializer=tf.keras.initializers.ones())

@tf.function
def call(self, x, training=True, mask=None):
return self.d(x)


# 定义回调函数
class Save_Callbacks(tf.keras.callbacks.Callback):
def __init__(self, checkpoint_manager):
self.checkpoint_manager = checkpoint_manager

def on_train_batch_end(self, batch, logs=None):
super().on_train_batch_end(batch, logs)
self.checkpoint_manager.save()


model = Model()
model.compile(loss=tf.keras.losses.binary_crossentropy,
optimizer='SGD')
checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, 'save', max_to_keep=N)

model.fit(x=tf.ones((100, 3)), y=tf.constant(tf.ones((100, 1))), batch_size=2,
callbacks=[Save_Callbacks(checkpoint_manager)])
model.reset_metrics()

# 从最近保存的ckpt中,恢复模型
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore(tf.train.latest_checkpoint('save'))
print(model(tf.ones((2, 3))))