#!/usr/bin/env python
# coding: utf-8

# # CS 20 : TensorFlow for Deep Learning Research
# ## Lecture 07 : ConvNet in TensorFlow
# same contents, but different style with [Lec07_ConvNet mnist by high-level.ipynb](https://nbviewer.jupyter.org/github/aisolab/CS20/blob/master/Lec07_ConvNet%20in%20Tensorflow/Lec07_ConvNet%20mnist%20by%20high-level.ipynb)
# 
# ### ConvNet mnist by high-level
# - Creating the **data pipeline** with `tf.data`
# - Using `tf.keras`, alias `keras`
# - Creating the model as **Class** by subclassing `tf.keras.Model`
# - Training the model with **Drop out** technique by `tf.keras.layers.Dropout`
# - Using tensorboard

# ### Setup

# In[1]:


from __future__ import absolute_import, division, print_function
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
get_ipython().run_line_magic('matplotlib', 'inline')

print(tf.__version__)


# ### Load and Pre-process data

# In[2]:


(x_train, y_train), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()
x_train = x_train  / 255
x_train = x_train.reshape(-1, 28, 28, 1).astype(np.float32)
x_tst = x_tst / 255
x_tst = x_tst.reshape(-1, 28, 28, 1).astype(np.float32)
y_tst = y_tst.astype(np.int32)


# In[3]:


tr_indices = np.random.choice(range(x_train.shape[0]), size = 55000, replace = False)

x_tr = x_train[tr_indices]
y_tr = y_train[tr_indices].astype(np.int32)

x_val = np.delete(arr = x_train, obj = tr_indices, axis = 0)
y_val = np.delete(arr = y_train, obj = tr_indices, axis = 0).astype(np.int32)

print(x_tr.shape, y_tr.shape)
print(x_val.shape, y_val.shape)


# ### Define SimpleCNN class by high-level api

# In[4]:


class SimpleCNN(keras.Model):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.__conv1 = keras.layers.Conv2D(filters=32, kernel_size=[5,5], padding='same',
                                           kernel_initializer=keras.initializers.truncated_normal(),
                                           bias_initializer=keras.initializers.truncated_normal(),
                                           activation=tf.nn.relu)
        self.__conv2 = keras.layers.Conv2D(filters=64, kernel_size=[5,5], padding='same',
                                           kernel_initializer=keras.initializers.truncated_normal(),
                                           bias_initializer=keras.initializers.truncated_normal(),
                                           activation=tf.nn.relu)
        self.__pool = keras.layers.MaxPooling2D()
        self.__flatten = keras.layers.Flatten()
        self.__dropout = keras.layers.Dropout(rate =.5)
        self.__dense1 = keras.layers.Dense(units=1024, activation=tf.nn.relu, 
                                           kernel_initializer=keras.initializers.truncated_normal(),
                                           bias_initializer=keras.initializers.truncated_normal())
        self.__dense2 = keras.layers.Dense(units=num_classes,
                                           kernel_initializer=keras.initializers.truncated_normal(),
                                           bias_initializer=keras.initializers.truncated_normal(),
                                           activation='softmax')
    
    def call(self, inputs, training=False):
        conv1 = self.__conv1(inputs)
        pool1 = self.__pool(conv1)
        conv2 = self.__conv2(pool1)
        pool2 = self.__pool(conv2)
        flattened = self.__flatten(pool2)
        fc = self.__dense1(flattened)
        if training:
            fc = self.__dropout(fc, training=training)
        score = self.__dense2(fc)
        return score


# ### Create a model of SimpleCNN

# In[5]:


# hyper-parameter
lr = .001
epochs = 10
batch_size = 100
total_step = int(x_tr.shape[0] / batch_size)
print(total_step)


# In[6]:


## create input pipeline with tf.data
# for train
tr_dataset = tf.data.Dataset.from_tensor_slices((x_tr, y_tr))
tr_dataset = tr_dataset.batch(batch_size = batch_size).repeat()
print(tr_dataset)

# for validation
val_dataset = tf.data.Dataset.from_tensor_slices((x_val,y_val))
val_dataset = val_dataset.batch(batch_size = batch_size).repeat()
print(val_dataset)

# for test
tst_dataset = tf.data.Dataset.from_tensor_slices((x_tst, y_tst))
tst_dataset = tst_dataset.batch(batch_size=100)
print(tst_dataset)


# In[7]:


## create model
cnn = SimpleCNN(num_classes=10)

# creating callbacks for tensorboard
callbacks = [keras.callbacks.TensorBoard(log_dir='../graphs/lecture07/convnet_mnist_high_kd/',
                                         write_graph=True, write_images=True)]


# In[8]:


# complile
cnn.compile(optimizer=tf.train.AdamOptimizer(learning_rate=lr),
            loss=keras.losses.sparse_categorical_crossentropy,
            callbacks=callbacks)


# ### Train a model

# In[9]:


cnn.fit(tr_dataset, epochs=epochs, steps_per_epoch=total_step,
        validation_data=val_dataset, validation_steps=5000//100)


# ### Calculate accuracy

# In[10]:


sess = keras.backend.get_session()
x_tst_tensor = tf.convert_to_tensor(x_tst)
yhat = cnn(x_tst_tensor, training=False)
print(yhat)


# In[11]:


yhat = sess.run(yhat)
print('tst acc : {:.2%}'.format(np.mean(np.argmax(yhat, axis=-1) == y_tst)))