import numpy as np
import pandas as pd
import time
from sklearn.datasets.samples_generator import make_blobs
from sklearn.preprocessing import StandardScaler
from perceptronutils import *
import plotly.graph_objects as go
import plotly.express as px
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
sns.set_style('whitegrid')
sns.set_context('paper',font_scale=2)
%matplotlib inline
w = [4,2]
w0 = 1
x1min,x1max = [-10,10]
x2min,x2max = [-10,10]
x11,x12 = x1min,(-w0-w[0]*x1min)/w[1]
x21,x22 = x1max,(-w0-w[0]*x1max)/w[1]
x31,x32 = (-w0-w[1]*x2min)/w[0],x2min
x41,x42 = (-w0-w[1]*x2max)/w[0],x2max
plt.plot([x11,x21,x31,x41],[x12,x22,x32,x42])
plt.xlim([-5,5])
plt.ylim([-5,5])
plt.xlabel('x1')
plt.ylabel('x2')
Text(0, 0.5, 'x2')
Below is the implementation of the 0-1
loss function.
def zerooneobjfunc(w,X,y):
return np.sum(y*np.dot(X,w) < 0)
We will consider a 2-D data set. The decision boundary (or the line) will be represented using three parameters ($w_0, w_1, w_2$). We will evaluate the objective function for different combinations of the three parameters and then plot the objective function as a function of only the last two parameters (Since we cannot plot a 4D plot)
def zerooneobjfunc(w,X,y):
return np.sum(y*np.dot(X,w) < 0)
# consider a simple 2-D data set
# prepare data
# make_blobs is a function provided by sklearn to generate synthetic data
X, y = make_blobs(n_samples=50, centers=2, n_features=2, random_state=0)
X = StandardScaler().fit_transform(X)
X_i = np.hstack([np.ones((X.shape[0],1)),X])
y[y == 0] = -1
y = y[:,np.newaxis]
plt.scatter(X[:,0],X[:,1],c=y.flatten(),cmap='jet')
<matplotlib.collections.PathCollection at 0x7fa0fcb3ada0>
w1s = w2s = np.linspace(-200,200,200)
W1,W2 = np.meshgrid(w1s,w2s)
losses = []
w0 = 10
for w1,w2 in zip(W1.flatten(),W2.flatten()):
w = np.array([[w0],[w1],[w2]])
#losses.append(computeLoss(X,y,w))
losses.append(zerooneobjfunc(w,X_i,y))
losses = np.array(losses)
J = np.reshape(losses,W1.shape)
cs = 'oranges'
fig = go.Figure(data=[go.Surface(z=J, x=W1, y=W2,
contours = {"x": {"show": True},
"y": {"show": True}
},
showscale=False,
colorscale=cs
)])
fig.update_layout(title='Loss function', autosize=False,
width=500, height=700,
margin=dict(l=65, r=50, b=65, t=90))
# an alternative no-plotly way of generating the 3d surface
fig = plt.figure(figsize=(14,8))
ax = fig.add_subplot(projection='3d')
ax.plot_surface(W1, W2, J, rstride=1, cstride=1,
cmap='autumn_r', edgecolor='none')
Here we study the different types of loss functions and how they are different from each other. In particular, the 0-1
loss function is the one that we aspire to be. However, as shown later, the objective function that uses the 0-1
loss function is hard to optimize over.
In the example below, we are assuming that we have only one training example, whose label is +1
.
def zerooneloss(f,y):
if f*y > 0:
return 0
else:
return 1
def squaredloss(f,y):
return 0.5*np.power(f-y,2)
def logisticloss(f,y):
return np.log(1+np.exp(-f*y))
#return np.exp(-f*y)
def hingedloss(f,y):
return max(0,1-f*y)
wtx = np.linspace(-1,3,500)
lossfuncs = {'0-1':zerooneloss,'squared':squaredloss,'logistic':logisticloss,'hinged':hingedloss}
losses = {}
for l,func in lossfuncs.items():
loss = []
for f in wtx:
loss.append(func(f,1))
losses[l] = loss
df_loss = pd.DataFrame.from_dict(losses,orient='columns')
df_loss['x'] = wtx
fig = px.line(df_loss, x='x', y=list(lossfuncs.keys()))
fig.show()
The conclusion from above is that while 0-1
loss is the loss that we want to optimize over, its shape is not conducive for optimization methods. The other loss-functions approximate this step function with other functions that are easier to optimize over.
def plotBoundary(X,y,w,ax):
h = .02 # step size in the mesh
X1_min, X1_max = X[:, 0].min() - .5, X[:, 0].max() + .5
X2_min, X2_max = X[:, 1].min() - .5, X[:, 1].max() + .5
X1_, X2_ = np.meshgrid(np.arange(X1_min, X1_max, h),np.arange(X2_min, X2_max, h))
Xpred = np.c_[np.ones(len(X1_.ravel())),X1_.ravel(), X2_.ravel()]
ypred = np.dot(Xpred,w)
ypred_ = np.zeros(ypred.shape)
ypred_[ypred >= 0] = 1
ypred_[ypred < 0] = -1
ypred = ypred_.reshape(X1_.shape)
cm = plt.cm.RdBu
cm_bright = ListedColormap(['#FF0000', '#0000FF'])
ax.pcolormesh(X1_, X2_, ypred, cmap=cm_bright,alpha=.1)
sp = ax.scatter(X[:, 0], X[:, 1], c=y.flatten(), cmap=cm_bright)
ax.set_xlim(X1_.min(), X1_.max())
ax.set_ylim(X2_.min(), X2_.max())
ax.set_xticks(())
ax.set_yticks(())
# prepare data
# make_blobs is a function provided by sklearn to generate synthetic data
X, y = make_blobs(n_samples=20, centers=2, n_features=2, random_state=0)
X = StandardScaler().fit_transform(X)
y[y == 0] = -1
y = y[:,np.newaxis]
plt.scatter(X[:,0],X[:,1],c=y.flatten(),cmap='jet')
<matplotlib.collections.PathCollection at 0x7fa0e7f51f60>
eta = 0.02
# initialize w
winit = np.array([1,1,1])
winit = winit[:,np.newaxis]
w = winit
losses = []
mistakes = []
numiters = 20
for iter in range(numiters):
print("After iteration %d"%iter)
# compute loss
losses.append(computeLoss(X,y,w))
# compute number of mistakes
mistakes.append(computeMistakes(X,y,w))
# compute gradient
grad = computeGradient(X,y,w)
print(losses)
print(mistakes)
print(grad)
print(w)
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(numiters,3,3*iter+1)
# plot current boundary
plotBoundary(X,y,w,ax)
ax.set_title('Iteration %d'%iter)
# plot losses
ax = fig.add_subplot(numiters,3,3*iter+2)
ax.plot(range(len(losses)),losses,'-+')
ax.set_xlim([0,numiters])
ax.set_title('Loss %.2f'%losses[iter])
# plot mistakes
ax = fig.add_subplot(numiters,3,3*iter+3)
ax.plot(range(len(mistakes)),mistakes,'-o')
ax.set_xlim([0,numiters])
ax.set_title('Mistakes %d'%mistakes[iter])
# update weight
w = w - eta*grad
After iteration 0 [44.191805849197046] [13] [[20. ] [ 7.01238981] [30.29938462]] [[1] [1] [1]] After iteration 1 [44.191805849197046, 21.731461613080167] [13, 12] [[12. ] [ 8.37664825] [19.14454002]] [[0.6 ] [0.8597522 ] [0.39401231]] After iteration 2 [44.191805849197046, 21.731461613080167, 11.999223026605373] [13, 12, 11] [[ 7.2 ] [ 7.66028968] [12.63935607]] [[0.36 ] [0.69221924] [0.01112151]]
/Users/chandola/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:15: MatplotlibDeprecationWarning: shading='flat' when X and Y have the same dimensions as C is deprecated since 3.3. Either specify the corners of the quadrilaterals with X and Y, or pass shading='auto', 'nearest' or 'gouraud', or set rcParams['pcolor.shading']. This will become an error two minor releases later.
After iteration 3 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645] [13, 12, 11, 5] [[4.32 ] [6.33535715] [8.63767432]] [[ 0.216 ] [ 0.53901345] [-0.24166561]] After iteration 4 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593] [13, 12, 11, 5, 3] [[2.592 ] [4.9897637 ] [6.05435371]] [[ 0.1296 ] [ 0.4123063] [-0.4144191]] After iteration 5 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365] [13, 12, 11, 5, 3, 2] [[1.5552 ] [3.8269411 ] [4.31920685]] [[ 0.07776 ] [ 0.31251103] [-0.53550617]] After iteration 6 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212] [13, 12, 11, 5, 3, 2, 2] [[0.93312 ] [2.89049023] [3.11811361]] [[ 0.046656 ] [ 0.23597221] [-0.62189031]] After iteration 7 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357] [13, 12, 11, 5, 3, 2, 2, 1] [[0.559872 ] [2.16334853] [2.26860144]] [[ 0.0279936 ] [ 0.1781624 ] [-0.68425258]] After iteration 8 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333] [13, 12, 11, 5, 3, 2, 2, 1, 1] [[0.3359232 ] [1.6101701 ] [1.65883898]] [[ 0.01679616] [ 0.13489543] [-0.72962461]] After iteration 9 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1] [[0.20155392] [1.19435935] [1.21686381]] [[ 0.0100777 ] [ 0.10269203] [-0.76280139]] After iteration 10 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1] [[0.12093235] [0.88405683] [0.89446288]] [[ 0.00604662] [ 0.07880484] [-0.78713867]] After iteration 11 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1] [[0.07255941] [0.65351275] [0.6583245 ]] [[ 0.00362797] [ 0.06112371] [-0.80502793]] After iteration 12 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457, 2.638469882765738] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1, 1] [[0.04353565] [0.48269352] [0.48491847]] [[ 0.00217678] [ 0.04805345] [-0.81819442]] After iteration 13 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457, 2.638469882765738, 2.6303051785588023] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1] [[0.02612139] [0.3563412 ] [0.35737001]] [[ 0.00130607] [ 0.03839958] [-0.82789279]] After iteration 14 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457, 2.638469882765738, 2.6303051785588023, 2.6258687264980822] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1] [[0.01567283] [0.26297906] [0.26345478]] [[ 7.83641641e-04] [ 3.12727559e-02] [-8.35040186e-01]] After iteration 15 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457, 2.638469882765738, 2.6303051785588023, 2.6258687264980822, 2.6234570658846366] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1] [[0.0094037 ] [0.19403898] [0.19425895]] [[ 4.70184985e-04] [ 2.60131748e-02] [-8.40309281e-01]] After iteration 16 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457, 2.638469882765738, 2.6303051785588023, 2.6258687264980822, 2.6234570658846366, 2.6221457145434703] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [[0.00564222] [0.14315354] [0.14325526]] [[ 2.82110991e-04] [ 2.21323952e-02] [-8.44194460e-01]] After iteration 17 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457, 2.638469882765738, 2.6303051785588023, 2.6258687264980822, 2.6234570658846366, 2.6221457145434703, 2.621432528231017] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [[0.00338533] [0.10560414] [0.10565117]] [[ 1.69266594e-04] [ 1.92693244e-02] [-8.47059566e-01]] After iteration 18 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457, 2.638469882765738, 2.6303051785588023, 2.6258687264980822, 2.6234570658846366, 2.6221457145434703, 2.621432528231017, 2.6210446096491333] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [[0.0020312 ] [0.07790015] [0.0779219 ]] [[ 1.01559957e-04] [ 1.71572417e-02] [-8.49172589e-01]] After iteration 19 [44.191805849197046, 21.731461613080167, 11.999223026605373, 7.4083963034860645, 5.123261518332593, 3.9477813093369365, 3.330777561898212, 3.002801544037357, 2.827053907769333, 2.7323874239582304, 2.6812217120510984, 2.653505622169457, 2.638469882765738, 2.6303051785588023, 2.6258687264980822, 2.6234570658846366, 2.6221457145434703, 2.621432528231017, 2.6210446096491333, 2.620833594481525] [13, 12, 11, 5, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [[0.00121872] [0.05746219] [0.05747225]] [[ 6.09359740e-05] [ 1.55992387e-02] [-8.50731027e-01]]