Classifying all the Tribune images

Using the simple model we created in this notebook, let's attempt to classify all the images in the Tribune collection.

In [65]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pandas as pd
import os
from urllib.parse import urlparse
import requests
from IPython.display import display, HTML
import copy
from tqdm import tqdm_notebook


import sys
import time

import numpy as np
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
In [66]:
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

def load_graph(model_file):
  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())
  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
				input_mean=0, input_std=255):
  input_name = "file_reader"
  output_name = "normalized"
  file_reader = tf.read_file(file_name, input_name)
  if file_name.endswith(".png"):
    image_reader = tf.image.decode_png(file_reader, channels = 3,
                                       name='png_reader')
  elif file_name.endswith(".gif"):
    image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
                                                  name='gif_reader'))
  elif file_name.endswith(".bmp"):
    image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
  else:
    image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
                                        name='jpeg_reader')
  float_caster = tf.cast(image_reader, tf.float32)
  dims_expander = tf.expand_dims(float_caster, 0);
  resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
  normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
  sess = tf.Session()
  result = sess.run(normalized)

  return result

def load_labels(label_file):
  label = []
  proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
  for l in proto_as_ascii_lines:
    label.append(l.rstrip())
  return label
In [67]:
def label_tribune_image(file_name):
    model_file = "tensorflow-for-poets-2/tf_files/tribune_graph.pb"
    label_file = "tensorflow-for-poets-2/tf_files/tribune_labels.txt"
    input_height = 224
    input_width = 224
    input_mean = 128
    input_std = 128
    input_layer = "input"
    output_layer = "final_result"
    graph = load_graph(model_file)
    t = read_tensor_from_image_file(file_name,
                                  input_height=input_height,
                                  input_width=input_width,
                                  input_mean=input_mean,
                                  input_std=input_std)

    input_name = "import/" + input_layer
    output_name = "import/" + output_layer
    input_operation = graph.get_operation_by_name(input_name);
    output_operation = graph.get_operation_by_name(output_name);

    with tf.Session(graph=graph) as sess:
        start = time.time()
        results = sess.run(output_operation.outputs[0],
                      {input_operation.outputs[0]: t})
        end=time.time()
    results = np.squeeze(results)

    top_k = results.argsort()[-5:][::-1]
    labels = load_labels(label_file)
    scores = {}
    for i in top_k:
        scores[labels[i]] = results[i]
    return scores


def detect_all():
    '''
    I've already got copies of all the images, so I'll just point the script at them.
    '''
    face_data = []
    image_dir = '/Volumes/bigdata/mydata/SLNSW/Tribune/images/500'
    images = [i for i in os.listdir(image_dir) if i[-4:] == '.jpg']
    for image in tqdm_notebook(images):
        img_file = os.path.join(image_dir, image)
        scores = label_tribune_image(img_file)
        scores['image'] = image.replace('-500.jpg', '')
        results.append(scores)
In [ ]:
results = []
detect_all()
In [70]:
df = pd.DataFrame(results)
In [71]:
df.head()
Out[71]:
image portraits protests
0 FL1817297 2.656395e-09 1.000000
1 FL1817300 7.659095e-03 0.992341
2 FL1817303 4.747240e-05 0.999953
3 FL1817307 1.007780e-04 0.999899
4 FL1817310 1.188131e-10 1.000000
In [73]:
df.round(5)
Out[73]:
image portraits protests
0 FL1817297 0.00000 1.00000
1 FL1817300 0.00766 0.99234
2 FL1817303 0.00005 0.99995
3 FL1817307 0.00010 0.99990
4 FL1817310 0.00000 1.00000
5 FL1817311 0.97723 0.02277
6 FL1817312 0.97215 0.02785
7 FL1817313 0.00029 0.99971
8 FL1817317 0.99869 0.00131
9 FL1817320 1.00000 0.00000
10 FL1817323 0.00072 0.99928
11 FL1817324 0.02886 0.97114
12 FL1817325 0.99999 0.00001
13 FL1817328 1.00000 0.00000
14 FL1817330 0.95103 0.04897
15 FL1817335 1.00000 0.00000
16 FL1817336 0.99618 0.00382
17 FL1817337 0.99382 0.00618
18 FL1817342 0.98216 0.01784
19 FL1817344 0.99674 0.00326
20 FL1817346 0.13909 0.86091
21 FL1817348 0.91126 0.08874
22 FL1817349 0.00000 1.00000
23 FL1817350 0.00000 1.00000
24 FL1817351 0.00000 1.00000
25 FL1817352 0.00000 1.00000
26 FL1817356 0.00005 0.99995
27 FL1817359 0.99467 0.00533
28 FL1817360 0.00000 1.00000
29 FL1817362 0.08009 0.91991
... ... ... ...
18487 FL4464422 0.99992 0.00008
18488 FL4464425 0.00120 0.99880
18489 FL4464428 0.00752 0.99248
18490 FL4464430 0.00000 1.00000
18491 FL4464433 1.00000 0.00000
18492 FL4464435 0.00120 0.99880
18493 FL4464436 0.54698 0.45302
18494 FL4464438 0.95447 0.04553
18495 FL4464440 0.00000 1.00000
18496 FL4464441 0.00002 0.99998
18497 FL4464442 0.00000 1.00000
18498 FL4464444 0.00001 0.99999
18499 FL4464447 0.00000 1.00000
18500 FL4464449 0.00001 0.99999
18501 FL4464451 0.03624 0.96376
18502 FL4464453 0.00000 1.00000
18503 FL4464455 0.00226 0.99774
18504 FL4464457 0.76447 0.23553
18505 FL4464459 0.00004 0.99996
18506 FL4464460 0.98578 0.01422
18507 FL4464465 0.00017 0.99983
18508 FL4464468 0.00010 0.99990
18509 FL4464469 0.00236 0.99764
18510 FL4464471 0.07899 0.92101
18511 FL4464472 0.10720 0.89280
18512 FL4464474 0.99999 0.00001
18513 FL4464476 0.00000 1.00000
18514 FL4464477 0.00001 0.99999
18515 FL4464480 0.87323 0.12677
18516 FL4464482 0.88942 0.11058

18517 rows × 3 columns

In [74]:
df.to_csv('classified.csv', index=False)
In [78]:
df = pd.read_csv('classified.csv')
In [79]:
df.describe()
Out[79]:
portraits protests
count 1.851700e+04 1.851700e+04
mean 2.729989e-01 7.270011e-01
std 4.200135e-01 4.200135e-01
min 6.136425e-17 2.156611e-14
25% 2.731243e-07 2.685431e-01
50% 2.668783e-04 9.997332e-01
75% 7.314569e-01 9.999998e-01
max 1.000000e+00 1.000000e+00
In [97]:
portraits = df.loc[df['portraits'] > 0.95]
In [104]:
samples = portraits.sample(6)
html = ''
for image in samples['image'].tolist():
    html += '<a title="{0}" target="_blank" href="https://s3-ap-southeast-2.amazonaws.com/wraggetribune/images/{0}.jpg"><image style="width: 300px; height: 300px; float: left; margin: 10px; object-fit: contain;" src="https://s3-ap-southeast-2.amazonaws.com/wraggetribune/images/500/{0}-500.jpg"></a>'.format(image)
display(HTML(html))
In [103]:
protests = df.loc[df['protests'] > 0.95]
In [101]:
samples = protests.sample(6)
html = ''
for image in samples['image'].tolist():
    html += '<a title="{0}" target="_blank" href="https://s3-ap-southeast-2.amazonaws.com/wraggetribune/images/{0}.jpg"><image style="width: 300px; height: 300px; float: left; margin: 10px; object-fit: contain;" src="https://s3-ap-southeast-2.amazonaws.com/wraggetribune/images/500/{0}-500.jpg"></a>'.format(image)
display(HTML(html))
In [ ]: