from google.colab import drive
drive.mount('/content/gdrive')
import os
os.chdir('/content/gdrive/My Drive/finch/tensorflow2/semantic_parsing/tree_slu/main')
%tensorflow_version 2.x
!pip install -U tensorflow-addons
from modified_beam_search_decoder import BeamSearchDecoder
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import pprint
import logging
import time
import nltk
print("TensorFlow Version", tf.__version__)
print('GPU Enabled:', tf.test.is_gpu_available())
# stream data from text files
def data_generator(f_path, params):
with open(f_path) as f:
print('Reading', f_path)
for line in f:
text_raw, text_tokenized, label = line.split('\t')
text_tokenized = text_tokenized.lower().split()
label = label.replace('[', '[ ').lower().split()
source = [params['tgt2idx'].get(w, len(params['tgt2idx'])) for w in text_tokenized]
target = [params['tgt2idx'].get(w, len(params['tgt2idx'])) for w in label]
target_in = [1] + target
target_out = target + [2]
yield (source, target_in, target_out)
def dataset(is_training, params):
_shapes = ([None], [None], [None])
_types = (tf.int32, tf.int32, tf.int32)
_pads = (0, 0, 0)
if is_training:
ds = tf.data.Dataset.from_generator(
lambda: data_generator(params['train_path'], params),
output_shapes = _shapes,
output_types = _types,)
ds = ds.shuffle(params['buffer_size'])
ds = ds.padded_batch(params['train_batch_size'], _shapes, _pads)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
else:
ds = tf.data.Dataset.from_generator(
lambda: data_generator(params['test_path'], params),
output_shapes = _shapes,
output_types = _types,)
ds = ds.padded_batch(params['eval_batch_size'], _shapes, _pads)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
class Embed(tf.keras.Model):
def __init__(self):
super().__init__()
self.embedding = tf.Variable(np.load('../vocab/word.npy'),
dtype=tf.float32,
name='pretrained_embedding')
def call(self, inputs):
if inputs.dtype != tf.int32:
inputs = tf.cast(inputs, tf.int32)
x = tf.nn.embedding_lookup(self.embedding, inputs)
return x
class Encoder(tf.keras.Model):
def __init__(self, params):
super().__init__()
self.dropout = tf.keras.layers.Dropout(params['dropout_rate'])
self.birnn = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(
params['rnn_units'], return_state=True, return_sequences=True, zero_output_for_mask=True))
self.state_fc = tf.keras.layers.Dense(params['rnn_units'], params['activation'], name='state_fc')
self.out_fc = tf.keras.layers.Dense(params['rnn_units'], params['activation'], name='out_fc')
def call(self, inputs, mask, training):
if mask.dtype != tf.bool:
mask = tf.cast(mask, tf.bool)
x = self.dropout(inputs, training=training)
encoder_o, state_fw, state_bw = self.birnn(x, mask=mask)
encoder_s = self.state_fc(tf.concat((state_fw, state_bw), -1))
return encoder_o, tuple([encoder_s])
class TiedDense(tf.keras.layers.Layer):
def __init__(self, tied_embed, out_dim):
super().__init__()
self.tied_embed = tied_embed
self.out_dim = out_dim
def build(self, input_shape):
self.bias = self.add_weight(name='bias',
shape=[self.out_dim],
trainable=True)
super().build(input_shape)
def call(self, inputs):
x = tf.matmul(inputs, self.tied_embed, transpose_b=True)
x = tf.nn.bias_add(x, self.bias)
return tf.nn.softmax(x)
def compute_output_shape(self, input_shape):
return input_shape[:-1].concatenate(self.out_dim)
class Pointer(tf.keras.layers.Layer):
def __init__(self, vocab_size):
super().__init__()
self.encoder_ids = None
self.encoder_out = None
self.vocab_size = vocab_size
self.is_beam_search = None
def call(self, inputs):
_max_len = tf.shape(self.encoder_ids)[1]
_batch_size_ori = tf.shape(inputs)[0]
if self.is_beam_search:
_batch_size= _batch_size_ori * params['beam_width']
else:
_batch_size = _batch_size_ori
inputs = tf.reshape(inputs, (_batch_size, params['rnn_units']))
attn_weights = tf.matmul(self.encoder_out, tf.expand_dims(inputs, -1))
attn_weights = tf.squeeze(attn_weights, -1)
updates = tf.nn.softmax(attn_weights)
batch_nums = tf.range(0, _batch_size)
batch_nums = tf.expand_dims(batch_nums, axis=1)
batch_nums = tf.tile(batch_nums, [1, _max_len])
indices = tf.stack([batch_nums, self.encoder_ids], axis=2)
if self.is_beam_search:
x = tf.scatter_nd(indices, updates, (_batch_size, self.vocab_size))
return tf.reshape(x, (_batch_size_ori, params['beam_width'], self.vocab_size))
else:
x = tf.scatter_nd(indices, updates, (_batch_size, self.vocab_size))
return x
def compute_output_shape(self, input_shape):
return input_shape[:-1].concatenate(self.vocab_size)
class OutputProj(tf.keras.layers.Layer):
def __init__(self, tied_embed, vocab_size):
super().__init__()
self.generator = TiedDense(tied_embed, vocab_size)
self.pointer = Pointer(vocab_size)
self.vocab_size = vocab_size
def build(self, input_shape):
self.gate_fc = tf.keras.layers.Dense(1, tf.sigmoid, use_bias=False)
super().build(input_shape)
def call(self, inputs):
gen_dist = self.generator(inputs)
copy_dist = self.pointer(inputs)
gate = self.gate_fc(inputs)
return gate * gen_dist + (1 - gate) * copy_dist
def compute_output_shape(self, input_shape):
return input_shape[:-1].concatenate(self.vocab_size)
class Model(tf.keras.Model):
def __init__(self, params):
super().__init__()
self.embed = Embed()
self.encoder = Encoder(params)
self.dropout = tf.keras.layers.Dropout(params['dropout_rate'])
self.attn = tfa.seq2seq.BahdanauAttention(params['rnn_units'])
self.decoder_cell = tfa.seq2seq.AttentionWrapper(
tf.keras.layers.StackedRNNCells([tf.keras.layers.GRUCell(params['rnn_units'])]),
self.attn,
attention_layer_size=params['rnn_units'])
self.proj_layer = OutputProj(self.embed.embedding, len(params['tgt2idx'])+1)
self.teach_forcing = tfa.seq2seq.BasicDecoder(
self.decoder_cell,
tfa.seq2seq.sampler.TrainingSampler(),
output_layer = self.proj_layer)
self.beam_search = BeamSearchDecoder(
self.decoder_cell,
beam_width = params['beam_width'],
embedding_fn = lambda x: self.embed(x),
output_layer = self.proj_layer,
maximum_iterations = 64,)
def call(self, inputs, training=True):
if training:
source, target_in = inputs
else:
source = inputs
source = tf.cast(source, tf.int32)
batch_sz = tf.shape(source)[0]
encoder_o, encoder_s = self.encoder(self.embed(source), mask=tf.sign(source), training=training)
if training:
self.attn.setup_memory(encoder_o, tf.math.count_nonzero(source, 1))
attn_state = self.decoder_cell.get_initial_state(batch_size=batch_sz, dtype=tf.float32)
attn_state = attn_state.clone(cell_state=encoder_s)
self.proj_layer.pointer.encoder_ids = source
self.proj_layer.pointer.encoder_out = self.encoder.out_fc(encoder_o)
self.proj_layer.pointer.is_beam_search = False
decoder_o, _, _ = self.teach_forcing(
inputs = self.dropout(self.embed(target_in), training=training),
initial_state = attn_state,
sequence_length = tf.math.count_nonzero(target_in, 1, dtype=tf.int32))
logits_or_ids = decoder_o.rnn_output
else:
encoder_o_t = tfa.seq2seq.tile_batch(encoder_o, params['beam_width'])
encoder_len_t = tfa.seq2seq.tile_batch(tf.math.count_nonzero(source, 1), params['beam_width'])
encoder_s_t = tfa.seq2seq.tile_batch(encoder_s, params['beam_width'])
self.attn.setup_memory(encoder_o_t, encoder_len_t)
attn_state = self.decoder_cell.get_initial_state(batch_size=batch_sz*params['beam_width'], dtype=tf.float32)
attn_state = attn_state.clone(cell_state=encoder_s_t)
self.proj_layer.pointer.encoder_ids = tfa.seq2seq.tile_batch(source, params['beam_width'])
self.proj_layer.pointer.encoder_out = self.encoder.out_fc(encoder_o_t)
self.proj_layer.pointer.is_beam_search = True
decoder_o, _, _ = self.beam_search(
embedding = None,
start_tokens = tf.tile(tf.constant([1], tf.int32), [batch_sz]),
end_token = 2,
initial_state = attn_state,)
logits_or_ids = decoder_o.predicted_ids[:, :, 0]
return logits_or_ids
def get_vocab(f_path):
word2idx = {}
with open(f_path) as f:
for i, line in enumerate(f):
line = line.rstrip()
word2idx[line] = i
return word2idx
params = {
'train_path': '../data/train.tsv',
'test_path': '../data/test.tsv',
'vocab_src_path': '../vocab/source.txt',
'vocab_tgt_path': '../vocab/target.txt',
'model_path': '../model/',
'dropout_rate': .2,
'rnn_units': 300,
'embed_dim': 300,
'activation': tf.nn.swish,
'beam_width': 10,
'init_lr': 1e-4,
'max_lr': 8e-4,
'clip_norm': .1,
'buffer_size': 31279,
'train_batch_size': 32,
'eval_batch_size': 128,
'num_patience': 6,
}
params['tgt2idx'] = get_vocab(params['vocab_tgt_path'])
params['idx2tgt'] = {idx: tgt for tgt, idx in params['tgt2idx'].items()}
model = Model(params)
model.build(input_shape=[[None, None], [None, None]])
pprint.pprint([(v.name, v.shape) for v in model.trainable_variables])
decay_lr = tfa.optimizers.Triangular2CyclicalLearningRate(
initial_learning_rate = params['init_lr'],
maximal_learning_rate = params['max_lr'],
step_size = 4*params['buffer_size']//params['train_batch_size'],)
optim = tf.optimizers.Adam(params['init_lr'])
global_step = 0
t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.propagate = False
logger.setLevel(logging.INFO)
best_acc = .0
count = 0
def unit_test(model, params):
test_str = ['what', 'times', 'are', 'the', 'nutcracker', 'show', 'playing', 'near', 'me']
test_arr = tf.convert_to_tensor([[params['tgt2idx'][w] for w in test_str]])
generated = model(inputs=test_arr, training=False)
print('-'*12)
print('unit test')
print('utterance:', ' '.join(test_str))
parsed = ' '.join([params['idx2tgt'][idx] for idx in generated[0].numpy() if (idx != 0 and idx != 2)])
print('parsed:', parsed)
print()
try:
nltk.tree.Tree.fromstring(parsed.replace('[ ', '(').replace(' ]', ')')).pretty_print()
except:
pass
print('-'*12)
def cross_entropy_loss(logits, labels, vocab_size, smoothing):
soft_targets = tf.one_hot(tf.cast(labels, tf.int32), depth=vocab_size)
soft_targets = ((1-smoothing) * soft_targets) + (smoothing / vocab_size)
logits = tf.math.minimum(1., logits + 1e-6)
log_probs = tf.math.log(logits)
xentropy = - tf.reduce_sum(soft_targets * log_probs, axis=-1)
weights = tf.cast(tf.math.not_equal(labels, 0), tf.float32)
xentropy *= weights
return tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
while True:
# TRAINING
is_training = True
for i, (source, target_in, target_out) in enumerate(dataset(is_training=is_training, params=params)):
with tf.GradientTape() as tape:
logits_or_ids = model((source, target_in), training=is_training)
loss = cross_entropy_loss(logits_or_ids, target_out, len(params['tgt2idx'])+1, .2)
variables = model.trainable_variables
optim.lr.assign(decay_lr(global_step))
grads = tape.gradient(loss, variables)
grads, _ = tf.clip_by_global_norm(grads, params['clip_norm'])
optim.apply_gradients(zip(grads, variables))
if global_step % 50 == 0:
logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(
global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))
t0 = time.time()
global_step += 1
# EVALUATION
is_training = False
unit_test(model, params)
m = tf.keras.metrics.Mean()
parse_fn = lambda x: [e for e in x if (e != 0 and e != 2)]
for i, (source, target_in, target_out) in enumerate(dataset(is_training=is_training, params=params)):
generated = model(inputs=source, training=is_training)
for pred, tgt in zip(generated.numpy(), target_out.numpy()):
matched = np.array_equal(parse_fn(pred), parse_fn(tgt))
m.update_state(int(matched))
acc = m.result().numpy()
logger.info("Evaluation: Testing EM: {:.3f}".format(acc))
if acc > best_acc:
best_acc = acc
count = 0
model.save_weights('../model/gru_pointer_clr')
else:
count += 1
logger.info("Best EM: {:.3f}".format(best_acc))
if count == params['num_patience']:
print(params['num_patience'], "times not improve the best result, therefore stop training")
break