Imports, Data

In [ ]:
import tensorflow as tf
from tensorflow.keras import layers

import matplotlib.pyplot as plt
import numpy as np
import os
import time
from glob import glob
import math
In [ ]:
print(tf.__version__)
2.4.1
In [ ]:
!nvidia-smi
Tue Mar  9 23:35:13 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.56       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
In [ ]:
!pip install tensorflow-addons
Collecting tensorflow-addons
  Downloading https://files.pythonhosted.org/packages/74/e3/56d2fe76f0bb7c88ed9b2a6a557e25e83e252aec08f13de34369cd850a0b/tensorflow_addons-0.12.1-cp37-cp37m-manylinux2010_x86_64.whl (703kB)
     |████████████████████████████████| 706kB 8.0MB/s 
Requirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow-addons) (2.7.1)
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.12.1
In [ ]:
import tensorflow_addons as tfa
In [ ]:
import cv2
from scipy import ndimage
In [ ]:
# Download resized shots
# Download Line art shots
# Distance map shots
#download weights
In [ ]:
!unzip -q resizedshots.zip
In [ ]:
!unzip -q line_art_shots.zip
In [ ]:
!unzip -q distance_map_shots.zip
In [ ]:
!unzip -q weights.zip
In [ ]:
!mv weights training_checkpoints
In [ ]:
!find . -name ".DS_Store" -delete

Sketch Keras

In [ ]:
!unzip -q sketchKeras.h5.zip
In [ ]:
!unzip -q final_shots.zip
In [ ]:
def create_sketch_model(input_shape=(1024, 1024, 1)):
  input = layers.Input(shape=input_shape)
  model = layers.Conv2D(32, (3, 3), padding='same')(input)
  model = layers.BatchNormalization()(model)
  activation_1 = layers.ReLU()(model)

  model = layers.Conv2D(64, (4, 4), (2, 2), padding='same')(activation_1)
  model = layers.BatchNormalization()(model)
  model = layers.ReLU()(model)

  model = layers.Conv2D(64, (3, 3), padding='same')(model)
  model = layers.BatchNormalization()(model)
  activation_3 = layers.ReLU()(model)

  model = layers.Conv2D(128, (4, 4), (2, 2), padding='same')(activation_3)
  model = layers.BatchNormalization()(model)
  model = layers.ReLU()(model)

  model = layers.Conv2D(128, (3, 3), padding='same')(model)
  model = layers.BatchNormalization()(model)
  activation_5 = layers.ReLU()(model)

  model = layers.Conv2D(256, (4, 4), (2, 2), padding='same')(activation_5)
  model = layers.BatchNormalization()(model)
  model = layers.ReLU()(model)

  model = layers.Conv2D(256, (3, 3), padding='same')(model)
  model = layers.BatchNormalization()(model)
  activation_7 = layers.ReLU()(model)

  model = layers.Conv2D(512, (4, 4), (2, 2), padding='same')(activation_7)
  model = layers.BatchNormalization()(model)
  activation_8 = layers.ReLU()(model)

  model = layers.Conv2D(512, (3, 3), padding='same')(activation_8)
  model = layers.BatchNormalization()(model)
  activation_9 = layers.ReLU()(model)

  model = layers.concatenate([activation_8, activation_9])

  model = layers.UpSampling2D()(model)

  model = layers.Conv2D(512, (4, 4), padding='same')(model)
  model = layers.BatchNormalization()(model)
  model = layers.ReLU()(model)

  model = layers.Conv2D(256, (3, 3), padding='same')(model)
  model = layers.BatchNormalization()(model)
  activation_11 = layers.ReLU()(model)

  model = layers.concatenate([activation_7, activation_11])

  model = layers.UpSampling2D()(model)

  model = layers.Conv2D(256, (4, 4), padding='same')(model)
  model = layers.BatchNormalization()(model)
  model = layers.ReLU()(model)

  model = layers.Conv2D(128, (3, 3), padding='same')(model)
  model = layers.BatchNormalization()(model)
  activation_13 = layers.ReLU()(model)

  model = layers.concatenate([activation_5, activation_13])

  model = layers.UpSampling2D()(model)

  model = layers.Conv2D(128, (4, 4), padding='same')(model)
  model = layers.BatchNormalization()(model)
  model = layers.ReLU()(model)

  model = layers.Conv2D(64, (3, 3), padding='same')(model)
  model = layers.BatchNormalization()(model)
  activation_15 = layers.ReLU()(model)

  model = layers.concatenate([activation_3, activation_15])

  model = layers.UpSampling2D()(model)

  model = layers.Conv2D(64, (4, 4), padding='same')(model)
  model = layers.BatchNormalization()(model)
  model = layers.ReLU()(model)

  model = layers.Conv2D(32, (3, 3), padding='same')(model)
  model = layers.BatchNormalization()(model)
  activation_17 = layers.ReLU()(model)

  model = layers.concatenate([activation_1, activation_17])

  model = layers.Conv2D(1, (3, 3), padding='same')(model)

  return tf.keras.Model(inputs=input, outputs=[model])
In [ ]:
sketchKeras = create_sketch_model()
In [ ]:
sketchKeras.load_weights("sketchKeras.h5")
In [ ]:
def get_light_map_single(img): # Used
  gray = img
  gray = gray[None]
  gray = gray.transpose((1,2,0))
  blur = cv2.GaussianBlur(gray, (0, 0), 3)
  gray = gray.reshape((gray.shape[0],gray.shape[1]))
  highPass = gray.astype(int) - blur.astype(int)
  highPass = highPass.astype(np.float)
  highPass = highPass / 128.0
  return highPass

def normalize_pic(img): # used
  img = img / np.max(img)
  return img

def resize_img_512_3d(img): # used
    zeros = np.zeros((1,3,1024,1024), dtype=np.float)
    zeros[0 , 0 : img.shape[0] , 0 : img.shape[1] , 0 : img.shape[2]] = img
    return zeros.transpose((1,2,3,0)) # BxCxHxW to CxHxWxB

def show_active_img_and_save_denoise(img, dir_path):
    mat = img.astype(np.float)
    mat = - mat + 1
    mat = mat * 255.0
    mat[mat < 0] = 0
    mat[mat > 255] = 255
    mat = mat.astype(np.uint8)
    mat = ndimage.median_filter(mat, 1)
    # mat = cv2.resize(mat, (455, 256))
    mat = cv2.resize(mat, (1024, 576))
    cv2.imwrite(dir_path, mat)
    return mat
In [ ]:
def get_image(img_path):
  from_mat = cv2.imread(img_path)
  width = float(from_mat.shape[1])
  height = float(from_mat.shape[0])
  new_width = 0
  new_height = 0

  if (width > height):
      from_mat = cv2.resize(from_mat, (1024, int(1024 / width * height)), interpolation=cv2.INTER_AREA)
      new_width = 1024
      new_height = int(1024 / width * height)
  else:
      from_mat = cv2.resize(from_mat, (int(512 / height * width), 512), interpolation=cv2.INTER_AREA)
      new_width = int(512 / height * width)
      new_height = 512

  from_mat = from_mat.transpose((2, 0, 1)) # HxWxC to CxHxW
  light_map = np.zeros(from_mat.shape, dtype=np.float)

  for channel in range(3):
    light_map[channel] = get_light_map_single(from_mat[channel])

  light_map = normalize_pic(light_map)
  light_map = resize_img_512_3d(light_map)

  return light_map, new_width, new_height
In [ ]:
def get_line_art(line_mat, dir_path, new_width, new_height):
  line_mat = line_mat.transpose((3, 1, 2, 0))[0] # CxHxWxB to BxHxWxC
  line_mat = line_mat[0:int(new_height), 0:int(new_width), :]

  image = np.amax(line_mat, 2)

  show_active_img_and_save_denoise(image, dir_path)
In [ ]:
folders = glob("final_shots/*")
In [ ]:
!mkdir line_art_shots
In [ ]:
for folder in folders:
  name = folder.split("/")[1]
  new_dir = f"line_art_shots/{name}"
  os.mkdir(new_dir)
  images = glob(f"./final_shots/{name}/*")

  img_stack = np.empty((len(images) * 3, 1024, 1024, 1))

  for idx, image in enumerate(images):
    img, new_width, new_height = get_image(image)
    img_stack[idx * 3:(idx + 1) * 3,] = img
  
  # print('\r', len(images), end='')

  for index, image in enumerate(images):
    line_mat_stack = np.empty((3, 1024, 1024, 1))

    line_mat_stack = sketchKeras.predict(img_stack[index * 3:(index + 1) * 3,], verbose=0)
    image_name = image.split("/")[-1]
    dir_path = f"{new_dir}/{image_name}"

    get_line_art(line_mat_stack, dir_path, new_width, new_height)
In [ ]:
# from google.colab import output
# output.eval_js('new Audio("https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg").play()')
In [ ]:
line_art_folders = glob("line_art_shots/*")
In [ ]:
!mkdir distance_map_shots
In [ ]:
def binarize(sketch, threshold=127):
    return tf.where(sketch < threshold, x=tf.zeros_like(sketch), y=tf.ones_like(sketch) * 255.)
In [ ]:
def get_distance_map(image_path, dir):
  img = tf.keras.preprocessing.image.load_img(image_path, color_mode="grayscale")
  img = tf.keras.preprocessing.image.img_to_array(img)

  sketch = binarize(img)
  a = tf.cast(sketch, tf.uint8).numpy()
  a = a[:, :, 0]

  distance = ndimage.distance_transform_edt(a)
  distance = distance / tf.reduce_max(distance)

  final = (distance + (img[:,:,0] / 255.0) / 12)

  final = tf.image.resize(tf.expand_dims(final, axis=-1), [256, 455])

  tf.keras.preprocessing.image.save_img(dir, final)  
In [ ]:
for folder in line_art_folders:
  name = folder.split("/")[1]
  new_dir = f"distance_map_shots/{name}"
  os.mkdir(new_dir)
  images = glob(f"./line_art_shots/{name}/*")

  print('\r', len(images), end='')

  for image in images:
    image_name = image.split("/")[-1]
    dir_path = f"{new_dir}/{image_name}"

    get_distance_map(image, dir_path)
 8

Dataset

In [ ]:
name_folders = os.listdir("line_art_shots")
In [ ]:
size = len(glob("line_art_shots/*/*"))
In [ ]:
size
Out[ ]:
4527
In [ ]:
class TestDataGenerator():
  def __init__(self):
    self.batch_size = 4
    self.image_shape = (256, 455, 3)

  def getitem(self, folders):
    X = self.__get_data(folders)
    return X

  def get_random_positions(self, images_name):
    limit = len(images_name) - 4
    position = np.random.randint(0, limit)

    return [images_name[position], images_name[position + 2], images_name[position + 4]]
  
  def load_image(self, image_path, resize=False):
    img = tf.keras.preprocessing.image.load_img(image_path)
    img = tf.keras.preprocessing.image.img_to_array(img)

    if resize:
      img = tf.image.resize(img, [256, 455])

    img = img / 255.0
    
    return img

  def __get_data(self, batch):
    reference_color_images_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_line_art_images_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_distance_maps_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)

    reference_color_images_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_line_art_images_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_distance_maps_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)

    middle_color = np.empty((self.batch_size, *self.image_shape))
    middle_line_art = np.empty((self.batch_size, *self.image_shape))
    middle_distance_map = np.empty((self.batch_size, *self.image_shape))
    
    for i, shot_path in enumerate(batch):
      images_name = sorted(os.listdir(f"resizedshots/{shot_path}"))
      selected_images = self.get_random_positions(images_name)
      
      color_images_paths = [f"resizedshots/{shot_path}/{s_i}" for s_i in selected_images]
      line_art_images_paths = [f"line_art_shots/{shot_path}/{s_i}" for s_i in selected_images]
      distance_map_images_paths = [f"distance_map_shots/{shot_path}/{s_i}" for s_i in selected_images]
      
      reference_color_images_0[i,] = self.load_image(color_images_paths[0])
      reference_line_art_images_0[i,] = self.load_image(line_art_images_paths[0], resize=True)
      reference_distance_maps_0[i,] = self.load_image(distance_map_images_paths[0])

      middle_color[i,] = self.load_image(color_images_paths[1])
      middle_line_art[i,] = self.load_image(line_art_images_paths[1], resize=True)
      middle_distance_map[i,] = self.load_image(distance_map_images_paths[1])

      reference_color_images_1[i,] = self.load_image(color_images_paths[2])
      reference_line_art_images_1[i,] = self.load_image(line_art_images_paths[2], resize=True)
      reference_distance_maps_1[i,] = self.load_image(distance_map_images_paths[2])

      concated_image = tf.concat([reference_color_images_0,
                                reference_line_art_images_0,
                                reference_distance_maps_0,
                                middle_color,
                                middle_line_art,
                                middle_distance_map,
                                reference_color_images_1,
                                reference_line_art_images_1,
                                reference_distance_maps_1], axis=0)

      cropped_image = tf.image.crop_to_bounding_box(concated_image, 0, 99, 256, 256)

    
    rf_0 = cropped_image[0:4], cropped_image[4:8], cropped_image[8:12]
    mid = cropped_image[12:16], cropped_image[16:20], cropped_image[20:24]
    rf_1 = cropped_image[24:28], cropped_image[28:32], cropped_image[32:36]

    #y_small = tf.image.resize(mid[0], [64, 64], method=tf.image.ResizeMethod.BILINEAR)

    return rf_0, mid, rf_1#, y_small
In [ ]:
class InferenceDataGenerator(tf.keras.utils.Sequence):
  def __init__(self, folders_paths, batch_size=4):
    self.image_shape = (256, 455, 3)
    self.batch_size = batch_size
    self.folders_paths = folders_paths
    self.shuffle = shuffle
    self.on_epoch_end()

  def __len__(self):
    return len(self.folders_paths) // self.batch_size

  def __getitem__(self, index):
    index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
    batch = [self.folders_paths[k] for k in index]

    X = self.__get_data(batch)
    return X

  def on_epoch_end(self):
    self.index = np.arange(len(self.folders_paths))
  
  def load_image(self, image_path, resize=False):
    img = tf.keras.preprocessing.image.load_img(image_path)
    img = tf.keras.preprocessing.image.img_to_array(img)
    
    if resize:
      img = tf.image.resize(img, [256, 455])

    img = img / 255.0
    
    return img

  def __get_data(self, batch):
    reference_color_images_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_line_art_images_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_distance_maps_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)

    reference_color_images_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_line_art_images_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_distance_maps_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)

    middle_line_art = np.empty((self.batch_size, *self.image_shape))
    middle_distance_map = np.empty((self.batch_size, *self.image_shape))
    
    for i, shot_path in enumerate(batch):
      selected_images = sorted(os.listdir(f"test_shots/color/{shot_path}"))[:4]
      
      color_images_paths = [f"test_shots/color/{shot_path}/{s_i}" for s_i in [selected_images[0], selected_images[-1]]]
      line_art_images_paths = [f"test_shots/line_art_shots/{shot_path}/{s_i}" for s_i in selected_images]
      distance_map_images_paths = [f"test_shots/distance_map_shots/{shot_path}/{s_i}" for s_i in selected_images]
      
      reference_color_images_0[i,] = self.load_image(color_images_paths[0])
      reference_line_art_images_0[i,] = self.load_image(line_art_images_paths[0], resize=True)
      reference_distance_maps_0[i,] = self.load_image(distance_map_images_paths[0])

      middle_line_art[i,] = self.load_image(line_art_images_paths[1], resize=True)
      middle_distance_map[i,] = self.load_image(distance_map_images_paths[1])

      reference_color_images_1[i,] = self.load_image(color_images_paths[1])
      reference_line_art_images_1[i,] = self.load_image(line_art_images_paths[2], resize=True)
      reference_distance_maps_1[i,] = self.load_image(distance_map_images_paths[2])

      concated_image = tf.concat([reference_color_images_0,
                                reference_line_art_images_0,
                                reference_distance_maps_0,
                                middle_color,
                                middle_line_art,
                                middle_distance_map,
                                reference_color_images_1,
                                reference_line_art_images_1,
                                reference_distance_maps_1], axis=0)

      cropped_image = tf.image.crop_to_bounding_box(concated_image, 0, 99, 256, 256)

    rf_0 = cropped_image[0:(self.batch_size)], cropped_image[(self.batch_size):(self.batch_size * 2)], cropped_image[(self.batch_size * 2):(self.batch_size * 3)]
    mid = cropped_image[(self.batch_size * 3):(self.batch_size * 4)], cropped_image[(self.batch_size * 4):(self.batch_size * 5)]
    rf_1 = cropped_image[(self.batch_size * 5):(self.batch_size * 6)], cropped_image[(self.batch_size * 6):(self.batch_size * 7)], cropped_image[(self.batch_size * 7):(self.batch_size * 8)]

    return rf_0, mid, rf_1
In [ ]:
class DataGenerator(tf.keras.utils.Sequence):
  def __init__(self, folders_paths, batch_size=4, shuffle=True):
    self.image_shape = (256, 455, 3)
    self.batch_size = batch_size
    self.folders_paths = folders_paths
    self.shuffle = shuffle
    self.on_epoch_end()

  def __len__(self):
    return len(self.folders_paths) // self.batch_size

  def __getitem__(self, index):
    index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
    batch = [self.folders_paths[k] for k in index]

    X = self.__get_data(batch)
    return X

  def on_epoch_end(self):
    self.index = np.arange(len(self.folders_paths))
    if self.shuffle == True:
        np.random.shuffle(self.index)

  def get_random_positions(self, images_name):
    limit = len(images_name) - 4
    position = np.random.randint(0, limit)

    return [images_name[position], images_name[position + 2], images_name[position + 4]]
  
  def load_image(self, image_path, resize=False):
    img = tf.keras.preprocessing.image.load_img(image_path)
    img = tf.keras.preprocessing.image.img_to_array(img)
    
    if resize:
      img = tf.image.resize(img, [256, 455])

    img = img / 255.0
    
    return img

  def __get_data(self, batch):
    reference_color_images_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_line_art_images_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_distance_maps_0 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)

    reference_color_images_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_line_art_images_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)
    reference_distance_maps_1 = np.empty((self.batch_size, *self.image_shape), dtype=np.float32)

    middle_color = np.empty((self.batch_size, *self.image_shape))
    middle_line_art = np.empty((self.batch_size, *self.image_shape))
    middle_distance_map = np.empty((self.batch_size, *self.image_shape))
    
    for i, shot_path in enumerate(batch):
      images_name = sorted(os.listdir(f"resizedshots/{shot_path}"))
      selected_images = self.get_random_positions(images_name)
      
      color_images_paths = [f"resizedshots/{shot_path}/{s_i}" for s_i in selected_images]
      line_art_images_paths = [f"line_art_shots/{shot_path}/{s_i}" for s_i in selected_images]
      distance_map_images_paths = [f"distance_map_shots/{shot_path}/{s_i}" for s_i in selected_images]
      
      reference_color_images_0[i,] = self.load_image(color_images_paths[0])
      reference_line_art_images_0[i,] = self.load_image(line_art_images_paths[0], resize=True)
      reference_distance_maps_0[i,] = self.load_image(distance_map_images_paths[0])

      middle_color[i,] = self.load_image(color_images_paths[1])
      middle_line_art[i,] = self.load_image(line_art_images_paths[1], resize=True)
      middle_distance_map[i,] = self.load_image(distance_map_images_paths[1])

      reference_color_images_1[i,] = self.load_image(color_images_paths[2])
      reference_line_art_images_1[i,] = self.load_image(line_art_images_paths[2], resize=True)
      reference_distance_maps_1[i,] = self.load_image(distance_map_images_paths[2])

      concated_image = tf.concat([reference_color_images_0,
                                reference_line_art_images_0,
                                reference_distance_maps_0,
                                middle_color,
                                middle_line_art,
                                middle_distance_map,
                                reference_color_images_1,
                                reference_line_art_images_1,
                                reference_distance_maps_1], axis=0)
      
      # cropped_image = tf.image.random_crop(concated_image, size=[self.batch_size * 9, 256, 256, 3])

      cropped_image = tf.image.crop_to_bounding_box(concated_image, 0, 99, 256, 256)

    
    rf_0 = cropped_image[0:4], cropped_image[4:8], cropped_image[8:12]
    mid = cropped_image[12:16], cropped_image[16:20], cropped_image[20:24]
    rf_1 = cropped_image[24:28], cropped_image[28:32], cropped_image[32:36]

    # y_small = tf.image.resize(mid[0], [64, 64], method=tf.image.ResizeMethod.BILINEAR)

    return rf_0, mid, rf_1#, y_small
In [ ]:
train_generator = DataGenerator(name_folders)
In [ ]:
test_generator = TestDataGenerator()
In [ ]:
reference_0, middle, reference_1 = train_generator.__getitem__(44)
In [ ]:
reference_0[0][0].dtype, middle[0][0].dtype, reference_1[0][0].dtype
Out[ ]:
(tf.float32, tf.float32, tf.float32)
In [ ]:
_, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 10))
ax[0].set_title("reference_0")
ax[0].imshow(reference_0[0][0])

ax[1].set_title("middle")
ax[1].imshow(middle[0][0])

ax[2].set_title("reference_1")
ax[2].imshow(reference_1[0][0])
Out[ ]:
<matplotlib.image.AxesImage at 0x7f6c10132940>
In [ ]:
_, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 10))
ax[0].set_title("reference_0")
ax[0].imshow(reference_0[1][0])

ax[1].set_title("middle")
ax[1].imshow(middle[1][1])

ax[2].set_title("reference_1")
ax[2].imshow(reference_1[1][2])
Out[ ]:
<matplotlib.image.AxesImage at 0x7f6bc2009e10>
In [ ]:
_, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 10))
ax[0].set_title("reference_0")
ax[0].imshow(reference_0[2][0])

ax[1].set_title("middle")
ax[1].imshow(middle[2][0])

ax[2].set_title("reference_1")
ax[2].imshow(reference_1[2][0])
Out[ ]:
<matplotlib.image.AxesImage at 0x7f6bb02c5a90>

Spectral Normalization

In [ ]:
class SpectralNormalization(layers.Wrapper):
  def __init__(self, layer, iteration=1, eps=1e-12, training=True, **kwargs):
    super(SpectralNormalization, self).__init__(layer, **kwargs)
    self.iteration = iteration
    self.eps = eps
    self.do_power_iteration = training
  
  def build(self, input_shape):
    self.layer.build(input_shape)
    self.w = self.layer.kernel
    self.w_shape = self.w.shape.as_list()

    self.v = self.add_weight(shape=(1, tf.math.reduce_prod(self.w_shape[:-1])),
                                    initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                                    trainable=False,
                                    name='sn_v',
                                    dtype=tf.float32)
    
    self.u = self.add_weight(shape=(1, self.w_shape[-1]),
                              initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                              trainable=False,
                              name='sn_u',
                              dtype=tf.float32)
  
    super(SpectralNormalization, self).build()
  
  def call(self, layer_inputs):
    self.update_weights()
    output = self.layer(layer_inputs)
    # self.restore_weights()

    return output

  def update_weights(self):
    w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
    
    u_hat = self.u
    v_hat = self.v

    if self.do_power_iteration:
      for _ in range(self.iteration):
        v_ = tf.matmul(u_hat, tf.transpose(w_reshaped))
        v_hat = v_ / (tf.reduce_sum(v_ ** 2) ** 0.5 + self.eps)

        u_ = tf.matmul(v_hat, w_reshaped)
        u_hat = u_ / (tf.reduce_sum(u_**2)**0.5 + self.eps)


    sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
    self.u.assign(u_hat)
    self.v.assign(v_hat)

    self.layer.kernel.assign(self.w / sigma)
      
  def restore_weights(self):
    self.layer.kernel.assign(self.w)

Temporal Constrain Network

LearnableTSM

In [ ]:
class LearnableTSM(tf.keras.Model):
  def __init__(self):
    super(LearnableTSM, self).__init__()
    self.shift_ratio = 0.5
    self.shift_groups = 2
    self.shift_width = 3

    pre_weights = tf.constant([0.0, 0.0, 1.0], dtype=tf.float32)
    pre_weights = tf.reshape(pre_weights, [3, 1, 1, 1, 1])

    post_weights = tf.constant([1.0, 0.0, 0.0], dtype=tf.float32)
    post_weights = tf.reshape(post_weights, [3, 1, 1, 1, 1])

    self.pre_shift_conv = layers.Conv3D(1, [3, 1, 1], use_bias=False, padding="same", weights=[pre_weights])
    self.post_shift_conv = layers.Conv3D(1, [3, 1, 1], use_bias=False, padding="same", weights=[post_weights])

    # at least 3 kernels so we obtain T=3 ?

  def apply_tsm(self, tensor, conv):
    B, T, H, W, C = tensor.shape

    tensor = tf.transpose(tensor, [0, 4, 1, 2, 3])
    tensor = conv(tf.reshape(tensor, [B * C, T, H, W, 1]))
    tensor = tf.reshape(tensor, [B, C, T, H, W])
    tensor = tf.transpose(tensor, [0, 2, 3, 4, 1])

    return tensor

  def call(self, input_tensor):
    shape = B, T, H, W, C = input_tensor.shape
    split_size = int(C * self.shift_ratio) // self.shift_groups

    split_sizes = [split_size] * self.shift_groups + [C - split_size * self.shift_groups]
    tensors = tf.split(input_tensor, split_sizes, -1)
    assert len(tensors) == self.shift_groups + 1

    # we pass all the images here(full batch) but each image only contains a part of its channels
    tensor_1 = self.apply_tsm(tensors[0], self.pre_shift_conv)
    tensor_2 = self.apply_tsm(tensors[1], self.post_shift_conv)

    final_tensor = tf.concat([tensor_1, tensor_2, tensors[2]], -1)
    final_tensor = tf.reshape(final_tensor, shape)

    # tf.keras.layers.Reshape
    
    return final_tensor

Gated Conv

In [ ]:
class GatedConv(tf.keras.Model):
  def __init__(self, kernels, kernel_size, strides, dilation=(1, 1)):
    super(GatedConv, self).__init__()

    self.learnableTSM = LearnableTSM()
    self.feature_conv = SpectralNormalization(layers.Conv2D(kernels, kernel_size, strides=strides, padding="same", dilation_rate=dilation))

    self.gate_conv = SpectralNormalization(layers.Conv2D(kernels, kernel_size, strides=strides, padding="same", dilation_rate=dilation))

    self.activation = layers.LeakyReLU(0.2)
  
  def call(self, input_tensor):
    B, T, H, W, C = input_tensor.shape
    xs = tf.split(input_tensor, num_or_size_splits=T, axis=1)
    gating = tf.stack([self.gate_conv(tf.squeeze(x, axis=1)) for x in xs], axis=1)
    gating = tf.keras.activations.sigmoid(gating)

    feature = self.learnableTSM(input_tensor)
    # shape B, T, H, W, C

    feature = self.feature_conv(tf.reshape(feature, [B * T, H, W, C]))
    _, H_, W_, C_ = feature.shape
    feature = tf.reshape(feature, [B, T, H_, W_, C_])
    feature = self.activation(feature)

    out = gating * feature

    return out
In [ ]:
class GatedDeConv(tf.keras.Model):
  def __init__(self, kernels):
    super(GatedDeConv, self).__init__()
    self.gate_conv = GatedConv(kernels, (3, 3), (1, 1))
    self.upsampling = layers.UpSampling3D(size=(1, 2, 2))
  
  def call(self, input_tensor):
    x = self.upsampling(input_tensor)
    x = self.gate_conv(x)

    return x
In [ ]:
# NEW OPTION
class TemporalConstraintNetwork(tf.keras.Model):
  def __init__(self):
    super(TemporalConstraintNetwork, self).__init__()
    self.conv_1 = layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same")
    self.conv_2 = GatedConv(64, (3, 3), (1, 1))
    self.conv_3 = GatedConv(128, (3, 3), (2, 2))
    self.conv_4 = GatedConv(256, (3, 3), (2, 2))

    self.dilation_1 = GatedConv(256, (3, 3), (1, 1), (2, 2)) # 2
    self.dilation_2 = GatedConv(256, (3, 3), (1, 1), (2, 2)) # 4
    self.dilation_3 = GatedConv(256, (3, 3), (1, 1), (2, 2)) # 8
    self.dilation_4 = GatedConv(256, (3, 3), (1, 1), (2, 2)) # 16

    self.conv_5 = GatedConv(256, (3, 3), (1, 1))
    self.up_conv_1 = GatedDeConv(128)
    self.up_conv_2 = GatedDeConv(3)

    # self.activation = layers.Activation("sigmoid")
  
  def call(self, input_tensor):
    x = self.conv_1(input_tensor)
    x = self.conv_2(x) # Bx3x256x256x64
    x_1 = self.conv_3(x) # Bx3x128x128x128
    x_2 = self.conv_4(x_1) # Bx3x64x64x256

    x = self.dilation_1(x_2)
    x = self.dilation_2(x)
    x = self.dilation_3(x)
    x = self.dilation_4(x) # Bx3x64x64x256

    x = self.conv_5(x) # Bx3x64x64x256
    x = layers.concatenate([x, x_2], axis=-1) # or axis 1??
    x = self.up_conv_1(x)
    x = layers.concatenate([x, x_1], axis=-1) # or axis 1??
    x = self.up_conv_2(x)
    # x = self.activation(x)

    return x

Temporal Constrain Discriminator

In [ ]:
class TemporalConstraintDiscriminator(tf.keras.Model):
  def __init__(self):
    super(TemporalConstraintDiscriminator, self).__init__()
    self.conv_1 = SpectralNormalization(layers.Conv3D(64, (1, 3, 3), (1, 2, 2), padding='same'))
    self.conv_2 = SpectralNormalization(layers.Conv3D(128, (1, 3, 3), (1, 2, 2), padding='same'))
    self.conv_3 = SpectralNormalization(layers.Conv3D(256, (1, 3, 3), (1, 2, 2), padding='same'))
    self.conv_4 = SpectralNormalization(layers.Conv3D(512, (1, 3, 3), (1, 2, 2), padding='same'))
    self.conv_5 = SpectralNormalization(layers.Conv3D(1, (1, 3, 3), (1, 2, 2), padding='same'))
  
  def call(self, input_tensor):
    x = self.conv_1(input_tensor)
    x = self.conv_2(x)
    x = self.conv_3(x)
    x = self.conv_4(x)
    x = self.conv_5(x)

    return x

Encoders

In [ ]:
class ColorEncoder(tf.keras.Model):
  def __init__(self):
    super(ColorEncoder, self).__init__()
    self.conv_1 = SpectralNormalization(layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same"))
    self.conv_2 = SpectralNormalization(layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"))
    self.conv_3 = SpectralNormalization(layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"))
    # self.conv_1 = layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same")
    # self.conv_2 = layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same")
    # self.conv_3 = layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same")

    self.norm_1 = tfa.layers.InstanceNormalization()
    self.norm_2 = tfa.layers.InstanceNormalization()
    self.norm_3 = tfa.layers.InstanceNormalization()

    self.activation = layers.ReLU()
  
  def call(self, x):
    x = self.conv_1(x) # 256
    x = self.norm_1(x)
    x = self.activation(x)
    x = self.conv_2(x) # 128
    x = self.norm_2(x)
    x = self.activation(x)
    x = self.conv_3(x) # 64
    x = self.norm_3(x)
    x = self.activation(x)

    return x # output Bx64x64x256
In [ ]:
class LineArtEncoder(tf.keras.Model):
  def __init__(self):
    super(LineArtEncoder, self).__init__()
    self.conv_1 = SpectralNormalization(layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same"))
    self.conv_2 = SpectralNormalization(layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"))
    self.conv_3 = SpectralNormalization(layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"))
    # self.conv_1 = layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same")
    # self.conv_2 = layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same")
    # self.conv_3 = layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same")

    self.norm_1 = tfa.layers.InstanceNormalization()
    self.norm_2 = tfa.layers.InstanceNormalization()
    self.norm_3 = tfa.layers.InstanceNormalization()

    self.activation = layers.ReLU()
  
  def call(self, x):
    x = self.conv_1(x) # 256
    x = self.norm_1(x)
    x = self.activation(x)
    x = self.conv_2(x) # 128
    x = self.norm_2(x)
    x = self.activation(x)
    x = self.conv_3(x) # 64
    x = self.norm_3(x)
    x = self.activation(x)

    return x # output Bx64x64x256
In [ ]:
class DistanceMapEncoder(tf.keras.Model):
  def __init__(self):
    super(DistanceMapEncoder, self).__init__()
    self.conv_1 = SpectralNormalization(layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same"))
    self.conv_2 = SpectralNormalization(layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"))
    self.conv_3 = SpectralNormalization(layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"))

    self.norm_1 = tfa.layers.InstanceNormalization()
    self.norm_2 = tfa.layers.InstanceNormalization()
    self.norm_3 = tfa.layers.InstanceNormalization()

    self.activation = layers.ReLU()
  
  def call(self, x):
    x = self.conv_1(x) # 256
    x = self.norm_1(x)
    x = self.activation(x)
    x = self.conv_2(x) # 128
    x = self.norm_2(x)
    x = self.activation(x)
    x = self.conv_3(x) # 64
    x = self.norm_3(x)
    x = self.activation(x)

    return x # output Bx64x64x256

Decoder

In [ ]:
class Decoder(tf.keras.Model):
  def __init__(self):
    super(Decoder, self).__init__()
    # SpectralNormalization In decoder
    self.conv_1 = SpectralNormalization(layers.Conv2D(256, (3, 3), strides=(1, 1), padding="same"))
    self.conv_2 = SpectralNormalization(layers.Conv2D(128, (3, 3), strides=(1, 1), padding="same"))
    self.conv_3 = SpectralNormalization(layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same"))
    self.conv_4 = SpectralNormalization(layers.Conv2D(3, (3, 3), strides=(1, 1), padding="same", activation="sigmoid"))

    self.norm_1 = tfa.layers.InstanceNormalization()
    self.norm_2 = tfa.layers.InstanceNormalization()
    self.norm_3 = tfa.layers.InstanceNormalization()

    self.upsampling = layers.UpSampling2D(size=(2, 2))

    self.activation = layers.ReLU()
  
  def call(self, x):
    x = self.conv_1(x) # 256
    x = self.norm_1(x)
    x = self.activation(x)
    x = self.upsampling(x)
    x = self.conv_2(x) # 128
    x = self.norm_2(x)
    x = self.activation(x)
    x = self.upsampling(x)
    x = self.conv_3(x) # 64
    x = self.norm_3(x)
    x = self.activation(x)
    x = self.conv_4(x) # 3

    return x # output Bx256x256x3

Color Transform Layer

In [ ]:
class CreateMasks(tf.keras.Model):
  def __init__(self):
    super(CreateMasks, self).__init__()
    self.conv_m = layers.Conv2D(256, (3, 3), padding="same")
    self.conv_n = layers.Conv2D(256, (3, 3), padding="same")
  
  def call(self, inputs):
    target_distance_map = inputs[0]
    reference_distance = inputs[1]
    # tensor_input = tf.concat([target_distance_map, reference_distance], -1) # 64x64x512
    tensor_input = layers.Concatenate(axis=-1)([target_distance_map, reference_distance])

    m = self.conv_m(tensor_input)
    m = tf.keras.activations.sigmoid(m)

    n = self.conv_n(tensor_input)
    n = tf.keras.activations.sigmoid(n)

    return m, n
In [ ]:
class LeftPart(tf.keras.Model):
  def __init__(self):
    super(LeftPart, self).__init__()
    kernels = 256 / 8
    self.conv = layers.Conv2D(kernels, (1, 1), padding="same")
  
  # def call(self, target_distance_map, reference_distance_feat):
  def call(self, inputs):
    target_distance_map = inputs[0]
    reference_distance_feat = inputs[1]
    reference_distance_x = self.conv(reference_distance_feat)
    target_distance_map_x = self.conv(target_distance_map)
    
    B, H, W, C = target_distance_map_x.shape

    # reference_distance_x = tf.reshape(reference_distance_x, [B, H * W, C]) # BxHWxC. # reshape H * W * C ???
    # target_distance_map_x = tf.reshape(target_distance_map_x, [B, H * W, C]) # BxHWxC

    reference_distance_x = layers.Reshape([H * W, C])(reference_distance_x)
    target_distance_map_x = layers.Reshape([H * W, C])(target_distance_map_x)
     
    M = tf.linalg.matmul(target_distance_map_x, reference_distance_x, transpose_b=True) #BxHWxHW
    # Multiplies each batch element separately

    return M
In [ ]:
class RightPart(tf.keras.Model):
  def __init__(self):
    super(RightPart, self).__init__()
    kernels = 256 / 8
    self.conv = layers.Conv2D(kernels, (1, 1), padding="same")
  
  def call(self, inputs):
    m = inputs[0]
    reference_color = inputs[1]
    fm = reference_color * m # like attention

    x = self.conv(fm)
    
    B, H, W, C = x.shape

    # x = tf.reshape(x, [B, H * W, C]) # BxHWxC
    x = layers.Reshape([H * W, C])(x) # BxHWxC
    x = tf.transpose(x, [0, 2, 1]) # BxCxHW

    return x, fm
In [ ]:
class ColorTransformLayer(tf.keras.Model):
  def __init__(self):
    super(ColorTransformLayer, self).__init__()
    self.lp = LeftPart()
    self.rp = RightPart()
    self.get_masks = CreateMasks()
    self.conv = layers.Conv2D(256, (1, 1), padding="same")
  
  def call(self, inputs):
    target_distance_map = inputs[0]
    reference_distance_0 = inputs[1]
    reference_distance_1 = inputs[2]
    reference_color_0 = inputs[3]
    reference_color_1 = inputs[4]

    # target_distance_map, reference_distance_0, reference_distance_1, reference_color_0, reference_color_1
    B, H, W, _ = target_distance_map.shape

    M_0 = self.lp([target_distance_map, reference_distance_0]) #HWxHW
    M_1 = self.lp([target_distance_map, reference_distance_1]) #HWxHW

    # matching_matrix = tf.concat([M_0, M_1], 1)
    matching_matrix = layers.Concatenate(axis=1)([M_0, M_1])
    matching_matrix = tf.keras.activations.softmax(matching_matrix) # HWKxHW

    small_m_0, n_0 = self.get_masks([target_distance_map, reference_distance_0])
    small_m_1, n_1 = self.get_masks([target_distance_map, reference_distance_1])

    c_0, fm_0 = self.rp([small_m_0, reference_color_0]) #BxCxHW
    c_1, fm_1 = self.rp([small_m_1, reference_color_1]) #BxCxHW

    # reference_color_matrix = tf.concat([c_0, c_1], -1) #BxCxKHW
    reference_color_matrix = layers.Concatenate(axis=-1)([c_0, c_1])

    f_mat = tf.linalg.matmul(reference_color_matrix, matching_matrix) #BxCxHW
    _, C, _ = f_mat.shape

    # f_mat = tf.reshape(f_mat, [B, C, H, W])
    f_mat = layers.Reshape([C, H, W])(f_mat)
    f_mat = tf.transpose(f_mat, [0, 2, 3, 1])

    f_mat = self.conv(f_mat) # BxHxWxC

    f_sim_left = (fm_1 * n_1) + ((n_1 - 1) * f_mat)
    f_sim_right = (fm_0 * n_0) + ((n_0 - 1) * f_mat)

    f_sim = (f_sim_left + f_sim_right) / 2
    # compute mean for each element in the batch

    return f_sim

Embedder

In [ ]:
class Embedder(tf.keras.Model):
  def __init__(self):
    super(Embedder, self).__init__()
    self.conv_1 = layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same")
    self.conv_2 = layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same")
    self.conv_3 = layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same")
    self.conv_4 = layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same")
    self.conv_5 = layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same")
  
  def call(self, inputs):
    reference_line_art = inputs[0]
    reference_color = inputs[1]
    # images not features
    # x = tf.concat([reference_line_art, reference_color], -1) # Bx256x256x6
    x = layers.Concatenate(axis=-1)([reference_line_art, reference_color])

    x = self.conv_1(x) # 256
    x = self.conv_2(x) # 128
    x = self.conv_3(x) # 64
    x = self.conv_4(x) # 32
    x = self.conv_5(x) # 16

    x = layers.AveragePooling2D((16, 16))(x) # Bx1x1x512

    # print(x.shape, "Embedder")

    return x
In [ ]:
class SEV(tf.keras.Model):
  def __init__(self):
    super(SEV, self).__init__()
    self.embedder = Embedder()
    self.dense_1 = layers.Dense(512)
    self.dense_2 = layers.Dense(512)
  
  def call(self, inputs):
    reference_line_art_0 = inputs[0]
    reference_color_0 = inputs[1] 
    reference_line_art_1 = inputs[2]
    reference_color_1 = inputs[3]

    latent_vector_0 = self.embedder([reference_line_art_0, reference_color_0])
    latent_vector_1 = self.embedder([reference_line_art_1, reference_color_1])

    x = (latent_vector_0 + latent_vector_1) / 2
    x = self.dense_1(x)
    x = self.dense_2(x)

    # print(x.shape, "SEV")

    return x

AdaIn Normalization

In [ ]:
content_mean, content_variance = tf.nn.moments(tf.random.normal([4, 128, 128, 32]), [0, 1, 2], keepdims=True) # Batch norm
content_mean.shape, content_variance.shape
Out[ ]:
(TensorShape([1, 1, 1, 32]), TensorShape([1, 1, 1, 32]))
In [ ]:
content_mean, content_variance = tf.nn.moments(tf.random.normal([4, 128, 128, 32]), [1, 2], keepdims=True) # Instance norm
content_mean.shape, content_variance.shape
Out[ ]:
(TensorShape([4, 1, 1, 32]), TensorShape([4, 1, 1, 32]))
In [ ]:
content_mean, content_variance = tf.nn.moments(tf.random.normal([4, 128, 128, 32]), [1, 2, 3], keepdims=True) # Layer norm
content_mean.shape, content_variance.shape
Out[ ]:
(TensorShape([4, 1, 1, 1]), TensorShape([4, 1, 1, 1]))
In [ ]:
# def ada_in(x, style_vector, epsilon=1e-5):
#   content_mean, content_variance = tf.nn.moments(x, [1, 2], keepdims=True) # Bx1x1xC
#   content_sigma = tf.sqrt(tf.add(content_variance, epsilon))

#   num_features = x.shape[-1]

#   style_mean = style_vector[:, :, :, :num_features]
#   style_sigma = style_vector[:, :, :, num_features:num_features*2]

#   out = (x - content_mean) / content_sigma
#   out = style_sigma * out + style_mean

#   return out
In [ ]:
class AdaInNormalization(tf.keras.layers.Layer):
  def __init__(self):
    super(AdaInNormalization, self).__init__()
    self.epsilon = 1e-5

  def call(self, x, style_vector):
    content_mean, content_variance = tf.nn.moments(x, [1, 2], keepdims=True) # Bx1x1xC
    content_sigma = tf.sqrt(tf.add(content_variance, self.epsilon))

    num_features = x.shape[-1]

    style_mean = style_vector[:, :, :, :num_features]
    style_sigma = style_vector[:, :, :, num_features:num_features*2]

    out = (x - content_mean) / content_sigma
    out = style_sigma * out + style_mean

    return out

ResBlocks

In [ ]:
class ResBlock(tf.keras.Model):
  def __init__(self):
    super(ResBlock, self).__init__()
    self.conv_1 = layers.Conv2D(256, kernel_size=(1, 1), strides=(1, 1), padding='valid')
    self.conv_2 = layers.Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding='same')
    self.conv_3 = layers.Conv2D(256, kernel_size=(1, 1), strides=(1, 1), padding='valid')

    # new
    self.AdaInLayer = AdaInNormalization()
  
  def call(self, inputs):
    x = inputs[0]
    # ada_norm = inputs[1]
    # style_vector = inputs[2]
    style_vector = inputs[1]

    x_skip = x 

    x = self.conv_1(x)
    x = self.AdaInLayer(x, style_vector)
    # x = ada_norm(x, style_vector)
    x = layers.ReLU()(x)

    x = self.conv_2(x)
    # x = ada_norm(x, style_vector)
    x = self.AdaInLayer(x, style_vector)
    x = layers.ReLU()(x)

    x = self.conv_3(x)
    # x = ada_norm(x, style_vector)
    x = self.AdaInLayer(x, style_vector)

    x = layers.add([x, x_skip])
    x = layers.ReLU()(x)

    # print(x.shape, "RESNET")

    return x

Color Transform Network

In [ ]:
class ColorTransformNetwork(tf.keras.Model):
  def __init__(self):
    super(ColorTransformNetwork, self).__init__()
  
    self.color_encoder = ColorEncoder()
    self.lineart_encoder = LineArtEncoder()
    self.distance_encoder = DistanceMapEncoder()

    self.color_transform_layer = ColorTransformLayer()
    self.sev = SEV()

    self.res_block_1 = ResBlock()
    self.res_block_2 = ResBlock()
    self.res_block_3 = ResBlock()
    self.res_block_4 = ResBlock()
    self.res_block_5 = ResBlock()
    self.res_block_6 = ResBlock()
    self.res_block_7 = ResBlock()
    self.res_block_8 = ResBlock()

    self.sim_conv = layers.Conv2D(3, kernel_size=(1, 1), strides=(1, 1), padding='same')
    self.mid_conv = layers.Conv2D(3, kernel_size=(1, 1), strides=(1, 1), padding='same')

    self.decoder = Decoder()
  
  def call(self, inputs):
    target_line_art_images = inputs[0]
    target_distance_maps = inputs[1]
    reference_color_images_0 = inputs[2]
    reference_line_art_images_0 = inputs[3]
    reference_distance_maps_0 = inputs[4]
    reference_color_images_1 = inputs[5]
    reference_line_art_images_1 = inputs[6]
    reference_distance_maps_1 = inputs[7]

    target_line_art_images_features = self.lineart_encoder(target_line_art_images) # EnL
    target_distance_maps_features = self.distance_encoder(target_distance_maps) # EnD

    reference_distance_maps_0_features = self.distance_encoder(reference_distance_maps_0) # EnD
    reference_distance_maps_1_features = self.distance_encoder(reference_distance_maps_1) # EnD

    reference_color_images_0_features = self.color_encoder(reference_color_images_0) # EnC
    reference_color_images_1_features = self.color_encoder(reference_color_images_1) # EnC

    f_sim = self.color_transform_layer([target_distance_maps_features,
                                  reference_distance_maps_0_features,
                                  reference_distance_maps_1_features,
                                  reference_color_images_0_features,
                                  reference_color_images_1_features])

    style_vector = self.sev([reference_line_art_images_0,
                            reference_color_images_0,
                            reference_line_art_images_1,
                            reference_color_images_1]) # [Batch, 1, 1, 512])

    Y_trans_sim = self.sim_conv(f_sim) # [Batch, 64, 64, 3]
    Y_trans_sim = layers.UpSampling2D(size=(2, 2))(Y_trans_sim)
    Y_trans_sim = layers.UpSampling2D(size=(2, 2))(Y_trans_sim) # [Batch, 256, 256, 3]
    
    res_input = layers.add([target_line_art_images_features, f_sim]) # [Batch, 64, 64, 256]

    x = self.res_block_1([res_input, style_vector])
    x = self.res_block_2([x, style_vector])
    x = self.res_block_3([x, style_vector])
    x = self.res_block_4([x, style_vector])
    x = self.res_block_5([x, style_vector])
    x = self.res_block_6([x, style_vector])
    x = self.res_block_7([x, style_vector])
    x = self.res_block_8([x, style_vector])

    Y_trans_mid = self.mid_conv(x)
    Y_trans_mid = layers.UpSampling2D(size=(2, 2))(Y_trans_mid)
    Y_trans_mid = layers.UpSampling2D(size=(2, 2))(Y_trans_mid)

    Y_trans = self.decoder(x)

    return Y_trans, Y_trans_mid, Y_trans_sim

Color Transform Discriminator

In [ ]:
class ColorTransformDiscriminator(tf.keras.Model):
  def __init__(self):
    super(ColorTransformDiscriminator, self).__init__()
    self.conv_1 = SpectralNormalization(layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"))
    self.conv_2 = SpectralNormalization(layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"))
    self.conv_3 = SpectralNormalization(layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"))
    self.conv_4 = SpectralNormalization(layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"))
    self.conv_5 = SpectralNormalization(layers.Conv2D(1, (3, 3), strides=(2, 2), padding="same"))

    self.activation = layers.LeakyReLU(0.2)
  
  def call(self, line_art, target_y_trans):
    x = tf.concat([line_art, target_y_trans], -1) # Bx256x256x6
    x = self.conv_1(x) # 128
    x = self.activation(x)
    x = self.conv_2(x) # 64
    x = self.activation(x)
    x = self.conv_3(x) # 32
    x = self.activation(x)
    x = self.conv_4(x) # 16
    x = self.activation(x)
    x = self.conv_5(x) # 8

    return x

VGG

In [ ]:
class Vgg19(tf.keras.Model):
  def __init__(self):
    super(Vgg19, self).__init__()
    layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1'] 
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    
    outputs = [vgg.get_layer(name).output for name in layers]

    self.model = tf.keras.Model([vgg.input], outputs)
  
  def call(self, x):
    x = tf.keras.applications.vgg19.preprocess_input(x * 255.0)
    return self.model(x)

Losses

In [ ]:
def l1_loss(y, y_trans):
  return tf.reduce_mean(tf.abs(y - y_trans))
In [ ]:
def perceptual_loss(y_list, y_trans_list):
  loss = 0
  for feature_map_y, feature_map_y_trans in zip(y_list, y_trans_list):
    loss += tf.reduce_mean(tf.math.abs(feature_map_y - feature_map_y_trans))
  
  return (loss / 5) * 3e-2
In [ ]:
def get_gram_matrix(feature_map):
  B, H, W, C = feature_map.shape
  matrix = tf.transpose(feature_map, [0, 3, 1, 2])
  matrix = tf.reshape(matrix, [B, C, H * W])

  num_locations = tf.cast(H * W, tf.float32)

  gram_matrix = tf.linalg.matmul(matrix, matrix, transpose_b=True) # C, HW * HW, C
  gram_matrix = gram_matrix / num_locations

  return gram_matrix
In [ ]:
def style_loss(y_list, y_trans_list):
  loss = 0
  for feature_map_y, feature_map_y_trans in zip(y_list, y_trans_list):
    loss += tf.reduce_mean(tf.abs(get_gram_matrix(feature_map_y) - get_gram_matrix(feature_map_y_trans)))
  
  return (loss / 5) * 1e-6
In [ ]:
def latent_constraint_loss(y, y_trans_sim, y_trans_mid):
  loss = tf.reduce_mean(tf.abs(y - y_trans_sim) + tf.abs(y - y_trans_mid))
  return loss
In [ ]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
In [ ]:
def generator_loss(y_trans_class):
  return cross_entropy(tf.ones_like(y_trans_class), y_trans_class)
In [ ]:
def compute_color_network_loss(y_trans_class, y, y_trans, y_trans_sim, y_trans_mid, y_list, y_trans_list, lambda_style=1000, lambda_l1=10):
  loss = 0
  gen_loss = generator_loss(y_trans_class)

  loss += gen_loss
  latent_loss = latent_constraint_loss(y, y_trans_sim, y_trans_mid)

  loss += latent_loss
  s_loss = style_loss(y_list, y_trans_list) * lambda_style

  loss += s_loss
  p_loss = perceptual_loss(y_list, y_trans_list)

  loss += p_loss
  l_loss = l1_loss(y, y_trans) * lambda_l1

  loss += l_loss

  return loss, gen_loss, latent_loss, s_loss, p_loss, l_loss
In [ ]:
def compute_discriminator_2d_loss(y_class, y_trans_class):
  real_loss = cross_entropy(tf.ones_like(y_class), y_class)
  fake_loss = cross_entropy(tf.zeros_like(y_trans_class), y_trans_class)
  loss = real_loss + fake_loss

  return loss

Color Transform Network Training

In [ ]:
# tf.keras.backend.clear_session()
In [ ]:
generator_lr = 1e-4
discriminator_lr = 1e-5

color_network_optimizer = tf.keras.optimizers.Adam(learning_rate=generator_lr, beta_1=0.5, beta_2=0.999)
discriminator_2d_optimizer = tf.keras.optimizers.Adam(learning_rate=discriminator_lr, beta_1=0.5, beta_2=0.999)

batch_size = 4
train_steps = len(name_folders) // batch_size
In [ ]:
color_network = ColorTransformNetwork()
discriminator_2d = ColorTransformDiscriminator()
vgg = Vgg19()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
80142336/80134624 [==============================] - 1s 0us/step
In [ ]:
# color_network.summary()
In [ ]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(color_network=color_network,
                                 discriminator_2d=discriminator_2d,
                                 color_network_optimizer=color_network_optimizer,
                                 discriminator_2d_optimizer=discriminator_2d_optimizer)
In [ ]:
print(tf.train.latest_checkpoint(checkpoint_dir))
./training_checkpoints/ckpt-350
In [ ]:
# tf.train.list_variables(tf.train.latest_checkpoint(checkpoint_dir))
In [ ]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
# checkpoint.restore("./training_checkpoints/ckpt-410")
Out[ ]:
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f03221124d0>
In [ ]:
# !rm -rf training_checkpoints/
In [ ]:
@tf.function
def train_step(x, y0, y1):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    y_trans, y_trans_mid, y_trans_sim = color_network([x[1], x[2], y0[0], y0[1], y0[2], y1[0], y1[1], y1[2]], training=True)

    y = x[0]

    y_list = vgg(y)
    # y_trans_list = vgg(tf.keras.backend.clip(y_trans, 0, 1))
    y_trans_list = vgg(y_trans)

    y_class = discriminator_2d(x[1], y, training=True)
    y_trans_class = discriminator_2d(x[1], y_trans, training=True)

    color_network_loss, gen_loss, latent_loss, s_loss, p_loss, l_loss = compute_color_network_loss(y_trans_class, y, y_trans, y_trans_sim, y_trans_mid, y_list, y_trans_list)
    discriminator_2d_loss = compute_discriminator_2d_loss(y_class, y_trans_class)
  
  color_network_gradients = gen_tape.gradient(color_network_loss, color_network.trainable_variables)
  discriminator_2d_gradients = disc_tape.gradient(discriminator_2d_loss, discriminator_2d.trainable_variables)

  color_network_optimizer.apply_gradients(zip(color_network_gradients, color_network.trainable_variables))
  discriminator_2d_optimizer.apply_gradients(zip(discriminator_2d_gradients, discriminator_2d.trainable_variables))

  return color_network_loss, gen_loss, latent_loss, s_loss, p_loss, l_loss, discriminator_2d_loss
In [ ]:
train_loss_results = []
generator_loss_results = []
discriminator_loss_results = []

gen_loss_results = []
latent_loss_results = []
s_loss_results = []
p_loss_results = []
l_loss_results = []
In [ ]:
def plot_metrics(train_loss, generator_loss, discriminator_loss, gen_loss, latent_loss, s_loss, p_loss, l_loss):
  fig, ax = plt.subplots(2, 4, figsize=(20, 20))

  ax[0, 0].plot(np.arange(len(train_loss)), train_loss)
  ax[0, 0].set_title('train_loss')

  ax[0, 1].plot(np.arange(len(generator_loss)), generator_loss)
  ax[0, 1].set_title('generator_loss')

  ax[0, 2].plot(np.arange(len(discriminator_loss)), discriminator_loss)
  ax[0, 2].set_title('discriminator_loss')

  ax[0, 3].plot(np.arange(len(gen_loss)), gen_loss)
  ax[0, 3].set_title('gen_loss')

  ax[1, 0].plot(np.arange(len(latent_loss)), latent_loss)
  ax[1, 0].set_title('latent_loss')

  ax[1, 1].plot(np.arange(len(s_loss)), s_loss)
  ax[1, 1].set_title('s_loss')

  ax[1, 2].plot(np.arange(len(p_loss)), p_loss)
  ax[1, 2].set_title('p_loss')

  ax[1, 3].plot(np.arange(len(l_loss)), l_loss)
  ax[1, 3].set_title('l_loss')
In [ ]:
def train():
  for epoch in range(epochs):
    batch_time = time.time()
    epoch_time = time.time()
    step = 0
    epoch_count = f"0{epoch + 1}/{epochs}" if epoch < 9 else f"{epoch + 1}/{epochs}"

    for reference_0, middle, reference_1 in train_generator:
      color_network_loss, gen_loss, latent_loss, s_loss, p_loss, l_loss, discriminator_2d_loss = train_step(middle, reference_0, reference_1)
      
      color_network_loss = float(color_network_loss)
      discriminator_2d_loss = float(discriminator_2d_loss)
      loss = color_network_loss + discriminator_2d_loss
      step += 1

      print('\r', 'Epoch', epoch_count, '| Step', f"{step}/{train_steps}",
              '| Loss:', f"{loss:.5f}", '| Discriminator loss:', f"{discriminator_2d_loss:.5f}",
             '| Generator loss:', f"{color_network_loss:.5f}", "| Step Time:", f"{time.time() - batch_time:.2f}", end='')    
        
      batch_time = time.time()
      train_loss_results.append(loss)
      generator_loss_results.append(color_network_loss)
      discriminator_loss_results.append(discriminator_2d_loss)
      gen_loss_results.append(float(gen_loss))
      latent_loss_results.append(float(latent_loss))
      s_loss_results.append(float(s_loss))
      p_loss_results.append(float(p_loss))
      l_loss_results.append(float(l_loss))
      
    checkpoint.save(file_prefix=checkpoint_prefix)

    print('\r', 'Epoch', epoch_count, '| Step', f"{step}/{train_steps}",
          '| Loss:', f"{loss:.5f}", '| Discriminator loss:', f"{discriminator_2d_loss:.5f}",
          '| Generator loss:', f"{color_network_loss:.5f}", "| Epoch Time:", f"{time.time() - epoch_time:.2f}")

Train with no spectral normalization

In [ ]:
reference_0, middle, reference_1 = test_generator.getitem(["shot67", "shot165", "shot193", "shot244"])

y_trans, y_trans_mid, y_trans_sim = color_network([middle[1], middle[2],
                                                  reference_0[0], reference_0[1],
                                                  reference_0[2], reference_1[0], 
                                                  reference_1[1], reference_1[2]])

_, ax = plt.subplots(nrows=1, ncols=4, figsize=(16, 10))
ax[0].set_title("Image 1")
ax[0].imshow(y_trans[0])

ax[1].set_title("Image 2")
ax[1].imshow(y_trans[1])

ax[2].set_title("image 3")
ax[2].imshow(y_trans[2])

ax[3].set_title("image 4")
ax[3].imshow(y_trans[3])
Out[ ]:
<matplotlib.image.AxesImage at 0x7f81a980f850>