#!/usr/bin/env python
# coding: utf-8
# # Classifying all the Tribune images
#
# Using the simple model we created in [this notebook](Training-a-classification-model-for-the-Tribune.ipynb), let's attempt to classify all the images in the Tribune collection.
# In[65]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pandas as pd
import os
from urllib.parse import urlparse
import requests
from IPython.display import display, HTML
import copy
from tqdm import tqdm_notebook
import sys
import time
import numpy as np
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# In[66]:
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
input_mean=0, input_std=255):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
if file_name.endswith(".png"):
image_reader = tf.image.decode_png(file_reader, channels = 3,
name='png_reader')
elif file_name.endswith(".gif"):
image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
name='gif_reader'))
elif file_name.endswith(".bmp"):
image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
else:
image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0);
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
for l in proto_as_ascii_lines:
label.append(l.rstrip())
return label
# In[67]:
def label_tribune_image(file_name):
model_file = "tensorflow-for-poets-2/tf_files/tribune_graph.pb"
label_file = "tensorflow-for-poets-2/tf_files/tribune_labels.txt"
input_height = 224
input_width = 224
input_mean = 128
input_std = 128
input_layer = "input"
output_layer = "final_result"
graph = load_graph(model_file)
t = read_tensor_from_image_file(file_name,
input_height=input_height,
input_width=input_width,
input_mean=input_mean,
input_std=input_std)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
input_operation = graph.get_operation_by_name(input_name);
output_operation = graph.get_operation_by_name(output_name);
with tf.Session(graph=graph) as sess:
start = time.time()
results = sess.run(output_operation.outputs[0],
{input_operation.outputs[0]: t})
end=time.time()
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
labels = load_labels(label_file)
scores = {}
for i in top_k:
scores[labels[i]] = results[i]
return scores
def detect_all():
'''
I've already got copies of all the images, so I'll just point the script at them.
'''
face_data = []
image_dir = '/Volumes/bigdata/mydata/SLNSW/Tribune/images/500'
images = [i for i in os.listdir(image_dir) if i[-4:] == '.jpg']
for image in tqdm_notebook(images):
img_file = os.path.join(image_dir, image)
scores = label_tribune_image(img_file)
scores['image'] = image.replace('-500.jpg', '')
results.append(scores)
# In[ ]:
results = []
detect_all()
# In[70]:
df = pd.DataFrame(results)
# In[71]:
df.head()
# In[73]:
df.round(5)
# In[74]:
df.to_csv('classified.csv', index=False)
# In[78]:
df = pd.read_csv('classified.csv')
# In[79]:
df.describe()
# In[97]:
portraits = df.loc[df['portraits'] > 0.95]
# In[104]:
samples = portraits.sample(6)
html = ''
for image in samples['image'].tolist():
html += ''.format(image)
display(HTML(html))
# In[103]:
protests = df.loc[df['protests'] > 0.95]
# In[101]:
samples = protests.sample(6)
html = ''
for image in samples['image'].tolist():
html += ''.format(image)
display(HTML(html))
# In[ ]: