In this experiment, participants were presented with flashing gratings of two different frequencies.
Stimuli were presented for 3 s with an intertrial interval of 600 ms and random jitter of ±100ms. The task was to passively fixate the center of the screen while the stimuli were displayed. Six blocks of 2 min were recorded for a single partipant.
The two gratings were presented in a "pattern reversal" way, i.e., the grating was alternated with a reversed version of itself. For more info, see Section 2 of:
Danhua Zhu, Jordi Bieger, Gary Garcia Molina, and Ronald M. Aarts, "A Survey of Stimulation Methods Used in SSVEP-Based BCIs," Computational Intelligence and Neuroscience, vol. 2010, 12 pages, 2010.
This led to stimulation frequencies of 30 and 20 Hz, respectively.
Moreover, an external electrode was added to the Muse headband to collect data at POz.
import os
import sys
from collections import OrderedDict
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
%matplotlib inline
sys.path.append('../muse_lsl/muse')
import utils
subject = 1
session = 1
raw = utils.load_data('visual/SSVEP', sfreq=256.,
subject_nb=subject, session_nb=session,
ch_ind=[0, 1, 2, 3, 4],
replace_ch_names={'Right AUX': 'POz'})
Creating RawArray with float64 data, n_channels=6, n_times=30732 Range : 0 ... 30731 = 0.000 ... 120.043 secs Ready. Creating RawArray with float64 data, n_channels=6, n_times=30732 Range : 0 ... 30731 = 0.000 ... 120.043 secs Ready. Creating RawArray with float64 data, n_channels=6, n_times=30720 Range : 0 ... 30719 = 0.000 ... 119.996 secs Ready. Creating RawArray with float64 data, n_channels=6, n_times=30732 Range : 0 ... 30731 = 0.000 ... 120.043 secs Ready. Creating RawArray with float64 data, n_channels=6, n_times=30732 Range : 0 ... 30731 = 0.000 ... 120.043 secs Ready. Creating RawArray with float64 data, n_channels=6, n_times=30720 Range : 0 ... 30719 = 0.000 ... 119.996 secs Ready.
We epoch the raw data according to the conditions, and visualize the spectrum of each condition. We expect clear peaks in the spectral domain at the stimulation frequencies of 30 and 20 Hz.
from mne import Epochs, find_events
events = find_events(raw)
event_id = {'30 Hz': 1, '20 Hz': 2}
epochs = Epochs(raw, events=events, event_id=event_id,
tmin=-0.5, tmax=4, baseline=None, preload=True,
verbose=False, picks=[0, 1, 2, 3, 4])
197 events found Events id: [1 2]
epochs['30 Hz']
<Epochs | n_events : 87 (all good), tmin : -0.5 (s), tmax : 4.0 (s), baseline : None, ~3.8 MB, data loaded>
from mne.time_frequency import psd_welch
f, axs = plt.subplots(2, 1, figsize=(10, 10))
psd1, freq1 = psd_welch(epochs['30 Hz'], n_fft=1028, n_per_seg=256 * 3)
psd2, freq2 = psd_welch(epochs['20 Hz'], n_fft=1028, n_per_seg=256 * 3)
psd1 = 10 * np.log10(psd1)
psd2 = 10 * np.log10(psd2)
psd1_mean = psd1.mean(0)
psd1_std = psd1.mean(0)
psd2_mean = psd2.mean(0)
psd2_std = psd2.mean(0)
axs[0].plot(freq1, psd1_mean[[0, 3], :].mean(0), color='b', label='30 Hz')
axs[0].plot(freq2, psd2_mean[[0, 3], :].mean(0), color='r', label='20 Hz')
axs[1].plot(freq1, psd1_mean[4, :], color='b', label='30 Hz')
axs[1].plot(freq2, psd2_mean[4, :], color='r', label='20 Hz')
axs[0].set_title('TP9 and TP10')
axs[1].set_title('POz')
axs[0].set_ylabel('Power Spectral Density (dB)')
axs[1].set_ylabel('Power Spectral Density (dB)')
axs[0].set_xlim((2, 50))
axs[1].set_xlim((2, 50))
axs[1].set_xlabel('Frequency (Hz)')
axs[0].legend()
axs[1].legend()
plt.show()
Effective window size : 4.016 (s) Effective window size : 4.016 (s)
We can see the expected peaks in both temporo-parietal and posterior electrodes.
This can also be seen using a time-frequency plot:
from mne.time_frequency import tfr_morlet
frequencies = np.logspace(1, 1.75, 60)
tfr, itc = tfr_morlet(epochs['30 Hz'], freqs=frequencies,
n_cycles=15, return_itc=True)
tfr.plot(picks=[4], baseline=(-0.5, -0.1), mode='logratio',
title='POz - 30 Hz stim');
tfr, itc = tfr_morlet(epochs['20 Hz'], freqs=frequencies,
n_cycles=15, return_itc=True)
tfr.plot(picks=[4], baseline=(-0.5, -0.1), mode='logratio',
title='POz - 20 Hz stim');
Applying baseline correction (mode: logratio)
Applying baseline correction (mode: logratio)
We filter data between 2 and 50 Hz.
raw_filt = raw.copy().filter(5, 50, method='iir')
Setting up band-pass filter from 5 - 50 Hz
Here we epoch data for -100 ms to 3 s after the stimulus. No baseline correction is needed (signal is bandpass filtered) and we reject every epochs were the signal exceeds 100 uV.
events = find_events(raw)
event_id = {'30 Hz': 1, '20 Hz': 2}
epochs = Epochs(raw_filt, events=events, event_id=event_id,
tmin=-0.3, tmax=3, baseline=None,
reject={'eeg': 100e-6}, preload=True,
verbose=False, picks=[0,4,2,3])
197 events found Events id: [1 2]
conditions = OrderedDict()
conditions['30 Hz'] = [1]
conditions['20 Hz'] = [2]
fig, ax = utils.plot_conditions(epochs, conditions=conditions,
ci=97.5, n_boot=1, title='',
diff_waveform=None, ylim=(-8, 8))
The stimulation pattern can also be seen in the temporal domain.
We use a filter bank approach on the original 4 Muse electrodes (to see how the headband alone without external electrodes could be used to classify SSVEP):
# Bandpass filter the raw data
muse_raw = raw.drop_channels(['POz'])
raw_filt_30Hz = muse_raw.copy().filter(25, 35, method='iir')
raw_filt_20Hz = muse_raw.copy().filter(15, 25, method='iir')
raw_filt_30Hz.rename_channels(lambda x: x + '_30Hz')
raw_filt_20Hz.rename_channels(lambda x: x + '_20Hz')
# Concatenate with the bandpass filtered channels
raw_all = raw_filt_30Hz.add_channels([raw_filt_20Hz],
force_update_info=True)
# Extract epochs
events = find_events(raw_all)
event_id = {'30 Hz': 1, '20 Hz': 2}
epochs_all = Epochs(raw_all, events=events, event_id=event_id, tmin=1,
tmax=3, baseline=None, reject={'eeg': 100e-6},
preload=True, verbose=False, add_eeg_ref=False)
Setting up band-pass filter from 25 - 35 Hz Setting up band-pass filter from 15 - 25 Hz 197 events found Events id: [1 2]
epochs_all.pick_types(eeg=True)
X = epochs_all.get_data() * 1e6
times = epochs.times
y = epochs_all.events[:, -1]
from sklearn.pipeline import make_pipeline
from mne.decoding import Vectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import cross_val_score, StratifiedShuffleSplit
from pyriemann.estimation import Covariances, ERPCovariances, XdawnCovariances
from pyriemann.spatialfilters import CSP
from pyriemann.tangentspace import TangentSpace
from pyriemann.classification import MDM
from collections import OrderedDict
clfs = OrderedDict()
clfs['CSP + RegLDA'] = make_pipeline(Covariances(), CSP(4), LDA(shrinkage='auto', solver='eigen'))
clfs['Cov + TS'] = make_pipeline(Covariances(), TangentSpace(), LogisticRegression())
clfs['Cov + MDM'] = make_pipeline(Covariances(), MDM())
clfs['CSP + Cov + TS'] = make_pipeline(Covariances(), CSP(4, log=False), TangentSpace(), LogisticRegression())
# define cross validation
cv = StratifiedShuffleSplit(n_splits=20, test_size=0.25,
random_state=42)
# run cross validation for each pipeline
auc = []
methods = []
for m in clfs:
print(m)
try:
res = cross_val_score(clfs[m], X, y==2, scoring='roc_auc',
cv=cv, n_jobs=-1)
auc.extend(res)
methods.extend([m]*len(res))
except:
pass
results = pd.DataFrame(data=auc, columns=['AUC'])
results['Method'] = methods
CSP + RegLDA Cov + TS Cov + MDM CSP + Cov + TS
fig = plt.figure(figsize=[8,4])
sns.barplot(data=results, x='AUC', y='Method')
plt.xlim(0.4, 1)
sns.despine()