import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import seaborn as sns
from IPython.display import HTML
def animateLists(xlist,ylist):
#https://stackoverflow.com/questions/46236902/redrawing-seaborn-figures-for-animations
lim = (-3,3)
x, y = xlist[:2],ylist[:2]
g = sns.JointGrid(x=x, y=y, height=4,xlim=lim,ylim=lim)
g.ax_joint.plot(x[0],x[1],'*C1',ms=20,markeredgewidth=1,markeredgecolor='k')
def prep_axes(g, xlim, ylim):
g.ax_joint.clear()
g.ax_joint.set_xlim(xlim)
g.ax_joint.set_ylim(ylim)
g.ax_marg_x.clear()
g.ax_marg_x.set_xlim(xlim)
g.ax_marg_y.clear()
g.ax_marg_y.set_ylim(ylim)
plt.setp(g.ax_marg_x.get_xticklabels(), visible=False)
plt.setp(g.ax_marg_y.get_yticklabels(), visible=False)
plt.setp(g.ax_marg_x.yaxis.get_majorticklines(), visible=False)
plt.setp(g.ax_marg_x.yaxis.get_minorticklines(), visible=False)
plt.setp(g.ax_marg_y.xaxis.get_majorticklines(), visible=False)
plt.setp(g.ax_marg_y.xaxis.get_minorticklines(), visible=False)
plt.setp(g.ax_marg_x.get_yticklabels(), visible=False)
plt.setp(g.ax_marg_y.get_xticklabels(), visible=False)
def animate(i):
prep_axes(g, lim, lim)
x,y=xlist[:2+i*10],ylist[:2+i*10]
g.x, g.y =x,y
sns.scatterplot(x=x,y=y,ax=g.ax_joint)
sns.histplot(y=y,ax=g.ax_marg_y,kde=True)
sns.histplot(x=x,ax=g.ax_marg_x,kde=True)
g.ax_joint.plot(x[0],x[1],'*C1',ms=20,markeredgewidth=1,markeredgecolor='k')
ani = animation.FuncAnimation(g.fig, animate, frames=100, repeat=True)
return ani
def logf(x):
sigma = np.eye(2)
return -np.sqrt((2*np.pi)**2 *1) - 0.5*(x.T@sigma@x)
xs = np.array([ (x,y) for x in np.linspace(-2,2,51) for y in np.linspace(-2,2,51)])
ys = np.array([np.exp(logf(x)) for x in xs])
plt.imshow(ys.reshape(51,51),cmap=sns.light_palette("Navy", as_cmap=True))
plt.colorbar()
plt.xlabel("x")
plt.ylabel("y")
plt.xticks(range(0,51,10),np.round(np.linspace(-2,2,51)[::10],1))
plt.yticks(range(0,51,10),np.round(np.linspace(-2,2,51)[::10],1))
plt.show()
x = np.array([2*np.random.random()-1,2*np.random.random()-1])
logf_old = logf(x)
xlist,ylist = [],[]
#paramaters
N = 10_000
step = 0.1
accept = 0
for n in range(N):
xlist.append(x[0]); ylist.append(x[1])
#take random (normal) step
xnew = np.random.multivariate_normal(x,[[step,0],[0,step]])
logf_new = logf(xnew)
if np.log(np.random.random()) < logf_new - logf_old:
x = xnew
logf_old = logf_new
accept+=1
print("acceptance:",accept/N)
acceptance: 0.8427
sns.jointplot(x=xlist[len(xlist)//2:],y=ylist[len(ylist)//2:],
ylim=(-3, 3), xlim=(-3, 3),marginal_kws=dict(kde=True),
)
plt.show()
x=xlist[len(xlist)//2:]
y=ylist[len(ylist)//2:]
lim = (-3,3)
g = sns.JointGrid(x=x, y=y, height=4,xlim=lim,ylim=lim)
# sns.scatterplot(x=x,y=y,ax=g.ax_joint)
sns.histplot(x=x, y=y, ax=g.ax_joint,)
sns.histplot(y=y,ax=g.ax_marg_y,kde=True)
sns.histplot(x=x,ax=g.ax_marg_x,kde=True)
plt.show()
ani=animateLists(xlist,ylist)
HTML(ani.to_jshtml())
def logf(x):
r = 1.5
sigma = 0.01
rx= np.linalg.norm(x)
return -(r-rx)**2/sigma
xs = np.array([ (x,y) for x in np.linspace(-3,3,501) for y in np.linspace(-3,3,501)])
ys = np.array([np.exp(logf(x)) for x in xs])
plt.imshow(ys.reshape(501,501),cmap=sns.light_palette("Navy", as_cmap=True))
plt.colorbar()
plt.xlabel("x")
plt.ylabel("y")
plt.xticks(range(0,501,100),np.round(np.linspace(-3,3,501)[::100],1))
plt.yticks(range(0,501,100),np.round(np.linspace(-3,3,501)[::100],1))
plt.show()
x0 = np.array([2*np.random.random()-1,2*np.random.random()-1])
x = x0.copy()
logf_old = logf(x)
xlist,ylist = [],[]
#paramaters
N = 10_000
step = 0.1
accept = 0
for n in range(N):
xlist.append(x[0]); ylist.append(x[1])
#take random (normal) step
xnew = np.random.multivariate_normal(x,[[step,0],[0,step]])
logf_new = logf(xnew)
if np.log(np.random.random()) < logf_new - logf_old:
x = xnew
logf_old = logf_new
accept+=1
print(&q