1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| N = 5
class Model(tf.keras.Model): def __init__(self, **kwargs): super(Model, self).__init__(**kwargs) self.d = tf.keras.layers.Dense(1, kernel_initializer=tf.keras.initializers.ones())
@tf.function def call(self, x, training=True, mask=None): return self.d(x)
class Save_Callbacks(tf.keras.callbacks.Callback): def __init__(self, checkpoint_manager): self.checkpoint_manager = checkpoint_manager
def on_train_batch_end(self, batch, logs=None): super().on_train_batch_end(batch, logs) self.checkpoint_manager.save()
model = Model() model.compile(loss=tf.keras.losses.binary_crossentropy, optimizer='SGD') checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer) checkpoint_manager = tf.train.CheckpointManager(checkpoint, 'save', max_to_keep=N)
model.fit(x=tf.ones((100, 3)), y=tf.constant(tf.ones((100, 1))), batch_size=2, callbacks=[Save_Callbacks(checkpoint_manager)]) model.reset_metrics()
ckpt = tf.train.Checkpoint(model=model) ckpt.restore(tf.train.latest_checkpoint('save')) print(model(tf.ones((2, 3))))
|