#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This tutorial provides an example of how to use tf.data.TextLineDataset
to load examples from text files. TextLineDataset
is designed to create a dataset from a text file, in which each example is a line of text from the original file. This is potentially useful for any text data that is primarily line-based (for example, poetry or error logs).
In this tutorial, we'll use three different English translations of the same work, Homer's Illiad, and train a model to identify the translator given a single line of text.
from __future__ import absolute_import, division, print_function, unicode_literals
try:
# %tensorflow_version only exists in Colab.
!pip install tf-nightly
except Exception:
pass
import tensorflow as tf
import tensorflow_datasets as tfds
import os
The texts of the three translations are by:
The text files used in this tutorial have undergone some typical preprocessing tasks, mostly removing stuff — document header and footer, line numbers, chapter titles. Download these lightly munged files locally.
DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']
for name in FILE_NAMES:
text_dir = tf.keras.utils.get_file(name, origin=DIRECTORY_URL+name)
parent_dir = os.path.dirname(text_dir)
parent_dir
Iterate through the files, loading each one into its own dataset.
Each example needs to be individually labeled, so use tf.data.Dataset.map
to apply a labeler function to each one. This will iterate over every example in the dataset, returning (example, label
) pairs.
def labeler(example, index):
return example, tf.cast(index, tf.int64)
labeled_data_sets = []
for i, file_name in enumerate(FILE_NAMES):
lines_dataset = tf.data.TextLineDataset(os.path.join(parent_dir, file_name))
labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))
labeled_data_sets.append(labeled_dataset)
Combine these labeled datasets into a single dataset, and shuffle it.
BUFFER_SIZE = 50000
BATCH_SIZE = 64
TAKE_SIZE = 5000
all_labeled_data = labeled_data_sets[0]
for labeled_dataset in labeled_data_sets[1:]:
all_labeled_data = all_labeled_data.concatenate(labeled_dataset)
all_labeled_data = all_labeled_data.shuffle(
BUFFER_SIZE, reshuffle_each_iteration=False)
You can use tf.data.Dataset.take
and print
to see what the (example, label)
pairs look like. The numpy
property shows each Tensor's value.
for ex in all_labeled_data.take(5):
print(ex)
Machine learning models work on numbers, not words, so the string values need to be converted into lists of numbers. To do that, map each unique word to a unique integer.
First, build a vocabulary by tokenizing the text into a collection of individual unique words. There are a few ways to do this in both TensorFlow and Python. For this tutorial:
numpy
value.tfds.features.text.Tokenizer
to split it into tokens.tokenizer = tfds.features.text.Tokenizer()
vocabulary_set = set()
for text_tensor, _ in all_labeled_data:
some_tokens = tokenizer.tokenize(text_tensor.numpy())
vocabulary_set.update(some_tokens)
vocab_size = len(vocabulary_set)
vocab_size
Create an encoder by passing the vocabulary_set
to tfds.features.text.TokenTextEncoder
. The encoder's encode
method takes in a string of text and returns a list of integers.
encoder = tfds.features.text.TokenTextEncoder(vocabulary_set)
You can try this on a single line to see what the output looks like.
example_text = next(iter(all_labeled_data))[0].numpy()
print(example_text)
encoded_example = encoder.encode(example_text)
print(encoded_example)
Now run the encoder on the dataset by wrapping it in tf.py_function
and passing that to the dataset's map
method.
def encode(text_tensor, label):
encoded_text = encoder.encode(text_tensor.numpy())
return encoded_text, label
You want to use Dataset.map
to apply this function to each element of the dataset. Dataset.map
runs in graph mode.
So you can't .map
this function directly: You need to wrap it in a tf.py_function
. The tf.py_function
will pass regular tensors (with a value and a .numpy()
method to access it), to the wrapped python function.
def encode_map_fn(text, label):
# py_func doesn't set the shape of the returned tensors.
encoded_text, label = tf.py_function(encode,
inp=[text, label],
Tout=(tf.int64, tf.int64))
# `tf.data.Datasets` work best if all components have a shape set
# so set the shapes manually:
encoded_text.set_shape([None])
label.set_shape([])
return encoded_text, label
all_encoded_data = all_labeled_data.map(encode_map_fn)
Use tf.data.Dataset.take
and tf.data.Dataset.skip
to create a small test dataset and a larger training set.
Before being passed into the model, the datasets need to be batched. Typically, the examples inside of a batch need to be the same size and shape. But, the examples in these datasets are not all the same size — each line of text had a different number of words. So use tf.data.Dataset.padded_batch
(instead of batch
) to pad the examples to the same size.
train_data = all_encoded_data.skip(TAKE_SIZE).shuffle(BUFFER_SIZE)
train_data = train_data.padded_batch(BATCH_SIZE)
test_data = all_encoded_data.take(TAKE_SIZE)
test_data = test_data.padded_batch(BATCH_SIZE)
Now, test_data
and train_data
are not collections of (example, label
) pairs, but collections of batches. Each batch is a pair of (many examples, many labels) represented as arrays.
To illustrate:
sample_text, sample_labels = next(iter(test_data))
sample_text[0], sample_labels[0]
Since we have introduced a new token encoding (the zero used for padding), the vocabulary size has increased by one.
vocab_size += 1
model = tf.keras.Sequential()
The first layer converts integer representations to dense vector embeddings. See the word embeddings tutorial or more details.
model.add(tf.keras.layers.Embedding(vocab_size, 64))
The next layer is a Long Short-Term Memory layer, which lets the model understand words in their context with other words. A bidirectional wrapper on the LSTM helps it to learn about the datapoints in relationship to the datapoints that came before it and after it.
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)))
Finally we'll have a series of one or more densely connected layers, with the last one being the output layer. The output layer produces a probability for all the labels. The one with the highest probability is the models prediction of an example's label.
# One or more dense layers.
# Edit the list in the `for` line to experiment with layer sizes.
for units in [64, 64]:
model.add(tf.keras.layers.Dense(units, activation='relu'))
# Output layer. The first argument is the number of labels.
model.add(tf.keras.layers.Dense(3, activation='softmax'))
Finally, compile the model. For a softmax categorization model, use sparse_categorical_crossentropy
as the loss function. You can try other optimizers, but adam
is very common.
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
This model running on this data produces decent results (about 83%).
model.fit(train_data, epochs=3, validation_data=test_data)
eval_loss, eval_acc = model.evaluate(test_data)
print('\nEval loss: {:.3f}, Eval accuracy: {:.3f}'.format(eval_loss, eval_acc))