#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
![]() |
![]() |
![]() |
![]() |
This guide demonstrates how to migrate from TensorFlow 1's tf.estimator.Estimator
APIs to TensorFlow 2's tf.keras
APIs. First, you will set up and run a basic model for training and evaluation with tf.estimator.Estimator
. Then, you will perform the equivalent steps in TensorFlow 2 with the tf.keras
APIs. You will also learn how to customize the training step by subclassing tf.keras.Model
and using tf.GradientTape
.
tf.estimator.Estimator
APIs let you train and evaluate a model, as well as perform inference and save your model (for serving).(For migrating model/checkpoint saving workflows to TensorFlow 2, check out the SavedModel and Checkpoint migration guides.)
Start with imports and a simple dataset:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]
This example shows how to perform training and evaluation with tf.estimator.Estimator
in TensorFlow 1.
Start by defining a few functions: an input function for the training data, an evaluation input function for the evaluation data, and a model function that tells the Estimator
how the training op is defined with the features and labels:
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
def _model_fn(features, labels, mode):
logits = tf1.layers.Dense(1)(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
Instantiate your Estimator
, and train the model:
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
estimator.train(_input_fn)
Evaluate the program with the evaluation set:
estimator.evaluate(_eval_input_fn)
This example demonstrates how to perform training and evaluation with Keras Model.fit
and Model.evaluate
in TensorFlow 2. (You can learn more in the Training and evaluation with the built-in methods guide.)
tf.data.Dataset
APIs.tf.keras.layers.Dense
) layer.tf.keras.optimizers.Adagrad
).optimizer
variable and the mean-squared error ("mse"
) loss to Model.compile
.dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer=optimizer, loss="mse")
With that, you are ready to train the model by calling Model.fit
:
model.fit(dataset)
Finally, evaluate the model with Model.evaluate
:
model.evaluate(eval_dataset, return_dict=True)
In TensorFlow 2, you can also write your own custom training step function with tf.GradientTape
to perform forward and backward passes, while still taking advantage of the built-in training support, such as tf.keras.callbacks.Callback
and tf.distribute.Strategy
. (Learn more in Customizing what happens in Model.fit and Writing custom training loops from scratch.)
In this example, start by creating a custom tf.keras.Model
by subclassing tf.keras.Sequential
that overrides Model.train_step
. (Learn more about subclassing tf.keras.Model). Inside that class, define a custom train_step
function that for each batch of data performs a forward pass and backward pass during one training step.
class CustomModel(tf.keras.Sequential):
"""A custom sequential model that overrides `Model.train_step`."""
def train_step(self, data):
batch_data, labels = data
with tf.GradientTape() as tape:
predictions = self(batch_data, training=True)
# Compute the loss value (the loss function is configured
# in `Model.compile`).
loss = self.compiled_loss(labels, predictions)
# Compute the gradients of the parameters with respect to the loss.
gradients = tape.gradient(loss, self.trainable_variables)
# Perform gradient descent by updating the weights/parameters.
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update the metrics (includes the metric that tracks the loss).
self.compiled_metrics.update_state(labels, predictions)
# Return a dict mapping metric names to the current values.
return {m.name: m.result() for m in self.metrics}
Next, as before:
tf.data.Dataset
.tf.keras.layers.Dense
layer.tf.keras.optimizers.Adagrad
)Model.compile
, while using mean-squared error ("mse"
) as the loss function.dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
model = CustomModel([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer=optimizer, loss="mse")
Call Model.fit
to train the model:
model.fit(dataset)
And, finally, evaluate the program with Model.evaluate
:
model.evaluate(eval_dataset, return_dict=True)
Additional Keras resources you may find useful:
The following guides can assist with migrating distribution strategy workflows from tf.estimator
APIs: