Ohshiro, Angelaki and DeAngelis (2011) Nature Neuroscience
A while ago I came across a cool paper by Ohshiro, Angelaki and DeAngelis on a model of multisensory integration [1]. I finally managed to dive into it. I think it is a beautiful piece of work, so I decided to put together a notebook on it. Also, one of my lab mates has asked me for advice on testing this model, so perhaps this notebook will be useful for people that just want to get started.
The model uses divisive normalization as a key operation, and explains a variety of observations about the response of multisensory neurons. Before moving on, I would really recommend reading the paper, for several reasons. First, I think it is really well written. Second, I will not be able to explain the model better than the authors themselves. Finally, I might have missed something important. This notebook is not intended as a self-contained document. To reinforce that, I won't write much. In fact, I will just reproduce the figures. The code will be moderately commented.
The paper is subdivided into two models. The first one - the spatial model - deals with multisensory integration of cues to spatial position. The second model deals with multisensory integration of visual and vestibular cues to heading direction.
Both models rely on the same general architecture. We start with two layers of neurons tuned to different cues, say, auditory and visual. The activity of these neurons increases sublinearly with stimulus intensity, as regularly observed in the brain. The activity of unisensory neurons is then linearly combined by neurons in the multisensory layer, and their activities are then divisively normalized.
import numpy as np
import matplotlib.pyplot as plt
%pylab inline
Populating the interactive namespace from numpy and matplotlib
Following the structure of the paper, we start with the spatial model. First, let's define some functions to (1) compute the unisensory response, (2) compute the multisensory response, and (3) apply divisive normalization.
def Gaussian(thetaHat, theta, sigma):
''' 2-d Gaussian for unisensory receptive fields '''
return np.exp(-1.0 * ((theta[0] - thetaHat[0])**2 + (theta[1] - thetaHat[1])**2) / (2*sigma**2))
def computeUniPopResp(thetaHat, theta, c):
''' computes the unisensory response '''
return c * Gaussian(thetaHat, theta, 2.0)
def computeMultiPopResp(uniPopResp, D):
''' computes the multisensory response '''
uniPopShape = np.shape(uniPopResp[0])
multiPopResp = np.zeros((uniPopShape[0], uniPopShape[1], len(D[0]), len(D[1])))
for i, d0 in enumerate(D[0]):
for j, d1 in enumerate(D[1]):
multiPopResp[:,:,i,j] = d0 * uniPopResp[0] + d1 * uniPopResp[1]
return multiPopResp
def normalizePopResp(popResp, n, alpha):
''' divisive normalization '''
popResp **= n
return popResp / (alpha**n + np.mean(popResp))
sigma = 2.0 # sd for Gaussian RFs
thetaX, thetaY = np.meshgrid(np.arange(29)+1, np.arange(29)+1) # RF positions
thetaHat = [thetaX, thetaY]
theta = [6.0, 6.0] # stimulus position
n = np.arange(3)+1 # exponents for divisive normalization
alpha = 1. # semi-saturation constant
c = np.zeros((12,)) # stimulus intensities
c[1:] = 2**np.arange(11, dtype=np.float)
cX, cY = np.meshgrid(np.arange(len(c)-1), np.arange(len(c)-1))
D = [np.linspace(0,1,5), np.linspace(0,1,5)] # modality weights [d1, d2]
R = np.zeros((len(c), len(c), len(n)))
xi, yi = 5, 5 # indices to the preferred neuron (thetaHat = theta)
di, dj = 4, 4 # to look at the d1 = d2 = 1
for i, c1 in enumerate(c):
for j, c2 in enumerate(c):
for k, nk in enumerate(n):
uniPopR = [np.sqrt(computeUniPopResp(thetaHat, theta, c1)), np.sqrt(computeUniPopResp(thetaHat, theta, c2))]
multiPopR = normalizePopResp(computeMultiPopResp(uniPopR, D), nk, alpha)
R[i, j, k] = multiPopR[xi,yi,di,dj]
fig1 = plt.figure(figsize=(8,8))
# panel a
plt.subplot(2,2,1)
plt.imshow(R[1:,1:,1], interpolation='bilinear', origin='lower', clim=(0,110))
plt.colorbar()
levels = np.linspace(np.amin(R[:,:,1]), np.amax(R[:,:,1]), 24)
plt.contour(cX, cY, R[1:,1:,1], levels, origin='lower', colors='k')
plt.xlim(0,10)
plt.ylim(0,10)
# panel b
plt.subplot(2,2,2)
plt.plot(np.diag(R[1:,1:,1]),'k.-')
plt.plot(R[0,1:,1]+2,'r.-') # +2 to add offset for visibility
plt.plot(R[1:,0,1],'b.-')
plt.plot(R[0,1:,1] + R[1:,0,1],'k--')
plt.ylim(0,200)
plt.ylabel('Firing rate')
plt.xlabel('Log intensity: input 1,2')
# panel c
addIndex = R[1:,1:,1] / (R[0,1:,1].reshape(len(c)-1,-1) + R[1:,0,1].reshape(-1,len(c)-1))
plt.subplot(2,2,3)
plt.imshow(addIndex, interpolation='bilinear', origin='lower', clim=(0,2))
plt.colorbar()
plt.contour(cX, cY, addIndex, origin='lower', colors='k')
plt.xlim(0,10)
plt.ylim(0,10)
# panel d
plt.subplot(2,2,4)
lineSpec = ['b.-','k.-','r.-']
plt.plot([0,10],[1,1],'k--')
for k in np.arange(len(n)):
plt.plot(np.diag(R[1:,1:,k])/(R[0,1:,k]+R[1:,0,k]), lineSpec[k])
plt.ylim(0.5,4.0)
plt.ylabel('Additivity index')
plt.yscale('log')
plt.yticks([0.5,1,2,4])
plt.gca().get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
plt.xlabel('Log intensity: input 1,2')
plt.tight_layout(pad=0.4, w_pad=1.0, h_pad=2.0)
plt.show()
nk = 2 # exponent for divisive normalization
alpha = 1.
theta = 6.0*np.ones((2,))
sigmaOffset = np.arange(4)*sigma # positional offset
R = np.zeros((len(c), len(c), len(sigmaOffset)))
xi, yi = 5, 5
di, dj = 4, 4
for i, c1 in enumerate(c):
for j, c2 in enumerate(c):
for k, dsigma in enumerate(sigmaOffset):
uniPopR = [np.sqrt(computeUniPopResp(thetaHat, theta, c1)),
np.sqrt(computeUniPopResp(thetaHat, theta + np.array([dsigma,0]), c2))]
multiPopR = normalizePopResp(computeMultiPopResp(uniPopR, D), nk, alpha)
R[i, j, k] = multiPopR[xi,yi,di,dj]
fig3 = plt.figure(figsize=(10,6))
for k, dsigma in enumerate(sigmaOffset):
plt.subplot(2,len(sigmaOffset),k+1)
plt.plot(np.diag(R[1:,1:,k]),'k.-')
plt.plot(R[0,1:,k],'r.-')
plt.plot(R[1:,0,k],'b.-')
plt.ylim(0,120)
plt.title(r'Offset = ' + str(k) + '$\sigma$')
if k == 0:
plt.ylabel('Firing rate')
plt.subplot(2,len(sigmaOffset),k+1+len(sigmaOffset))
plt.plot([0,10],[1,1],'k--')
plt.plot(np.diag(R[1:,1:,k])/(R[0,1:,k]+R[1:,0,k]), 'k.-')
plt.ylim(0.25,2.0)
plt.yscale('log')
plt.yticks([0.25,0.5,1,2])
plt.gca().get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
if k == 0:
plt.ylabel('Additivity Index')
plt.xlabel('Log intensity: input 1,2')
plt.tight_layout()
plt.show()
nk = 2
alpha = 1.
D = [np.linspace(0,1,5), np.linspace(0,1,5)]
di = 4
dj = np.array([4,3,2,0])
R = np.zeros((len(c), len(c), len(dj)))
xi, yi = 5, 5
for i, c1 in enumerate(c):
for j, c2 in enumerate(c):
uniPopR = [np.sqrt(computeUniPopResp(thetaHat, theta, c1)), np.sqrt(computeUniPopResp(thetaHat, theta, c2))]
multiPopR = normalizePopResp(computeMultiPopResp(uniPopR, D), nk, alpha)
for k, djk in enumerate(dj):
R[i, j, k] = multiPopR[xi,yi,di,djk]
fig4 = plt.figure(figsize=(10,6))
for k, djk in enumerate(dj):
plt.subplot(2,len(dj),k+1)
plt.plot(np.diag(R[1:,1:,k]),'k.-')
plt.plot(R[0,1:,k],'r.-')
plt.plot(R[1:,0,k],'b.-')
plt.ylim(0,120)
plt.title(r'$ d_1=1.0, \ d_2= ' + str(D[1][djk]) + '$')
if k == 0:
plt.ylabel('Firing rate')
plt.subplot(2,len(dj),k+1+len(dj))
plt.plot([0,10],[1,1],'k--')
plt.plot(np.diag(R[1:,1:,k])/(R[0,1:,k]+R[1:,0,k]), 'k.-')
plt.ylim(0.25,2.0)
plt.yscale('log')
plt.yticks([0.25,0.5,1,2])
plt.gca().get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
if k == 0:
plt.ylabel('Additivity Index')
plt.xlabel('Log intensity: input 1,2')
plt.tight_layout()
plt.show()
nk = 2
alpha = 1.
theta = 6.0*np.ones((2,))
sigmaOffset = np.arange(4)*sigma
R = np.zeros((len(c), len(c), len(sigmaOffset)))
xi, yi = 5, 5
di, dj = 4, 4
for i, c1a in enumerate(c):
for j, c1b in enumerate(c):
for k, dsigma in enumerate(sigmaOffset):
uniPopR1 = np.sqrt(computeUniPopResp(thetaHat, theta, c1a) +
computeUniPopResp(thetaHat, theta + np.array([dsigma,0]), c1b))
uniPopR = [uniPopR1, np.zeros(np.shape(uniPopR1))]
multiPopR = normalizePopResp(computeMultiPopResp(uniPopR, D), nk, alpha)
R[i, j, k] = multiPopR[xi,yi,di,dj]
fig5 = plt.figure(figsize=(10,6))
for k, dsigma in enumerate(sigmaOffset):
plt.subplot(2,len(sigmaOffset),k+1)
plt.plot(np.diag(R[1:,1:,k]),'k.-')
plt.plot(R[0,1:,k],'r.--')
plt.plot(R[1:,0,k],'r.-')
plt.ylim(0,120)
plt.title(r'Offset = ' + str(k) + '$\sigma$')
if k == 0:
plt.ylabel('Firing rate')
plt.subplot(2,len(sigmaOffset),k+1+len(sigmaOffset))
plt.plot([0,10],[1,1],'k--')
plt.plot(np.diag(R[1:,1:,k])/(R[0,1:,k]+R[1:,0,k]), 'k.-')
plt.ylim(0.25,2.0)
plt.yscale('log')
plt.yticks([0.25,0.5,1,2])
plt.gca().get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
if k == 0:
plt.ylabel('Additivity Index')
plt.xlabel('Log intensity: input 1a, 1b')
plt.tight_layout()
plt.show()
The last section of the paper looks at multisensory integration of vestibular and visual signals. I will work on these simulations at a later time.
Ohshiro T, Angelaki DE, DeAngelis GC (2011) A normalization model of multisensory integration. Nat Neurosci. 14(6):775-82. link