From bec1bb56d4db920f057f792ef0160934335e0aa2 Mon Sep 17 00:00:00 2001 From: Basak Celik Date: Mon, 23 Oct 2023 05:56:44 -0400 Subject: [PATCH] Offline analysis PR comments addressed. Visualization doc updated --- bcipy/helpers/visualization.py | 84 +++-- bcipy/signal/model/offline_analysis.py | 458 +++++++++++++------------ 2 files changed, 295 insertions(+), 247 deletions(-) diff --git a/bcipy/helpers/visualization.py b/bcipy/helpers/visualization.py index 04d4ef4b8..1d07f7d85 100644 --- a/bcipy/helpers/visualization.py +++ b/bcipy/helpers/visualization.py @@ -113,27 +113,36 @@ def visualize_erp( def visualize_gaze( - data, - left_keys=['left_x', 'left_y'], - right_keys=['right_x', 'right_y'], - save_path=None, - show=False, - img_path=None, - screen_size=(1920, 1080), - heatmap=False, - raw_plot=False) -> Figure: + data: RawData, + left_keys: List[str] = ['left_x', 'left_y'], + right_keys: List[str] = ['right_x', 'right_y'], + save_path: Optional[str] = None, + show: Optional[bool] = False, + img_path: Optional[str] = None, + screen_size: Tuple[int, int] = (1920, 1080), + heatmap: Optional[bool] = False, + raw_plot: Optional[bool] = False) -> Figure: """Visualize Eye Data. Assumes that the data is collected using BciPy and a Tobii-nano eye tracker. The default - image used is for the matrix calibration task on a 1920x1080 screen. + image used is for the matrix calibration task on a 1920x1080 screen. + + Generates a comparative matrix figure following the execution of tasks. Given a set of + trailed data, the gaze distribution are plotted and may be saved or shown in a window. + + Returns a list of the figure handles created. Parameters ---------- - data: RawData - save_path: Optional[str] - show: Optional[bool] - img_path: Optional[str] - screen_size: Optional[Tuple[int, int]] TODO + data: RawData: Gaze RawData + left_keys: Optional[List[str]]: List of channels for the left eye data. Default: ['left_x', 'left_y'] + right_keys: Optional[List[str]]: List of channels for the right eye data. Default: ['right_x', 'right_y'] + save_path: Optional[str]: optional path to a save location of the figure generated + show: Optional[bool]: whether or not to show the figures generated. Default: False + img_path: Optional[str]: Image to be used as the background. Default: matrix.png + screen_size: Optional[Tuple[int, int]]: Size of the screen used for Calibration/Copy Phrase tasks. Default: (1920, 1080) + heatmap: Optional[bool]: Whether or not to plot the heatmap. Default: False + raw_plot: Optional[bool]: Whether or not to plot the raw gaze data. Default: False """ title = f'{data.daq_type} ' @@ -207,27 +216,40 @@ def visualize_gaze( def visualize_gaze_inquiries( - left_eye, - right_eye, - means=None, - covs=None, - save_path=None, - show=False, - img_path=None, - heatmap=False, - raw_plot=False) -> Figure: - """Visualize Eye Data. + left_eye: np.ndarray, + right_eye: np.ndarray, + means: Optional[np.ndarray] = None, + covs: Optional[np.ndarray] = None, + save_path: Optional[str] = None, + show: Optional[bool] = False, + img_path: Optional[str] = None, + screen_size: Tuple[int, int] = (1920, 1080), + heatmap: Optional[bool] = False, + raw_plot: Optional[bool] = False) -> Figure: + """Visualize Gaze Inquiries. Assumes that the data is collected using BciPy and a Tobii-nano eye tracker. The default - image used is for the matrix calibration task on a 1920x1080 screen. + image used is for the matrix calibration task on a 1920x1080 screen. + + Generates a comparative matrix figure following the execution of offline analysis. Given a set of + trailed data (left & right eye), the gaze distribution for each prompted symbol are plotted, along + with the contour plots of mean and covariances calculated by the Gaussian Mixture Model. + The figures may be saved or shown in a window. + + Returns a list of the figure handles created. Parameters ---------- - left eye, right eye: (np.ndarray): Raw gaze data, plotting both eyes is optional - means, covs: Optional(np.ndarray): Means and covariances of the Gaussian Mixture Model - save_path: Optional[str] - show: Optional[bool] - img_path: Optional[str] + left_eye: (np.ndarray): Data array for the left eye data. + right_eye: (np.ndarray): Data array for the right eye data. + means: Optional[np.ndarray]: Means of the Gaussian Mixture Model + covs: Optional[np.ndarray]: Covariances of the Gaussian Mixture Model + save_path: Optional[str]: optional path to a save location of the figure generated + show: Optional[bool]: whether or not to show the figures generated. Default: False + img_path: Optional[str]: Image to be used as the background. Default: matrix.png + screen_size: Optional[Tuple[int, int]]: Size of the screen used for Calibration/Copy Phrase tasks. Default: (1920, 1080) + heatmap: Optional[bool]: Whether or not to plot the heatmap. Default: False + raw_plot: Optional[bool]: Whether or not to plot the raw gaze data. Default: False """ title = 'Raw Gaze Inquiries ' diff --git a/bcipy/signal/model/offline_analysis.py b/bcipy/signal/model/offline_analysis.py index 114617e12..b2a4206e6 100644 --- a/bcipy/signal/model/offline_analysis.py +++ b/bcipy/signal/model/offline_analysis.py @@ -65,6 +65,242 @@ def subset_data(data: np.ndarray, labels: np.ndarray, test_size: float, random_s return train_data, test_data, train_labels, test_labels +def analyze_erp(erp_data, parameters, device_spec, data_folder, estimate_balanced_acc, + save_figures=True, show_figures=False): + """Analyze ERP data and return/save the ERP model. + Extract relevant information from raw data object. + Extract timing information from trigger file. + Apply filtering and preprocessing on the raw data. + Reshape and label the data for the training procedure. + Fit the model to the data. Use cross validation to select parameters. + Pickle dump model into .pkl file + Generate and [optional] save/show ERP figures. + """ + # Extract relevant session information from parameters file + trial_window = parameters.get("trial_window") + # This parameter does not exist in parameters files for Multimodal datasets. Adding manually. + # TODO: Update parameters files for multimodal datasets. + if trial_window is None: + trial_window = [0.0, 0.5] + window_length = trial_window[1] - trial_window[0] + + prestim_length = parameters.get("prestim_length") + trials_per_inquiry = parameters.get("stim_length") + # The task buffer length defines the min time between two inquiries + # We use half of that time here to buffer during transforms + buffer = int(parameters.get("task_buffer_length") / 2) + + # Get signal filtering information + transform_params = parameters.instantiate(ERPTransformParams) + downsample_rate = transform_params.down_sampling_rate + static_offset = parameters.get("static_trigger_offset") + + log.info( + f"\nData processing settings: \n" + f"{str(transform_params)} \n" + f"Trial Window: {trial_window[0]}-{trial_window[1]}s, " + f"Prestimulus Buffer: {prestim_length}s, Poststimulus Buffer: {buffer}s \n" + f"Static offset: {static_offset}" + ) + channels = erp_data.channels + type_amp = erp_data.daq_type + sample_rate = erp_data.sample_rate + + # setup filtering + default_transform = get_default_transform( + sample_rate_hz=sample_rate, + notch_freq_hz=transform_params.notch_filter_frequency, + bandpass_low=transform_params.filter_low, + bandpass_high=transform_params.filter_high, + bandpass_order=transform_params.filter_order, + downsample_factor=transform_params.down_sampling_rate, + ) + + log.info(f"Channels read from csv: {channels}") + log.info(f"Device type: {type_amp}, fs={sample_rate}") + + k_folds = parameters.get("k_folds") + model = PcaRdaKdeModel(k_folds=k_folds) + + # Process triggers.txt files + trigger_targetness, trigger_timing, _ = trigger_decoder( + trigger_path=f"{data_folder}/{TRIGGER_FILENAME}", + exclusion=[TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION], + offset=static_offset, + device_type='EEG' + ) + + # update the trigger timing list to account for the initial trial window + corrected_trigger_timing = [timing + trial_window[0] for timing in trigger_timing] + + # Channel map can be checked from raw_data.csv file or the devices.json located in the acquisition module + # The timestamp column [0] is already excluded. + channel_map = analysis_channels(channels, device_spec) + channels_used = [channels[i] for i, keep in enumerate(channel_map) if keep == 1] + log.info(f'Channels used in analysis: {channels_used}') + + data, fs = erp_data.by_channel() + + inquiries, inquiry_labels, inquiry_timing = model.reshaper( + trial_targetness_label=trigger_targetness, + timing_info=corrected_trigger_timing, + eeg_data=data, + sample_rate=sample_rate, + trials_per_inquiry=trials_per_inquiry, + channel_map=channel_map, + poststimulus_length=window_length, + prestimulus_length=prestim_length, + transformation_buffer=buffer, + ) + + inquiries, fs = filter_inquiries(inquiries, default_transform, sample_rate) + inquiry_timing = update_inquiry_timing(inquiry_timing, downsample_rate) + trial_duration_samples = int(window_length * fs) + data = model.reshaper.extract_trials(inquiries, trial_duration_samples, inquiry_timing) + + # define the training classes using integers, where 0=nontargets/1=targets + labels = inquiry_labels.flatten() + + # train and save the model as a pkl file + log.info("Training model. This will take some time...") + model = PcaRdaKdeModel(k_folds=k_folds) + model.fit(data, labels) + model.metadata = SignalModelMetadata(device_spec=device_spec, + transform=default_transform) + log.info(f"Training complete [AUC={model.auc:0.4f}]. Saving data...") + + save_model(model, Path(data_folder, f"model_{model.auc:0.4f}.pkl")) + preferences.signal_model_directory = data_folder + + # Using an 80/20 split, report on balanced accuracy + if estimate_balanced_acc: + train_data, test_data, train_labels, test_labels = subset_data(data, labels, test_size=0.2) + dummy_model = PcaRdaKdeModel(k_folds=k_folds) + dummy_model.fit(train_data, train_labels) + probs = dummy_model.predict_proba(test_data) + preds = probs.argmax(-1) + score = balanced_accuracy_score(test_labels, preds) + log.info(f"Balanced acc with 80/20 split: {score}") + del dummy_model, train_data, test_data, train_labels, test_labels, probs, preds + + # this should have uncorrected trigger timing for display purposes + figure_handles = visualize_erp( + erp_data, + channel_map, + trigger_timing, + labels, + trial_window, + transform=default_transform, + plot_average=True, + plot_topomaps=True, + save_path=data_folder if save_figures else None, + show=show_figures + ) + + +def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_figures=False): + """Analyze gaze data and return/save the gaze model. + Extract relevant information from gaze data object. + Extract timing information from trigger file. + Apply preprocessing on the raw data. Extract the data for each target label and each eye separately. + Extract inquiries dictionary with keys as target symbols and values as inquiry windows. + Fit the model to the data. + Pickle dump model into .pkl file + Generate and [optional] save/show gaze figures. + """ + figure_handles = visualize_gaze( + gaze_data, + save_path=data_folder if save_figures else None, + show=show_figures, + raw_plot=True, + ) + + channels = gaze_data.channels + type_amp = gaze_data.daq_type + sample_rate = gaze_data.sample_rate + + log.info(f"Channels read from csv: {channels}") + log.info(f"Device type: {type_amp}, fs={sample_rate}") + channel_map = analysis_channels(channels, device_spec) + + channels_used = [channels[i] for i, keep in enumerate(channel_map) if keep == 1] + log.info(f'Channels used in analysis: {channels_used}') + + data, fs = gaze_data.by_channel() + + model = GazeModel() + + # Extract all Triggers info + trigger_targetness, trigger_timing, trigger_symbols = trigger_decoder( + trigger_path=f"{data_folder}/{TRIGGER_FILENAME}", + remove_pre_fixation=False, + exclusion=[ + TriggerType.PREVIEW, + TriggerType.EVENT, + TriggerType.FIXATION, + TriggerType.SYSTEM, + TriggerType.OFFSET], + device_type='EYETRACKER', + apply_starting_offset=False + ) + ''' Trigger_timing includes PROMPT and excludes FIXATION ''' + + target_symbols = trigger_symbols[0::11] # target symbols are the PROMPT triggers + # Use trigger_timing to generate time windows for each letter flashing + # Take every 10th trigger as the start point of timing. + inq_start = trigger_timing[1::11] # start of each inquiry (here we jump over prompts) + + # Extract the inquiries dictionary with keys as target symbols and values as inquiry windows: + inquiries = model.reshaper( + inq_start_times=inq_start, + target_symbols=target_symbols, + gaze_data=data, + sample_rate=sample_rate + ) + + symbol_set = alphabet() + + # Extract the data for each target label and each eye separately. + # Apply preprocessing: + preprocessed_data = {i: [] for i in symbol_set} + for i in symbol_set: + # Skip if there's no evidence for this symbol: + if len(inquiries[i]) == 0: + continue + + left_eye, right_eye = extract_eye_info(inquiries[i]) + preprocessed_data[i] = np.array([left_eye, right_eye]) # Channels x Sample Size x Dimensions(x,y) + + # Train test split: + test_size = int(len(right_eye) * 0.2) + train_size = len(right_eye) - test_size + train_right_eye = right_eye[:train_size] + test_right_eye = right_eye[train_size:] + + train_left_eye = left_eye[:train_size] + test_left_eye = left_eye[train_size:] + + # Fit the model: + model.fit(train_right_eye) + + scores, means, covs = model.get_scores(test_right_eye) + + # Visualize the results: + figure_handles = visualize_gaze_inquiries( + left_eye, right_eye, + means, covs, + save_path=None, + show=show_figures, + raw_plot=True, + ) + model.metadata = SignalModelMetadata(device_spec=device_spec, + transform=None) + log.info("Training complete for Eyetracker model. Saving data...") + save_model( + model, + Path(data_folder, f"model_{device_spec.content_type}.pkl")) + + @report_execution_time def offline_analysis( data_folder: str = None, @@ -109,33 +345,6 @@ def offline_analysis( if not data_folder: data_folder = load_experimental_data() - # extract relevant session information from parameters file - trial_window = parameters.get("trial_window") - window_length = trial_window[1] - trial_window[0] - - prestim_length = parameters.get("prestim_length") - trials_per_inquiry = parameters.get("stim_length") - # The task buffer length defines the min time between two inquiries - # We use half of that time here to buffer during transforms - buffer = int(parameters.get("task_buffer_length") / 2) - - # get signal filtering information - transform_params = parameters.instantiate(ERPTransformParams) - downsample_rate = transform_params.down_sampling_rate - static_offset = parameters.get("static_trigger_offset") - if trial_window is None: - trial_window = [0.0, 0.5] - # NOTE: Had to add it manually for offline analysis of multimodal datasets, - # since the trial_window information was not available in the parameters file. - - log.info( - f"\nData processing settings: \n" - f"{str(transform_params)} \n" - f"Trial Window: {trial_window[0]}-{trial_window[1]}s, " - f"Prestimulus Buffer: {prestim_length}s, Poststimulus Buffer: {buffer}s \n" - f"Static offset: {static_offset}" - ) - devices_by_name = devices.load( Path(data_folder, DEFAULT_DEVICE_SPEC_FILENAME), replace=True) data_file_paths = [ @@ -149,199 +358,16 @@ def offline_analysis( device_spec = devices_by_name.get(raw_data.daq_type) # extract relevant information from raw data object eeg if device_spec.content_type == "EEG": - channels = raw_data.channels - type_amp = raw_data.daq_type - sample_rate = raw_data.sample_rate - - # setup filtering - default_transform = get_default_transform( - sample_rate_hz=sample_rate, - notch_freq_hz=transform_params.notch_filter_frequency, - bandpass_low=transform_params.filter_low, - bandpass_high=transform_params.filter_high, - bandpass_order=transform_params.filter_order, - downsample_factor=transform_params.down_sampling_rate, - ) - - log.info(f"Channels read from csv: {channels}") - log.info(f"Device type: {type_amp}, fs={sample_rate}") - - k_folds = parameters.get("k_folds") - model = PcaRdaKdeModel(k_folds=k_folds) - - # Process triggers.txt files - trigger_targetness, trigger_timing, _ = trigger_decoder( - trigger_path=f"{data_folder}/{TRIGGER_FILENAME}", - exclusion=[TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION], - offset=static_offset, - device_type='EEG' - ) - - # update the trigger timing list to account for the initial trial window - corrected_trigger_timing = [timing + trial_window[0] for timing in trigger_timing] - - # Channel map can be checked from raw_data.csv file or the devices.json located in the acquisition module - # The timestamp column [0] is already excluded. - channel_map = analysis_channels(channels, device_spec) - channels_used = [channels[i] for i, keep in enumerate(channel_map) if keep == 1] - log.info(f'Channels used in analysis: {channels_used}') - - data, fs = raw_data.by_channel() - - inquiries, inquiry_labels, inquiry_timing = model.reshaper( - trial_targetness_label=trigger_targetness, - timing_info=corrected_trigger_timing, - eeg_data=data, - sample_rate=sample_rate, - trials_per_inquiry=trials_per_inquiry, - channel_map=channel_map, - poststimulus_length=window_length, - prestimulus_length=prestim_length, - transformation_buffer=buffer, - ) - - inquiries, fs = filter_inquiries(inquiries, default_transform, sample_rate) - inquiry_timing = update_inquiry_timing(inquiry_timing, downsample_rate) - trial_duration_samples = int(window_length * fs) - data = model.reshaper.extract_trials(inquiries, trial_duration_samples, inquiry_timing) - - # define the training classes using integers, where 0=nontargets/1=targets - labels = inquiry_labels.flatten() - - # train and save the model as a pkl file - log.info("Training model. This will take some time...") - model = PcaRdaKdeModel(k_folds=k_folds) - model.fit(data, labels) - model.metadata = SignalModelMetadata(device_spec=device_spec, - transform=default_transform) - log.info(f"Training complete [AUC={model.auc:0.4f}]. Saving data...") - - save_model(model, Path(data_folder, f"model_{model.auc:0.4f}.pkl")) - preferences.signal_model_directory = data_folder - - # Using an 80/20 split, report on balanced accuracy - if estimate_balanced_acc: - train_data, test_data, train_labels, test_labels = subset_data(data, labels, test_size=0.2) - dummy_model = PcaRdaKdeModel(k_folds=k_folds) - dummy_model.fit(train_data, train_labels) - probs = dummy_model.predict_proba(test_data) - preds = probs.argmax(-1) - score = balanced_accuracy_score(test_labels, preds) - log.info(f"Balanced acc with 80/20 split: {score}") - del dummy_model, train_data, test_data, train_labels, test_labels, probs, preds - - # this should have uncorrected trigger timing for display purposes - figure_handles = visualize_erp( - raw_data, - channel_map, - trigger_timing, - labels, - trial_window, - transform=default_transform, - plot_average=True, - plot_topomaps=True, - save_path=data_folder if save_figures else None, - show=show_figures - ) + analyze_erp(raw_data, parameters, device_spec, data_folder, estimate_balanced_acc, + save_figures, show_figures) if device_spec.content_type == "Eyetracker": - figure_handles = visualize_gaze( - raw_data, - save_path=data_folder if save_figures else None, - show=show_figures, - raw_plot=True, - ) - - channels = raw_data.channels - type_amp = raw_data.daq_type - sample_rate = raw_data.sample_rate - - log.info(f"Channels read from csv: {channels}") - log.info(f"Device type: {type_amp}, fs={sample_rate}") - channel_map = analysis_channels(channels, device_spec) - - channels_used = [channels[i] for i, keep in enumerate(channel_map) if keep == 1] - log.info(f'Channels used in analysis: {channels_used}') - - data, fs = raw_data.by_channel() - - model = GazeModel() - - # Extract all Triggers info - trigger_targetness, trigger_timing, trigger_symbols = trigger_decoder( - trigger_path=f"{data_folder}/{TRIGGER_FILENAME}", - remove_pre_fixation=False, - exclusion=[ - TriggerType.PREVIEW, - TriggerType.EVENT, - TriggerType.FIXATION, - TriggerType.SYSTEM, - TriggerType.OFFSET], - device_type='EYETRACKER', - apply_starting_offset=False - ) - ''' Trigger_timing includes PROMPT and excludes FIXATION ''' - - # Use trigger_timing to generate time windows for each letter flashing - # Take every 10th trigger as the start point of timing. - # trigger_symbols keeps the PROMPT info, use it to find the target symbol. - target_symbols = trigger_symbols[0::11] # target symbols - inq_start = trigger_timing[1::11] # start of each inquiry (here we jump over prompts) - - # Extract the inquiries dictionary with keys as target symbols and values as inquiry windows: - inquiries = model.reshaper( - inq_start_times=inq_start, - target_symbols=target_symbols, - gaze_data=data, - sample_rate=sample_rate - ) - - symbol_set = alphabet() - - # Extract the data for each target label and each eye separately. - # Apply preprocessing: - preprocessed_data = {i: [] for i in symbol_set} - for i in symbol_set: - # Skip if there's no evidence for this symbol: - if len(inquiries[i]) == 0: - continue - - left_eye, right_eye = extract_eye_info(inquiries[i]) - preprocessed_data[i] = np.array([left_eye, right_eye]) # Channels x Sample Size x Dimensions(x,y) - - # Train test split: - test_size = int(len(right_eye) * 0.2) - train_size = len(right_eye) - test_size - train_right_eye = right_eye[:train_size] - test_right_eye = right_eye[train_size:] - - train_left_eye = left_eye[:train_size] - test_left_eye = left_eye[train_size:] - - # Fit the model: - model.fit(train_right_eye) - - scores, means, covs = model.get_scores(test_right_eye) - - # Visualize the results: - figure_handles = visualize_gaze_inquiries( - left_eye, right_eye, - means, covs, - save_path=None, - show=show_figures, - raw_plot=True, - ) - - model.metadata = SignalModelMetadata(device_spec=device_spec, - transform=None) - log.info("Training complete for Eyetracker model. Saving data...") - save_model( - model, - Path(data_folder, f"model_{device_spec.content_type}.pkl")) + analyze_gaze(raw_data, device_spec, data_folder, save_figures, show_figures) if alert_finished: play_sound(f"{STATIC_AUDIO_PATH}/{parameters['alert_sound_file']}") - return model, figure_handles + # return model, figure_handles + return if __name__ == "__main__": @@ -357,7 +383,7 @@ def offline_analysis( parser.set_defaults(alert=False) parser.set_defaults(balanced=False) parser.set_defaults(save_figures=False) - parser.set_defaults(show_figures=False) + parser.set_defaults(show_figures=True) args = parser.parse_args() log.info(f"Loading params from {args.parameters_file}")