Source code for spidet.spike_detection.spike_detection_pipeline

import logging
import multiprocessing
import os
from datetime import datetime
from typing import Tuple, List, Dict

import numpy as np
import pandas as pd
from loguru import logger
from scipy.special import rel_entr
from sklearn.preprocessing import normalize
from pathlib import Path

from spidet.utils import logging_utils

from spidet.domain.BasisFunction import BasisFunction
from spidet.domain.ActivationFunction import ActivationFunction
from spidet.spike_detection.clustering import BasisFunctionClusterer
from spidet.spike_detection.line_length import LineLength
from spidet.spike_detection.nmf import Nmf
from spidet.spike_detection.thresholding import ThresholdGenerator
from spidet.utils.times_utils import compute_rescaled_timeline
from spidet.utils.plotting_utils import plot_w_and_consensus_matrix


[docs] class SpikeDetectionPipeline: """ This class builds the heart of the automatic-spike-detection library. It provides an end-to-end pipeline that takes in a path to a file containing an iEEG recording and returns periods of abnormal activity. The pipeline is a multistep process that includes 1. reading the data from the provided file (supported file formats are .h5, .edf, .fif) and transforming the data into a list of :py:mod:`~spidet.domain.Trace` objects, 2. performing the necessary preprocessing steps by means of the :py:mod:`~spidet.preprocess.preprocessing` module, 3. applying the line-length transformation using the :py:mod:`~spidet.spike_detection.line_length` module, 4. performing Nonnegative Matrix Factorization to extract the most discriminating metappatterns, done by the :py:mod:`~spidet.spike_detection.nmf` module and 5. computing periods of abnormal activity by means of the :py:mod:`~spidet.spike_detection.thresholding` module. Parameters ---------- file_path: str Path to the file containing the iEEG data. results_dir: str Path to the directory where the folder containing the results of the spike detection run should be saved. If None, the results folder will be saved in the user's home directory. By default, the results contain the metrics representing the computation of the optimal rank and two plots for both the basis functions and the consensus matrices of the different ranks. save_nmf_matrices: bool If True, in addition to the results saved by default, the W matrix containing the basis functions and the H matrix containing the activation functions for each rank, the line-length transformed data and the standard deviation of the line length are saved. sparseness: float A floating point number :math:`\in [0, 1]`. If this parameter is non-zero, nonnegative matrix factorization is run with sparseness constraints. bad_times: numpy.ndarray[numpy.dtype[float]] An optional N x 2 numpy array containing periods that must be excluded before applying the line-length transformation. Each of th N rows in the array represents a period to be excluded, defined by the start and end indices of the period in the original iEEG data. The defined periods will be set to zero with the transitions being smoothed by applying a hanning window to prevent spurious patterns. nmf_runs: int The number of nonnegative matrix factorization runs performed for each rank, default is 100. rank_range: Tuple[int, int] A tuple defining the range of ranks for which to perform the nonnegative matrix factorization, default is (2, 5). line_length_freq: int The sampling frequency of the line-length transformed data, default is 50 hz. """ def __init__( self, file_path: str, results_dir: str = None, save_nmf_matrices: bool = False, sparseness: float = 0.0, bad_times: np.ndarray[np.dtype[float]] = None, nmf_runs: int = 100, rank_range: Tuple[int, int] = (2, 5), line_length_freq: int = 50, ): self.sparseness = sparseness self.file_path = file_path self.results_dir: str = self.__create_results_dir(results_dir) self.save_nmf_matrices: bool = save_nmf_matrices self.bad_times = bad_times self.nmf_runs: int = nmf_runs self.rank_range: Tuple[int, int] = rank_range self.line_length_freq = line_length_freq # Configure logger logging_utils.add_logger_with_process_name(self.results_dir) def __create_results_dir(self, results_dir: str): # Create folder to save results file_path = self.file_path filename_for_saving = ( file_path[file_path.rfind("/") + 1 :].replace(".", "_").replace(" ", "_") ) timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") nmf_version = "NMFSC" if self.sparseness != 0.0 else "NMF" folder_name = "_".join([nmf_version, filename_for_saving, timestamp]) if results_dir is None: results_dir = os.path.join(Path.home(), folder_name) else: results_dir = os.path.join(results_dir, folder_name) os.makedirs(results_dir, exist_ok=True) return results_dir @staticmethod def __compute_cdf(matrix: np.ndarray, bins: np.ndarray): N = matrix.shape[0] values = matrix[np.triu_indices(N)] counts, _ = np.histogram(values, bins=bins, density=True) cdf_vals = np.cumsum(counts) / (N * (N - 1) / 2) return cdf_vals + 1e-10 # add a small offset to avoid div0! @staticmethod def __compute_cdf_area(cdf_vals, bin_width): return np.sum(cdf_vals[:-1]) * bin_width @staticmethod def __compute_delta_k(areas, cdfs): delta_k = np.zeros(len(areas)) delta_y = np.zeros(len(areas)) delta_k[0] = areas[0] for i in range(1, len(areas)): delta_k[i] = (areas[i] - areas[i - 1]) / areas[i - 1] delta_y[i] = sum(rel_entr(cdfs[:, i], cdfs[:, i - 1])) return delta_k, delta_y def __calculate_statistics(self, consensus_matrices: List[np.ndarray]): k_min, k_max = self.rank_range bins = np.linspace(0, 1, 101) bin_width = bins[1] - bins[0] num_bins = len(bins) - 1 cdfs = np.zeros((num_bins, k_max - k_min + 1)) areas = np.zeros(k_max - k_min + 1) for idx, consensus in enumerate(consensus_matrices): cdf_vals = self.__compute_cdf(consensus, bins) areas[idx] = self.__compute_cdf_area(cdf_vals, bin_width) cdfs[:, idx] = cdf_vals delta_k, delta_y = self.__compute_delta_k(areas, cdfs) k_opt = np.argmax(delta_k) + k_min if delta_k.size > 0 else k_min return areas, delta_k, delta_y, k_opt def perform_nmf_steps_for_rank( self, preprocessed_data: np.ndarray, rank: int, n_runs: int, ) -> Tuple[ Dict, np.ndarray[np.dtype[float]], np.ndarray[np.dtype[float]], np.ndarray[np.dtype[float]], Dict[int, np.ndarray[np.dtype[int]]], float, Dict[int, int], ]: logging.debug(f"Starting Spike detection pipeline for rank {rank}") ##################### # NMF # ##################### # Instantiate nmf classifier nmf_classifier = Nmf(rank=rank, sparseness=self.sparseness) # Run NMF consensus clustering for specified rank and number of runs (default = 100) metrics, consensus, h_best, w_best = nmf_classifier.nmf_run( preprocessed_data, n_runs ) ##################### # CLUSTERING BS FCT # ##################### # Initialize kmeans classifier kmeans = BasisFunctionClusterer(n_clusters=2, use_cosine_dist=True) # Cluster into noise / basis function and sort according to cluster assignment sorted_w, sorted_h, cluster_assignments = kmeans.cluster_and_sort( h_matrix=h_best, w_matrix=w_best ) # TODO check if necessary: cluster_assignments = np.where(cluster_assignments == 1, "BF", "noise") ##################### # THRESHOLDING # ##################### threshold_generator = ThresholdGenerator(sorted_h, preprocessed_data, sfreq=50) threshold_generator.generate_individual_thresholds() spike_annotations = threshold_generator.find_events() return ( metrics, consensus, sorted_h, sorted_w, spike_annotations, threshold_generator.thresholds, cluster_assignments, ) def parallel_processing( self, preprocessed_data: np.ndarray[np.dtype[float]], channel_names: List[str], ) -> Tuple[ np.ndarray[np.dtype[float]], np.ndarray[np.dtype[float]], Dict[int, np.ndarray[np.dtype[int]]], Dict[int, float], Dict[int, int], ]: # List of ranks to run NMF for rank_list = list(range(self.rank_range[0], self.rank_range[1] + 1)) nr_ranks = len(rank_list) # Normalize for NMF (preprocessed data needs to be non-negative) data_matrix = normalize(preprocessed_data) # Using all cores except 2 if necessary n_cores = multiprocessing.cpu_count() - 2 logger.debug( f"Running NMF on {n_cores if nr_ranks > n_cores else nr_ranks} cores " f"for ranks {rank_list} and {self.nmf_runs} runs each" ) with multiprocessing.Pool(processes=n_cores) as pool: results = pool.starmap( self.perform_nmf_steps_for_rank, [ (data_matrix, rank, self.nmf_runs) for rank in range(self.rank_range[0], self.rank_range[1] + 1) ], ) # Extract return objects from results consensus_matrices = [consensus for _, consensus, _, _, _, _, _ in results] h_matrices = [h_best for _, _, h_best, _, _, _, _ in results] w_matrices = [w_best for _, _, _, w_best, _, _, _ in results] metrics = [metrics for metrics, _, _, _, _, _, _ in results] event_annotations = [ spike_annotations for _, _, _, _, spike_annotations, _, _ in results ] thresholds = [threshold for _, _, _, _, _, threshold, _ in results] cluster_assignments = [assignments for _, _, _, _, _, _, assignments in results] # Calculate final statistics C, delta_k, delta_y, k_opt = self.__calculate_statistics(consensus_matrices) # Get objects for the optimal rank idx_opt = k_opt - self.rank_range[0] h_opt = h_matrices[idx_opt] w_opt = w_matrices[idx_opt] events_opt = event_annotations[idx_opt] thresholds_opt = thresholds[idx_opt] assignments_opt = cluster_assignments[idx_opt] # Generate metrics data frame metrics_df = pd.DataFrame(metrics) metrics_df["AUC"] = C metrics_df["delta_k (CDF)"] = delta_k metrics_df["delta_y (KL-div)"] = delta_y logger.debug(f"Optimal rank: k = {k_opt}") # Plot and save W and consensus matrices plot_w_and_consensus_matrix( w_matrices=w_matrices, consensus_matrices=consensus_matrices, experiment_dir=self.results_dir, channel_names=channel_names, ) # Saving metrics as CSV logger.debug("Saving metrics") metrics_path = os.path.join(self.results_dir, "metrics.csv") metrics_df.to_csv(metrics_path, index=False) # Saving H and W matrices, event annotations and line length matrix if self.save_nmf_matrices: logger.debug( f"Saving LineLength and Consensus, W, H matrices and corresponding event annotations for ranks {rank_list}" ) # Saving line length and std line length np.savetxt( f"{self.results_dir}/line_length.csv", data_matrix, delimiter="," ) np.savetxt( f"{self.results_dir}/std_line_length.csv", np.std(data_matrix, axis=0), delimiter=",", ) for idx in range(nr_ranks): # Saving Consensus, W and H matrices h_matrix = h_matrices[idx] w_matrix = w_matrices[idx] consensus_matrix = consensus_matrices[idx] saving_path = os.path.join(self.results_dir, f"k={rank_list[idx]}") os.makedirs(saving_path, exist_ok=True) np.savetxt(f"{saving_path}/H_best.csv", h_matrix, delimiter=",") np.savetxt(f"{saving_path}/W_best.csv", w_matrix, delimiter=",") np.savetxt( f"{saving_path}/consensus_matrix.csv", consensus_matrix, delimiter=",", ) # Saving event annotations spikes = event_annotations[idx] headers = [] event_times = [] max_length = 0 for h_idx in spikes.keys(): event_times_on = spikes.get(h_idx).get("events_on") event_times_off = spikes.get(h_idx).get("events_off") if len(event_times_on) > max_length: max_length = len(event_times_on) event_times.append(event_times_on) event_times.append(event_times_off) headers.extend( [f"h{h_idx + 1}_events_on", f"h{h_idx + 1}_events_off"] ) df_spike_times = pd.DataFrame(event_times) df_spike_times = df_spike_times.transpose() df_spike_times.columns = headers df_spike_times.to_csv(f"{saving_path}/event_annotations.csv") return h_opt, w_opt, events_opt, thresholds_opt, assignments_opt
[docs] def run( self, channel_paths: List[str] = None, bipolar_reference: bool = False, exclude: List[str] = None, leads: List[str] = None, notch_freq: int = 50, resampling_freq: int = 500, bandpass_cutoff_low: int = 0.1, bandpass_cutoff_high: int = 200, line_length_freq: int = 50, line_length_window: int = 40, ) -> Tuple[List[BasisFunction], List[ActivationFunction]]: """ This method triggers a complete run of the spike detection pipline with the arguments passed to the :py:class:`SpikeDetectionPipeline` at initialization. Parameters ---------- channel_paths: List[str] A list of paths to the traces to be included within an h5 file. This is only necessary in the case of h5 files. bipolar_reference: bool If True, the bipolar references of the included channels will be computed. If channels already are in bipolar form this needs to be False. exclude: List[str] A list of channel names that need to be excluded. This only applies in the case of .edf and .fif files. leads: List[str] A list of the leads included. Only necessary if bipolar_reference is True, otherwise can be None. notch_freq: int, optional, default = 50 The frequency of the notch filter; data will be notch-filtered at this frequency and at the corresponding harmonics, e.g. notch_freq = 50 Hz -> harmonics = [50, 100, 150, etc.] resampling_freq: int, optional, default = 500 The frequency to resample the data after filtering and rescaling bandpass_cutoff_low: int, optional, default = 0.1 Cut-off frequency at the lower end of the passband of the bandpass filter. bandpass_cutoff_high: int, optional, default = 200 Cut-off frequency at the higher end of the passband of the bandpass filter. line_length_freq: int, optional, default = 50 Sampling frequency of the line-length transformed data line_length_window: int, optional, default = 40 Window length used to for the line-length operation (in milliseconds). Returns ------- Tuple[List[BasisFunction], List[ActivationFunction]] Two lists containing the :py:mod:`~spidet.domain.BasisFunction` and :py:mod:`~spidet.domain.ActivationFunction`, where each activation function contains the corresponding detected events. """ # Instantiate a LineLength instance line_length = LineLength( file_path=self.file_path, dataset_paths=channel_paths, exclude=exclude, bipolar_reference=bipolar_reference, leads=leads, bad_times=self.bad_times, ) # Perform line length steps to compute line length ( start_timestamp, channel_names, line_length_matrix, ) = line_length.apply_parallel_line_length_pipeline( notch_freq=notch_freq, resampling_freq=resampling_freq, bandpass_cutoff_low=bandpass_cutoff_low, bandpass_cutoff_high=bandpass_cutoff_high, line_length_freq=line_length_freq, line_length_window=line_length_window, ) # Run parallelized NMF ( h_opt, w_opt, spikes_opt, thresholds_opt, assignments_opt, ) = self.parallel_processing( preprocessed_data=line_length_matrix, channel_names=channel_names ) # Create unique id prefix filename = self.file_path[self.file_path.rfind("/") + 1 :] unique_id_prefix = filename[: filename.rfind(".")] # Compute times for H x-axis times = compute_rescaled_timeline( start_timestamp=start_timestamp, length=h_opt.shape[1], sfreq=self.line_length_freq, ) # Create return objects basis_functions: List[BasisFunction] = [] activation_functions: List[ActivationFunction] = [] for idx, (bf, sdf, spikes_idx, threshold_idx, assignments_idx) in enumerate( zip(w_opt.T, h_opt, spikes_opt, thresholds_opt, assignments_opt) ): # Create BasisFunction label_bf = f"W{idx + 1}" unique_id_bf = f"{unique_id_prefix}_{label_bf}" basis_fct = BasisFunction( label=label_bf, unique_id=unique_id_bf, channel_names=channel_names, data_array=bf, ) # Create ActivationFunction label_af = f"H{idx + 1}" unique_id_af = f"{unique_id_prefix}_{label_af}" activation_fct = ActivationFunction( label=label_af, unique_id=unique_id_af, times=times, data_array=sdf, detected_events_on=spikes_opt.get(spikes_idx)["events_on"], detected_events_off=spikes_opt.get(spikes_idx)["events_off"], event_threshold=thresholds_opt.get(threshold_idx), ) basis_functions.append(basis_fct) activation_functions.append(activation_fct) return basis_functions, activation_functions