Source code for spidet.preprocess.preprocessing
from typing import List
import numpy as np
from loguru import logger
from spidet.domain.Trace import Trace
from spidet.preprocess.filtering import filter_signal, notch_filter_signal
from spidet.preprocess.resampling import resample_data
from spidet.preprocess.rescaling import rescale_data
[docs]
def apply_preprocessing_steps(
traces: List[Trace],
notch_freq: int,
resampling_freq: int,
bandpass_cutoff_low: int,
bandpass_cutoff_high: int,
) -> np.ndarray[np.dtype[float]]:
"""
Applies the necessary preprocessing steps to the original iEEG data. This involves:
1. Bandpass-filtering with a butterworth forward-backward filter of order 2
2. Notch-filtering
3. Rescaling
4. Resampling
Parameters
----------
traces : List[Trace]
The original iEEG data as a list of Traces objects. Each trace corresponds to the recording of single channel.
notch_freq : int
The frequency of the notch filter; data will be notch-filtered at this frequency
and at the corresponding harmonics,
e.g. notch_freq = 50 Hz -> harmonics = [50, 100, 150, etc.]
resampling_freq: int
The frequency to resample the data after filtering and rescaling
bandpass_cutoff_low : int
Cut-off frequency at the lower end of the passband of the bandpass filter.
bandpass_cutoff_high : int
Cut-off frequency at the higher end of the passband of the bandpass filter.
Returns
-------
numpy.ndarray[np.dtype[float]]
2-dimensional numpy array containing the preprocessed data where the rows correspond to the input traces.
"""
# Extract channel names
channel_names = [trace.label for trace in traces]
logger.debug(f"Channels processed by worker: {channel_names}")
# Extract data sampling freq
sfreq = traces[0].sfreq
# Extract raw data from traces
traces = np.array([trace.data for trace in traces])
# 1. Bandpass filter
logger.debug(
f"Bandpass filter data between {bandpass_cutoff_low} and {bandpass_cutoff_high} Hz"
)
bandpass_filtered = filter_signal(
sfreq=sfreq,
cutoff_freq_low=bandpass_cutoff_low,
cutoff_freq_high=bandpass_cutoff_high,
data=traces,
)
# 2. Notch filter
logger.debug(f"Apply notch filter at {notch_freq} Hz")
notch_filtered = notch_filter_signal(
eeg_data=bandpass_filtered,
notch_frequency=notch_freq,
low_pass_freq=bandpass_cutoff_high,
sfreq=sfreq,
)
# 3. Scaling channels
logger.debug("Rescale filtered data")
scaled_data = rescale_data(
data_to_be_scaled=notch_filtered, original_data=traces, sfreq=sfreq
)
# 4. Resampling data
logger.debug(f"Resample data at sampling frequency {resampling_freq} Hz")
resampled_data = resample_data(
data=scaled_data,
channel_names=channel_names,
sfreq=sfreq,
resampling_freq=resampling_freq,
)
return resampled_data