#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This tutorial provides an example of how to load CSV data from a file into a tf.data.Dataset
.
The data used in this tutorial are taken from the Titanic passenger list. The model will predict the likelihood a passenger survived based on characteristics like age, gender, ticket class, and whether the person was traveling alone.
import functools
import numpy as np
import tensorflow as tf
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
TEST_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/eval.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
test_file_path = tf.keras.utils.get_file("eval.csv", TEST_DATA_URL)
# Make numpy values easier to read.
np.set_printoptions(precision=3, suppress=True)
To start, let's look at the top of the CSV file to see how it is formatted.
!head {train_file_path}
You can load this using pandas, and pass the NumPy arrays to TensorFlow. If you need to scale up to a large set of files, or need a loader that integrates with TensorFlow and tf.data then use the tf.data.experimental.make_csv_dataset
function:
The only column you need to identify explicitly is the one with the value that the model is intended to predict.
LABEL_COLUMN = 'survived'
LABELS = [0, 1]
Now read the CSV data from the file and create a dataset.
(For the full documentation, see tf.data.experimental.make_csv_dataset
)
def get_dataset(file_path, **kwargs):
dataset = tf.data.experimental.make_csv_dataset(
file_path,
batch_size=5, # Artificially small to make examples easier to show.
label_name=LABEL_COLUMN,
na_value="?",
num_epochs=1,
ignore_errors=True,
**kwargs)
return dataset
raw_train_data = get_dataset(train_file_path)
raw_test_data = get_dataset(test_file_path)
def show_batch(dataset):
for batch, label in dataset.take(1):
for key, value in batch.items():
print("{:20s}: {}".format(key,value.numpy()))
Each item in the dataset is a batch, represented as a tuple of (many examples, many labels). The data from the examples is organized in column-based tensors (rather than row-based tensors), each with as many elements as the batch size (5 in this case).
It might help to see this yourself.
show_batch(raw_train_data)
As you can see, the columns in the CSV are named. The dataset constructor will pick these names up automatically. If the file you are working with does not contain the column names in the first line, pass them in a list of strings to the column_names
argument in the make_csv_dataset
function.
CSV_COLUMNS = ['survived', 'sex', 'age', 'n_siblings_spouses', 'parch', 'fare', 'class', 'deck', 'embark_town', 'alone']
temp_dataset = get_dataset(train_file_path, column_names=CSV_COLUMNS)
show_batch(temp_dataset)
This example is going to use all the available columns. If you need to omit some columns from the dataset, create a list of just the columns you plan to use, and pass it into the (optional) select_columns
argument of the constructor.
SELECT_COLUMNS = ['survived', 'age', 'n_siblings_spouses', 'class', 'deck', 'alone']
temp_dataset = get_dataset(train_file_path, select_columns=SELECT_COLUMNS)
show_batch(temp_dataset)
A CSV file can contain a variety of data types. Typically you want to convert from those mixed types to a fixed length vector before feeding the data into your model.
TensorFlow has a built-in system for describing common input conversions: tf.feature_column
, see this tutorial for details.
You can preprocess your data using any tool you like (like nltk or sklearn), and just pass the processed output to TensorFlow.
The primary advantage of doing the preprocessing inside your model is that when you export the model it includes the preprocessing. This way you can pass the raw data directly to your model.
If your data is already in an appropriate numeric format, you can pack the data into a vector before passing it off to the model:
SELECT_COLUMNS = ['survived', 'age', 'n_siblings_spouses', 'parch', 'fare']
DEFAULTS = [0, 0.0, 0.0, 0.0, 0.0]
temp_dataset = get_dataset(train_file_path,
select_columns=SELECT_COLUMNS,
column_defaults = DEFAULTS)
show_batch(temp_dataset)
example_batch, labels_batch = next(iter(temp_dataset))
Here's a simple function that will pack together all the columns:
def pack(features, label):
return tf.stack(list(features.values()), axis=-1), label
Apply this to each element of the dataset:
packed_dataset = temp_dataset.map(pack)
for features, labels in packed_dataset.take(1):
print(features.numpy())
print()
print(labels.numpy())
If you have mixed datatypes you may want to separate out these simple-numeric fields. The tf.feature_column
api can handle them, but this incurs some overhead and should be avoided unless really necessary. Switch back to the mixed dataset:
show_batch(raw_train_data)
example_batch, labels_batch = next(iter(temp_dataset))
So define a more general preprocessor that selects a list of numeric features and packs them into a single column:
class PackNumericFeatures(object):
def __init__(self, names):
self.names = names
def __call__(self, features, labels):
numeric_features = [features.pop(name) for name in self.names]
numeric_features = [tf.cast(feat, tf.float32) for feat in numeric_features]
numeric_features = tf.stack(numeric_features, axis=-1)
features['numeric'] = numeric_features
return features, labels
NUMERIC_FEATURES = ['age','n_siblings_spouses','parch', 'fare']
packed_train_data = raw_train_data.map(
PackNumericFeatures(NUMERIC_FEATURES))
packed_test_data = raw_test_data.map(
PackNumericFeatures(NUMERIC_FEATURES))
show_batch(packed_train_data)
example_batch, labels_batch = next(iter(packed_train_data))
Continuous data should always be normalized.
import pandas as pd
desc = pd.read_csv(train_file_path)[NUMERIC_FEATURES].describe()
desc
MEAN = np.array(desc.T['mean'])
STD = np.array(desc.T['std'])
def normalize_numeric_data(data, mean, std):
# Center the data
return (data-mean)/std
Now create a numeric column. The tf.feature_columns.numeric_column
API accepts a normalizer_fn
argument, which will be run on each batch.
Bind the MEAN
and STD
to the normalizer fn using functools.partial
.
# See what you just created.
normalizer = functools.partial(normalize_numeric_data, mean=MEAN, std=STD)
numeric_column = tf.feature_column.numeric_column('numeric', normalizer_fn=normalizer, shape=[len(NUMERIC_FEATURES)])
numeric_columns = [numeric_column]
numeric_column
When you train the model, include this feature column to select and center this block of numeric data:
example_batch['numeric']
numeric_layer = tf.keras.layers.DenseFeatures(numeric_columns)
numeric_layer(example_batch).numpy()
The mean based normalization used here requires knowing the means of each column ahead of time.
Some of the columns in the CSV data are categorical columns. That is, the content should be one of a limited set of options.
Use the tf.feature_column
API to create a collection with a tf.feature_column.indicator_column
for each categorical column.
CATEGORIES = {
'sex': ['male', 'female'],
'class' : ['First', 'Second', 'Third'],
'deck' : ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],
'embark_town' : ['Cherbourg', 'Southhampton', 'Queenstown'],
'alone' : ['y', 'n']
}
categorical_columns = []
for feature, vocab in CATEGORIES.items():
cat_col = tf.feature_column.categorical_column_with_vocabulary_list(
key=feature, vocabulary_list=vocab)
categorical_columns.append(tf.feature_column.indicator_column(cat_col))
# See what you just created.
categorical_columns
categorical_layer = tf.keras.layers.DenseFeatures(categorical_columns)
print(categorical_layer(example_batch).numpy()[0])
This will be become part of a data processing input later when you build the model.
Add the two feature column collections and pass them to a tf.keras.layers.DenseFeatures
to create an input layer that will extract and preprocess both input types:
preprocessing_layer = tf.keras.layers.DenseFeatures(categorical_columns+numeric_columns)
print(preprocessing_layer(example_batch).numpy()[0])
Build a tf.keras.Sequential
, starting with the preprocessing_layer
.
model = tf.keras.Sequential([
preprocessing_layer,
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1),
])
model.compile(
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer='adam',
metrics=['accuracy'])
Now the model can be instantiated and trained.
train_data = packed_train_data.shuffle(500)
test_data = packed_test_data
model.fit(train_data, epochs=20)
Once the model is trained, you can check its accuracy on the test_data
set.
test_loss, test_accuracy = model.evaluate(test_data)
print('\n\nTest Loss {}, Test Accuracy {}'.format(test_loss, test_accuracy))
Use tf.keras.Model.predict
to infer labels on a batch or a dataset of batches.
predictions = model.predict(test_data)
# Show some results
for prediction, survived in zip(predictions[:10], list(test_data)[0][1][:10]):
prediction = tf.sigmoid(prediction).numpy()
print("Predicted survival: {:.2%}".format(prediction[0]),
" | Actual outcome: ",
("SURVIVED" if bool(survived) else "DIED"))