Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eicther express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Graph Neural Networks (GNNs) are a powerful tool for deep learning on relational data. This tutorial introduces the two main tools required to train GNNs at scale:
This tutorial is intended for ML practitioners with a basic idea of GNNs.
!pip install -q tensorflow-gnn || echo "Ignoring package errors..."
import functools
import itertools
import os
import re
from typing import Mapping
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn import runner
from tensorflow_gnn.experimental import sampler
from tensorflow_gnn.models import mt_albis
tf.get_logger().setLevel('ERROR')
print(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
NUM_TRAINING_SAMPLES = 629571
NUM_VALIDATION_SAMPLES = 64879
Running TF-GNN 0.6.0 under TensorFlow 2.12.0.
OGBN-MAG is Open Graph Benchmark's Node classification task on a subset of the Microsoft Academic Graph.
The OGBN-MAG dataset is one big heterogeneous graph. The graph has four sets (or types) of nodes.
The graph has four sets (or types) of directed edges, with no associated features on any of them.
The task is to predict the venue (journal or conference) at which each of the papers has been published. There are 349 distinct venues, not represented in the graph itself. The benchmark metric is the accuracy of the predicted venue.
Results for this benchmark confirm that the graph structure provides a lot of relevant but "latent" information. Baseline models that only use the one explicit input feature (the word2vec embedding of a paper's title and abstract) perform less well.
OGBN-MAG defines a split of node set "papers" into train, validation and test nodes, based on its "year" feature:
year<=2017
,year==2018
, andyear==2019
.However, under OGB rules, training may happen on the full graph, just restricted to predictions on the "train" nodes. We follow that for consistency in benchmarking. However, users working on their own datasets may wish to validate and test with a more realistic separation between training data from the past and evaluation data representative of future inputs for prediction.
OGBN-MAG asks to classify each of the "paper" nodes. The number of nodes is on the order of a million, and we intuit that the most informative other nodes are found just a few hops away (cited papers, papers with overlapping authors, etc.).
Therefore, and to stay scalable for even bigger datasets, we approach this task with graph sampling: Each "paper" node becomes one training example, expressed by a subgraph that has the node to be classified as its root and stores a sample of its neighborhood in the original graph. The sample is taken by going out a fixed number of steps along specific edge sets, and randomly downsampling the edges in each step if they are too numerous.
The actual TensorFlow model runs on batches of these sampled subgraphs, applies a Graph Neural Network to propagate information from related nodes towards the root node of each batch, and then applies a softmax classifier to predict one of 349 classes (each venue is a class).
The exponential fan-out of graph sampling quickly gets expensive. Sampling and model should be designed together to make the most of the available information in carefully sampled subgraphs.
We provide the entire OGBN-MAG graph data casted as a TF-GNN graph tensor as input to the graph sampler. The command below loads the entire OGBN-MAG as a single graph tensor from the already-saved serialized Tensorflow Example message (subject to this license). Additionally, it loads the supporting OGBN-MAG graph schema.
GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb'
SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt'
graph_schema = tfgnn.read_schema(SCHEMA_FILE)
serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE)
full_ogbn_mag_graph_tensor = tfgnn.parse_single_example(
tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64),
serialized_ogbn_mag_graph_tensor_string)
As OGBN-MAG dataset as a graph is huge, we sample from the graph to facilitate training on batches of subgraphs.
The sampling we have chosen for OGBN-MAG proceeds as follows:
Below, we spell out the above sampling strategy in an easy-to-read python code.
train_sampling_sizes = {
"cites": 8,
"rev_writes": 8,
"writes": 8,
"affiliated_with": 8,
"has_topic": 8,
}
validation_sample_sizes = train_sampling_sizes.copy()
def create_sampling_model(
full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]
) -> tf.keras.Model:
def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):
edge_set_name = sampling_op.edge_set_name
sample_size = sizes[edge_set_name]
return sampler.InMemUniformEdgesSampler.from_graph_tensor(
full_graph_tensor, edge_set_name, sample_size=sample_size
)
def get_features(node_set_name: tfgnn.NodeSetName):
return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
full_graph_tensor, node_set_name
)
# Spell out the sampling procedure in python
sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema)
seed = sampling_spec_builder.seed("paper")
papers_cited_from_seed = seed.sample(sizes["cites"], "cites")
authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes")
papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes")
institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with")
fields_of_study = (seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic"))
sampling_spec = sampling_spec_builder.build()
model = sampler.create_sampling_model_from_spec(
graph_schema, sampling_spec, edge_sampler, get_features,
seed_node_dtype=tf.int64)
return model
Notice how our sampler allows sampling edge sets in the reverse direction by setting add_reverse_edge_sets=True
while loading full_ogbn_mag_graph_tensor
. The rev_writes
is the derived from the one edge set writes
of the original OGBN-MAG graph which goes in the direction from node set papers
to node set authors
.
The sampling output contains all nodes and edges traversed by sampling, in their respective node/edge sets and with their associated features. An edge between two sampled nodes that exists in the input graph but has not been traversed by sampling is not included in the sampled output. For example, we get the cites
edges followed in step 2, but no edges for citations between the papers discovered in step 4.
Under OGB rules, we can sample subgraphs for the training, validation and test dataset from the full graph, just with different seed nodes, selected by the year of publication. We define the seed_dataset
responsible for providing the seeds for the different splits. (Models for production systems should probably use separate validation and test data, to prevent leakage of their seed nodes into the sampled subgraphs of other splits.)
def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset:
"""Seed dataset as indices of papers within split years."""
if split_name == "train":
mask = years <= 2017 # 629,571 examples
elif split_name == "validation":
mask = years == 2018 # 64,879 examples
elif split_name == "test":
mask = years == 2019 # 41,939 examples
else:
raise ValueError(f"Unknown split_name: '{split_name}'")
seed_indices = tf.squeeze(tf.where(mask), axis=-1)
return tf.data.Dataset.from_tensor_slices(seed_indices)
Next, we combine the seed_dataset
with the sampling model to obtain the SubgraphDatasetProvider
.
class SubgraphDatasetProvider(runner.DatasetProvider):
"""Dataset Provider based on Sampler V2."""
def __init__(self,
full_graph_tensor: tfgnn.GraphTensor,
sizes: Mapping[str, int],
split_name: str):
super().__init__()
# Extract years of publication of all papers for determining seeds.
self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1)
self._sampling_model = create_sampling_model(full_graph_tensor, sizes)
self._split_name = split_name
self.input_graph_spec = self._sampling_model.output.spec
def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
"""Creates TF dataset."""
self._seed_dataset = seed_dataset(self._years, self._split_name)
ds = self._seed_dataset.shard(
num_shards=context.num_input_pipelines, index=context.input_pipeline_id)
if self._split_name == "train":
ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat()
# samples 128 subgraphs in parallel. Larger is better, but could cause OOM.
ds = ds.batch(128)
ds = ds.map(
functools.partial(self.sample),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False,
)
return ds.unbatch().prefetch(tf.data.AUTOTUNE)
def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor:
seeds = tf.cast(seeds, tf.int64)
batch_size = tf.size(seeds)
# samples subgraphs for each seed independently as [[seed1], [seed2], ...]
seeds_ragged = tf.RaggedTensor.from_row_lengths(
seeds, tf.ones([batch_size], tf.int64),
)
return self._sampling_model(seeds_ragged)
train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train")
valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation")
example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()
We use TensorFlow's Distribution Strategy API to write a model that can run on multiple TPUs, multiple GPUs, or maybe just locally on CPU.
try:
tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
print("Running on TPU ", tpu_resolver.cluster_spec().as_dict()["worker"])
except:
tpu_resolver = None
if tpu_resolver:
print("Using TPUStrategy")
min_nodes_per_component = {"paper": 1}
strategy = runner.TPUStrategy()
train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider, min_nodes_per_component)
valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider, min_nodes_per_component)
elif tf.config.list_physical_devices("GPU"):
print(f"Using MirroredStrategy for GPUs")
gpu_list = !nvidia-smi -L
print("\n".join(gpu_list))
strategy = tf.distribute.MirroredStrategy()
train_padding = None
valid_padding = None
else:
print(f"Using default strategy")
strategy = tf.distribute.get_strategy()
train_padding = None
valid_padding = None
print(f"Found {strategy.num_replicas_in_sync} replicas in sync")
Running on TPU ['10.31.81.194:8470'] Using TPUStrategy Found 8 replicas in sync
As you might have noticed above, we need to provide a padding strategy when we want to train on TPUs. Next, we explain the need for paddings on TPU and the different padding strategies employed during training and validation.
Training on TPUs involves just-in-time compilation of a TensorFlow model to TPU code, and requires fixed shapes for all Tensors involved. To achieve that for graph data with variable numbers of nodes and edges, we need to pad each input Tensor to some fixed maximum size. For training on GPUs or CPU, this extra step is not necessary.
For the validation dataset, we need to make sure that every batch of examples fits within the fixed size, no matter how the parallelism in the input pipeline ends up combining examples into batches. Therefore, we use a rather generous estimate, basically scaling each Tensor's observed maximum size by a factor of batch_size
. If that were to run into limitations of accelerator memory, we'd rather shrink the batch size than lose examples.
The dataset in this example is not too big, so we can scan it within a few minutes to determine constraints large enough for all inputs. (For huge datasets under your control, it may be worth inferring an upper bound from the sampling spec instead.)
For the training dataset, TF-GNN allows you to optimize more aggressively for large batch sizes: size constraints satisfied by 100% of the inputs have to accommodate the rare combination of many large examples in one batch.
Instead, we use size constraints that will fit close to 100% of the randomly drawn training batches. This is not covered by the theory supporting stochastic gradient descent (which calls for examples drawn independently at random), but in practice, it often works, and allows larger batch sizes within the limits of accelerator memory, and hence faster convergence of the training.
We build a model on sampled subgraphs that predicts one of 349 classes (venues) for the subgraph's root node. We use a Graph Neural Network (GNN) to propagate information along edge sets towards the subgraph's root node.
Observe how the various node sets play different roles:
For node sets "paper" and "author", we follow the standard GNN approach to maintain a hidden state for each node and update it several times with information from the inbound edges. Notice how sampling has equipped each "paper" or "author" adjacent to the root node with a 1-hop neighborhood of its own. Our model does 4 rounds of updates, which covers the longest possible path in a sampled subgraph: a seed paper "cites" a paper that was written by ("rev_writes") an author who "writes" another paper that "has_topic" in some field of study.
For node sets "field_of_study" and "institution", a GNN on the full graph could produce meaningful hidden states for their few elements in the same way. However, in the sampled approach, it seems wasteful to do that from scratch for every subgraph. Instead, our model reads hidden states for them out of an embedding table. This way, the GNN can treat them as read-only nodes with outgoing edges only; the writing happens implicitly by gradient updates to their embeddings. (We choose to maintain a single embedding shared between the rounds of GNN updates.) – Notice how this modeling decision directly influences the sampling spec.
Usually in TensorFlow, the non-trainable transformations of the input features are split off into a Dataset.map()
call while the main model consists of the trainable and accelerator-compatible parts. However, even this non-trainable part is put into a Keras model, which is a convenient way to track resources (such as lookup tables) for exporting to a SavedModel.
Typically, feature preprocessing happens locally on nodes and edges. TF-GNN strives to reuse standard Keras implementations for this. The tfgnn.keras.layers.MapFeatures
layer lets you express feature transformations on the graph as a collection of feature transformations for the various graph pieces (node sets, edge sets, and context).
At this stage, the eventual training label is still a feature on the GraphTensor
. If necessary, it could also be preprocessed (e.g., turn a string-valued class label into a numeric id), but that's not the case here.
The training Task
(defined below) splits the label out of the GraphTensor
.
# For nodes
def process_node_features(node_set: tfgnn.NodeSet, node_set_name: str):
if node_set_name == "field_of_study":
return {"hashed_id": tf.keras.layers.Hashing(50_000)(node_set["#id"])}
if node_set_name == "institution":
return {"hashed_id": tf.keras.layers.Hashing(6_500)(node_set["#id"])}
if node_set_name == "paper":
# Keep `labels` for eventual extraction.
return {"feat": node_set["feat"], "labels": node_set["label"]}
if node_set_name == "author":
return {"empty_state": tfgnn.keras.layers.MakeEmptyFeature()(node_set)}
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
# For context and edges, in this example, we drop all features.
def drop_all_features(_, **unused_kwargs):
return {}
# The combined feature mapping of context, edges and nodes
# is all the preprocessing we need for this dataset.
feature_processors = [
tfgnn.keras.layers.MapFeatures(context_fn=drop_all_features,
node_sets_fn=process_node_features,
edge_sets_fn=drop_all_features),
]
Typically, a model with a GNN architecture at its core consists of three parts:
We are going to use one model for training, validation, and export for inference, so we need to build it from an input type spec with generic tensor shapes. (For TPUs, it's good enough to use it on a dataset with fixed-size elements.) Before defining the base Graph Neural Network, we show how to initialize the hidden states of all the necessary components (nodes, edges and context) given the pre-processed features.
The hidden states on nodes are created by mapping a dict of (preprocessed) features to fixed-size hidden states for nodes. Similarly to feature preprocessing, the tfgnn.keras.layers.MapFeatures
layer lets you specify such a transformation as a callback function that transforms feature dicts, with GraphTensor mechanics taken off your shoulders.
# Hyperparameters
node_state_dim = 128
def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
if node_set_name == "field_of_study":
return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"])
if node_set_name == "institution":
return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
if node_set_name == "paper":
return tf.keras.layers.Dense(node_state_dim)(node_set["feat"])
if node_set_name == "author":
return node_set["empty_state"]
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
It is important to understand the distinction between feature pre-processing and hidden state intialization despite the fact that both of the steps are defined using tfgnn.keras.layers.MapFeatures
. Feature pre-processing step is non-trainable and occurs asynchronous to the training loop. On the other hand, hidden state initialization is trainable and occurs on the corresponding accelerator.
After the hidden states have been initialized, we pass the graph through the base Graph Neural Network, which is a sequence of GraphUpdates. Each GraphUpdate inputs a GraphTensor and returns a GraphTensor with the same graph structure, but the hidden states of nodes have been updated using the information of the neighbor nodes. In our example, the input examples are sampled subgraphs with up to 4 hops, so we perform 4 rounds of graph updates which suffice to bring all information into the root node.
Here, we use TF-GNN's Model Template version A, code-named MtAlbis. It provides a curated shortlist of modeling options, and we invite our users to try this one before exploring the other choices offered in tensorflow_gnn/models or building their own as described in the Modeling guide.
# Hyperparameters
num_graph_updates = 4
message_dim = 128
state_dropout_rate = 0.2
l2_regularization = 1e-5
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=set_initial_node_states)(graph)
for i in range(num_graph_updates):
graph = mt_albis.MtAlbisGraphUpdate(
units=node_state_dim,
message_dim=message_dim,
receiver_tag=tfgnn.SOURCE,
node_set_names=None if i < num_graph_updates-1 else ["paper"],
simple_conv_reduce_type="mean|sum",
state_dropout_rate=state_dropout_rate,
l2_regularization=l2_regularization,
normalization_type="layer",
next_state_type="residual",
)(graph)
return tf.keras.Model(inputs, graph)
An important parameter to set in the GraphUpdate layer is the receiver_tag
. To determine this tag, it is important to understand the difference between tfgnn.SOURCE
and tfgnn.TARGET
. Source indictates the node from where an edge originates while Target indicates the node to which an edge points to.
The graph sampler starts sampling from the root node (one can think of the root node as the main source of the subgraph) and stores edges in the direction of their discovery while sampling. Given this construct, the GNN needs to send information in the reverse direction towards the root. In other words, the information needs to be propagated towards the SOURCE
of each edge, so that it can reach and update the hidden state of the root. Thus, we set the receiver_tag
to be tfgnn.SOURCE
.
An interesting observation arising from the fact that receiver_tag=tfgnn.SOURCE
is that since the node sets "field_of_study"
and "institution"
have no outgoing edge sets, the MtAlbisGraphUpdate
does not change their hidden states: these remain the embedding tables from node state initialization. The other node sets have their hidden states computed in a GraphUpdate: "paper"
in all four rounds, "author"
in all rounds but the last (because that hidden state has no opportunity to influence the final state of "paper"
).
A Task collects the ancillary pieces for training a Keras model with the graph learning objective. It also provides losses and metrics for that objective. Common implementations for classification and regression (by graph or root node) are provided in TF-GNN library.
label_fn = runner.RootNodeLabelFn(node_set_name="paper", feature_name="labels")
task = runner.RootNodeMulticlassClassification(
node_set_name="paper",
num_classes=349,
label_fn=label_fn)
A Trainer provides any training and validation loops. These may be uses of tf.keras.Model.fit
or arbitrary custom training loops. The Trainer provides accesors to training properties (like its tf.distribute.Strategy
and model_dir) and is expected to return a trained tf.keras.Model.
# Hyperparameters
global_batch_size = 128
epochs = 10
initial_learning_rate = 0.001
if tpu_resolver:
# Training on TPU takes ~90 secs / epoch, so we train for the entire epoch.
epoch_divisor = 1
else:
# Training on GPU / CPU is slower, so we train for 1/100th of a true epoch.
# Feel free to edit the `epoch_divisor` according to your patience and ambition. ;-)
epoch_divisor = 100
steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size // epoch_divisor
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size // epoch_divisor
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate, steps_per_epoch*epochs)
optimizer_fn = functools.partial(tf.keras.optimizers.Adam,
learning_rate=learning_rate)
# Trainer
trainer = runner.KerasTrainer(
strategy=strategy,
model_dir="/tmp/gnn_model/",
callbacks=None,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
restore_best_weights=False,
checkpoint_every_n_steps="never",
summarize_every_n_steps="never",
backup_and_restore=False,
)
For inference, a SavedModel must be exported by the runner at the end of training. C++ inference environments like TF Serving do not support input of extension types like GraphTensor, so the KerasModelExporter
exports the model with a SavedModel Signature that accepts a batch of serialized tf.Examples and preprocesses them like training did.
Note: After connecting this Colab to a TPU worker, explicit device placements are necessary to do the test on the colab host (which has the /tmp/gnn_model
directory).
save_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")
model_exporter = runner.KerasModelExporter(options=save_options)
Orchestration (a term for the composition, wiring and execution of the above abstractions) happens via a single run method with following signature shown below.
Training for 10 epochs of sampled subgraphs takes a few hours on a free colab with one GPU (T4) and should achieve an accuracy above 0.50. Training with the free Cloud TPU runtime is much faster, and completes the entire training within 20 mins.
You can drive accuracy even higher by training a bigger model for longer: setting node_state_dim = 256; message_dim = 256; epochs = 20
should take your val_sparse_categorical_accuracy above 0.52.
NOTE: It take ~4 minutes before training starts on TPU to learn optimal TPU padding constraints.
runner.run(
train_ds_provider=train_ds_provider,
train_padding=train_padding,
model_fn=model_fn,
optimizer_fn=optimizer_fn,
epochs=epochs,
trainer=trainer,
task=task,
gtspec=example_input_graph_spec,
global_batch_size=global_batch_size,
model_exporters=[model_exporter],
feature_processors=feature_processors,
valid_ds_provider=valid_ds_provider, # <<< Remove if not training for real.
valid_padding=valid_padding)
Epoch 1/10 4918/4918 [==============================] - 141s 29ms/step - loss: 2.5981 - sparse_categorical_accuracy: 0.3266 - sparse_categorical_crossentropy: 2.7027 - val_loss: 2.1224 - val_sparse_categorical_accuracy: 0.4123 - val_sparse_categorical_crossentropy: 2.1855 Epoch 2/10 4918/4918 [==============================] - 90s 18ms/step - loss: 2.0927 - sparse_categorical_accuracy: 0.4256 - sparse_categorical_crossentropy: 2.1475 - val_loss: 1.9497 - val_sparse_categorical_accuracy: 0.4474 - val_sparse_categorical_crossentropy: 1.9906 Epoch 3/10 4918/4918 [==============================] - 89s 18ms/step - loss: 1.9615 - sparse_categorical_accuracy: 0.4556 - sparse_categorical_crossentropy: 1.9998 - val_loss: 1.9174 - val_sparse_categorical_accuracy: 0.4524 - val_sparse_categorical_crossentropy: 1.9509 Epoch 4/10 4918/4918 [==============================] - 89s 18ms/step - loss: 1.8757 - sparse_categorical_accuracy: 0.4744 - sparse_categorical_crossentropy: 1.9056 - val_loss: 1.8417 - val_sparse_categorical_accuracy: 0.4782 - val_sparse_categorical_crossentropy: 1.8690 Epoch 5/10 4918/4918 [==============================] - 89s 18ms/step - loss: 1.8065 - sparse_categorical_accuracy: 0.4908 - sparse_categorical_crossentropy: 1.8320 - val_loss: 1.8294 - val_sparse_categorical_accuracy: 0.4797 - val_sparse_categorical_crossentropy: 1.8571 Epoch 6/10 4918/4918 [==============================] - 89s 18ms/step - loss: 1.7430 - sparse_categorical_accuracy: 0.5035 - sparse_categorical_crossentropy: 1.7666 - val_loss: 1.7843 - val_sparse_categorical_accuracy: 0.4924 - val_sparse_categorical_crossentropy: 1.8119 Epoch 7/10 4918/4918 [==============================] - 89s 18ms/step - loss: 1.6875 - sparse_categorical_accuracy: 0.5165 - sparse_categorical_crossentropy: 1.7104 - val_loss: 1.7536 - val_sparse_categorical_accuracy: 0.4978 - val_sparse_categorical_crossentropy: 1.7819 Epoch 8/10 4918/4918 [==============================] - 91s 18ms/step - loss: 1.6454 - sparse_categorical_accuracy: 0.5260 - sparse_categorical_crossentropy: 1.6680 - val_loss: 1.7462 - val_sparse_categorical_accuracy: 0.4992 - val_sparse_categorical_crossentropy: 1.7761 Epoch 9/10 4918/4918 [==============================] - 90s 18ms/step - loss: 1.6120 - sparse_categorical_accuracy: 0.5341 - sparse_categorical_crossentropy: 1.6340 - val_loss: 1.7419 - val_sparse_categorical_accuracy: 0.4998 - val_sparse_categorical_crossentropy: 1.7724 Epoch 10/10 4918/4918 [==============================] - 89s 18ms/step - loss: 1.5973 - sparse_categorical_accuracy: 0.5373 - sparse_categorical_crossentropy: 1.6189 - val_loss: 1.7394 - val_sparse_categorical_accuracy: 0.5018 - val_sparse_categorical_crossentropy: 1.7699
WARNING:absl:Found untraced functions such as _update_step_xla, node_set_update_layer_call_fn, node_set_update_layer_call_and_return_conditional_losses, node_set_update_1_layer_call_fn, node_set_update_1_layer_call_and_return_conditional_losses while saving (showing 5 of 143). These functions will not be directly callable after loading. /usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.GraphTensorSpec; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this " /usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.ContextSpec.v2; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this " /usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.NodeSetSpec; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this " /usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.EdgeSetSpec; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this " /usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.AdjacencySpec; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this "
RunResult(preprocess_model=<keras.engine.functional.Functional object at 0x7df54cbbe6b0>, base_model=<keras.engine.sequential.Sequential object at 0x7df54bebe920>, trained_model=<keras.engine.functional.Functional object at 0x7df54bbaff10>)
At the end of training, a SavedModel is exported by the Runner for inference. For demonstration, let's call the exported model on the validation dataset from above, but without labels. We load it as a SavedModel, like TF Serving would. Analogous to the SaveOptions above, LoadOptions with a device placement are necessary when connecting this Colab to a TPU worker.
NOTE: TF Serving usually expects examples in form of serialized strings, therefore we explicitly convert the graph tensors to serialized string format and pass it to the loaded model.
load_options = tf.saved_model.LoadOptions(experimental_io_device="/job:localhost")
saved_model = tf.saved_model.load(os.path.join(trainer.model_dir, "export"),
options=load_options)
def _clean_example_for_serving(graph_tensor):
serialized_example = tfgnn.write_example(graph_tensor)
return serialized_example.SerializeToString()
# Convert 10 examples to serialized string format.
num_examples = 10
demo_ds = valid_ds_provider.get_dataset(tf.distribute.InputContext())
serialized_examples = [_clean_example_for_serving(gt) for gt in itertools.islice(demo_ds, num_examples)]
# Inference on 10 examples
ds = tf.data.Dataset.from_tensor_slices(serialized_examples)
kwargs = {"examples": next(iter(ds.batch(10)))}
output = saved_model.signatures["serving_default"](**kwargs)
# Outputs are in the form of logits
logits = next(iter(output.values()))
probabilities = tf.math.softmax(logits).numpy()
classes = probabilities.argmax(axis=1)
# Print the predicted classes
for i, c in enumerate(classes):
print(f"The predicted class for input {i} is {c:3} "
f"with predicted probability {probabilities[i, c]:.4}")
The predicted class for input 0 is 289 with predicted probability 0.3882 The predicted class for input 1 is 281 with predicted probability 0.4415 The predicted class for input 2 is 189 with predicted probability 0.3256 The predicted class for input 3 is 158 with predicted probability 0.7522 The predicted class for input 4 is 82 with predicted probability 0.2598 The predicted class for input 5 is 247 with predicted probability 0.8446 The predicted class for input 6 is 209 with predicted probability 0.4843 The predicted class for input 7 is 247 with predicted probability 0.672 The predicted class for input 8 is 192 with predicted probability 0.5376 The predicted class for input 9 is 311 with predicted probability 0.8969
This tutorial has shown how to solve a node classification problem in a large graph with TF-GNN using
The Data Preparation and Sampling guide describes how you can create training data for other datasets.
The colab notebook An in-depth look at TF-GNN solves OGBN-MAG again, but without the abstractions provided by the Runner and the ready-to-use MtAlbis model. Take a look if you like to know more, or want more control in designing GNNs for your own task.
For more complete documentation, please check out the TF-GNN documentation.