import numpy as np import matplotlib.pyplot as plt from scipy.stats import zscore from sklearn.decomposition import PCA # @title Figure settings from matplotlib import rcParams rcParams['figure.figsize'] = [20, 4] rcParams['font.size'] = 15 rcParams['axes.spines.top'] = False rcParams['axes.spines.right'] = False rcParams['figure.autolayout'] = True # @title Data retrieval import os, requests fname = [] for j in range(3): fname.append('steinmetz_part%d.npz'%j) url = ["https://osf.io/agvxh/download"] url.append("https://osf.io/uv3mw/download") url.append("https://osf.io/ehmw2/download") for j in range(len(url)): if not os.path.isfile(fname[j]): try: r = requests.get(url[j]) except requests.ConnectionError: print("!!! Failed to download data !!!") else: if r.status_code != requests.codes.ok: print("!!! Failed to download data !!!") else: with open(fname[j], "wb") as fid: fid.write(r.content) # @title Data loading alldat = np.array([]) for j in range(len(fname)): alldat = np.hstack((alldat, np.load('steinmetz_part%d.npz'%j, allow_pickle=True)['dat'])) # Make a plot of which brain areas are present in each dataset # note that region 4 ("other ctx" are neurons that were not able to be classified) # region 4 does not correspond to brain_group 4, which are all cortical neurons outside of visual cortex regions = ["vis ctx", "thal", "hipp", "other ctx", "midbrain", "basal ganglia", "cortical subplate", "other"] region_colors = ['blue', 'red', 'green', 'darkblue', 'violet', 'lightblue', 'orange', 'gray'] brain_groups = [["VISa", "VISam", "VISl", "VISp", "VISpm", "VISrl"], # visual cortex ["CL", "LD", "LGd", "LH", "LP", "MD", "MG", "PO", "POL", "PT", "RT", "SPF", "TH", "VAL", "VPL", "VPM"], # thalamus ["CA", "CA1", "CA2", "CA3", "DG", "SUB", "POST"], # hippocampal ["ACA", "AUD", "COA", "DP", "ILA", "MOp", "MOs", "OLF", "ORB", "ORBm", "PIR", "PL", "SSp", "SSs", "RSP","TT"], # non-visual cortex ["APN", "IC", "MB", "MRN", "NB", "PAG", "RN", "SCs", "SCm", "SCig", "SCsg", "ZI"], # midbrain ["ACB", "CP", "GPe", "LS", "LSc", "LSr", "MS", "OT", "SNr", "SI"], # basal ganglia ["BLA", "BMA", "EP", "EPd", "MEA"] # cortical subplate ] # Assign each area an index area_to_index = dict(root=0) counter = 1 for group in brain_groups: for area in group: area_to_index[area] = counter counter += 1 # Figure out which areas are in each dataset areas_by_dataset = np.zeros((counter, len(alldat)), dtype=bool) for j, d in enumerate(alldat): for area in np.unique(d['brain_area']): i = area_to_index[area] areas_by_dataset[i, j] = True # Show the binary matrix plt.figure(figsize=(8, 10)) plt.imshow(areas_by_dataset, cmap="Greys", aspect="auto", interpolation="none") # Label the axes plt.xlabel("dataset") plt.ylabel("area") # Add tick labels yticklabels = ["root"] for group in brain_groups: yticklabels.extend(group) plt.yticks(np.arange(counter), yticklabels, fontsize=8) plt.xticks(np.arange(len(alldat)), fontsize=9) # Color the tick labels by region ytickobjs = plt.gca().get_yticklabels() ytickobjs[0].set_color("black") counter = 1 for group, color in zip(brain_groups, region_colors): for area in group: ytickobjs[counter].set_color(color) counter += 1 plt.title("Brain areas present in each dataset") plt.grid(True) plt.show() # @title Basic plots of population average # select just one of the recordings here. 11 is nice because it has some neurons in vis ctx. dat = alldat[11] print(dat.keys()) dt = dat['bin_size'] # binning at 10 ms NT = dat['spks'].shape[-1] ax = plt.subplot(1, 5, 1) response = dat['response'] # right - nogo - left (-1, 0, 1) vis_right = dat['contrast_right'] # 0 - low - high vis_left = dat['contrast_left'] # 0 - low - high plt.plot(dt * np.arange(NT), 1/dt * dat['spks'][:, response >= 0].mean(axis=(0, 1))) # left responses plt.plot(dt * np.arange(NT), 1/dt * dat['spks'][:, response < 0].mean(axis=(0, 1))) # right responses plt.plot(dt * np.arange(NT), 1/dt * dat['spks'][:, vis_right > 0].mean(axis=(0, 1))) # stimulus on the right plt.plot(dt * np.arange(NT), 1/dt * dat['spks'][:, vis_right == 0].mean(axis=(0, 1))) # no stimulus on the right plt.legend(['left resp', 'right resp', 'right stim', 'no right stim'], fontsize=12) ax.set(xlabel='time (sec)', ylabel='firing rate (Hz)') plt.show() nareas = 4 # only the top 4 regions are in this particular mouse NN = len(dat['brain_area']) # number of neurons barea = nareas * np.ones(NN, ) # last one is "other" for j in range(nareas): barea[np.isin(dat['brain_area'], brain_groups[j])] = j # assign a number to each region # @title plots by brain region and visual conditions for j in range(nareas): ax = plt.subplot(1, nareas, j + 1) plt.plot(1/dt * dat['spks'][barea==j][:, np.logical_and(vis_left == 0, vis_right > 0)].mean(axis=(0, 1))) plt.plot(1/dt * dat['spks'][barea==j][:, np.logical_and(vis_left > 0, vis_right == 0)].mean(axis=(0, 1))) plt.plot(1/dt * dat['spks'][barea==j][:, np.logical_and(vis_left == 0, vis_right == 0)].mean(axis=(0, 1))) plt.plot(1/dt * dat['spks'][barea==j][:, np.logical_and(vis_left > 0, vis_right > 0)].mean(axis=(0, 1))) plt.text(.25, .92, 'n=%d'%np.sum(barea == j), transform=ax.transAxes) if j==0: plt.legend(['right only', 'left only', 'neither', 'both'], fontsize=12) ax.set(xlabel='binned time', ylabel='mean firing rate (Hz)', title=regions[j]) plt.show() # @title plots by brain region and response type for j in range(nareas): ax = plt.subplot(1, nareas, j + 1) plt.title(regions[j]) if np.sum(barea == j) == 0: continue plt.plot(1/dt * dat['spks'][barea == j][:, response < 0].mean(axis=(0, 1))) plt.plot(1/dt * dat['spks'][barea == j][:, response > 0].mean(axis=(0, 1))) plt.plot(1/dt * dat['spks'][barea == j][:, response == 0].mean(axis=(0, 1))) if j == 0: plt.legend(['resp = left', 'resp = right', 'resp = none'], fontsize=12) ax.set(xlabel='time', ylabel='mean firing rate (Hz)') plt.show() # @title top PC directions from stimulus + response period, with projections of the entire duration droll = np.reshape(dat['spks'][:, :, 51:130], (NN, -1)) # first 80 bins = 1.6 sec droll = droll - np.mean(droll, axis=1)[:, np.newaxis] model = PCA(n_components=5).fit(droll.T) W = model.components_ pc_10ms = W @ np.reshape(dat['spks'], (NN, -1)) pc_10ms = np.reshape(pc_10ms, (5, -1, NT)) # @title The top PCs capture most variance across the brain. What do they care about? plt.figure(figsize= (20, 6)) for j in range(len(pc_10ms)): ax = plt.subplot(2, len(pc_10ms) + 1, j + 1) pc1 = pc_10ms[j] plt.plot(pc1[np.logical_and(vis_left == 0, vis_right > 0), :].mean(axis=0)) plt.plot(pc1[np.logical_and(vis_left > 0, vis_right == 0), :].mean(axis=0)) plt.plot(pc1[np.logical_and(vis_left == 0, vis_right == 0), :].mean(axis=0)) plt.plot(pc1[np.logical_and(vis_left > 0, vis_right > 0), :].mean(axis=0)) if j == 0: plt.legend(['right only', 'left only', 'neither', 'both'], fontsize=8) ax.set(xlabel = 'binned time', ylabel='mean firing rate (Hz)') plt.title('PC %d'%j) ax = plt.subplot(2, len(pc_10ms) + 1, len(pc_10ms) + 1 + j + 1) plt.plot(pc1[response > 0, :].mean(axis=0)) plt.plot(pc1[response < 0, :].mean(axis=0)) plt.plot(pc1[response == 0, :].mean(axis=0)) if j == 0: plt.legend(['resp = left', 'resp = right', 'resp = none'], fontsize=8) ax.set(xlabel='binned time', ylabel='mean firing rate (Hz)') plt.title('PC %d'%j) plt.show() # @title now sort all trials by response latency and see if the PCs care about that. isort = np.argsort(dat['response_time'].flatten()) for j in range(len(pc_10ms)): ax = plt.subplot(1, len(pc_10ms) + 1, j + 1) pc1 = zscore(pc_10ms[j]) plt.imshow(pc1[isort, :], aspect='auto', vmax=2, vmin=-2, cmap='gray') ax.set(xlabel='binned time', ylabel='trials sorted by latency') plt.title('PC %d'%j) plt.show() # @title correct vs incorrect trials # the following are the correct responses: # if vis_left > vis_right : response >0 # if vis_left < vis_right : response <0 # if vis_left = vis_right : response =0 # trials below red line are incorrect is_correct = np.sign(response) == np.sign(vis_left - vis_right) # sort by correct, and then by response isort = np.argsort(-is_correct.astype('float32') + response/10) nwrong = np.sum(is_correct) for j in range(len(pc_10ms)): ax = plt.subplot(1, len(pc_10ms) + 1, j + 1) pc1 = zscore(pc_10ms[j]) plt.imshow(pc1[isort, :], aspect='auto', vmax=2, vmin=-2, cmap='gray') ax.set(xlabel='binned time') if j == 0: ax.set(ylabel='trials sorted by latency') plt.title('PC %d'%j) plt.plot([0, NT], [nwrong, nwrong], 'r') plt.show() # plot the behavioral data (pupil area is noisy because it's very small) ax = plt.subplot(1, 5, 1) plt.plot(dat['pupil'][0, :].mean(0)); ax.set(ylabel='pupil area', xlabel='binned time', title='Pupil dynamics') yl = [-10, 10] ax = plt.subplot(1, 5, 2) plt.plot(dat['wheel'][0, response > 0].mean(0)) ax.set(ylim=yl) ax.set(ylim=yl, ylabel='wheel turns', xlabel='binned time', title='Left choices') ax = plt.subplot(1, 5, 3) plt.plot(dat['wheel'][0, response<0].mean(0)) ax.set(ylim=yl) ax.set(ylim=yl, ylabel='wheel turns', xlabel='binned time', title='Right choices') ax = plt.subplot(1, 5, 4) plt.plot(dat['wheel'][0, response==0].mean(0)) ax.set(ylim=yl, ylabel='wheel turns', xlabel='binned time', title='No go choices') plt.show() # plots by brain region and visual conditions for PASSIVE trials vis_left_p = dat['contrast_left_passive'] vis_right_p = dat['contrast_right_passive'] for j in range(nareas): ax = plt.subplot(1, nareas, j + 1) plt.title(regions[j]) plt.plot(1/dt * dat['spks_passive'][barea == j][:, np.logical_and(vis_left_p == 0, vis_right_p > 0)].mean(axis=(0, 1))) plt.plot(1/dt * dat['spks_passive'][barea == j][:, np.logical_and(vis_left_p > 0, vis_right_p == 0)].mean(axis=(0, 1))) #plt.plot(1/dt * dat['spks_passive'][barea == j][:, np.logical_and(vis_left_p == 0, vis_right_p == 0)].mean(axis=(0, 1))) plt.plot(1/dt * dat['spks_passive'][barea == j][:, np.logical_and(vis_left_p > 0, vis_right_p > 0)].mean(axis=(0, 1))) plt.text(.25, .92, 'n=%d'%np.sum(barea == j), transform=ax.transAxes) if j == 0: plt.legend(['right only', 'left only', 'both'], fontsize=12) ax.set(xlabel='binned time', ylabel='mean firing rate (Hz)') plt.show() # for more variables check out the additional notebook (load_steinmetz_extra) which includes LFP, waveform shapes and exact spike times (non-binned)