%matplotlib inline

SAVE_FIGURES = False

import cv2
import matplotlib.pyplot as plt

import lxml.html
import requests
import urlparse
import posixpath
import itertools

from sklearn import preprocessing
from sklearn import svm
from sklearn.cross_validation import train_test_split
from sklearn.decomposition import PCA
from sklearn.grid_search import GridSearchCV
from sklearn.cross_validation import ShuffleSplit
from sklearn.metrics import confusion_matrix

import requests
import tempfile
import cv2
from PIL import Image
import pandas as pd
import numpy as np

def load_gif_url(url):
    with tempfile.NamedTemporaryFile(suffix=".gif") as f:
        f.write(requests.get(url).content)
        f.flush()
        img = Image.open(f.name)

    with tempfile.NamedTemporaryFile(suffix=".png") as f:
        img.save(f.name)
        f.flush()
        src = cv2.imread(f.name)

    assert src is not None and len(src), "Empty"

    return src

def show_bw(bw, frameon=None):
    if frameon is not None:
        plt.gca().set_frame_on(frameon)

    plt.imshow(bw, cmap='gray')
    _ = plt.xticks([]), plt.yticks([])


def get_bw(src):
    return cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)

text = requests.get("http://www.50states.com/us.htm").text
doc = lxml.html.document_fromstring(text)

states = []
for a in doc.findall(".//ul[@class='bulletedList']/li/a"):
    url = a.get("href")
    state_name = posixpath.splitext(posixpath.split(urlparse.urlsplit(url).path)[-1])[0]
    states.append(state_name)
    
def make_url(state):
    return "http://www.50states.com/maps/%s.gif" % state

def get_state_color(state, dilate=True):
    url = make_url(state)
    
    IN = load_gif_url(url)

    #Drop the text at the top
    IN = IN[150:]

    #Convert 3 color channels to 1
    IN_bw = get_bw(IN)
  
    #invert colors (per docs for findContour)
    IMG = 255-IN_bw
    
    # This seems to bre required for mass
    if dilate:
        kernel = np.ones((3,3),np.uint8)
        IMG = cv2.dilate(IMG,kernel,iterations = 1)
    
    img_out = IMG.copy()
    contours, hierarchy = cv2.findContours(img_out, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    threshold = 0.02

    img = 255*np.ones(IN.shape, dtype=np.uint8)
    for i in xrange(len(contours)):
        cnt = contours[i]
        cnt_len = cv2.arcLength(cnt, True)
        cc = cv2.approxPolyDP(cnt, threshold * cnt_len, True)

        area = cv2.contourArea(cc)

        if cnt_len > 50 and area > 500:
            cv2.drawContours(img, contours, i, (0,0,0),thickness=cv2.cv.CV_FILLED) 

    return img

show_bw(get_bw(get_state_color("massachusetts", False)))

if SAVE_FIGURES:
    plt.savefig("/home/alex/git/octopress/source/images/post_images/2014-05-13-map-recognition/massachusetts-fail.png", 
                bbox_inches="tight", pad_inches=0.01, transparent=True)

state_images_color = {}
for state in states:
    state_images_color[state] = get_state_color(state)

state_images = {}
for state,img_color in state_images_color.iteritems():
    state_images[state] = get_bw(img_color)

plt.figure(1, figsize=(100, 100))
for i,state in enumerate(sorted(state_images.keys())):
    state_image = state_images[state]
    plt.subplot(7,8,i+1)
    show_bw(state_image)
    plt.title(state,fontsize=100)

if SAVE_FIGURES:
    plt.savefig("/home/alex/git/octopress/source/images/post_images/2014-05-13-map-recognition/states.png", bbox_inches="tight", pad_inches=0, dpi=20)

show_bw(state_images["texas"])

if SAVE_FIGURES:
    plt.savefig("/home/alex/git/octopress/source/images/post_images/2014-05-13-map-recognition/texas.png", 
                bbox_inches="tight", pad_inches=0.01, transparent=True)

def rotate(input_img, angle):
    rot = cv2.getRotationMatrix2D((input_img.shape[0]/2, input_img.shape[1]/2), angle, 1.0)
    #For some reason we have to swtich the output size dimensions here:
    rotated = cv2.warpAffine(input_img, rot, (input_img.shape[1], input_img.shape[0]), borderValue=255)
    return rotated

assert np.allclose(state_images['texas'], rotate(state_images['texas'], 0))
assert np.allclose(state_images['texas'], rotate(rotate(state_images['texas'], 180), 180))

def rescale(input_img, scale):
    return cv2.resize(input_img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)

assert np.allclose(state_images['texas'], rescale(state_images['texas'], 1))

def translate(input_img, x, y):
    M = np.float32([[1,0,x],[0,1,y]])
    #For some reason we have to swtich the output size dimensions here:
    new_shape = (input_img.shape[1] + max(0, x), input_img.shape[0] + max(0, y))
    dst = cv2.warpAffine(input_img, M, new_shape, borderValue=255)
    return dst

assert np.allclose(state_images['texas'], translate(state_images['texas'], 0, 0))

show_bw(translate(rotate(state_images["texas"],66), 200, 0))

if SAVE_FIGURES:
    plt.savefig("/home/alex/git/octopress/source/images/post_images/2014-05-13-map-recognition/texas-rotate.png", 
                bbox_inches="tight", pad_inches=0.015, transparent=True)

def get_binary(img):
    return (img==0).astype(int)

def get_hu_from_binary(img_binary):
    return cv2.HuMoments(cv2.moments(img_binary, True))

def get_all_hus(state):
    state_img = state_images[state]
    
    angles = range(1, 359+1)
    hus = [get_hu_from_binary(get_binary(rotate(state_img, angle))) for angle in angles]
    rotation_hus = pd.DataFrame.from_records([[x[0] for x in hu] for hu in hus])
    rotation_hus["Method"] = "Rotate"

    scales = [0.5, 1, 1.25, 2, 4]
    hus = [get_hu_from_binary(get_binary(rescale(state_img, scale))) for scale in scales]
    scale_hus = pd.DataFrame.from_records([[x[0] for x in hu] for hu in hus])
    scale_hus["Method"] = "Scale"

    translations = [0, 10, 50, 100, 200]
    translations_all = itertools.product(translations, translations)
    hus = [get_hu_from_binary(get_binary(translate(state_img, dx, dy))) for dx, dy in translations_all]
    translation_hus = pd.DataFrame.from_records([[x[0] for x in hu] for hu in hus])
    translation_hus["Method"] = "Translate"
    
    all_hus = pd.concat([rotation_hus, scale_hus, translation_hus])

    all_hus["State"] = state

    return all_hus

state_hus = {}
for state in states:
    #print state
    state_hus[state] = get_all_hus(state)

all_state_hus = pd.concat(state_hus.values())
items = all_state_hus

features = range(0,6+1)

plt.figure(1, figsize=(10, 10), )
for i in features:
    plt.subplot(4,2,i+1)
    items[i].hist()
plt.tight_layout()

pca = PCA(n_components=2, whiten=True)

X = preprocessing.scale(all_state_hus[features])

pcaed = pca.fit_transform(X)
plt.scatter(pcaed[:,0], pcaed[:,1])

X = items[features]
y = items["State"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

scaler = preprocessing.StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

#Change the dtype from object to string (required by sklearn)
y_train = y_train.astype(str)
y_test = y_test.astype(str)

clf = svm.SVC(kernel="linear")
clf.fit(X_train, y_train)
clf.score(X_test, y_test)

clf = svm.SVC()
clf.fit(X_train, y_train)
clf.score(X_test, y_test)

clf = svm.SVC(gamma=50)
clf.fit(X_train, y_train)
clf.score(X_test, y_test)

clf = svm.SVC(gamma=10000)
clf.fit(X_train, y_train)
clf.score(X_test, y_test)

def log_ftr(ftr):
    return np.log(ftr.abs()) * np.sign(ftr)

plt.figure(1, figsize=(10, 10), )
for i in features:
    plt.subplot(4,2,i+1)
    log_ftr(items[i]).hist()
plt.tight_layout()

pca = PCA(n_components=2, whiten=True)

X = preprocessing.scale(log_ftr(all_state_hus[features]))

pcaed = pca.fit_transform(X)
plt.scatter(pcaed[:,0], pcaed[:,1])

X = log_ftr(items[features])
y = items["State"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

scaler = preprocessing.StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

#Change the dtype from object to string (required by sklearn)
y_train = y_train.astype(str)
y_test = y_test.astype(str)

clf = svm.SVC(kernel="linear")
clf.fit(X_train, y_train)
clf.score(X_test, y_test)

clf = svm.SVC()
clf.fit(X_train, y_train)
clf.score(X_test, y_test)

labels = list(sorted(set(y)))

y_pred = clf.predict(X_test)

# Compute confusion matrix
cm = confusion_matrix(y_test.astype(str), y_pred.astype(str), labels=labels)

# Show confusion matrix in a separate window
fig = plt.figure(figsize=(8,8), dpi=5000)

ax1=fig.add_subplot(111)
ax1.set_frame_on(False)

plt.matshow(np.log(cm), 1)
#plt.title('Confusion matrix')
#plt.colorbar()
plt.ylabel('True label')
plt.xlabel('Predicted label')

plt.xticks(xrange(0, len(labels)), labels, rotation=90, fontsize="small")
plt.yticks(xrange(0, len(labels)), labels,fontsize="small")

if SAVE_FIGURES:
    plt.savefig("/home/alex/git/octopress/source/images/post_images/2014-05-13-map-recognition/confusion-final.png", bbox_inches="tight", pad_inches=0)

plt.show()

def get_raw_hus(img):
    return pd.Series([x[0] for x in get_hu_from_binary(get_binary(img))])

def blur_img(img, blur):
    if blur == 0:
        return img
    return cv2.blur(img, (blur,blur))

def gblur_img(img, blur):
    if blur == 0:
        return img
    return cv2.GaussianBlur(img, ksize=(blur,blur), sigmaX=0)

blur_score = []
for blur in range(0, 50):
    score = clf.score(pd.DataFrame([
        scaler.transform(log_ftr(get_raw_hus(blur_img(state_images[state],blur)))) for state in labels]), labels)
    blur_score.append((blur, score))

pd.DataFrame(blur_score, columns=["blur", "score"]).set_index("blur").plot(figsize=(10,8),
            title="Score Degradation Due to Blur")

if SAVE_FIGURES:
    plt.savefig("/home/alex/git/octopress/source/images/post_images/2014-05-13-map-recognition/blur.png", bbox_inches="tight", pad_inches=0)

gblur_score = []
for blur in range(1, 50, 2):
    score = clf.score(pd.DataFrame([
        scaler.transform(log_ftr(get_raw_hus(
            gblur_img(state_images[state],blur)))) for state in labels]), labels)
    gblur_score.append((blur, score))

pd.DataFrame(gblur_score, columns=["blur", "score"]).set_index("blur").plot(figsize=(10,8),
            title="Score Degradation Due to Gaussian Blur")

if SAVE_FIGURES:
    plt.savefig("/home/alex/git/octopress/source/images/post_images/2014-05-13-map-recognition/gblur.png", bbox_inches="tight", pad_inches=0)