#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
![]() |
![]() |
![]() |
![]() |
Continually saving the "best" model or model weights/parameters has many benefits. These include being able to track the training progress and load saved models from different saved states.
In TensorFlow 1, to configure checkpoint saving during training/validation with the tf.estimator.Estimator
APIs, you specify a schedule in tf.estimator.RunConfig
or use tf.estimator.CheckpointSaverHook
. This guide demonstrates how to migrate from this workflow to TensorFlow 2 Keras APIs.
In TensorFlow 2, you can configure tf.keras.callbacks.ModelCheckpoint
in a number of ways:
save_best_only=True
parameter, where monitor
can be, for example, 'loss'
, 'val_loss'
, 'accuracy', or
'val_accuracy'`.save_freq
argument).save_weights_only
to True
.For more details, refer to the tf.keras.callbacks.ModelCheckpoint
API docs and the Save checkpoints during training section in the Save and load models tutorial. Learn more about the Checkpoint format in the TF Checkpoint format section in the Save and load Keras models guide. In addition, to add fault tolerance, you can use tf.keras.callbacks.BackupAndRestore
or tf.train.Checkpoint
for manual checkpointing. Learn more in the Fault tolerance migration guide.
Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit
/Model.evaluate
/Model.predict
APIs. Learn more in the Next steps section at the end of the guide.
Start with imports and a simple dataset for demonstration purposes:
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
This TensorFlow 1 example shows how to configure tf.estimator.RunConfig
to save checkpoints at every step during training/evaluation with the tf.estimator.Estimator
APIs:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
%ls {classifier.model_dir}
In TensorFlow 2, when you use the built-in Keras Model.fit
(or Model.evaluate
) for training/evaluation, you can configure tf.keras.callbacks.ModelCheckpoint
and then pass it to the callbacks
parameter of Model.fit
(or Model.evaluate
). (Learn more in the API docs and the Using callbacks section in the Training and evaluation with the built-in methods guide.)
In the example below, you will use a tf.keras.callbacks.ModelCheckpoint
callback to store checkpoints in a temporary directory:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
%ls {model_checkpoint_callback.filepath}
Learn more about checkpointing in:
tf.keras.callbacks.ModelCheckpoint
Learn more about callbacks in:
tf.keras.callbacks.Callback
You may also find the following migration-related resources useful:
tf.keras.callbacks.BackupAndRestore
for Model.fit
, or tf.train.Checkpoint
and tf.train.CheckpointManager
APIs for a custom training looptf.keras.callbacks.EarlyStopping
is a built-in early stopping callback