from typing import Dict, Tuple
import numpy as np
from loguru import logger
[docs]
class ThresholdGenerator:
"""
This class is the primary entity for computing detected events on a given single activation function or
set of activation functions.
Parameters
----------
activation_function_matrix: numpy.ndarray[numpy.dtype[float]]
A single or set of activation functions for which to compute events
preprocessed_data: np.ndarray[numpy.dtype[float]]
The preprocessed iEEG data, produced by applying the preprocessing steps listed in the preprocessing section.
sfreq: int
The sampling frequency of the data contained in the activation functions, defaults to 50 Hz.
z_threshold: int
The z-threshold used for computing the channels involved in a particular event.
"""
def __init__(
self,
activation_function_matrix: np.ndarray[np.dtype[float]],
preprocessed_data: np.ndarray[np.dtype[float]] = None,
sfreq: int = 50,
z_threshold: int = 10,
):
self.activation_function_matrix = (
activation_function_matrix
if len(activation_function_matrix.shape) > 1
else activation_function_matrix[np.newaxis, :]
)
self.preprocessed_data = preprocessed_data
self.sfreq = sfreq
self.z_threshold = z_threshold
self.thresholds = dict()
def __determine_involved_channels(
self, events_on: np.ndarray, events_off: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if self.preprocessed_data is None:
logger.warning(
"Cannot determine involved channels as preprocessed data is None"
)
return np.array([]), events_on, events_off
if len(events_on) == 0:
logger.debug(
"Cannot determine involved channels as as no events were found"
)
return np.array([]), events_on, events_off
nr_events = len(events_on)
# Return empty arrays if no events available
if nr_events == 0:
return tuple((np.array([]), np.array([]), np.array([])))
nr_channels = self.preprocessed_data.shape[0]
channels_involved = np.zeros((nr_events, nr_channels))
# Calculate background
background = np.zeros((self.preprocessed_data.shape[1]))
if events_on[0] > 1:
background[: events_on[0]] = 1
for idx in range(nr_events - 1):
background[events_off[idx] : events_on[idx + 1]] = 1
# TODO: check why np.median returns all zeros
# Get mean and standard deviation of background for each channel
median_channels = np.median(
self.preprocessed_data[:, background.nonzero()[0]], axis=1
)
std_channels = np.std(
self.preprocessed_data[:, background.nonzero()[0]], axis=1
)
# For each event determine the involved channels
for event in range(nr_events):
event_window = self.preprocessed_data[
:, events_on[event] : events_off[event]
]
# Calculate z-scores for channels along the event window
z_scores = (event_window - median_channels[:, None]) / std_channels[:, None]
# Get maximum z-scores along event window and respective indices for each channel
max_z, channel_lags = np.max(z_scores, axis=1), np.argmax(z_scores, axis=1)
# Include channels having z-score higher than z-threshold
channels = max_z > self.z_threshold
if not any(channels):
continue
# Set value to maximum lag for channels not included
not_included = np.nonzero((channels + 1) % 2)[0]
channel_lags[not_included] = np.max(channel_lags)
# Get the channel that first reaches max z-score
min_lag = np.min(channel_lags)
channels_involved[event, :] = channels * (channel_lags - min_lag + 1)
if nr_channels > 50:
# For large nr of channels, only consider events associated with multiple channels
relevant_events = [
event
for event in range(nr_events)
if np.sum(channels_involved[event]) > 1
]
else:
# Remove events not associated with any channel
relevant_events = [
event
for event in range(nr_events)
if np.sum(channels_involved[event]) > 0
]
return (
channels_involved[relevant_events, :],
events_on[relevant_events],
events_off[relevant_events],
)
[docs]
def generate_individual_thresholds(self) -> None:
"""
Computes the threshold for each individual activation function based on
:py:func:`~generate_threshold`
"""
for idx, activation_function in enumerate(self.activation_function_matrix):
threshold = self.generate_threshold(data=activation_function)
self.thresholds.update({idx: threshold})
[docs]
def generate_threshold(self, data: np.ndarray[np.dtype[float]] = None) -> float:
"""
Computes the threshold for individual activation functions. The threshold is defined as the
zero-crossing of the line that is fitted to the right of the histogram of a given
activation function.
Parameters
----------
data: np.ndarray[np.dtype[float]]
This represents the data for which to compute the threshold. If None, the threshold is computed
for the activation_function_matrix passed to the ThresholdGenerator at initialization.
Returns
-------
float
The threshold computed for either the data passed as a function argument or the activation function
passed to the ThresholdGenerator at initialization.
"""
# TODO: add doc
# Determine data to compute threshold for
data = data if data is not None else self.activation_function_matrix
# Calculate number of bins
nr_bins = min(round(0.1 * data.shape[-1]), 1000)
# Create histogram of data_matrix
hist, bin_edges = np.histogram(data, bins=nr_bins)
# TODO: check whether disregard bin 0 (Epitome)
# Get rid of bin 0
hist, bin_edges = hist[1:], bin_edges[1:]
# Smooth hist with running mean of 10 dps
hist_smoothed = np.convolve(hist, np.ones(10) / 10, mode="same")
# Smooth hist 10 times with running mean of 3 dps
for _ in range(10):
hist_smoothed = np.convolve(hist_smoothed, np.ones(3) / 3, mode="same")
# TODO: check whether disregard 10 last dp, depending on smoothing
hist, hist_smoothed, bin_edges = (
hist[:-10],
hist_smoothed[:-10],
bin_edges[:-10],
)
# Compute first differences
first_diff = np.diff(hist_smoothed, 1)
# Correct for size of result array of first difference, duplicate first value
first_diff = np.append(first_diff[0], first_diff)
# Smooth first difference matrix 10 times with running mean of 3
# data points
first_diff_smoothed = first_diff
for _ in range(10):
first_diff_smoothed = np.convolve(
first_diff_smoothed, np.ones(3) / 3, mode="same"
)
# Get first 2 indices of localized modes in hist
modes = np.nonzero(np.diff(np.sign(first_diff), 1) == -2)[0][:2]
# Get index of first mode that is at least 10 dp to the right
candidates = modes[np.where((modes > 9) & (modes < len(bin_edges) / 10))]
idx_mode = modes[0] if len(candidates) == 0 else candidates[0]
# Index of first inflection point to the right of the mode
idx_first_inf = np.argmin(first_diff_smoothed[idx_mode:])
# Get index in original hist
idx_first_inf += idx_mode - 1
# Second difference of hist
second_diff = np.diff(first_diff_smoothed, 1)
# Correct for size of result array of differentiation, duplicate first column
second_diff = np.append(second_diff[0], second_diff)
# Get index of max value in second diff to the right of the first peak
# -> corresponds to values around spikes
idx_second_peak = np.argmax(second_diff[idx_first_inf:])
# Get index in original hist
idx_second_peak += idx_first_inf - 1
# Fit a line in hist
start_idx = np.max(
[
idx_mode,
idx_first_inf
- np.rint((idx_second_peak - idx_first_inf) / 2).astype(int),
],
)
end_idx = idx_second_peak
if end_idx - start_idx <= 1:
end_idx = [end_idx + 3, start_idx + 3][
np.argmax(np.array([end_idx - start_idx + 3, 3]) > 2)
]
logger.warning(
f"End index for threshold line fit either before or too close to start index, modified to: {end_idx}"
)
threshold_fit = np.polyfit(
bin_edges[start_idx:end_idx],
hist_smoothed[start_idx:end_idx],
deg=1,
)
threshold = -threshold_fit[1] / threshold_fit[0]
return threshold
[docs]
def find_events(self, threshold: float = None) -> Dict[(int, Dict)]:
"""
Computes the events for the activation functions in the activation_function_matrix, which was
passed to the ThresholdGenerator at initialization. If the threshold argument is None,
the computation is based on the thresholds generated for each activation function
by :py:func:`~generate_individual_thresholds`
Parameters
----------
threshold: float
The threshold used to compute events for the activation_function_matrix. This can be useful e.g.
if the activation_function_matrix contains a single activation function and events need to be
computed based on a custom threshold.
Returns
-------
Dict[(int, Dict)]
A nested dictionary containing the events for each activation function. A given activation function
in the dictionary can be accessed by its respective index in the :py:attr:`activation_function_matrix`.
The events for a given activation function are represented by a dictionary containing two index arrays
corresponding to the onset, accessible by the "events_on"-key, and offset,
accessible by the "events_off"-key, indices of the events, and one binary masking array
indicating the indices of all detected events, accessible via the "event_mask"-key.
"""
# Process rows sequentially
events = dict()
for idx, activation_function in enumerate(self.activation_function_matrix):
# Determine threshold
threshold = threshold if threshold is not None else self.thresholds.get(idx)
# Create event mask indicating whether specific time point belongs to event
event_mask = activation_function > threshold
# Find starting time points of events
events_on = np.array(np.diff(np.append(0, event_mask), 1) == 1).nonzero()[0]
# Find ending time points of events (i.e. blocks of 1s)
events_off = np.array(np.diff(np.append(0, event_mask), 1) == -1).nonzero()[
0
]
# Correct for any starting event not ending within recording period
if len(events_on) > len(events_off):
events_on = events_on[:-1]
event_durations = events_off - events_on
# Consider only events having a duration of at least 20 ms
events_on = events_on[event_durations >= 0.02 * self.sfreq]
events_off = events_off[event_durations >= 0.02 * self.sfreq]
# Likewise, if gaps between events are < 40 ms, they are considered the same event
gaps = events_on[1:] - events_off[:-1]
gaps_mask = gaps >= 0.04 * self.sfreq
channel_event_assoc = []
if not len(events_on) == 0:
events_on = events_on[np.append(1, gaps_mask).nonzero()[0]]
events_off = events_off[np.append(gaps_mask, 1).nonzero()[0]]
# Add +/- 40 ms on either side of the events, zeroing out any negative values
# and upper bounding values by maximum time point
events_on = np.maximum(0, events_on - 0.04 * self.sfreq).astype(int)
events_off = np.minimum(
len(event_mask) - 1, events_off + 0.04 * self.sfreq
).astype(int)
# Merge overlapping events
gaps = events_on[1:] - events_off[:-1]
gaps_mask = gaps > 0
events_on = events_on[np.append(1, gaps_mask).nonzero()[0]]
events_off = events_off[np.append(gaps_mask, 1).nonzero()[0]]
# Determine which channels were involved in measuring which events
(
channel_event_assoc,
events_on,
events_off,
) = self.__determine_involved_channels(events_on, events_off)
# Create event mask
event_mask = np.zeros(len(activation_function))
for on, off in zip(events_on, events_off):
event_mask[on : off + 1] = 1
events.update(
{
idx: dict(
{
"events_on": events_on,
"events_off": events_off,
"event_mask": event_mask.astype(int),
"channels_involved": channel_event_assoc,
}
)
}
)
return events