import tensorflow as tf
import numpy as np
tf.__version__
'0.9.0'
char_rdic = ['h', 'e', 'l', 'o'] # id -> char
char_dic = {w : i for i, w in enumerate(char_rdic)} # char -> id
print char_dic
{'h': 0, 'e': 1, 'l': 2, 'o': 3}
ground_truth = [char_dic[c] for c in 'hello']
print ground_truth
[0, 1, 2, 2, 3]
x_data = np.array([[1,0,0,0], # h
[0,1,0,0], # e
[0,0,1,0], # l
[0,0,1,0]], # l
dtype = 'float32')
print x_data.shape, x_data.dtype
(4, 4) float32
tf.one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
session = tf.InteractiveSession()
session.run(tf.initialize_all_variables())
Exception AssertionError: AssertionError("Nesting violated for default stack of <type 'weakref'> objects",) in <bound method InteractiveSession.__del__ of <tensorflow.python.client.session.InteractiveSession object at 0x11b850b90>> ignored
print ground_truth[:], ground_truth[:-1]
[0, 1, 2, 2, 3] [0, 1, 2, 2]
x_data = tf.one_hot(ground_truth[:-1], depth = len(char_dic), on_value = 1.0, off_value = 0.0)
print x_data.eval()
[[ 1. 0. 0. 0.] [ 0. 1. 0. 0.] [ 0. 0. 1. 0.] [ 0. 0. 1. 0.]]
# Configuration
rnn_size = len(char_dic) # 4
batch_size = 1
output_size = 4
# RNN Model
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units = rnn_size)
#initial_state = rnn_cell.zero_state(batch_size, tf.float32)
initial_state = tf.zeros([batch_size, rnn_cell.state_size]) # 위 코드와 같은 결과
print(initial_state)
print initial_state.eval()
Tensor("zeros_12:0", shape=(1, 4), dtype=float32) [[ 0. 0. 0. 0.]]
tf.split(split_dim, num_split, value, name='split')
print x_data.eval()
print
x_split = tf.split(split_dim = 0, num_split = len(char_dic), value = x_data) # dimension=0 을 기준으로 4개로 split
print type(x_split)
for t in x_split:
print t.eval()
[[ 1. 0. 0. 0.] [ 0. 1. 0. 0.] [ 0. 0. 1. 0.] [ 0. 0. 1. 0.]] <type 'list'> [[ 1. 0. 0. 0.]] [[ 0. 1. 0. 0.]] [[ 0. 0. 1. 0.]] [[ 0. 0. 1. 0.]]
with tf.variable_scope('forward'):
outputs, state = tf.nn.rnn(cell = rnn_cell, inputs = x_split, initial_state = initial_state)
print type(outputs)
print
for t in outputs:
print t.get_shape()
print
print state.get_shape()
<type 'list'> (1, 4) (1, 4) (1, 4) (1, 4) (1, 4)
result_outputs = tf.reshape(tf.concat(1, outputs), # shape = 1 x 16
[-1, rnn_size]) # shape = 4 x 4
print result_outputs.get_shape()
(4, 4)
print ground_truth[1:]
targets = tf.constant(ground_truth[1:], tf.int32) # a shape of [-1] flattens into 1-D
print targets.eval()
[1, 2, 2, 3] [1 2 2 3]
weights = tf.ones([len(char_dic) * batch_size]) # tf.ones([4])
print weights.eval()
[ 1. 1. 1. 1.]
loss = tf.nn.seq2seq.sequence_loss_by_example([result_outputs], [targets], [weights])
cost = tf.reduce_sum(loss) / batch_size
train_op = tf.train.RMSPropOptimizer(0.01, 0.9).minimize(cost)
# Launch the graph in a session
with tf.Session() as sess:
tf.initialize_all_variables().run()
for i in range(100):
sess.run(train_op)
result = sess.run(tf.argmax(result_outputs, 1))
print(result, [char_rdic[t] for t in result])
(array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 2]), ['e', 'l', 'l', 'l']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o']) (array([1, 2, 2, 3]), ['e', 'l', 'l', 'o'])