# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eicther express or implied. # See the License for the specific language governing permissions and # limitations under the License. !pip install -q tensorflow-gnn || echo "Ignoring package errors..." import functools import itertools import os import re from typing import Mapping import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn import runner from tensorflow_gnn.experimental import sampler from tensorflow_gnn.models import mt_albis tf.get_logger().setLevel('ERROR') print(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.") NUM_TRAINING_SAMPLES = 629571 NUM_VALIDATION_SAMPLES = 64879 GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb' SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt' graph_schema = tfgnn.read_schema(SCHEMA_FILE) serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE) full_ogbn_mag_graph_tensor = tfgnn.parse_single_example( tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64), serialized_ogbn_mag_graph_tensor_string) train_sampling_sizes = { "cites": 8, "rev_writes": 8, "writes": 8, "affiliated_with": 8, "has_topic": 8, } validation_sample_sizes = train_sampling_sizes.copy() def create_sampling_model( full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int] ) -> tf.keras.Model: def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp): edge_set_name = sampling_op.edge_set_name sample_size = sizes[edge_set_name] return sampler.InMemUniformEdgesSampler.from_graph_tensor( full_graph_tensor, edge_set_name, sample_size=sample_size ) def get_features(node_set_name: tfgnn.NodeSetName): return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor( full_graph_tensor, node_set_name ) # Spell out the sampling procedure in python sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema) seed = sampling_spec_builder.seed("paper") papers_cited_from_seed = seed.sample(sizes["cites"], "cites") authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes") papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes") institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with") fields_of_study = (seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic")) sampling_spec = sampling_spec_builder.build() model = sampler.create_sampling_model_from_spec( graph_schema, sampling_spec, edge_sampler, get_features, seed_node_dtype=tf.int64) return model def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset: """Seed dataset as indices of papers within split years.""" if split_name == "train": mask = years <= 2017 # 629,571 examples elif split_name == "validation": mask = years == 2018 # 64,879 examples elif split_name == "test": mask = years == 2019 # 41,939 examples else: raise ValueError(f"Unknown split_name: '{split_name}'") seed_indices = tf.squeeze(tf.where(mask), axis=-1) return tf.data.Dataset.from_tensor_slices(seed_indices) class SubgraphDatasetProvider(runner.DatasetProvider): """Dataset Provider based on Sampler V2.""" def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int], split_name: str): super().__init__() # Extract years of publication of all papers for determining seeds. self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1) self._sampling_model = create_sampling_model(full_graph_tensor, sizes) self._split_name = split_name self.input_graph_spec = self._sampling_model.output.spec def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset: """Creates TF dataset.""" self._seed_dataset = seed_dataset(self._years, self._split_name) ds = self._seed_dataset.shard( num_shards=context.num_input_pipelines, index=context.input_pipeline_id) if self._split_name == "train": ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat() # samples 128 subgraphs in parallel. Larger is better, but could cause OOM. ds = ds.batch(128) ds = ds.map( functools.partial(self.sample), num_parallel_calls=tf.data.AUTOTUNE, deterministic=False, ) return ds.unbatch().prefetch(tf.data.AUTOTUNE) def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor: seeds = tf.cast(seeds, tf.int64) batch_size = tf.size(seeds) # samples subgraphs for each seed independently as [[seed1], [seed2], ...] seeds_ragged = tf.RaggedTensor.from_row_lengths( seeds, tf.ones([batch_size], tf.int64), ) return self._sampling_model(seeds_ragged) train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train") valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation") example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch() try: tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() print("Running on TPU ", tpu_resolver.cluster_spec().as_dict()["worker"]) except: tpu_resolver = None if tpu_resolver: print("Using TPUStrategy") min_nodes_per_component = {"paper": 1} strategy = runner.TPUStrategy() train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider, min_nodes_per_component) valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider, min_nodes_per_component) elif tf.config.list_physical_devices("GPU"): print(f"Using MirroredStrategy for GPUs") gpu_list = !nvidia-smi -L print("\n".join(gpu_list)) strategy = tf.distribute.MirroredStrategy() train_padding = None valid_padding = None else: print(f"Using default strategy") strategy = tf.distribute.get_strategy() train_padding = None valid_padding = None print(f"Found {strategy.num_replicas_in_sync} replicas in sync") # For nodes def process_node_features(node_set: tfgnn.NodeSet, node_set_name: str): if node_set_name == "field_of_study": return {"hashed_id": tf.keras.layers.Hashing(50_000)(node_set["#id"])} if node_set_name == "institution": return {"hashed_id": tf.keras.layers.Hashing(6_500)(node_set["#id"])} if node_set_name == "paper": # Keep `labels` for eventual extraction. return {"feat": node_set["feat"], "labels": node_set["label"]} if node_set_name == "author": return {"empty_state": tfgnn.keras.layers.MakeEmptyFeature()(node_set)} raise KeyError(f"Unexpected node_set_name='{node_set_name}'") # For context and edges, in this example, we drop all features. def drop_all_features(_, **unused_kwargs): return {} # The combined feature mapping of context, edges and nodes # is all the preprocessing we need for this dataset. feature_processors = [ tfgnn.keras.layers.MapFeatures(context_fn=drop_all_features, node_sets_fn=process_node_features, edge_sets_fn=drop_all_features), ] # Hyperparameters node_state_dim = 128 def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str): if node_set_name == "field_of_study": return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"]) if node_set_name == "institution": return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"]) if node_set_name == "paper": return tf.keras.layers.Dense(node_state_dim)(node_set["feat"]) if node_set_name == "author": return node_set["empty_state"] raise KeyError(f"Unexpected node_set_name='{node_set_name}'") # Hyperparameters num_graph_updates = 4 message_dim = 128 state_dropout_rate = 0.2 l2_regularization = 1e-5 def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec): graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec) graph = tfgnn.keras.layers.MapFeatures( node_sets_fn=set_initial_node_states)(graph) for i in range(num_graph_updates): graph = mt_albis.MtAlbisGraphUpdate( units=node_state_dim, message_dim=message_dim, receiver_tag=tfgnn.SOURCE, node_set_names=None if i < num_graph_updates-1 else ["paper"], simple_conv_reduce_type="mean|sum", state_dropout_rate=state_dropout_rate, l2_regularization=l2_regularization, normalization_type="layer", next_state_type="residual", )(graph) return tf.keras.Model(inputs, graph) label_fn = runner.RootNodeLabelFn(node_set_name="paper", feature_name="labels") task = runner.RootNodeMulticlassClassification( node_set_name="paper", num_classes=349, label_fn=label_fn) # Hyperparameters global_batch_size = 128 epochs = 10 initial_learning_rate = 0.001 if tpu_resolver: # Training on TPU takes ~90 secs / epoch, so we train for the entire epoch. epoch_divisor = 1 else: # Training on GPU / CPU is slower, so we train for 1/100th of a true epoch. # Feel free to edit the `epoch_divisor` according to your patience and ambition. ;-) epoch_divisor = 100 steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size // epoch_divisor validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size // epoch_divisor learning_rate = tf.keras.optimizers.schedules.CosineDecay( initial_learning_rate, steps_per_epoch*epochs) optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate) # Trainer trainer = runner.KerasTrainer( strategy=strategy, model_dir="/tmp/gnn_model/", callbacks=None, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, restore_best_weights=False, checkpoint_every_n_steps="never", summarize_every_n_steps="never", backup_and_restore=False, ) save_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost") model_exporter = runner.KerasModelExporter(options=save_options) runner.run( train_ds_provider=train_ds_provider, train_padding=train_padding, model_fn=model_fn, optimizer_fn=optimizer_fn, epochs=epochs, trainer=trainer, task=task, gtspec=example_input_graph_spec, global_batch_size=global_batch_size, model_exporters=[model_exporter], feature_processors=feature_processors, valid_ds_provider=valid_ds_provider, # <<< Remove if not training for real. valid_padding=valid_padding) load_options = tf.saved_model.LoadOptions(experimental_io_device="/job:localhost") saved_model = tf.saved_model.load(os.path.join(trainer.model_dir, "export"), options=load_options) def _clean_example_for_serving(graph_tensor): serialized_example = tfgnn.write_example(graph_tensor) return serialized_example.SerializeToString() # Convert 10 examples to serialized string format. num_examples = 10 demo_ds = valid_ds_provider.get_dataset(tf.distribute.InputContext()) serialized_examples = [_clean_example_for_serving(gt) for gt in itertools.islice(demo_ds, num_examples)] # Inference on 10 examples ds = tf.data.Dataset.from_tensor_slices(serialized_examples) kwargs = {"examples": next(iter(ds.batch(10)))} output = saved_model.signatures["serving_default"](**kwargs) # Outputs are in the form of logits logits = next(iter(output.values())) probabilities = tf.math.softmax(logits).numpy() classes = probabilities.argmax(axis=1) # Print the predicted classes for i, c in enumerate(classes): print(f"The predicted class for input {i} is {c:3} " f"with predicted probability {probabilities[i, c]:.4}")