Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update EEGExtract for modern day libraries #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 45 additions & 33 deletions EEGExtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import pandas as pd
import pywt
from scipy import stats, signal, integrate
from dit.other import tsallis_entropy
import dit
import librosa
import statsmodels.api as sm
import itertools
from pyinform import mutualinfo
Expand Down Expand Up @@ -83,22 +80,28 @@ def burst_supression_detection(x,fs,suppression_threshold = 10):
high = high / nyq
be, ae = signal.butter(order, [low, high], btype='band')
'''
#rng = np.random.default_rng().random()
# CALCULATE ENVELOPE
e = abs(signal.hilbert(x,axis=1));
# same as smooth(e,Fs/4) in MATLAB, apply 1/2 second smoothing
ME = np.array([np.convolve(el,np.ones(int(fs/4))/(fs/4),'same') for el in e.tolist()])
# same as smooth(e,Fs/4) in MATLAB, apply 1/8 second smoothing
ME = np.array([np.convolve(el,np.ones(int(fs/8))/(fs/8),'same') for el in e.tolist()])
e = ME
#print("Thresh:", suppression_threshold, "\n\tSTD:", np.std(ME), "\n\tMean:", np.mean(ME))
# DETECT SUPRESSIONS
# apply threshold -- 10uv
z = (ME<suppression_threshold)
#print(f"{rng} xShape: " + str(x.shape) + " z:" + str(z))
# remove too-short suppression segments
z = fcnRemoveShortEvents(z,fs/2)
z = fcnRemoveShortEvents(z,fs/4)
# remove too-short burst segments
b = fcnRemoveShortEvents(1-z,fs/2)
b = fcnRemoveShortEvents(1-z,fs/4)
z = 1-b
#print(f"{rng}\tz-post:" + str(z))
went_high = [np.where(np.array(chD[:-1]) < np.array(chD[1:]))[0].tolist() for chD in z.tolist()]
went_low = [np.where(np.array(chD[:-1]) > np.array(chD[1:]))[0].tolist() for chD in z.tolist()]

#print(f"{rng}\twent-high:" + str(went_high))
#print(f"{rng}\twent-low:" + str(went_low))
bursts = get_intervals(went_high,went_low)
supressions = get_intervals(went_low,went_high)

Expand Down Expand Up @@ -196,23 +199,32 @@ def shannonEntropy(eegData, bin_min, bin_max, binWidth):

##########
# Extract the tsalis Entropy
def tsalisEntropy(eegData, bin_min, bin_max, binWidth, orders = [1]):
H = [np.zeros((eegData.shape[0], eegData.shape[2]))]*len(orders)
for chan in range(H[0].shape[0]):
for epoch in range(H[0].shape[1]):
counts, bins = np.histogram(eegData[chan,:,epoch], bins=np.arange(-200+1, 200, 2))
dist = dit.Distribution([str(bc).zfill(5) for bc in bins[:-1]],counts/sum(counts))
for ii,order in enumerate(orders):
H[ii][chan,epoch] = tsallis_entropy(dist,order)
def tsalisEntropy(eegData, bin_min, bin_max, binWidth, orders=[1]):
num_orders = len(orders)
H = [np.zeros((eegData.shape[0], eegData.shape[2])) for _ in range(num_orders)]

for chan in range(eegData.shape[0]):
for epoch in range(eegData.shape[2]):
counts, bins = np.histogram(eegData[chan, :, epoch], bins=np.arange(bin_min, bin_max, binWidth))
probs = counts / np.sum(counts)
for ii, order in enumerate(orders):
q = order
if q == 1:
H[ii][chan, epoch] = -np.sum(probs * np.log(probs + np.finfo(float).eps))
else:
H[ii][chan, epoch] = (1 / (1 - q)) * (np.sum(probs ** q) - 1)
return H


##########
# Cepstrum Coefficients (n=2)
def mfcc(eegData,fs,order=2):
def mfcc(eegData,fs,order=2,enable=False):
H = np.zeros((eegData.shape[0], eegData.shape[2],order))
if not enable: return H # Disable until future notice
for chan in range(H.shape[0]):
for epoch in range(H.shape[1]):
H[chan, epoch, : ] = librosa.feature.mfcc(np.asfortranarray(eegData[chan,:,epoch]), sr=fs)[0:order].T
break
#H[chan, epoch, : ] = librosa.feature.mfcc(np.asfortranarray(eegData[chan,:,epoch]), sr=fs)[0:order].T
return H

##########
Expand Down Expand Up @@ -335,13 +347,13 @@ def falseNearestNeighbor(eegData, fast=True):
return out

##########
# ARMA coefficients
def arma(eegData,order=2):
# ARIMA coefficients
def arima(eegData,order=2):
H = np.zeros((eegData.shape[0], eegData.shape[2],order))
for chan in range(H.shape[0]):
for epoch in range(H.shape[1]):
arma_mod = sm.tsa.ARMA(eegData[chan,:,epoch], order=(order,order))
arma_res = arma_mod.fit(trend='nc', disp=-1)
arma_mod = tsa.arima.model.ARIMA(eegData[chan,:,epoch], order=(order,0,1), trend='ct')
arma_res = arma_mod.fit()
H[chan, epoch, : ] = arma_res.arparams
return H

Expand Down Expand Up @@ -373,8 +385,8 @@ def spikeNum(eegData,minNumSamples=7,stdAway = 3):
for chan in range(H.shape[0]):
for epoch in range(H.shape[1]):
mean = np.mean(eegData[chan, :, epoch])
std = np.std(eegData[chan,:,epoch],axis=1)
H[chan,epoch] = len(signal.find_peaks(abs(eegData[chan,:,epoch]-mean), 3*std,epoch,width=7)[0])
std = np.std(eegData[chan,:,epoch])
H[chan,epoch] = len(signal.find_peaks(abs(eegData[chan,:,epoch]-mean), 3*std,epoch,width=minNumSamples)[0])
return H

##########
Expand Down Expand Up @@ -426,14 +438,14 @@ def eegVoltage(eegData,voltage=20):
# Diffuse Slowing
# look for diffuse slowing (bandpower max from frequency domain integral)
# repeated integration of a huge tensor is really expensive
def diffuseSlowing(eegData, Fs=100, fast=True):
def diffuseSlowing(eegData, Fs=100, fast=True, debug=False):
maxBP = np.zeros((eegData.shape[0], eegData.shape[2]))
idx = np.zeros((eegData.shape[0], eegData.shape[2]))
if fast:
return idx
for j in range(1, Fs//2):
print("BP", j)
cbp = bandpower(eegData, Fs, [j-1, j])
for j in range(1, Fs//2 - 1):
if debug: print("BP", j)
cbp = bandPower(eegData, fs=Fs, lowcut=j-0.5, highcut=j+0.5)
biggerCIdx = cbp > maxBP
idx[biggerCIdx] = j
maxBP[biggerCIdx] = cbp[biggerCIdx]
Expand Down Expand Up @@ -482,11 +494,11 @@ def shortSpikeNum(eegData,minNumSamples=7,stdAway = 3):

##########
# Number of Bursts
def numBursts(eegData,fs):
def numBursts(eegData,fs,suppression_threshold=10):
bursts = []
supressions = []
for epoch in range(eegData.shape[2]):
epochBurst,epochSupressions = burst_supression_detection(eegData[:,:,epoch],fs,suppression_threshold=10)#,low=30,high=49)
epochBurst,epochSupressions = burst_supression_detection(eegData[:,:,epoch],fs,suppression_threshold=suppression_threshold)#,low=30,high=49)
bursts.append(epochBurst)
supressions.append(epochSupressions)
# Number of Bursts
Expand All @@ -498,11 +510,11 @@ def numBursts(eegData,fs):

##########
# Burst length μ and σ
def burstLengthStats(eegData,fs):
def burstLengthStats(eegData,fs,suppression_threshold=10):
bursts = []
supressions = []
for epoch in range(eegData.shape[2]):
epochBurst,epochSupressions = burst_supression_detection(eegData[:,:,epoch],fs,suppression_threshold=10)#,low=30,high=49)
epochBurst,epochSupressions = burst_supression_detection(eegData[:,:,epoch],fs,suppression_threshold=suppression_threshold)#,low=30,high=49)
bursts.append(epochBurst)
supressions.append(epochSupressions)
# Number of Bursts
Expand All @@ -518,12 +530,12 @@ def burstLengthStats(eegData,fs):

##########
# Burst band powers (δ, α, θ, β, γ)
def burstBandPowers(eegData, lowcut, highcut, fs, order=7):
def burstBandPowers(eegData, lowcut, highcut, fs, order=7, suppression_threshold=10):
band_burst_powers = np.zeros((eegData.shape[0], eegData.shape[2]))
bursts = []
supressions = []
for epoch in range(eegData.shape[2]):
epochBurst,epochSupressions = burst_supression_detection(eegData[:,:,epoch],fs,suppression_threshold=10)#,low=30,high=49)
epochBurst,epochSupressions = burst_supression_detection(eegData[:,:,epoch],fs,suppression_threshold=suppression_threshold)#,low=30,high=49)
bursts.append(epochBurst)
supressions.append(epochSupressions)
eegData_band = filt_data(eegData, lowcut, highcut, fs, order=7)
Expand Down