`
%% A.1. Data ingestion
close all;
clear; clc;
% a. Set main path
cd('/data_code/cuttingEEG2021-TRFtutorial')
% b. Add other directories to path
addpath code
addpath code/tutorials
addpath code/libs/cnsp_utils
addpath code/libs/cnsp_utils/cnd
addpath code/libs/eeglab
addpath code/libs/mTRF-Toolbox_v2/mtrf
addpath data/LalorNatSpeech/stim
% c. Load data
disp('Loading data...')
load('./data/LalorNatSpeech/dataCND/dataStim.mat','stim');
load('./data/LalorNatSpeech/dataCND/dataSub10.mat','eeg');
%% A.2. Data preprocessing
% a. Set up highpass filter
highpass_cutoff = 1;
highpass_order = 3;
hd_hpf = getHPFilt(eeg.fs,highpass_cutoff,highpass_order);
% b. Set up lowpass filter
lowpass_cutoff = 8;
lowpass_order = 3;
hd_lpf = getLPFilt(eeg.fs,lowpass_cutoff,lowpass_order);
% c. Filter EEG recording channels
disp('Filtering recording channels...')
eeg.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.data,'UniformOutput',false);
eeg.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.data,'UniformOutput',false);
% d. Filter EEG external channels
disp('Filtering external channels...')
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);
% e. Downsample EEG data to 64 Hz
fs_new = 64;
disp('Downsampling data...')
eeg = cndDownsample(eeg,fs_new);
% f. Crop EEG to match stim length
for i = 1:numel(eeg.data)
eeg.data{i} = eeg.data{i}(1:length(stim.data{i}),🙂;
eeg.extChan{1,1}.data{i} = eeg.extChan{1,1}.data{i}(1:length(stim.data{i}),🙂;
end
% g. Interpolate bad channels
disp('Interpolating bad channels...')
if isfield(eeg,'chanlocs')
for i = 1:numel(eeg.data)
eeg.data{i} = removeBadChannels(eeg.data{i},eeg.chanlocs);
end
end
% h. Re-reference EEG data
disp('Re-referencing EEG data...')
eeg = cndReref(eeg,'Mastoids');
% i. Normalize EEG data
disp('Normalizing data...')
eeg_data_mat = cell2mat(eeg.data');
eeg_std = std(eeg_data_mat(🙂);
eeg.data = cellfun(@(x) x/eeg_std,eeg.data,'UniformOutput',false);
%% B.1. Speech envelope visualization
% a. Load auidio
[audio,fs] = audioread('audio1.wav');
% b. Take left stereo channel
audio = audio(:,1);
% c. Compute speech envelope at 128 Hz
envelope = mTRFenvelope(audio,fs,64,1,1);
% d. Compute speech envelope with compression
envelope_comp = mTRFenvelope(audio,fs,64,1,0.3);
% e. Plot envelope
figure(1)
hold on
plot((1:length(audio))/fs,audio)
plot((1:length(envelope))/64,envelope,'LineWidth',2)
plot((1:length(envelope))/64,envelope_comp,'LineWidth',2)
legend('Audio (44.1 kHz)','Env (64 Hz)','Env0.3 (64 Hz)')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
xlim([0.75,4])
%% B.2. Cross-validation
% a. Apply compression to envelope
envelopes = cellfun(@(x) x.0.3,stim.data(1,🙂,'UniformOutput',false);
% b. Define training and test sets
stim_train = envelopes;
eeg_train = eeg.data;
stim_test = envelopes;
eeg_test =eeg.data;
% c. Model hyperparameters
chan = 85; % Fz
Dir = 1;
tmin = -100;
tmax = 350;
lamda_idx = -4:2:10;
lambda_vals = 10.lamda_idx;
nlambda = numel(lambda_vals);
% d. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
'zeropad',0,'fast',1);
% e. Plot CV accuracy
figure(2)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r(:,:,chan)),std(cv.r(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1/times10^/lambda)')
ylabel('Correlation')
axis square, grid on
% f. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err(:,:,chan)),std(cv.err(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1/times10^/lambda)')
ylabel('MSE')
axis square, grid on
%% B.3. Model training
% a. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r(:,:,chan)));
lambda = lambda_vals(idx);
% b. Train model
disp('Training model...')
Emodel = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
'zeropad',0);
% c. Plot TRF weights
lim = max(max(abs(Emodel.w(:,13:18,🙂),[],3),[],2);
figure(3)
subplot(2,2,1)
plot(Emodel.t,squeeze(Emodel.w))
xlim([-50,300])
title('Temporal Response Function (TRF)')
xlabel('Time lag (ms)')
ylabel('Amplitude (a.u.)')
subplot(2,2,2)
plot(Emodel.t,std(Emodel.w,[],3))
xlim([-50,300])
title('Global Field Power (GFP)')
xlabel('Time lag (ms)')
subplot(2,2,3)
topoplot(Emodel.w(:,13,🙂,eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Emodel.t(13)),' ms'])
subplot(2,2,4)
topoplot(Emodel.w(:,18,🙂,eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Emodel.t(18)),' ms'])
%% B.4. Model testing
% a. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,Emodel,'zeropad',0);
% b. Plot prediction
figure(2)
subplot(2,2,3)
plot((1:length(eeg_test{1}))/eeg.fs,eeg_test{1}(:,chan),'linewidth',2), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1}(:,chan),'linewidth',2), hold off
xlim([0,10])
title('Prediciton')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')
% c. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r(chan))), hold off
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on
`