import multiprocessing
from typing import List, Tuple
import numpy as np
from loguru import logger
from scipy.signal.windows import hann
from spidet.domain.ActivationFunction import ActivationFunction
from spidet.domain.Trace import Trace
from spidet.load.data_loading import DataLoader
from spidet.preprocess.preprocessing import apply_preprocessing_steps
from spidet.preprocess.resampling import resample_data
from spidet.spike_detection.thresholding import ThresholdGenerator
from spidet.utils.times_utils import compute_rescaled_timeline
[docs]
class LineLength:
"""
This class provides all operations regarding the line-length transformation.
Parameters
----------
file_path: str
Path to the file containing the iEEG data.
bad_times: numpy.ndarray[numpy.dtype[float]]
An optional N x 2 numpy array containing periods that must be excluded before applying
the line-length transformation. Each of th N rows in the array represents a period to be excluded,
defined by the start and end indices of the period in the original iEEG data.
The defined periods will be set to zero with the transitions being smoothed by applying a hanning window
to prevent spurious patterns.
dataset_paths: List[str], mandatory when the file is in .h5 format
A list of paths to the traces to be included within an h5 file. This is only necessary in the case
of h5 files.
bipolar_reference: bool, optional, default = False
If True, the bipolar references of the included channels will be computed. If channels already are
in bipolar form this needs to be False.
exclude: List[str], optional
A list of channel names that need to be excluded. This only applies in the case of .edf and .fif files.
leads: List[str]
A list of the leads included. Only necessary if bipolar_reference is True, otherwise can be None.
"""
def __init__(
self,
file_path: str,
dataset_paths: List[str] = None,
exclude: List[str] = None,
bipolar_reference: bool = False,
leads: List[str] = None,
bad_times: np.ndarray = None,
):
self.file_path = file_path
self.dataset_paths = dataset_paths
self.exclude = exclude
self.bipolar_reference = bipolar_reference
self.leads = leads
self.bad_times = bad_times
self.line_length_window: int = 40
self.line_length_freq: int = 50
[docs]
def dampen_bad_times(
self,
data: np.ndarray[np.dtype[float]],
sfreq: int,
orig_sfreq: int,
window_length: int = 100,
) -> np.ndarray:
"""
Dampens bad times within preprocessed iEEG data by setting values of bad times intervals to zero
and applying hann windows (https://en.wikipedia.org/wiki/Hann_function) around starting and ending
points in order to get smoothed transitions
Parameters
----------
data : numpy.ndarray[np.dtype[float]]
The preprocessed iEEG data.
sfreq : int
The sampling frequency of the preprocessed iEEG data.
orig_sfreq : int
The sampling frequency of the original iEEG data.
window_length : int, optional, default = 100
The length of the smoothed transition periods in milliseconds
Returns
-------
smoothed_data : numpy.ndarray[np.dtype[float]]
The preprocessed iEEG data wih artifacts being zeroed and having smoothed transition periods.
"""
if len(self.bad_times.shape) == 1:
self.bad_times = self.bad_times[np.newaxis, :]
self.bad_times = np.rint(self.bad_times * sfreq / orig_sfreq).astype(int)
# Create window
window = 2 * np.rint(window_length / 1000 * sfreq).astype(int)
# Make window length even
window = window if window % 2 == 0 else window + 1
# Create hanning window
hann_w = 1 - hann(window)
left_hann = hann_w[0 : int(window / 2)]
right_hann = hann_w[int(window / 2) + 1 :]
# Bound to limits
self.bad_times[:, 0] = np.maximum(
1 + window / 2, self.bad_times[:, 0] - window / 2
)
self.bad_times[:, 1] = np.minimum(
data.shape[1] - window / 2, self.bad_times[:, 1] + window / 2
)
# Create the hann mask matrix
hann_mask = np.ones(data.shape)
for event_idx in range(self.bad_times.shape[0]):
hann_mask[
:,
int(self.bad_times[event_idx, 0] - window / 2) : int(
self.bad_times[event_idx, 1] + window / 2
),
] = np.hstack(
(
left_hann,
np.zeros((np.diff(self.bad_times[event_idx]).astype(int)[0] + 1)),
right_hann,
)
)
return hann_mask * data
[docs]
def compute_line_length(self, eeg_data: np.ndarray, sfreq: int):
"""
Performs the line-length transformation on the input EEG data..
Parameters
----------
eeg_data : numpy.ndarray
Input EEG data.
sfreq : int
Sampling frequency of the input EEG data.
Returns
-------
numpy.ndarray[Any,
Line length representation of the input EEG data.
Notes
-----
The line length operation involves slicing the input data into evenly spaced intervals
along the time axis and processing each block separately. It computes the summed absolute
difference of the data along consecutive time points over a predefined segment. [1]_
References
----------
.. [1]
Koolen, N., Jansen, K., Vervisch, J., Matic, V., De Vos, M., Naulaers, G., & Van Huffel, S. (2014).
Line length as a robust method to detect high-activity events:
Automated burst detection in premature EEG recordings.
Clinical Neurophysiology, 125(10), 1985–1994. https://doi.org/https://doi.org/10.1016/j.clinph.2014.02.015
"""
# shape of the data: number of channels x duration
nr_channels, duration = np.shape(eeg_data)
# window size for line length calculations, default 40 ms
window = self.line_length_window
# effective window size: round to nearest even in the data points
w_eff = 2 * round(sfreq * window / 2000)
# to optimize computation, calculations are performed on intervals built from 40000 evenly spaced
# discrete time points along the duration of the signal
time_points = np.round(
np.linspace(0, duration - 1, max(2, round(duration / 40000)))
).astype(dtype=int)
line_length_eeg = np.empty((nr_channels, time_points.take(-1)))
# iterate over time points
for idx in range(len(time_points) - 1):
# extract a segment of eeg data containing the data of a single time interval
# (i.e. time_points[idx] up to time_points[idx + 1])
if idx == len(time_points) - 2:
eeg_interval = np.concatenate(
(
eeg_data[:, time_points[idx] : time_points[idx + 1]],
np.zeros((nr_channels, w_eff)),
),
axis=1,
)
else:
# add a pad to the time dimension of size w_eff
eeg_interval = np.array(
eeg_data[:, time_points[idx] : time_points[idx + 1] + w_eff]
)
# build cuboid containing w_eff number of [nr_channels, interval_length]-planes,
# where each plane is shifted by a millisecond w.r.t. the preceding plane
eeg_cuboid = np.empty((eeg_interval.shape[0], eeg_interval.shape[1], w_eff))
for j in range(w_eff):
eeg_cuboid[:, :, j] = np.concatenate(
(eeg_interval[:, j:], np.zeros((nr_channels, j))), axis=1
)
# perform line length computations
line_length_interval = np.nansum(np.abs(np.diff(eeg_cuboid, 1, 2)), 2)
# remove the pad
line_length_eeg[
:, time_points[idx] : time_points[idx + 1]
] = line_length_interval[:, : line_length_interval.shape[1] - w_eff]
# center the data a window
line_length_eeg = np.concatenate(
(
np.zeros((nr_channels, np.floor(w_eff / 2).astype(int))),
line_length_eeg[:, : -np.ceil(w_eff / 2).astype(int)],
),
axis=1,
)
return line_length_eeg
def line_length_pipeline(
self,
traces: List[Trace],
notch_freq: int,
resampling_freq: int,
bandpass_cutoff_low: int,
bandpass_cutoff_high: int,
) -> np.ndarray:
# Extract channel names
channel_names = [trace.label for trace in traces]
# Preprocess the data
preprocessed = apply_preprocessing_steps(
traces=traces,
notch_freq=notch_freq,
resampling_freq=resampling_freq,
bandpass_cutoff_low=bandpass_cutoff_low,
bandpass_cutoff_high=bandpass_cutoff_high,
)
# Zero out bad times if any
if self.bad_times is not None:
logger.debug("Dampening bad times on preprocessed EEG with hann windows")
preprocessed = self.dampen_bad_times(
data=preprocessed, sfreq=resampling_freq, orig_sfreq=traces[0].sfreq
)
# Compute line length
logger.debug("Apply line length computations")
line_length = self.compute_line_length(
eeg_data=preprocessed, sfreq=resampling_freq
)
# Downsample to line_length_freq (default 50 Hz)
logger.debug(f"Resample line length at {self.line_length_freq} Hz")
line_length_resampled_data = resample_data(
data=line_length,
channel_names=channel_names,
sfreq=resampling_freq,
resampling_freq=self.line_length_freq,
)
# Resampling produced some negative values, replace by 0
line_length_resampled_data[line_length_resampled_data < 0] = 0
return line_length_resampled_data
[docs]
def apply_parallel_line_length_pipeline(
self,
notch_freq: int = 50,
resampling_freq: int = 500,
bandpass_cutoff_low: int = 0.1,
bandpass_cutoff_high: int = 200,
line_length_freq: int = 50,
line_length_window: int = 40,
) -> Tuple[float, List[str], np.ndarray[np.dtype[float]]]:
"""
This function launches the line length pipeline, which first carries out the necessary preprocessing steps
and then performs the line-length transformation of the preprocessed EEG data. The individual steps include
1. reading the data from the provided file (supported file formats are .h5, .edf, .fif)
using the :py:mod:`~spidet.load.data_loading` module, which transforms the data
into a list of :py:mod:`~spidet.domain.Trace` objects,
2. performing the necessary preprocessing steps by means of the
:py:mod:`~spidet.preprocess.preprocessing` module,
3. and applying the line-length transformation.
To optimize computation, the channels are split into subsets and processed in parallel.
Parameters
----------
notch_freq: int, optional, default = 50
The frequency of the notch filter; data will be notch-filtered at this frequency
and at the corresponding harmonics,
e.g. notch_freq = 50 Hz -> harmonics = [50, 100, 150, etc.]
resampling_freq: int, optional, default = 500
The frequency to resample the data after filtering and rescaling
bandpass_cutoff_low: int, optional, default = 0.1
Cut-off frequency at the lower end of the passband of the bandpass filter.
bandpass_cutoff_high: int, optional, default = 200
Cut-off frequency at the higher end of the passband of the bandpass filter.
line_length_freq: int, optional, default = 50
Sampling frequency of the line-length transformed data
line_length_window: int, optional, default = 40
Window length used to for the line-length operation (in milliseconds).
Returns
-------
Tuple[float, List[str], numpy.ndarray[numpy.dtype[float]]]
Tuple containing, the start timestamp of the recording, a list of channel names
corresponding to the channels in the line-length transformed data,
the line-length transformed data
"""
# Set optional line length params
self.line_length_freq = line_length_freq
self.line_length_window = line_length_window
# Load the eeg traces from the given file
data_loader = DataLoader()
start_timestamp = None
labels = []
line_length_list = []
# Sequentially load, preprocess and line-length transform subsets of channels due to memory limitations
nr_channel_subsets = (
1 if len(self.dataset_paths) // 10 == 0 else len(self.dataset_paths) // 10
)
for channel_set in np.array_split(self.dataset_paths, nr_channel_subsets):
traces: List[Trace] = data_loader.read_file(
self.file_path,
list(channel_set),
self.exclude,
self.bipolar_reference,
self.leads,
)
# Extract the channel names
labels.extend([trace.label for trace in traces])
# Start time of the recording
start_timestamp = traces[0].start_timestamp
# Using all available cores for process pool
n_cores = multiprocessing.cpu_count()
# Define the number of parallel process used for preprocessing and line-length transformation
n_processes = min(5, len(channel_set))
with multiprocessing.Pool(processes=n_cores) as pool:
line_length = pool.starmap(
self.line_length_pipeline,
[
(
data,
notch_freq,
resampling_freq,
bandpass_cutoff_low,
bandpass_cutoff_high,
)
for data in np.array_split(traces, n_processes)
],
)
# Combine results from parallel processing
line_length_subset = np.concatenate(line_length, axis=0)
line_length_list.append(line_length_subset)
return start_timestamp, labels, np.concatenate(line_length_list, axis=0)
[docs]
def compute_unique_line_length(
self,
notch_freq: int = 50,
resampling_freq: int = 500,
bandpass_cutoff_low: int = 0.1,
bandpass_cutoff_high: int = 200,
n_processes: int = 5,
line_length_freq: int = 50,
line_length_window: int = 40,
) -> ActivationFunction:
"""
This function computes the standard deviation of the data after performing
the line-length transformation using the :py:func:`apply_parallel_line_length_pipeline` method
and wraps it into a single :py:class:`~spidet.domain.ActivationFunction` object.
The defined parameters will be passed on to the :py:func:`apply_parallel_line_length_pipeline` method.
Parameters
----------
notch_freq: int, optional, default = 50
The frequency of the notch filter; data will be notch-filtered at this frequency
and at the corresponding harmonics,
e.g. notch_freq = 50 Hz -> harmonics = [50, 100, 150, etc.]
resampling_freq: int, optional, default = 500
The frequency to resample the data after filtering and rescaling
bandpass_cutoff_low: int, optional, default = 0.1
Cut-off frequency at the lower end of the passband of the bandpass filter.
bandpass_cutoff_high: int, optional, default = 200
Cut-off frequency at the higher end of the passband of the bandpass filter.
n_processes: int, optional, default = 5
Number of parallel processes to use for the line-length pipeline
line_length_freq: int, optional, default = 50
Sampling frequency of the line-length transformed data
line_length_window: int, optional, default = 40
Window length used to for the line-length operation (in milliseconds).
Returns
-------
:py:class:`~spidet.domain.ActivationFunction`
ActivationFunction containing the standard deviation of the line-length transformed data.
"""
# Compute line length for each channel (done in parallel)
start_timestamp, _, line_length = self.apply_parallel_line_length_pipeline(
notch_freq=notch_freq,
resampling_freq=resampling_freq,
bandpass_cutoff_low=bandpass_cutoff_low,
bandpass_cutoff_high=bandpass_cutoff_high,
n_processes=n_processes,
line_length_freq=line_length_freq,
line_length_window=line_length_window,
)
# Compute standard deviation between line length channels which is our unique line length
std_line_length = np.std(line_length, axis=0)
# Compute times for x-axis
times = compute_rescaled_timeline(
start_timestamp=start_timestamp,
length=line_length.shape[1],
sfreq=line_length_freq,
)
# Generate threshold and detect periods
threshold_generator = ThresholdGenerator(
activation_function_matrix=std_line_length, sfreq=line_length_freq
)
threshold = threshold_generator.generate_threshold()
detected_periods = threshold_generator.find_events(threshold)
# Create unique id
filename = self.file_path[self.file_path.rfind("/") + 1 :]
unique_id = f"{filename[:filename.rfind('.')]}_std_line_length"
return ActivationFunction(
label="Std Line Length",
unique_id=unique_id,
times=times,
data_array=std_line_length,
detected_events_on=detected_periods.get(0)["events_on"],
detected_events_off=detected_periods.get(0)["events_off"],
event_threshold=threshold,
)