import mahotas as mh
from mahotas import polygon
import pymorph as pm

import networkx as nx

from scipy import ndimage as nd
import skimage.transform as transform
import skimage.io as sio
import scipy.misc as sm

import os
import math

def branchedPoints(skel, showSE=True):
    X=[]
    #cross X
    X0 = np.array([[0, 1, 0], 
                   [1, 1, 1], 
                   [0, 1, 0]])
    X1 = np.array([[1, 0, 1], 
                   [0, 1, 0], 
                   [1, 0, 1]])
    X.append(X0)
    X.append(X1)
    #T like
    T=[]
    #T0 contains X0
    T0=np.array([[2, 1, 2], 
                 [1, 1, 1], 
                 [2, 2, 2]])
            
    T1=np.array([[1, 2, 1], 
                 [2, 1, 2],
                 [1, 2, 2]])  # contains X1
  
    T2=np.array([[2, 1, 2], 
                 [1, 1, 2],
                 [2, 1, 2]])
    
    T3=np.array([[1, 2, 2],
                 [2, 1, 2],
                 [1, 2, 1]])
    
    T4=np.array([[2, 2, 2],
                 [1, 1, 1],
                 [2, 1, 2]])
    
    T5=np.array([[2, 2, 1], 
                 [2, 1, 2],
                 [1, 2, 1]])
    
    T6=np.array([[2, 1, 2],
                 [2, 1, 1],
                 [2, 1, 2]])
    
    T7=np.array([[1, 2, 1],
                 [2, 1, 2],
                 [2, 2, 1]])
    T.append(T0)
    T.append(T1)
    T.append(T2)
    T.append(T3)
    T.append(T4)
    T.append(T5)
    T.append(T6)
    T.append(T7)
    #Y like
    Y=[]
    Y0=np.array([[1, 0, 1], 
                 [0, 1, 0], 
                 [2, 1, 2]])
    
    Y1=np.array([[0, 1, 0], 
                 [1, 1, 2], 
                 [0, 2, 1]])
    
    Y2=np.array([[1, 0, 2], 
                 [0, 1, 1], 
                 [1, 0, 2]])
    
    Y2=np.array([[1, 0, 2], 
                 [0, 1, 1], 
                 [1, 0, 2]])
    
    Y3=np.array([[0, 2, 1], 
                 [1, 1, 2], 
                 [0, 1, 0]])
    
    Y4=np.array([[2, 1, 2], 
                 [0, 1, 0], 
                 [1, 0, 1]])
    Y5=np.rot90(Y3)
    Y6 = np.rot90(Y4)
    Y7 = np.rot90(Y5)
    Y.append(Y0)
    Y.append(Y1)
    Y.append(Y2)
    Y.append(Y3)
    Y.append(Y4)
    Y.append(Y5)
    Y.append(Y6)
    Y.append(Y7)
    
    bp = np.zeros(skel.shape, dtype=int)
    for x in X:
        bp = bp + mh.morph.hitmiss(skel,x)
    for y in Y:
        bp = bp + mh.morph.hitmiss(skel,y)
    for t in T:
        bp = bp + mh.morph.hitmiss(skel,t)
        
    if showSE==True:
        fig = plt.figure(figsize=(4,5))
        tX =['X0','X1']
        tY =['Y'+str(i) for i in range(0,8)]
        tT =['T'+str(i) for i in range(0,8)]
        ti= tX+tY+tT
        SE=X+Y+T
        print len(SE), len(ti)
        n = 1
        ti = iter(ti)
        for se in SE:
            #print next(ti)
            #print se
            mycmap = mpl.colors.ListedColormap(['black','blue','red'])
            ax = fig.add_subplot(4,5,n,frameon=False, xticks=[], yticks=[])
            title(str(next(ti)))
            imshow(se, interpolation='nearest',vmin=0,vmax=2,cmap=mycmap)
            n = n+1
        fig.subplots_adjust(hspace=0.1,wspace=0.08)
        #ax_cb = fig.add_axes([.9,.25,.1,.3])#
        color_vals=[0,1,2]
        #cb = mpl.colorbar.ColorbarBase(ax_cb,cmap=mycmap, ticks=color_vals)
        #cb.set_ticklabels(['back', 'hit', 'don\'t care'])
        
        plt.show()
    return bp

def endPoints(skel):
    endpoint1=np.array([[0, 0, 0],
                        [0, 1, 0],
                        [2, 1, 2]])
    
    endpoint2=np.array([[0, 0, 0],
                        [0, 1, 2],
                        [0, 2, 1]])
    
    endpoint3=np.array([[0, 0, 2],
                        [0, 1, 1],
                        [0, 0, 2]])
    
    endpoint4=np.array([[0, 2, 1],
                        [0, 1, 2],
                        [0, 0, 0]])
    
    endpoint5=np.array([[2, 1, 2],
                        [0, 1, 0],
                        [0, 0, 0]])
    
    endpoint6=np.array([[1, 2, 0],
                        [2, 1, 0],
                        [0, 0, 0]])
    
    endpoint7=np.array([[2, 0, 0],
                        [1, 1, 0],
                        [2, 0, 0]])
    
    endpoint8=np.array([[0, 0, 0],
                        [2, 1, 0],
                        [1, 2, 0]])
    
    ep1=mh.morph.hitmiss(skel,endpoint1)
    ep2=mh.morph.hitmiss(skel,endpoint2)
    ep3=mh.morph.hitmiss(skel,endpoint3)
    ep4=mh.morph.hitmiss(skel,endpoint4)
    ep5=mh.morph.hitmiss(skel,endpoint5)
    ep6=mh.morph.hitmiss(skel,endpoint6)
    ep7=mh.morph.hitmiss(skel,endpoint7)
    ep8=mh.morph.hitmiss(skel,endpoint8)
    ep = ep1+ep2+ep3+ep4+ep5+ep6+ep7+ep8
    return ep

def pruning(skeleton, size):
    '''remove iteratively end points "size" 
       times from the skeleton
    '''
    for i in range(0, size):
        endpoints = endPoints(skeleton)
        endpoints = np.logical_not(endpoints)
        skeleton = np.logical_and(skeleton,endpoints)
    return skeleton

a = np.array([[0,0,0,0,0,0],
              [0,0,1,0,1,0],
              [0,1,1,0,1,0],
              [0,0,0,1,0,0],
              [0,0,1,0,1,0],
              [0,1,0,0,0,0]])
lab,_ = mh.label(a>0)
sk =mh.thin(a)
print a.dtype, sk.dtype
bp = branchedPoints(a>0)
h = mh.labeled.labeled_size(bp)
ep = endPoints(a)
subplot(141)
title('skeleton')
imshow(a,interpolation='nearest')

subplot(142)
title('label')
imshow(lab,interpolation='nearest')
subplot(143)
title('junction')
imshow(bp,interpolation='nearest')
subplot(144)
title('end-points')
imshow(ep,interpolation='nearest')

mh.__version__