"""
We use following lines because we are running on Google Colab
If you are running notebook on a local computer, you don't need this cell
"""
from google.colab import drive
drive.mount('/content/gdrive')
import os
os.chdir('/content/gdrive/My Drive/finch/tensorflow2/semantic_parsing/tree_slu/main')
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
%tensorflow_version 2.x
!pip install tensorflow-addons
TensorFlow 2.x selected. Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (0.6.0) Requirement already satisfied: tensorflow-gpu==2.0.0 in /tensorflow-2.0.0/python3.6 (from tensorflow-addons) (2.0.0) Requirement already satisfied: six>=1.10.0 in /tensorflow-2.0.0/python3.6 (from tensorflow-addons) (1.13.0) Requirement already satisfied: gast==0.2.2 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (0.2.2) Requirement already satisfied: grpcio>=1.8.6 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (1.25.0) Requirement already satisfied: tensorboard<2.1.0,>=2.0.0 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (2.0.1) Requirement already satisfied: protobuf>=3.6.1 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (3.10.0) Requirement already satisfied: keras-preprocessing>=1.0.5 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (1.1.0) Requirement already satisfied: numpy<2.0,>=1.16.0 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (1.17.4) Requirement already satisfied: wrapt>=1.11.1 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (1.11.2) Requirement already satisfied: wheel>=0.26 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (0.33.6) Requirement already satisfied: tensorflow-estimator<2.1.0,>=2.0.0 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (2.0.1) Requirement already satisfied: google-pasta>=0.1.6 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (0.1.8) Requirement already satisfied: absl-py>=0.7.0 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (0.8.1) Requirement already satisfied: astor>=0.6.0 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (0.8.0) Requirement already satisfied: opt-einsum>=2.3.2 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (3.1.0) Requirement already satisfied: keras-applications>=1.0.8 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (1.0.8) Requirement already satisfied: termcolor>=1.1.0 in /tensorflow-2.0.0/python3.6 (from tensorflow-gpu==2.0.0->tensorflow-addons) (1.1.0) Requirement already satisfied: markdown>=2.6.8 in /tensorflow-2.0.0/python3.6 (from tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (3.1.1) Requirement already satisfied: werkzeug>=0.11.15 in /tensorflow-2.0.0/python3.6 (from tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (0.16.0) Requirement already satisfied: setuptools>=41.0.0 in /tensorflow-2.0.0/python3.6 (from tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (41.6.0) Requirement already satisfied: google-auth<2,>=1.6.3 in /tensorflow-2.0.0/python3.6 (from tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (1.7.0) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /tensorflow-2.0.0/python3.6 (from tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (0.4.1) Requirement already satisfied: h5py in /tensorflow-2.0.0/python3.6 (from keras-applications>=1.0.8->tensorflow-gpu==2.0.0->tensorflow-addons) (2.10.0) Requirement already satisfied: cachetools<3.2,>=2.0.0 in /tensorflow-2.0.0/python3.6 (from google-auth<2,>=1.6.3->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (3.1.1) Requirement already satisfied: pyasn1-modules>=0.2.1 in /tensorflow-2.0.0/python3.6 (from google-auth<2,>=1.6.3->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (0.2.7) Requirement already satisfied: rsa<4.1,>=3.1.4 in /tensorflow-2.0.0/python3.6 (from google-auth<2,>=1.6.3->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (4.0) Requirement already satisfied: requests-oauthlib>=0.7.0 in /tensorflow-2.0.0/python3.6 (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (1.3.0) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /tensorflow-2.0.0/python3.6 (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (0.4.7) Requirement already satisfied: oauthlib>=3.0.0 in /tensorflow-2.0.0/python3.6 (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (3.1.0) Requirement already satisfied: requests>=2.0.0 in /tensorflow-2.0.0/python3.6 (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (2.22.0) Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /tensorflow-2.0.0/python3.6 (from requests>=2.0.0->requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (3.0.4) Requirement already satisfied: certifi>=2017.4.17 in /tensorflow-2.0.0/python3.6 (from requests>=2.0.0->requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (2019.9.11) Requirement already satisfied: idna<2.9,>=2.5 in /tensorflow-2.0.0/python3.6 (from requests>=2.0.0->requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (2.8) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /tensorflow-2.0.0/python3.6 (from requests>=2.0.0->requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.1.0,>=2.0.0->tensorflow-gpu==2.0.0->tensorflow-addons) (1.25.7)
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import pprint
import logging
import time
import nltk
print("TensorFlow Version", tf.__version__)
print('GPU Enabled:', tf.test.is_gpu_available())
TensorFlow Version 2.0.0 GPU Enabled: True
# stream data from text files
def data_generator(f_path, params):
with open(f_path) as f:
print('Reading', f_path)
for line in f:
text_raw, text_tokenized, label = line.split('\t')
text_tokenized = text_tokenized.lower().split()
label = label.replace('[', '[ ').lower().split()
source = [params['tgt2idx'].get(w, len(params['tgt2idx'])) for w in text_tokenized]
target = [params['tgt2idx'].get(w, len(params['tgt2idx'])) for w in label]
target_in = [1] + target
target_out = target + [2]
yield (source, target_in, target_out)
def dataset(is_training, params):
_shapes = ([None], [None], [None])
_types = (tf.int32, tf.int32, tf.int32)
_pads = (0, 0, 0)
if is_training:
ds = tf.data.Dataset.from_generator(
lambda: data_generator(params['train_path'], params),
output_shapes = _shapes,
output_types = _types,)
ds = ds.shuffle(params['buffer_size'])
ds = ds.padded_batch(params['train_batch_size'], _shapes, _pads)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
else:
ds = tf.data.Dataset.from_generator(
lambda: data_generator(params['test_path'], params),
output_shapes = _shapes,
output_types = _types,)
ds = ds.padded_batch(params['eval_batch_size'], _shapes, _pads)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
class Embed(tf.keras.Model):
def __init__(self):
super().__init__()
self.embedding = tf.Variable(np.load('../vocab/word.npy'),
dtype=tf.float32,
name='pretrained_embedding')
def call(self, inputs):
if inputs.dtype != tf.int32:
inputs = tf.cast(inputs, tf.int32)
x = tf.nn.embedding_lookup(self.embedding, inputs)
return x
class Encoder(tf.keras.Model):
def __init__(self, params):
super().__init__()
self.dropout = tf.keras.layers.Dropout(params['dropout_rate'])
self.encoder = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(
params['rnn_units'], return_state=True, return_sequences=True))
self.state_fc = tf.keras.layers.Dense(params['rnn_units'], params['activation'], name='state_fc')
def call(self, inputs, training):
x = self.dropout(inputs, training=training)
encoder_o, state_fw_h, state_fw_c, state_bw_h, state_bw_c = self.encoder(x)
encoder_s = [
self.state_fc(tf.concat((state_fw_h, state_bw_h), -1)),
self.state_fc(tf.concat((state_fw_c, state_bw_c), -1)),]
return encoder_o, encoder_s
class TiedDense(tf.keras.layers.Layer):
def __init__(self, tied_embed, out_dim):
super().__init__()
self.tied_embed = tied_embed
self.out_dim = out_dim
def build(self, input_shape):
self.bias = self.add_weight(name='bias',
shape=[self.out_dim],
trainable=True)
"""
self.proj_W = self.add_weight(name='proj_W',
shape=[2*params['rnn_units'], params['embed_dim']],
trainable=True)
self.proj_b = self.add_weight(name='proj_b',
shape=[params['embed_dim']],
trainable=True)
"""
super().build(input_shape)
@tf.function
def call(self, inputs):
"""
inputs = params['activation'](inputs @ self.proj_W + self.proj_b)
"""
x = tf.matmul(inputs, self.tied_embed, transpose_b=True)
x += self.bias
return x
def compute_output_shape(self, input_shape):
return input_shape[:-1].concatenate(self.out_dim)
class Model(tf.keras.Model):
def __init__(self, params):
super().__init__()
self.embed = Embed()
self.encoder = Encoder(params)
self.dropout = tf.keras.layers.Dropout(params['dropout_rate'])
self.attn = tfa.seq2seq.BahdanauAttention(params['rnn_units'])
self.decoder_cell = tfa.seq2seq.AttentionWrapper(
tf.keras.layers.LSTMCell(params['rnn_units']),
self.attn,
attention_layer_size=params['rnn_units'])
self.proj_layer = TiedDense(self.embed.embedding, len(params['tgt2idx'])+1)
self.teach_forcing = tfa.seq2seq.BasicDecoder(
self.decoder_cell,
tfa.seq2seq.sampler.TrainingSampler(),
output_layer = self.proj_layer)
self.beam_search = tfa.seq2seq.BeamSearchDecoder(
self.decoder_cell,
beam_width = params['beam_width'],
embedding_fn = lambda x: self.embed(x),
output_layer = self.proj_layer,
maximum_iterations = 80,)
def call(self, inputs, training=True):
if training:
source, target_in = inputs
else:
source = inputs
batch_sz = tf.shape(source)[0]
encoder_o, encoder_s = self.encoder(self.embed(source), training=training)
if training:
self.attn([encoder_o, tf.math.count_nonzero(source, 1)], setup_memory=True)
attn_state = self.decoder_cell.get_initial_state(batch_size=batch_sz, dtype=tf.float32)
attn_state = attn_state.clone(cell_state=encoder_s)
decoder_o, _, _ = self.teach_forcing(
inputs = self.dropout(self.embed(target_in), training=training),
initial_state = attn_state,
sequence_length = tf.math.count_nonzero(target_in, 1, dtype=tf.int32))
logits_or_ids = decoder_o.rnn_output
else:
encoder_o_t = tfa.seq2seq.tile_batch(encoder_o, params['beam_width'])
encoder_len_t = tfa.seq2seq.tile_batch(tf.math.count_nonzero(source, 1), params['beam_width'])
encoder_s_t = tfa.seq2seq.tile_batch(encoder_s, params['beam_width'])
self.attn([encoder_o_t, encoder_len_t], setup_memory=True)
attn_state = self.decoder_cell.get_initial_state(batch_size=batch_sz*params['beam_width'], dtype=tf.float32)
attn_state = attn_state.clone(cell_state=encoder_s_t)
decoder_o, _, _ = self.beam_search(
None,
start_tokens = tf.tile(tf.constant([1], tf.int32), [batch_sz]),
end_token = 2,
initial_state = attn_state,)
logits_or_ids = decoder_o.predicted_ids[:, :, 0]
return logits_or_ids
def get_vocab(f_path):
word2idx = {}
with open(f_path) as f:
for i, line in enumerate(f):
line = line.rstrip()
word2idx[line] = i
return word2idx
def is_descending(history: list) -> bool:
history = history[-(params['num_patience']+1):]
for i in range(1, len(history)):
if history[i-1] <= history[i]:
return False
return True
params = {
'train_path': '../data/train.tsv',
'test_path': '../data/test.tsv',
'vocab_src_path': '../vocab/source.txt',
'vocab_tgt_path': '../vocab/target.txt',
'model_path': '../model/',
'dropout_rate': 0.2,
'rnn_units': 300,
'embed_dim': 300,
'activation': tf.nn.relu,
'beam_width': 5,
'lr': 4e-4,
'clip_norm': .1,
'buffer_size': 31279,
'train_batch_size': 32,
'eval_batch_size': 128,
'num_patience': 5,
}
params['tgt2idx'] = get_vocab(params['vocab_tgt_path'])
params['idx2tgt'] = {idx: tgt for tgt, idx in params['tgt2idx'].items()}
model = Model(params)
model.build(input_shape=[[None, None], [None, None]])
pprint.pprint([(v.name, v.shape) for v in model.trainable_variables])
[('pretrained_embedding:0', TensorShape([8692, 300])), ('encoder/bidirectional/forward_lstm/kernel:0', TensorShape([300, 1200])), ('encoder/bidirectional/forward_lstm/recurrent_kernel:0', TensorShape([300, 1200])), ('encoder/bidirectional/forward_lstm/bias:0', TensorShape([1200])), ('encoder/bidirectional/backward_lstm/kernel:0', TensorShape([300, 1200])), ('encoder/bidirectional/backward_lstm/recurrent_kernel:0', TensorShape([300, 1200])), ('encoder/bidirectional/backward_lstm/bias:0', TensorShape([1200])), ('encoder/state_fc/kernel:0', TensorShape([600, 300])), ('encoder/state_fc/bias:0', TensorShape([300])), ('BahdanauAttention/attention_v:0', TensorShape([300])), ('attention_wrapper/BahdanauAttention/kernel:0', TensorShape([300, 300])), ('BahdanauAttention/kernel:0', TensorShape([600, 300])), ('attention_wrapper/attention_layer/kernel:0', TensorShape([900, 300])), ('attention_wrapper/lstm_cell_3/kernel:0', TensorShape([600, 1200])), ('attention_wrapper/lstm_cell_3/recurrent_kernel:0', TensorShape([300, 1200])), ('attention_wrapper/lstm_cell_3/bias:0', TensorShape([1200])), ('tied_dense/bias:0', TensorShape([8692]))]
decay_lr = tf.optimizers.schedules.ExponentialDecay(params['lr'], 1000, 0.99)
optim = tf.optimizers.Adam(params['lr'])
global_step = 0
history_acc = []
best_acc = .0
t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.propagate = False
logger.setLevel(logging.INFO)
def minimal_test(model, params):
test_str = ['what', 'times', 'are', 'the', 'nutcracker', 'show', 'playing', 'near', 'me']
test_arr = tf.convert_to_tensor([[params['tgt2idx'][w] for w in test_str]])
generated = model(inputs=test_arr, training=False)
print('-'*12)
print('minimal test')
print('utterance:', ' '.join(test_str))
parsed = ' '.join([params['idx2tgt'][idx] for idx in generated[0].numpy()])
parsed = parsed.replace('<end>', '').strip()
print('parsed:', parsed)
print()
try:
nltk.tree.Tree.fromstring(parsed.replace('[ ', '(').replace(' ]', ')')).pretty_print()
except:
pass
print('-'*12)
def is_descending(history: list) -> bool:
history = history[-(params['num_patience']+1):]
for i in range(1, len(history)):
if history[i-1] <= history[i]:
return False
return True
while True:
# TRAINING
is_training = True
for i, (source, target_in, target_out) in enumerate(dataset(is_training=is_training, params=params)):
with tf.GradientTape() as tape:
logits_or_ids = model((source, target_in), training=is_training)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=target_out, logits=logits_or_ids)
weights = tf.cast(tf.sign(target_in), tf.float32)
loss = tf.reduce_sum(loss * weights) / tf.reduce_sum(weights)
variables = model.trainable_variables
optim.lr.assign(decay_lr(global_step))
grads = tape.gradient(loss, variables)
grads, _ = tf.clip_by_global_norm(grads, params['clip_norm'])
optim.apply_gradients(zip(grads, variables))
if global_step % 50 == 0:
logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(
global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))
t0 = time.time()
global_step += 1
# EVALUATION
is_training = False
minimal_test(model, params)
m = tf.keras.metrics.Mean()
for i, (source, target_in, target_out) in enumerate(dataset(is_training=is_training, params=params)):
generated = model(inputs=source, training=is_training)
seq_lens = tf.argmax(tf.cast(tf.equal(target_out, 2), tf.int32), axis=1)
for pred, tgt, seq_len in zip(generated.numpy(), target_out.numpy(), seq_lens.numpy()):
matched = np.all(pred[:seq_len+1] == tgt[:seq_len+1])
m.update_state(int(matched))
acc = m.result().numpy()
logger.info("Evaluation: Testing Exact Match Accuracy: {:.3f}".format(acc))
history_acc.append(acc)
if acc > best_acc:
best_acc = acc
logger.info("Best Accuracy: {:.3f}".format(best_acc))
if len(history_acc) > params['num_patience'] and is_descending(history_acc):
logger.info("Testing Accuracy not improved over {} epochs, Early Stop".format(params['num_patience']))
break
Reading ../data/train.tsv INFO:tensorflow:Step 0 | Loss: 9.2826 | Spent: 6.0 secs | LR: 0.000400 INFO:tensorflow:Step 50 | Loss: 3.4936 | Spent: 31.0 secs | LR: 0.000400 INFO:tensorflow:Step 100 | Loss: 2.5489 | Spent: 31.4 secs | LR: 0.000400 INFO:tensorflow:Step 150 | Loss: 1.8968 | Spent: 30.4 secs | LR: 0.000399 INFO:tensorflow:Step 200 | Loss: 1.5225 | Spent: 31.3 secs | LR: 0.000399 INFO:tensorflow:Step 250 | Loss: 1.5356 | Spent: 31.6 secs | LR: 0.000399 INFO:tensorflow:Step 300 | Loss: 1.1140 | Spent: 34.1 secs | LR: 0.000399 INFO:tensorflow:Step 350 | Loss: 0.9405 | Spent: 32.0 secs | LR: 0.000399 INFO:tensorflow:Step 400 | Loss: 0.6781 | Spent: 31.5 secs | LR: 0.000398 INFO:tensorflow:Step 450 | Loss: 0.6539 | Spent: 31.1 secs | LR: 0.000398 INFO:tensorflow:Step 500 | Loss: 0.7228 | Spent: 31.4 secs | LR: 0.000398 INFO:tensorflow:Step 550 | Loss: 0.5562 | Spent: 32.1 secs | LR: 0.000398 INFO:tensorflow:Step 600 | Loss: 0.5013 | Spent: 31.4 secs | LR: 0.000398 INFO:tensorflow:Step 650 | Loss: 0.3475 | Spent: 30.2 secs | LR: 0.000397 INFO:tensorflow:Step 700 | Loss: 0.4110 | Spent: 31.2 secs | LR: 0.000397 INFO:tensorflow:Step 750 | Loss: 0.4447 | Spent: 33.5 secs | LR: 0.000397 INFO:tensorflow:Step 800 | Loss: 0.6801 | Spent: 32.9 secs | LR: 0.000397 INFO:tensorflow:Step 850 | Loss: 1.3601 | Spent: 30.7 secs | LR: 0.000397 INFO:tensorflow:Step 900 | Loss: 0.3804 | Spent: 30.7 secs | LR: 0.000396 INFO:tensorflow:Step 950 | Loss: 0.1635 | Spent: 30.6 secs | LR: 0.000396 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are the nutcracker show playing near me ] in:get_event ______________________|__________________________ what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.383 INFO:tensorflow:Best Accuracy: 0.383 Reading ../data/train.tsv INFO:tensorflow:Step 1000 | Loss: 0.3383 | Spent: 167.3 secs | LR: 0.000396 INFO:tensorflow:Step 1050 | Loss: 0.3575 | Spent: 31.2 secs | LR: 0.000396 INFO:tensorflow:Step 1100 | Loss: 0.3602 | Spent: 33.0 secs | LR: 0.000396 INFO:tensorflow:Step 1150 | Loss: 0.2549 | Spent: 32.3 secs | LR: 0.000395 INFO:tensorflow:Step 1200 | Loss: 0.2287 | Spent: 31.8 secs | LR: 0.000395 INFO:tensorflow:Step 1250 | Loss: 0.3000 | Spent: 32.4 secs | LR: 0.000395 INFO:tensorflow:Step 1300 | Loss: 0.2924 | Spent: 31.4 secs | LR: 0.000395 INFO:tensorflow:Step 1350 | Loss: 0.1161 | Spent: 31.1 secs | LR: 0.000395 INFO:tensorflow:Step 1400 | Loss: 0.1842 | Spent: 32.1 secs | LR: 0.000394 INFO:tensorflow:Step 1450 | Loss: 0.2146 | Spent: 31.8 secs | LR: 0.000394 INFO:tensorflow:Step 1500 | Loss: 0.2259 | Spent: 30.4 secs | LR: 0.000394 INFO:tensorflow:Step 1550 | Loss: 0.1540 | Spent: 33.3 secs | LR: 0.000394 INFO:tensorflow:Step 1600 | Loss: 0.1853 | Spent: 31.5 secs | LR: 0.000394 INFO:tensorflow:Step 1650 | Loss: 0.1268 | Spent: 31.9 secs | LR: 0.000393 INFO:tensorflow:Step 1700 | Loss: 0.1582 | Spent: 31.1 secs | LR: 0.000393 INFO:tensorflow:Step 1750 | Loss: 0.1391 | Spent: 30.1 secs | LR: 0.000393 INFO:tensorflow:Step 1800 | Loss: 0.0759 | Spent: 31.9 secs | LR: 0.000393 INFO:tensorflow:Step 1850 | Loss: 0.0602 | Spent: 31.0 secs | LR: 0.000393 INFO:tensorflow:Step 1900 | Loss: 0.2162 | Spent: 32.2 secs | LR: 0.000392 INFO:tensorflow:Step 1950 | Loss: 0.0843 | Spent: 31.0 secs | LR: 0.000392 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.597 INFO:tensorflow:Best Accuracy: 0.597 Reading ../data/train.tsv INFO:tensorflow:Step 2000 | Loss: 0.2686 | Spent: 186.5 secs | LR: 0.000392 INFO:tensorflow:Step 2050 | Loss: 0.1794 | Spent: 31.5 secs | LR: 0.000392 INFO:tensorflow:Step 2100 | Loss: 0.0995 | Spent: 30.3 secs | LR: 0.000392 INFO:tensorflow:Step 2150 | Loss: 0.2179 | Spent: 31.6 secs | LR: 0.000391 INFO:tensorflow:Step 2200 | Loss: 0.1103 | Spent: 30.6 secs | LR: 0.000391 INFO:tensorflow:Step 2250 | Loss: 0.0681 | Spent: 31.2 secs | LR: 0.000391 INFO:tensorflow:Step 2300 | Loss: 0.1054 | Spent: 30.5 secs | LR: 0.000391 INFO:tensorflow:Step 2350 | Loss: 0.1282 | Spent: 30.8 secs | LR: 0.000391 INFO:tensorflow:Step 2400 | Loss: 0.0614 | Spent: 32.4 secs | LR: 0.000390 INFO:tensorflow:Step 2450 | Loss: 0.0933 | Spent: 32.1 secs | LR: 0.000390 INFO:tensorflow:Step 2500 | Loss: 0.0860 | Spent: 32.1 secs | LR: 0.000390 INFO:tensorflow:Step 2550 | Loss: 0.1360 | Spent: 31.7 secs | LR: 0.000390 INFO:tensorflow:Step 2600 | Loss: 0.0816 | Spent: 31.0 secs | LR: 0.000390 INFO:tensorflow:Step 2650 | Loss: 0.1040 | Spent: 31.2 secs | LR: 0.000389 INFO:tensorflow:Step 2700 | Loss: 0.1258 | Spent: 30.6 secs | LR: 0.000389 INFO:tensorflow:Step 2750 | Loss: 0.0744 | Spent: 30.1 secs | LR: 0.000389 INFO:tensorflow:Step 2800 | Loss: 0.1592 | Spent: 32.1 secs | LR: 0.000389 INFO:tensorflow:Step 2850 | Loss: 0.1176 | Spent: 32.1 secs | LR: 0.000389 INFO:tensorflow:Step 2900 | Loss: 0.1056 | Spent: 31.9 secs | LR: 0.000389 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.651 INFO:tensorflow:Best Accuracy: 0.651 Reading ../data/train.tsv INFO:tensorflow:Step 2950 | Loss: 0.0584 | Spent: 185.4 secs | LR: 0.000388 INFO:tensorflow:Step 3000 | Loss: 0.0577 | Spent: 32.3 secs | LR: 0.000388 INFO:tensorflow:Step 3050 | Loss: 0.0769 | Spent: 32.0 secs | LR: 0.000388 INFO:tensorflow:Step 3100 | Loss: 0.0682 | Spent: 32.6 secs | LR: 0.000388 INFO:tensorflow:Step 3150 | Loss: 0.0645 | Spent: 33.1 secs | LR: 0.000388 INFO:tensorflow:Step 3200 | Loss: 0.0889 | Spent: 30.9 secs | LR: 0.000387 INFO:tensorflow:Step 3250 | Loss: 0.0837 | Spent: 31.7 secs | LR: 0.000387 INFO:tensorflow:Step 3300 | Loss: 0.0779 | Spent: 30.9 secs | LR: 0.000387 INFO:tensorflow:Step 3350 | Loss: 0.0581 | Spent: 30.4 secs | LR: 0.000387 INFO:tensorflow:Step 3400 | Loss: 0.0877 | Spent: 31.9 secs | LR: 0.000387 INFO:tensorflow:Step 3450 | Loss: 0.0830 | Spent: 32.4 secs | LR: 0.000386 INFO:tensorflow:Step 3500 | Loss: 0.0820 | Spent: 30.9 secs | LR: 0.000386 INFO:tensorflow:Step 3550 | Loss: 0.0557 | Spent: 31.5 secs | LR: 0.000386 INFO:tensorflow:Step 3600 | Loss: 0.0908 | Spent: 32.1 secs | LR: 0.000386 INFO:tensorflow:Step 3650 | Loss: 0.0867 | Spent: 31.6 secs | LR: 0.000386 INFO:tensorflow:Step 3700 | Loss: 0.0269 | Spent: 32.9 secs | LR: 0.000385 INFO:tensorflow:Step 3750 | Loss: 0.0349 | Spent: 31.9 secs | LR: 0.000385 INFO:tensorflow:Step 3800 | Loss: 0.0666 | Spent: 31.4 secs | LR: 0.000385 INFO:tensorflow:Step 3850 | Loss: 0.0340 | Spent: 30.7 secs | LR: 0.000385 INFO:tensorflow:Step 3900 | Loss: 0.0743 | Spent: 31.5 secs | LR: 0.000385 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.669 INFO:tensorflow:Best Accuracy: 0.669 Reading ../data/train.tsv INFO:tensorflow:Step 3950 | Loss: 0.0279 | Spent: 193.0 secs | LR: 0.000384 INFO:tensorflow:Step 4000 | Loss: 0.0257 | Spent: 31.7 secs | LR: 0.000384 INFO:tensorflow:Step 4050 | Loss: 0.0623 | Spent: 29.8 secs | LR: 0.000384 INFO:tensorflow:Step 4100 | Loss: 0.0587 | Spent: 31.0 secs | LR: 0.000384 INFO:tensorflow:Step 4150 | Loss: 0.0527 | Spent: 31.1 secs | LR: 0.000384 INFO:tensorflow:Step 4200 | Loss: 0.0429 | Spent: 31.4 secs | LR: 0.000383 INFO:tensorflow:Step 4250 | Loss: 0.0460 | Spent: 32.6 secs | LR: 0.000383 INFO:tensorflow:Step 4300 | Loss: 0.0385 | Spent: 30.1 secs | LR: 0.000383 INFO:tensorflow:Step 4350 | Loss: 0.0462 | Spent: 31.6 secs | LR: 0.000383 INFO:tensorflow:Step 4400 | Loss: 0.0607 | Spent: 31.6 secs | LR: 0.000383 INFO:tensorflow:Step 4450 | Loss: 0.0660 | Spent: 32.9 secs | LR: 0.000383 INFO:tensorflow:Step 4500 | Loss: 0.0995 | Spent: 31.2 secs | LR: 0.000382 INFO:tensorflow:Step 4550 | Loss: 0.0563 | Spent: 30.7 secs | LR: 0.000382 INFO:tensorflow:Step 4600 | Loss: 0.0716 | Spent: 30.6 secs | LR: 0.000382 INFO:tensorflow:Step 4650 | Loss: 0.0909 | Spent: 30.8 secs | LR: 0.000382 INFO:tensorflow:Step 4700 | Loss: 0.0768 | Spent: 30.1 secs | LR: 0.000382 INFO:tensorflow:Step 4750 | Loss: 0.0677 | Spent: 29.6 secs | LR: 0.000381 INFO:tensorflow:Step 4800 | Loss: 0.1531 | Spent: 29.8 secs | LR: 0.000381 INFO:tensorflow:Step 4850 | Loss: 0.2851 | Spent: 31.2 secs | LR: 0.000381 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event __________________________|_____________________________________________________ | | | | | sl:location | | | | | | | | | | | in:get_location | | | | | ________________|_______________ | | | | sl:category_even sl:search_radius sl:location_user | | | | t | | | | | | ______________|__________ | | what times are playing the nutcracker show near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.694 INFO:tensorflow:Best Accuracy: 0.694 Reading ../data/train.tsv INFO:tensorflow:Step 4900 | Loss: 0.0242 | Spent: 184.8 secs | LR: 0.000381 INFO:tensorflow:Step 4950 | Loss: 0.0757 | Spent: 30.1 secs | LR: 0.000381 INFO:tensorflow:Step 5000 | Loss: 0.0745 | Spent: 32.5 secs | LR: 0.000380 INFO:tensorflow:Step 5050 | Loss: 0.0513 | Spent: 31.0 secs | LR: 0.000380 INFO:tensorflow:Step 5100 | Loss: 0.0487 | Spent: 30.4 secs | LR: 0.000380 INFO:tensorflow:Step 5150 | Loss: 0.0282 | Spent: 30.4 secs | LR: 0.000380 INFO:tensorflow:Step 5200 | Loss: 0.0369 | Spent: 31.3 secs | LR: 0.000380 INFO:tensorflow:Step 5250 | Loss: 0.2437 | Spent: 30.2 secs | LR: 0.000379 INFO:tensorflow:Step 5300 | Loss: 0.0572 | Spent: 31.2 secs | LR: 0.000379 INFO:tensorflow:Step 5350 | Loss: 0.0353 | Spent: 30.4 secs | LR: 0.000379 INFO:tensorflow:Step 5400 | Loss: 0.0466 | Spent: 30.6 secs | LR: 0.000379 INFO:tensorflow:Step 5450 | Loss: 0.0322 | Spent: 30.7 secs | LR: 0.000379 INFO:tensorflow:Step 5500 | Loss: 0.0708 | Spent: 29.4 secs | LR: 0.000378 INFO:tensorflow:Step 5550 | Loss: 0.0366 | Spent: 30.6 secs | LR: 0.000378 INFO:tensorflow:Step 5600 | Loss: 0.0157 | Spent: 29.2 secs | LR: 0.000378 INFO:tensorflow:Step 5650 | Loss: 0.1201 | Spent: 30.7 secs | LR: 0.000378 INFO:tensorflow:Step 5700 | Loss: 0.0399 | Spent: 30.3 secs | LR: 0.000378 INFO:tensorflow:Step 5750 | Loss: 0.0717 | Spent: 30.8 secs | LR: 0.000378 INFO:tensorflow:Step 5800 | Loss: 0.0491 | Spent: 29.5 secs | LR: 0.000377 INFO:tensorflow:Step 5850 | Loss: 0.0098 | Spent: 30.0 secs | LR: 0.000377 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.696 INFO:tensorflow:Best Accuracy: 0.696 Reading ../data/train.tsv INFO:tensorflow:Step 5900 | Loss: 0.0494 | Spent: 181.6 secs | LR: 0.000377 INFO:tensorflow:Step 5950 | Loss: 0.0392 | Spent: 31.5 secs | LR: 0.000377 INFO:tensorflow:Step 6000 | Loss: 0.0581 | Spent: 32.5 secs | LR: 0.000377 INFO:tensorflow:Step 6050 | Loss: 0.0291 | Spent: 31.1 secs | LR: 0.000376 INFO:tensorflow:Step 6100 | Loss: 0.0311 | Spent: 30.8 secs | LR: 0.000376 INFO:tensorflow:Step 6150 | Loss: 0.0325 | Spent: 31.3 secs | LR: 0.000376 INFO:tensorflow:Step 6200 | Loss: 0.0229 | Spent: 30.3 secs | LR: 0.000376 INFO:tensorflow:Step 6250 | Loss: 0.0252 | Spent: 29.7 secs | LR: 0.000376 INFO:tensorflow:Step 6300 | Loss: 0.0349 | Spent: 30.8 secs | LR: 0.000375 INFO:tensorflow:Step 6350 | Loss: 0.0262 | Spent: 31.1 secs | LR: 0.000375 INFO:tensorflow:Step 6400 | Loss: 0.0604 | Spent: 29.7 secs | LR: 0.000375 INFO:tensorflow:Step 6450 | Loss: 0.0282 | Spent: 30.4 secs | LR: 0.000375 INFO:tensorflow:Step 6500 | Loss: 0.1250 | Spent: 31.0 secs | LR: 0.000375 INFO:tensorflow:Step 6550 | Loss: 0.0328 | Spent: 30.3 secs | LR: 0.000375 INFO:tensorflow:Step 6600 | Loss: 0.0304 | Spent: 29.6 secs | LR: 0.000374 INFO:tensorflow:Step 6650 | Loss: 0.0247 | Spent: 30.3 secs | LR: 0.000374 INFO:tensorflow:Step 6700 | Loss: 0.0445 | Spent: 31.4 secs | LR: 0.000374 INFO:tensorflow:Step 6750 | Loss: 0.0437 | Spent: 31.1 secs | LR: 0.000374 INFO:tensorflow:Step 6800 | Loss: 0.0240 | Spent: 29.9 secs | LR: 0.000374 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.694 INFO:tensorflow:Best Accuracy: 0.696 Reading ../data/train.tsv INFO:tensorflow:Step 6850 | Loss: 0.0291 | Spent: 180.0 secs | LR: 0.000373 INFO:tensorflow:Step 6900 | Loss: 0.0290 | Spent: 31.3 secs | LR: 0.000373 INFO:tensorflow:Step 6950 | Loss: 0.0082 | Spent: 30.7 secs | LR: 0.000373 INFO:tensorflow:Step 7000 | Loss: 0.0399 | Spent: 29.8 secs | LR: 0.000373 INFO:tensorflow:Step 7050 | Loss: 0.0423 | Spent: 32.2 secs | LR: 0.000373 INFO:tensorflow:Step 7100 | Loss: 0.0260 | Spent: 32.1 secs | LR: 0.000372 INFO:tensorflow:Step 7150 | Loss: 0.0113 | Spent: 29.4 secs | LR: 0.000372 INFO:tensorflow:Step 7200 | Loss: 0.0420 | Spent: 30.9 secs | LR: 0.000372 INFO:tensorflow:Step 7250 | Loss: 0.0782 | Spent: 30.4 secs | LR: 0.000372 INFO:tensorflow:Step 7300 | Loss: 0.0452 | Spent: 29.8 secs | LR: 0.000372 INFO:tensorflow:Step 7350 | Loss: 0.0217 | Spent: 29.4 secs | LR: 0.000372 INFO:tensorflow:Step 7400 | Loss: 0.0375 | Spent: 29.3 secs | LR: 0.000371 INFO:tensorflow:Step 7450 | Loss: 0.0291 | Spent: 30.0 secs | LR: 0.000371 INFO:tensorflow:Step 7500 | Loss: 0.0309 | Spent: 30.2 secs | LR: 0.000371 INFO:tensorflow:Step 7550 | Loss: 0.0585 | Spent: 29.5 secs | LR: 0.000371 INFO:tensorflow:Step 7600 | Loss: 0.0419 | Spent: 30.7 secs | LR: 0.000371 INFO:tensorflow:Step 7650 | Loss: 0.0409 | Spent: 30.6 secs | LR: 0.000370 INFO:tensorflow:Step 7700 | Loss: 0.0502 | Spent: 31.1 secs | LR: 0.000370 INFO:tensorflow:Step 7750 | Loss: 0.0372 | Spent: 29.3 secs | LR: 0.000370 INFO:tensorflow:Step 7800 | Loss: 0.0281 | Spent: 31.2 secs | LR: 0.000370 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.707 INFO:tensorflow:Best Accuracy: 0.707 Reading ../data/train.tsv INFO:tensorflow:Step 7850 | Loss: 0.0284 | Spent: 185.7 secs | LR: 0.000370 INFO:tensorflow:Step 7900 | Loss: 0.2402 | Spent: 31.2 secs | LR: 0.000369 INFO:tensorflow:Step 7950 | Loss: 0.0243 | Spent: 30.9 secs | LR: 0.000369 INFO:tensorflow:Step 8000 | Loss: 0.0085 | Spent: 30.1 secs | LR: 0.000369 INFO:tensorflow:Step 8050 | Loss: 0.0190 | Spent: 30.2 secs | LR: 0.000369 INFO:tensorflow:Step 8100 | Loss: 0.0125 | Spent: 30.2 secs | LR: 0.000369 INFO:tensorflow:Step 8150 | Loss: 0.0215 | Spent: 30.6 secs | LR: 0.000369 INFO:tensorflow:Step 8200 | Loss: 0.0342 | Spent: 30.1 secs | LR: 0.000368 INFO:tensorflow:Step 8250 | Loss: 0.0319 | Spent: 29.8 secs | LR: 0.000368 INFO:tensorflow:Step 8300 | Loss: 0.0132 | Spent: 29.3 secs | LR: 0.000368 INFO:tensorflow:Step 8350 | Loss: 0.0261 | Spent: 31.4 secs | LR: 0.000368 INFO:tensorflow:Step 8400 | Loss: 0.0229 | Spent: 30.1 secs | LR: 0.000368 INFO:tensorflow:Step 8450 | Loss: 0.0257 | Spent: 30.9 secs | LR: 0.000367 INFO:tensorflow:Step 8500 | Loss: 0.0112 | Spent: 31.9 secs | LR: 0.000367 INFO:tensorflow:Step 8550 | Loss: 0.0272 | Spent: 30.6 secs | LR: 0.000367 INFO:tensorflow:Step 8600 | Loss: 0.0285 | Spent: 30.0 secs | LR: 0.000367 INFO:tensorflow:Step 8650 | Loss: 0.0459 | Spent: 29.7 secs | LR: 0.000367 INFO:tensorflow:Step 8700 | Loss: 0.0195 | Spent: 30.6 secs | LR: 0.000367 INFO:tensorflow:Step 8750 | Loss: 0.0387 | Spent: 31.2 secs | LR: 0.000366 INFO:tensorflow:Step 8800 | Loss: 0.0096 | Spent: 30.6 secs | LR: 0.000366 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event __________________________|_____________________________________________________ | | | | | sl:location | | | | | | | | | | | in:get_location | | | | | ________________|_______________ | | | | sl:category_even sl:search_radius sl:location_user | | | | t | | | | | | ______________|__________ | | what times are playing the nutcracker show near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.715 INFO:tensorflow:Best Accuracy: 0.715 Reading ../data/train.tsv INFO:tensorflow:Step 8850 | Loss: 0.0392 | Spent: 180.2 secs | LR: 0.000366 INFO:tensorflow:Step 8900 | Loss: 0.0419 | Spent: 30.9 secs | LR: 0.000366 INFO:tensorflow:Step 8950 | Loss: 0.0287 | Spent: 30.0 secs | LR: 0.000366 INFO:tensorflow:Step 9000 | Loss: 0.0359 | Spent: 30.4 secs | LR: 0.000365 INFO:tensorflow:Step 9050 | Loss: 0.0210 | Spent: 30.6 secs | LR: 0.000365 INFO:tensorflow:Step 9100 | Loss: 0.0197 | Spent: 32.8 secs | LR: 0.000365 INFO:tensorflow:Step 9150 | Loss: 0.0146 | Spent: 30.5 secs | LR: 0.000365 INFO:tensorflow:Step 9200 | Loss: 0.0270 | Spent: 30.5 secs | LR: 0.000365 INFO:tensorflow:Step 9250 | Loss: 0.0840 | Spent: 30.0 secs | LR: 0.000364 INFO:tensorflow:Step 9300 | Loss: 0.0144 | Spent: 29.8 secs | LR: 0.000364 INFO:tensorflow:Step 9350 | Loss: 0.0508 | Spent: 28.8 secs | LR: 0.000364 INFO:tensorflow:Step 9400 | Loss: 0.0365 | Spent: 31.1 secs | LR: 0.000364 INFO:tensorflow:Step 9450 | Loss: 0.0178 | Spent: 31.1 secs | LR: 0.000364 INFO:tensorflow:Step 9500 | Loss: 0.0140 | Spent: 31.3 secs | LR: 0.000364 INFO:tensorflow:Step 9550 | Loss: 0.0136 | Spent: 29.7 secs | LR: 0.000363 INFO:tensorflow:Step 9600 | Loss: 0.0190 | Spent: 31.6 secs | LR: 0.000363 INFO:tensorflow:Step 9650 | Loss: 0.0471 | Spent: 31.1 secs | LR: 0.000363 INFO:tensorflow:Step 9700 | Loss: 0.0232 | Spent: 30.3 secs | LR: 0.000363 INFO:tensorflow:Step 9750 | Loss: 0.0291 | Spent: 30.8 secs | LR: 0.000363 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.722 INFO:tensorflow:Best Accuracy: 0.722 Reading ../data/train.tsv INFO:tensorflow:Step 9800 | Loss: 0.0443 | Spent: 179.5 secs | LR: 0.000362 INFO:tensorflow:Step 9850 | Loss: 0.2110 | Spent: 31.5 secs | LR: 0.000362 INFO:tensorflow:Step 9900 | Loss: 0.0121 | Spent: 31.3 secs | LR: 0.000362 INFO:tensorflow:Step 9950 | Loss: 0.0367 | Spent: 30.9 secs | LR: 0.000362 INFO:tensorflow:Step 10000 | Loss: 0.0099 | Spent: 30.6 secs | LR: 0.000362 INFO:tensorflow:Step 10050 | Loss: 0.0136 | Spent: 29.3 secs | LR: 0.000362 INFO:tensorflow:Step 10100 | Loss: 0.0188 | Spent: 29.4 secs | LR: 0.000361 INFO:tensorflow:Step 10150 | Loss: 0.0197 | Spent: 31.0 secs | LR: 0.000361 INFO:tensorflow:Step 10200 | Loss: 0.0188 | Spent: 30.3 secs | LR: 0.000361 INFO:tensorflow:Step 10250 | Loss: 0.0223 | Spent: 30.6 secs | LR: 0.000361 INFO:tensorflow:Step 10300 | Loss: 0.0133 | Spent: 30.4 secs | LR: 0.000361 INFO:tensorflow:Step 10350 | Loss: 0.2084 | Spent: 30.3 secs | LR: 0.000360 INFO:tensorflow:Step 10400 | Loss: 0.0243 | Spent: 32.2 secs | LR: 0.000360 INFO:tensorflow:Step 10450 | Loss: 0.0163 | Spent: 30.4 secs | LR: 0.000360 INFO:tensorflow:Step 10500 | Loss: 0.0265 | Spent: 31.0 secs | LR: 0.000360 INFO:tensorflow:Step 10550 | Loss: 0.0506 | Spent: 30.1 secs | LR: 0.000360 INFO:tensorflow:Step 10600 | Loss: 0.0205 | Spent: 29.9 secs | LR: 0.000360 INFO:tensorflow:Step 10650 | Loss: 0.0157 | Spent: 31.4 secs | LR: 0.000359 INFO:tensorflow:Step 10700 | Loss: 0.0201 | Spent: 29.3 secs | LR: 0.000359 INFO:tensorflow:Step 10750 | Loss: 0.0189 | Spent: 30.2 secs | LR: 0.000359 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event __________________________|_____________________________________________________ | | | | | sl:location | | | | | | | | | | | in:get_location | | | | | ________________|_______________ | | | | sl:category_even sl:search_radius sl:location_user | | | | t | | | | | | ______________|__________ | | what times are playing the nutcracker show near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.719 INFO:tensorflow:Best Accuracy: 0.722 Reading ../data/train.tsv INFO:tensorflow:Step 10800 | Loss: 0.0141 | Spent: 180.5 secs | LR: 0.000359 INFO:tensorflow:Step 10850 | Loss: 0.0065 | Spent: 30.8 secs | LR: 0.000359 INFO:tensorflow:Step 10900 | Loss: 0.0356 | Spent: 30.9 secs | LR: 0.000358 INFO:tensorflow:Step 10950 | Loss: 0.0200 | Spent: 28.9 secs | LR: 0.000358 INFO:tensorflow:Step 11000 | Loss: 0.0101 | Spent: 30.9 secs | LR: 0.000358 INFO:tensorflow:Step 11050 | Loss: 0.0244 | Spent: 30.2 secs | LR: 0.000358 INFO:tensorflow:Step 11100 | Loss: 0.0254 | Spent: 30.9 secs | LR: 0.000358 INFO:tensorflow:Step 11150 | Loss: 0.0087 | Spent: 31.5 secs | LR: 0.000358 INFO:tensorflow:Step 11200 | Loss: 0.0269 | Spent: 31.4 secs | LR: 0.000357 INFO:tensorflow:Step 11250 | Loss: 0.0260 | Spent: 31.6 secs | LR: 0.000357 INFO:tensorflow:Step 11300 | Loss: 0.0085 | Spent: 30.8 secs | LR: 0.000357 INFO:tensorflow:Step 11350 | Loss: 0.0240 | Spent: 31.4 secs | LR: 0.000357 INFO:tensorflow:Step 11400 | Loss: 0.0235 | Spent: 32.0 secs | LR: 0.000357 INFO:tensorflow:Step 11450 | Loss: 0.0097 | Spent: 30.5 secs | LR: 0.000357 INFO:tensorflow:Step 11500 | Loss: 0.0152 | Spent: 30.3 secs | LR: 0.000356 INFO:tensorflow:Step 11550 | Loss: 0.0363 | Spent: 31.2 secs | LR: 0.000356 INFO:tensorflow:Step 11600 | Loss: 0.0155 | Spent: 30.4 secs | LR: 0.000356 INFO:tensorflow:Step 11650 | Loss: 0.0215 | Spent: 30.7 secs | LR: 0.000356 INFO:tensorflow:Step 11700 | Loss: 0.0132 | Spent: 30.8 secs | LR: 0.000356 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.719 INFO:tensorflow:Best Accuracy: 0.722 Reading ../data/train.tsv INFO:tensorflow:Step 11750 | Loss: 0.0434 | Spent: 180.4 secs | LR: 0.000355 INFO:tensorflow:Step 11800 | Loss: 0.0107 | Spent: 30.1 secs | LR: 0.000355 INFO:tensorflow:Step 11850 | Loss: 0.0130 | Spent: 31.0 secs | LR: 0.000355 INFO:tensorflow:Step 11900 | Loss: 0.0168 | Spent: 30.7 secs | LR: 0.000355 INFO:tensorflow:Step 11950 | Loss: 0.0335 | Spent: 32.7 secs | LR: 0.000355 INFO:tensorflow:Step 12000 | Loss: 0.0314 | Spent: 31.3 secs | LR: 0.000355 INFO:tensorflow:Step 12050 | Loss: 0.0053 | Spent: 31.1 secs | LR: 0.000354 INFO:tensorflow:Step 12100 | Loss: 0.0244 | Spent: 30.4 secs | LR: 0.000354 INFO:tensorflow:Step 12150 | Loss: 0.0124 | Spent: 31.4 secs | LR: 0.000354 INFO:tensorflow:Step 12200 | Loss: 0.0064 | Spent: 31.1 secs | LR: 0.000354 INFO:tensorflow:Step 12250 | Loss: 0.0221 | Spent: 31.0 secs | LR: 0.000354 INFO:tensorflow:Step 12300 | Loss: 0.0239 | Spent: 30.8 secs | LR: 0.000353 INFO:tensorflow:Step 12350 | Loss: 0.0133 | Spent: 30.8 secs | LR: 0.000353 INFO:tensorflow:Step 12400 | Loss: 0.0172 | Spent: 29.7 secs | LR: 0.000353 INFO:tensorflow:Step 12450 | Loss: 0.0098 | Spent: 33.2 secs | LR: 0.000353 INFO:tensorflow:Step 12500 | Loss: 0.0102 | Spent: 31.6 secs | LR: 0.000353 INFO:tensorflow:Step 12550 | Loss: 0.0205 | Spent: 31.2 secs | LR: 0.000353 INFO:tensorflow:Step 12600 | Loss: 0.0142 | Spent: 30.5 secs | LR: 0.000352 INFO:tensorflow:Step 12650 | Loss: 0.0143 | Spent: 30.6 secs | LR: 0.000352 INFO:tensorflow:Step 12700 | Loss: 0.0146 | Spent: 30.9 secs | LR: 0.000352 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.715 INFO:tensorflow:Best Accuracy: 0.722 Reading ../data/train.tsv INFO:tensorflow:Step 12750 | Loss: 0.0041 | Spent: 183.5 secs | LR: 0.000352 INFO:tensorflow:Step 12800 | Loss: 0.0149 | Spent: 31.3 secs | LR: 0.000352 INFO:tensorflow:Step 12850 | Loss: 0.0112 | Spent: 31.0 secs | LR: 0.000352 INFO:tensorflow:Step 12900 | Loss: 0.0069 | Spent: 30.2 secs | LR: 0.000351 INFO:tensorflow:Step 12950 | Loss: 0.0141 | Spent: 31.2 secs | LR: 0.000351 INFO:tensorflow:Step 13000 | Loss: 0.0068 | Spent: 30.9 secs | LR: 0.000351 INFO:tensorflow:Step 13050 | Loss: 0.0153 | Spent: 31.6 secs | LR: 0.000351 INFO:tensorflow:Step 13100 | Loss: 0.0293 | Spent: 31.3 secs | LR: 0.000351 INFO:tensorflow:Step 13150 | Loss: 0.0079 | Spent: 31.2 secs | LR: 0.000350 INFO:tensorflow:Step 13200 | Loss: 0.0236 | Spent: 32.3 secs | LR: 0.000350 INFO:tensorflow:Step 13250 | Loss: 0.0205 | Spent: 31.0 secs | LR: 0.000350 INFO:tensorflow:Step 13300 | Loss: 0.0134 | Spent: 30.2 secs | LR: 0.000350 INFO:tensorflow:Step 13350 | Loss: 0.0164 | Spent: 29.9 secs | LR: 0.000350 INFO:tensorflow:Step 13400 | Loss: 0.0160 | Spent: 32.4 secs | LR: 0.000350 INFO:tensorflow:Step 13450 | Loss: 0.0268 | Spent: 30.8 secs | LR: 0.000349 INFO:tensorflow:Step 13500 | Loss: 0.0066 | Spent: 31.6 secs | LR: 0.000349 INFO:tensorflow:Step 13550 | Loss: 0.0459 | Spent: 30.4 secs | LR: 0.000349 INFO:tensorflow:Step 13600 | Loss: 0.0256 | Spent: 31.5 secs | LR: 0.000349 INFO:tensorflow:Step 13650 | Loss: 0.0090 | Spent: 32.2 secs | LR: 0.000349 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event __________________________|_____________________________________________________ | | | | | sl:location | | | | | | | | | | | in:get_location | | | | | ________________|_______________ | | | | sl:category_even sl:search_radius sl:location_user | | | | t | | | | | | ______________|__________ | | what times are playing the nutcracker show near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.724 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 13700 | Loss: 0.0142 | Spent: 183.8 secs | LR: 0.000349 INFO:tensorflow:Step 13750 | Loss: 0.0087 | Spent: 31.9 secs | LR: 0.000348 INFO:tensorflow:Step 13800 | Loss: 0.0045 | Spent: 29.7 secs | LR: 0.000348 INFO:tensorflow:Step 13850 | Loss: 0.0086 | Spent: 30.5 secs | LR: 0.000348 INFO:tensorflow:Step 13900 | Loss: 0.0121 | Spent: 29.8 secs | LR: 0.000348 INFO:tensorflow:Step 13950 | Loss: 0.0083 | Spent: 30.9 secs | LR: 0.000348 INFO:tensorflow:Step 14000 | Loss: 0.0027 | Spent: 31.2 secs | LR: 0.000347 INFO:tensorflow:Step 14050 | Loss: 0.0100 | Spent: 31.7 secs | LR: 0.000347 INFO:tensorflow:Step 14100 | Loss: 0.0038 | Spent: 32.3 secs | LR: 0.000347 INFO:tensorflow:Step 14150 | Loss: 0.0281 | Spent: 29.9 secs | LR: 0.000347 INFO:tensorflow:Step 14200 | Loss: 0.0158 | Spent: 30.1 secs | LR: 0.000347 INFO:tensorflow:Step 14250 | Loss: 0.0080 | Spent: 29.9 secs | LR: 0.000347 INFO:tensorflow:Step 14300 | Loss: 0.0239 | Spent: 30.7 secs | LR: 0.000346 INFO:tensorflow:Step 14350 | Loss: 0.0146 | Spent: 30.7 secs | LR: 0.000346 INFO:tensorflow:Step 14400 | Loss: 0.0193 | Spent: 30.8 secs | LR: 0.000346 INFO:tensorflow:Step 14450 | Loss: 0.0111 | Spent: 31.6 secs | LR: 0.000346 INFO:tensorflow:Step 14500 | Loss: 0.0244 | Spent: 30.3 secs | LR: 0.000346 INFO:tensorflow:Step 14550 | Loss: 0.0327 | Spent: 31.1 secs | LR: 0.000346 INFO:tensorflow:Step 14600 | Loss: 0.0278 | Spent: 31.8 secs | LR: 0.000345 INFO:tensorflow:Step 14650 | Loss: 0.0062 | Spent: 31.0 secs | LR: 0.000345 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.713 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 14700 | Loss: 0.0083 | Spent: 182.0 secs | LR: 0.000345 INFO:tensorflow:Step 14750 | Loss: 0.0227 | Spent: 30.9 secs | LR: 0.000345 INFO:tensorflow:Step 14800 | Loss: 0.0045 | Spent: 31.4 secs | LR: 0.000345 INFO:tensorflow:Step 14850 | Loss: 0.0096 | Spent: 30.8 secs | LR: 0.000345 INFO:tensorflow:Step 14900 | Loss: 0.0030 | Spent: 30.7 secs | LR: 0.000344 INFO:tensorflow:Step 14950 | Loss: 0.0253 | Spent: 30.2 secs | LR: 0.000344 INFO:tensorflow:Step 15000 | Loss: 0.0164 | Spent: 30.4 secs | LR: 0.000344 INFO:tensorflow:Step 15050 | Loss: 0.0220 | Spent: 31.6 secs | LR: 0.000344 INFO:tensorflow:Step 15100 | Loss: 0.0074 | Spent: 31.1 secs | LR: 0.000344 INFO:tensorflow:Step 15150 | Loss: 0.0032 | Spent: 31.9 secs | LR: 0.000344 INFO:tensorflow:Step 15200 | Loss: 0.0133 | Spent: 31.4 secs | LR: 0.000343 INFO:tensorflow:Step 15250 | Loss: 0.0167 | Spent: 31.1 secs | LR: 0.000343 INFO:tensorflow:Step 15300 | Loss: 0.0218 | Spent: 30.4 secs | LR: 0.000343 INFO:tensorflow:Step 15350 | Loss: 0.0165 | Spent: 30.8 secs | LR: 0.000343 INFO:tensorflow:Step 15400 | Loss: 0.0042 | Spent: 31.0 secs | LR: 0.000343 INFO:tensorflow:Step 15450 | Loss: 0.0182 | Spent: 30.7 secs | LR: 0.000342 INFO:tensorflow:Step 15500 | Loss: 0.0028 | Spent: 29.3 secs | LR: 0.000342 INFO:tensorflow:Step 15550 | Loss: 0.0086 | Spent: 31.6 secs | LR: 0.000342 INFO:tensorflow:Step 15600 | Loss: 0.0059 | Spent: 30.8 secs | LR: 0.000342 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.709 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 15650 | Loss: 0.0188 | Spent: 178.6 secs | LR: 0.000342 INFO:tensorflow:Step 15700 | Loss: 0.0221 | Spent: 30.9 secs | LR: 0.000342 INFO:tensorflow:Step 15750 | Loss: 0.0097 | Spent: 31.3 secs | LR: 0.000341 INFO:tensorflow:Step 15800 | Loss: 0.0039 | Spent: 30.5 secs | LR: 0.000341 INFO:tensorflow:Step 15850 | Loss: 0.0028 | Spent: 30.3 secs | LR: 0.000341 INFO:tensorflow:Step 15900 | Loss: 0.0107 | Spent: 30.2 secs | LR: 0.000341 INFO:tensorflow:Step 15950 | Loss: 0.0087 | Spent: 31.9 secs | LR: 0.000341 INFO:tensorflow:Step 16000 | Loss: 0.0018 | Spent: 31.2 secs | LR: 0.000341 INFO:tensorflow:Step 16050 | Loss: 0.0071 | Spent: 30.1 secs | LR: 0.000340 INFO:tensorflow:Step 16100 | Loss: 0.0075 | Spent: 29.9 secs | LR: 0.000340 INFO:tensorflow:Step 16150 | Loss: 0.0049 | Spent: 30.4 secs | LR: 0.000340 INFO:tensorflow:Step 16200 | Loss: 0.0121 | Spent: 31.3 secs | LR: 0.000340 INFO:tensorflow:Step 16250 | Loss: 0.0035 | Spent: 30.7 secs | LR: 0.000340 INFO:tensorflow:Step 16300 | Loss: 0.0050 | Spent: 29.8 secs | LR: 0.000340 INFO:tensorflow:Step 16350 | Loss: 0.0230 | Spent: 31.5 secs | LR: 0.000339 INFO:tensorflow:Step 16400 | Loss: 0.0044 | Spent: 30.6 secs | LR: 0.000339 INFO:tensorflow:Step 16450 | Loss: 0.0176 | Spent: 30.6 secs | LR: 0.000339 INFO:tensorflow:Step 16500 | Loss: 0.0058 | Spent: 30.9 secs | LR: 0.000339 INFO:tensorflow:Step 16550 | Loss: 0.0127 | Spent: 31.6 secs | LR: 0.000339 INFO:tensorflow:Step 16600 | Loss: 0.0123 | Spent: 31.5 secs | LR: 0.000339 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.723 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 16650 | Loss: 0.0065 | Spent: 176.7 secs | LR: 0.000338 INFO:tensorflow:Step 16700 | Loss: 0.0133 | Spent: 32.2 secs | LR: 0.000338 INFO:tensorflow:Step 16750 | Loss: 0.0109 | Spent: 31.4 secs | LR: 0.000338 INFO:tensorflow:Step 16800 | Loss: 0.0043 | Spent: 30.8 secs | LR: 0.000338 INFO:tensorflow:Step 16850 | Loss: 0.0042 | Spent: 30.1 secs | LR: 0.000338 INFO:tensorflow:Step 16900 | Loss: 0.0082 | Spent: 31.3 secs | LR: 0.000338 INFO:tensorflow:Step 16950 | Loss: 0.0130 | Spent: 30.3 secs | LR: 0.000337 INFO:tensorflow:Step 17000 | Loss: 0.0147 | Spent: 30.5 secs | LR: 0.000337 INFO:tensorflow:Step 17050 | Loss: 0.0127 | Spent: 31.4 secs | LR: 0.000337 INFO:tensorflow:Step 17100 | Loss: 0.0150 | Spent: 31.2 secs | LR: 0.000337 INFO:tensorflow:Step 17150 | Loss: 0.0726 | Spent: 30.8 secs | LR: 0.000337 INFO:tensorflow:Step 17200 | Loss: 0.0172 | Spent: 29.7 secs | LR: 0.000337 INFO:tensorflow:Step 17250 | Loss: 0.0100 | Spent: 31.3 secs | LR: 0.000336 INFO:tensorflow:Step 17300 | Loss: 0.0103 | Spent: 31.1 secs | LR: 0.000336 INFO:tensorflow:Step 17350 | Loss: 0.0116 | Spent: 31.4 secs | LR: 0.000336 INFO:tensorflow:Step 17400 | Loss: 0.0183 | Spent: 30.5 secs | LR: 0.000336 INFO:tensorflow:Step 17450 | Loss: 0.0169 | Spent: 31.2 secs | LR: 0.000336 INFO:tensorflow:Step 17500 | Loss: 0.0164 | Spent: 29.3 secs | LR: 0.000335 INFO:tensorflow:Step 17550 | Loss: 0.0094 | Spent: 31.8 secs | LR: 0.000335 INFO:tensorflow:Step 17600 | Loss: 0.0093 | Spent: 30.8 secs | LR: 0.000335 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.723 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 17650 | Loss: 0.0252 | Spent: 179.8 secs | LR: 0.000335 INFO:tensorflow:Step 17700 | Loss: 0.0091 | Spent: 30.5 secs | LR: 0.000335 INFO:tensorflow:Step 17750 | Loss: 0.0106 | Spent: 31.5 secs | LR: 0.000335 INFO:tensorflow:Step 17800 | Loss: 0.0212 | Spent: 31.4 secs | LR: 0.000334 INFO:tensorflow:Step 17850 | Loss: 0.0012 | Spent: 30.0 secs | LR: 0.000334 INFO:tensorflow:Step 17900 | Loss: 0.0031 | Spent: 30.9 secs | LR: 0.000334 INFO:tensorflow:Step 17950 | Loss: 0.0098 | Spent: 30.9 secs | LR: 0.000334 INFO:tensorflow:Step 18000 | Loss: 0.0094 | Spent: 30.5 secs | LR: 0.000334 INFO:tensorflow:Step 18050 | Loss: 0.0186 | Spent: 30.7 secs | LR: 0.000334 INFO:tensorflow:Step 18100 | Loss: 0.0020 | Spent: 30.2 secs | LR: 0.000333 INFO:tensorflow:Step 18150 | Loss: 0.0070 | Spent: 32.1 secs | LR: 0.000333 INFO:tensorflow:Step 18200 | Loss: 0.0046 | Spent: 29.9 secs | LR: 0.000333 INFO:tensorflow:Step 18250 | Loss: 0.0037 | Spent: 29.8 secs | LR: 0.000333 INFO:tensorflow:Step 18300 | Loss: 0.0100 | Spent: 31.4 secs | LR: 0.000333 INFO:tensorflow:Step 18350 | Loss: 0.0274 | Spent: 29.8 secs | LR: 0.000333 INFO:tensorflow:Step 18400 | Loss: 0.0027 | Spent: 30.7 secs | LR: 0.000332 INFO:tensorflow:Step 18450 | Loss: 0.0068 | Spent: 30.7 secs | LR: 0.000332 INFO:tensorflow:Step 18500 | Loss: 0.0025 | Spent: 31.5 secs | LR: 0.000332 INFO:tensorflow:Step 18550 | Loss: 0.0111 | Spent: 30.4 secs | LR: 0.000332 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event __________________________|_____________________________________________________ | | | | | sl:location | | | | | | | | | | | in:get_location | | | | | ________________|_______________ | | | | sl:category_even sl:search_radius sl:location_user | | | | t | | | | | | ______________|__________ | | what times are playing the nutcracker show near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.721 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 18600 | Loss: 0.0054 | Spent: 177.4 secs | LR: 0.000332 INFO:tensorflow:Step 18650 | Loss: 0.0150 | Spent: 31.5 secs | LR: 0.000332 INFO:tensorflow:Step 18700 | Loss: 0.0035 | Spent: 30.4 secs | LR: 0.000331 INFO:tensorflow:Step 18750 | Loss: 0.0106 | Spent: 30.6 secs | LR: 0.000331 INFO:tensorflow:Step 18800 | Loss: 0.0042 | Spent: 30.9 secs | LR: 0.000331 INFO:tensorflow:Step 18850 | Loss: 0.0183 | Spent: 31.1 secs | LR: 0.000331 INFO:tensorflow:Step 18900 | Loss: 0.0109 | Spent: 29.3 secs | LR: 0.000331 INFO:tensorflow:Step 18950 | Loss: 0.0037 | Spent: 30.1 secs | LR: 0.000331 INFO:tensorflow:Step 19000 | Loss: 0.0145 | Spent: 29.6 secs | LR: 0.000330 INFO:tensorflow:Step 19050 | Loss: 0.0035 | Spent: 31.0 secs | LR: 0.000330 INFO:tensorflow:Step 19100 | Loss: 0.0082 | Spent: 31.4 secs | LR: 0.000330 INFO:tensorflow:Step 19150 | Loss: 0.0039 | Spent: 31.9 secs | LR: 0.000330 INFO:tensorflow:Step 19200 | Loss: 0.0009 | Spent: 30.2 secs | LR: 0.000330 INFO:tensorflow:Step 19250 | Loss: 0.0057 | Spent: 29.7 secs | LR: 0.000330 INFO:tensorflow:Step 19300 | Loss: 0.0010 | Spent: 32.2 secs | LR: 0.000329 INFO:tensorflow:Step 19350 | Loss: 0.0009 | Spent: 30.2 secs | LR: 0.000329 INFO:tensorflow:Step 19400 | Loss: 0.0038 | Spent: 31.5 secs | LR: 0.000329 INFO:tensorflow:Step 19450 | Loss: 0.0027 | Spent: 30.9 secs | LR: 0.000329 INFO:tensorflow:Step 19500 | Loss: 0.0069 | Spent: 31.0 secs | LR: 0.000329 INFO:tensorflow:Step 19550 | Loss: 0.0815 | Spent: 30.1 secs | LR: 0.000329 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.717 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 19600 | Loss: 0.0060 | Spent: 177.3 secs | LR: 0.000328 INFO:tensorflow:Step 19650 | Loss: 0.0195 | Spent: 32.1 secs | LR: 0.000328 INFO:tensorflow:Step 19700 | Loss: 0.0222 | Spent: 29.6 secs | LR: 0.000328 INFO:tensorflow:Step 19750 | Loss: 0.0011 | Spent: 30.9 secs | LR: 0.000328 INFO:tensorflow:Step 19800 | Loss: 0.0099 | Spent: 31.2 secs | LR: 0.000328 INFO:tensorflow:Step 19850 | Loss: 0.0065 | Spent: 30.4 secs | LR: 0.000328 INFO:tensorflow:Step 19900 | Loss: 0.0067 | Spent: 31.1 secs | LR: 0.000327 INFO:tensorflow:Step 19950 | Loss: 0.0016 | Spent: 29.8 secs | LR: 0.000327 INFO:tensorflow:Step 20000 | Loss: 0.0087 | Spent: 31.9 secs | LR: 0.000327 INFO:tensorflow:Step 20050 | Loss: 0.0118 | Spent: 31.3 secs | LR: 0.000327 INFO:tensorflow:Step 20100 | Loss: 0.0134 | Spent: 31.7 secs | LR: 0.000327 INFO:tensorflow:Step 20150 | Loss: 0.0194 | Spent: 31.3 secs | LR: 0.000327 INFO:tensorflow:Step 20200 | Loss: 0.0150 | Spent: 30.9 secs | LR: 0.000327 INFO:tensorflow:Step 20250 | Loss: 0.0112 | Spent: 31.0 secs | LR: 0.000326 INFO:tensorflow:Step 20300 | Loss: 0.0114 | Spent: 29.7 secs | LR: 0.000326 INFO:tensorflow:Step 20350 | Loss: 0.0073 | Spent: 30.6 secs | LR: 0.000326 INFO:tensorflow:Step 20400 | Loss: 0.0132 | Spent: 30.2 secs | LR: 0.000326 INFO:tensorflow:Step 20450 | Loss: 0.0009 | Spent: 29.5 secs | LR: 0.000326 INFO:tensorflow:Step 20500 | Loss: 0.0025 | Spent: 30.7 secs | LR: 0.000326 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.717 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 20550 | Loss: 0.0024 | Spent: 180.2 secs | LR: 0.000325 INFO:tensorflow:Step 20600 | Loss: 0.0014 | Spent: 30.7 secs | LR: 0.000325 INFO:tensorflow:Step 20650 | Loss: 0.0160 | Spent: 31.7 secs | LR: 0.000325 INFO:tensorflow:Step 20700 | Loss: 0.0024 | Spent: 31.2 secs | LR: 0.000325 INFO:tensorflow:Step 20750 | Loss: 0.0045 | Spent: 30.4 secs | LR: 0.000325 INFO:tensorflow:Step 20800 | Loss: 0.0030 | Spent: 31.0 secs | LR: 0.000325 INFO:tensorflow:Step 20850 | Loss: 0.0114 | Spent: 31.7 secs | LR: 0.000324 INFO:tensorflow:Step 20900 | Loss: 0.0020 | Spent: 30.7 secs | LR: 0.000324 INFO:tensorflow:Step 20950 | Loss: 0.0013 | Spent: 30.8 secs | LR: 0.000324 INFO:tensorflow:Step 21000 | Loss: 0.0038 | Spent: 29.8 secs | LR: 0.000324 INFO:tensorflow:Step 21050 | Loss: 0.0235 | Spent: 30.6 secs | LR: 0.000324 INFO:tensorflow:Step 21100 | Loss: 0.0065 | Spent: 31.0 secs | LR: 0.000324 INFO:tensorflow:Step 21150 | Loss: 0.0049 | Spent: 31.1 secs | LR: 0.000323 INFO:tensorflow:Step 21200 | Loss: 0.0185 | Spent: 31.0 secs | LR: 0.000323 INFO:tensorflow:Step 21250 | Loss: 0.0041 | Spent: 31.5 secs | LR: 0.000323 INFO:tensorflow:Step 21300 | Loss: 0.0112 | Spent: 31.7 secs | LR: 0.000323 INFO:tensorflow:Step 21350 | Loss: 0.0110 | Spent: 31.3 secs | LR: 0.000323 INFO:tensorflow:Step 21400 | Loss: 0.0123 | Spent: 31.1 secs | LR: 0.000323 INFO:tensorflow:Step 21450 | Loss: 0.0036 | Spent: 31.1 secs | LR: 0.000322 INFO:tensorflow:Step 21500 | Loss: 0.0055 | Spent: 31.0 secs | LR: 0.000322 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event ________________________|______________________________________________ | | | | sl:location | | | | | | | | | in:get_location | | | | ________________|_______________ | | | sl:category_even sl:search_radius sl:location_user | | | t | | | | | _________|_________________ | | what times are the nutcracker show playing near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.724 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 21550 | Loss: 0.0101 | Spent: 181.9 secs | LR: 0.000322 INFO:tensorflow:Step 21600 | Loss: 0.0104 | Spent: 31.7 secs | LR: 0.000322 INFO:tensorflow:Step 21650 | Loss: 0.0003 | Spent: 30.2 secs | LR: 0.000322 INFO:tensorflow:Step 21700 | Loss: 0.0045 | Spent: 31.4 secs | LR: 0.000322 INFO:tensorflow:Step 21750 | Loss: 0.0156 | Spent: 30.7 secs | LR: 0.000321 INFO:tensorflow:Step 21800 | Loss: 0.0074 | Spent: 30.6 secs | LR: 0.000321 INFO:tensorflow:Step 21850 | Loss: 0.0155 | Spent: 30.8 secs | LR: 0.000321 INFO:tensorflow:Step 21900 | Loss: 0.0025 | Spent: 31.5 secs | LR: 0.000321 INFO:tensorflow:Step 21950 | Loss: 0.0035 | Spent: 30.9 secs | LR: 0.000321 INFO:tensorflow:Step 22000 | Loss: 0.0066 | Spent: 32.5 secs | LR: 0.000321 INFO:tensorflow:Step 22050 | Loss: 0.0020 | Spent: 32.0 secs | LR: 0.000320 INFO:tensorflow:Step 22100 | Loss: 0.0069 | Spent: 30.7 secs | LR: 0.000320 INFO:tensorflow:Step 22150 | Loss: 0.0026 | Spent: 30.6 secs | LR: 0.000320 INFO:tensorflow:Step 22200 | Loss: 0.0019 | Spent: 30.1 secs | LR: 0.000320 INFO:tensorflow:Step 22250 | Loss: 0.0093 | Spent: 31.4 secs | LR: 0.000320 INFO:tensorflow:Step 22300 | Loss: 0.0222 | Spent: 31.3 secs | LR: 0.000320 INFO:tensorflow:Step 22350 | Loss: 0.0047 | Spent: 30.8 secs | LR: 0.000320 INFO:tensorflow:Step 22400 | Loss: 0.0026 | Spent: 30.5 secs | LR: 0.000319 INFO:tensorflow:Step 22450 | Loss: 0.0104 | Spent: 32.7 secs | LR: 0.000319 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event __________________________|_____________________________________________________ | | | | | sl:location | | | | | | | | | | | in:get_location | | | | | ________________|_______________ | | | | sl:category_even sl:search_radius sl:location_user | | | | t | | | | | | ______________|__________ | | what times are playing the nutcracker show near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.724 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 22500 | Loss: 0.0114 | Spent: 181.2 secs | LR: 0.000319 INFO:tensorflow:Step 22550 | Loss: 0.0113 | Spent: 31.4 secs | LR: 0.000319 INFO:tensorflow:Step 22600 | Loss: 0.0219 | Spent: 32.7 secs | LR: 0.000319 INFO:tensorflow:Step 22650 | Loss: 0.0267 | Spent: 31.4 secs | LR: 0.000319 INFO:tensorflow:Step 22700 | Loss: 0.0020 | Spent: 31.2 secs | LR: 0.000318 INFO:tensorflow:Step 22750 | Loss: 0.0018 | Spent: 31.6 secs | LR: 0.000318 INFO:tensorflow:Step 22800 | Loss: 0.0011 | Spent: 29.9 secs | LR: 0.000318 INFO:tensorflow:Step 22850 | Loss: 0.0020 | Spent: 30.9 secs | LR: 0.000318 INFO:tensorflow:Step 22900 | Loss: 0.0208 | Spent: 31.2 secs | LR: 0.000318 INFO:tensorflow:Step 22950 | Loss: 0.0025 | Spent: 30.6 secs | LR: 0.000318 INFO:tensorflow:Step 23000 | Loss: 0.0074 | Spent: 31.5 secs | LR: 0.000317 INFO:tensorflow:Step 23050 | Loss: 0.0108 | Spent: 32.6 secs | LR: 0.000317 INFO:tensorflow:Step 23100 | Loss: 0.0125 | Spent: 31.1 secs | LR: 0.000317 INFO:tensorflow:Step 23150 | Loss: 0.0352 | Spent: 32.2 secs | LR: 0.000317 INFO:tensorflow:Step 23200 | Loss: 0.0019 | Spent: 31.1 secs | LR: 0.000317 INFO:tensorflow:Step 23250 | Loss: 0.0013 | Spent: 31.6 secs | LR: 0.000317 INFO:tensorflow:Step 23300 | Loss: 0.0131 | Spent: 30.7 secs | LR: 0.000316 INFO:tensorflow:Step 23350 | Loss: 0.0126 | Spent: 31.3 secs | LR: 0.000316 INFO:tensorflow:Step 23400 | Loss: 0.0131 | Spent: 32.1 secs | LR: 0.000316 INFO:tensorflow:Step 23450 | Loss: 0.0084 | Spent: 31.2 secs | LR: 0.000316 ------------ minimal test utterance: what times are the nutcracker show playing near me parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ] in:get_event __________________________|_____________________________________________________ | | | | | sl:location | | | | | | | | | | | in:get_location | | | | | ________________|_______________ | | | | sl:category_even sl:search_radius sl:location_user | | | | t | | | | | | ______________|__________ | | what times are playing the nutcracker show near me ------------ Reading ../data/test.tsv
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
INFO:tensorflow:Evaluation: Testing Exact Match Accuracy: 0.724 INFO:tensorflow:Best Accuracy: 0.724 Reading ../data/train.tsv INFO:tensorflow:Step 23500 | Loss: 0.0013 | Spent: 179.2 secs | LR: 0.000316 INFO:tensorflow:Step 23550 | Loss: 0.0033 | Spent: 30.2 secs | LR: 0.000316 INFO:tensorflow:Step 23600 | Loss: 0.0050 | Spent: 31.8 secs | LR: 0.000316 INFO:tensorflow:Step 23650 | Loss: 0.0073 | Spent: 31.1 secs | LR: 0.000315 INFO:tensorflow:Step 23700 | Loss: 0.0042 | Spent: 31.7 secs | LR: 0.000315 INFO:tensorflow:Step 23750 | Loss: 0.0014 | Spent: 31.2 secs | LR: 0.000315 INFO:tensorflow:Step 23800 | Loss: 0.0015 | Spent: 30.6 secs | LR: 0.000315 INFO:tensorflow:Step 23850 | Loss: 0.0025 | Spent: 31.9 secs | LR: 0.000315 INFO:tensorflow:Step 23900 | Loss: 0.0112 | Spent: 30.5 secs | LR: 0.000315 INFO:tensorflow:Step 23950 | Loss: 0.0053 | Spent: 31.2 secs | LR: 0.000314 INFO:tensorflow:Step 24000 | Loss: 0.0061 | Spent: 32.0 secs | LR: 0.000314 INFO:tensorflow:Step 24050 | Loss: 0.0014 | Spent: 31.3 secs | LR: 0.000314 INFO:tensorflow:Step 24100 | Loss: 0.0055 | Spent: 31.8 secs | LR: 0.000314 INFO:tensorflow:Step 24150 | Loss: 0.0032 | Spent: 30.1 secs | LR: 0.000314 INFO:tensorflow:Step 24200 | Loss: 0.0020 | Spent: 32.0 secs | LR: 0.000314 INFO:tensorflow:Step 24250 | Loss: 0.0053 | Spent: 30.8 secs | LR: 0.000313 INFO:tensorflow:Step 24300 | Loss: 0.0049 | Spent: 30.4 secs | LR: 0.000313 INFO:tensorflow:Step 24350 | Loss: 0.0291 | Spent: 32.0 secs | LR: 0.000313