import numpy as np import matplotlib.pyplot as plt import scipy.stats as sts from celluloid import Camera from IPython.display import HTML import jax.numpy as jnp import jax.scipy as jsc from jax.nn import softmax plt.rc('figure', figsize=(10.0, 3.0), dpi=120, facecolor="w") np.random.seed(222) def kde(x, data, h): return jnp.mean(jsc.stats.norm.pdf(x.reshape(-1,1),loc=data,scale=h), axis=1) def kde_hist(events, bins, bandwidth=None, density=False): """ Args: events: (jax array-like) data to filter. bins: (jax array-like) intervals to calculate counts. bandwidth: (float) value that specifies the width of the individual distributions (kernels) whose cdfs are averaged over each bin. Defaults to Scott's rule -- the same as the scipy implementation of kde. density: (bool) whether or not to normalize the histogram to unit area. Returns: binned counts, calculated by kde! """ bandwidth = bandwidth or events.shape[-1]**-.25 # Scott's rule edge_hi = bins[1:] # ending bin edges ||<- edge_lo = bins[:-1] # starting bin edges ->|| # get cumulative counts (area under kde) for each set of bin edges cdf_up = jsc.stats.norm.cdf(edge_hi.reshape(-1,1),loc = events, scale = bandwidth) cdf_dn = jsc.stats.norm.cdf(edge_lo.reshape(-1,1),loc = events, scale = bandwidth) # sum kde contributions in each bin counts = (cdf_up - cdf_dn).sum(axis=1) if density: # normalize by bin width and counts for total area = 1 db = jnp.array(jnp.diff(bins), float) # bin spacing return counts/db/counts.sum(axis=0) return counts # This presentation is a Jupyter notebook! (thanks to https://github.com/damianavila/RISE ) 2*2 fig = plt.figure() cam = Camera(fig) plt.xlim([-1,4]) plt.axis('off') bins = np.linspace(-1,4,7) centers = bins[:-1] + np.diff(bins) / 2.0 grid = np.linspace(-1,4,500) mu_range = np.linspace(1,2,100) data = np.random.normal(size=100) truths = sts.norm(loc=mu_range.reshape(-1,1)).pdf(grid) for i,mu in enumerate(mu_range): plt.plot(grid,truths[i], color='C1') # plot true data distribution plt.hist(data+mu,bins=bins,density=True, color='C0', alpha=0.6) # histogram data plt.axvline(mu, color='slategray', linestyle=':', alpha=0.6) cam.snap() animation = cam.animate() HTML(animation.to_html5_video()) bw = 0.6 fig = plt.figure() cam = Camera(fig) plt.xlim([-1,4]) plt.axis('off') for i,mu in enumerate(mu_range): plt.plot(grid,truths[i], color='C1') plt.plot(grid,kde(grid,data+mu,h = bw),color='C9',linestyle=':') plt.bar(centers, kde_hist(data+mu,bins=bins,bandwidth=bw,density=True), color='C9', width = 5/(len(bins) - 1), alpha=0.6) # histogram data plt.axvline(mu, color='slategray', linestyle=':', alpha=0.6) cam.snap() animation = cam.animate() # animation.save('ani.gif',writer='imagemagick') HTML(animation.to_html5_video()) # Optional jax demo