Skip to content

Commit

Permalink
WIP testing
Browse files Browse the repository at this point in the history
  • Loading branch information
tab-cmd committed Oct 17, 2024
1 parent 2fedd80 commit 1106618
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 71 deletions.
4 changes: 2 additions & 2 deletions bcipy/gui/BCInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BCInterface(BCIGui):
max_length = 25
min_length = 1
timeout = 3
font = 'Consolas'
font = 'Courier New'

def __init__(self, *args, **kwargs):
super(BCInterface, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -440,7 +440,7 @@ def offline_analysis(self) -> None:
Run offline analysis as a script in a new process.
"""
if not self.action_disabled():
cmd = f'python {BCIPY_ROOT}/signal/model/offline_analysis.py --alert --p "{self.parameter_location}"'
cmd = f'bcipy-train --alert --p "{self.parameter_location}" -v'
subprocess.Popen(cmd, shell=True)

def action_disabled(self) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions bcipy/gui/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def confirm(message: str) -> bool:
message_type=AlertMessageType.INFO,
message_response=AlertMessageResponse.OCE)
button = dialog.exec()

result = bool(button == AlertResponse.OK.value)
app.quit()
QApplication.instance().quit()
return result
4 changes: 3 additions & 1 deletion bcipy/gui/bciui.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def list_property(self, prop: str):

def run_bciui(ui: Type[BCIUI], *args, **kwargs):
# add app to kwargs
app = QApplication(sys.argv)
app = QApplication(sys.argv).instance()
if not app:
app = QApplication(sys.argv)
ui_instance = ui(*args, **kwargs)
ui_instance.display()
app.exec()
7 changes: 3 additions & 4 deletions bcipy/helpers/tests/resources/mock_session/parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,13 @@
"type": "float"
},
"font": {
"value": "Overpass Mono Medium",
"value": "Courier New",
"section": "bci_config",
"name": "Font",
"helpTip": "Specifies the font used for all text stimuli. Default: Consolas",
"helpTip": "Specifies the font used for all text stimuli. Default: Courier New",
"recommended": [
"Courier New",
"Lucida Sans",
"Consolas"
"Lucida Sans"
],
"editable": true,
"type": "str"
Expand Down
42 changes: 36 additions & 6 deletions bcipy/helpers/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import matplotlib
matplotlib.use("QtAgg")
import matplotlib.pyplot as plt
import mne
import numpy as np
Expand All @@ -21,10 +19,11 @@
import bcipy.acquisition.devices as devices
from bcipy.config import (DEFAULT_DEVICE_SPEC_FILENAME,
DEFAULT_GAZE_IMAGE_PATH, RAW_DATA_FILENAME,
TRIGGER_FILENAME, SESSION_LOG_FILENAME)
TRIGGER_FILENAME, SESSION_LOG_FILENAME,
DEFAULT_PARAMETERS_PATH)
from bcipy.helpers.acquisition import analysis_channels
from bcipy.helpers.convert import convert_to_mne
from bcipy.helpers.load import choose_csv_file, load_raw_data
from bcipy.helpers.load import choose_csv_file, load_raw_data, load_json_parameters
from bcipy.helpers.parameters import Parameters
from bcipy.helpers.raw_data import RawData
from bcipy.helpers.stimuli import mne_epochs
Expand Down Expand Up @@ -672,7 +671,8 @@ def visualize_evokeds(epochs: Tuple[Epochs, Epochs],
def visualize_session_data(
session_path: str,
parameters: Union[dict, Parameters],
show=True) -> Figure:
show=True,
save=True) -> Figure:
"""Visualize Session Data.
This method is used to load and visualize EEG data after a session.
Expand Down Expand Up @@ -735,7 +735,7 @@ def visualize_session_data(
transform=default_transform,
plot_average=True,
plot_topomaps=True,
save_path=session_path,
save_path=session_path if save else None,
show=show,
)

Expand All @@ -759,3 +759,33 @@ def visualize_gaze_accuracies(accuracy_dict: Dict[str, np.ndarray],
ax.set_title('Overall Accuracy: ' + str(round(accuracy, 2)))

return fig

def erp():
import argparse

parser = argparse.ArgumentParser(description='Visualize ERP data')

parser.add_argument(
'-s', '--session_path',
type=str,
help='Path to the session directory',
required=True)
parser.add_argument(
'-p', '--parameters',
type=str,
help='Path to the parameters file',
default=DEFAULT_PARAMETERS_PATH)
parser.add_argument(
'--show',
action='store_true',
help='Whether to show the figure',
default=False)
parser.add_argument(
'--save',
action='store_true',
help='Whether to save the figure', default=True)

args = parser.parse_args()

parameters = load_json_parameters(args.parameters, value_cast=True)
visualize_session_data(args.session_path, parameters, args.show, args.save)
64 changes: 26 additions & 38 deletions bcipy/signal/model/offline_analysis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# mypy: disable-error-code="attr-defined"
import json
import logging
import subprocess
from pathlib import Path
from typing import Tuple

import numpy as np
import matplotlib
matplotlib.use('QtAgg')

from matplotlib.figure import Figure
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -212,20 +212,18 @@ def analyze_erp(erp_data, parameters, device_spec, data_folder, estimate_balance

save_model(model, Path(data_folder, f"model_{model.auc:0.4f}.pkl"))
preferences.signal_model_directory = data_folder
# this should have uncorrected trigger timing for display purposes
figure_handles = visualize_erp(
erp_data,
channel_map,
trigger_timing,
labels,
trial_window,
transform=default_transform,
plot_average=True,
plot_topomaps=True,
save_path=data_folder if save_figures else None,
show=show_figures
)
return model, figure_handles

if save_figures or show_figures:
cmd = f"bcipy-erp-viz --session_path '{data_folder}' --parameters '{parameters['parameter_location']}'"
if save_figures:
cmd += " --save"
if show_figures:
cmd += " --show"
subprocess.run(
cmd,
shell=True
)
return model


def analyze_gaze(
Expand Down Expand Up @@ -259,15 +257,13 @@ def analyze_gaze(
"Individual": Fits a separate Gaussian for each symbol. Default model
"Centralized": Uses data from all symbols to fit a single centralized Gaussian
"""
figures = []
figure_handles = visualize_gaze(
visualize_gaze(
gaze_data,
save_path=save_figures,
img_path=f'{data_folder}/{MATRIX_IMAGE_FILENAME}',
show=show_figures,
raw_plot=plot_points,
)
figures.extend(figure_handles)

channels = gaze_data.channels
type_amp = gaze_data.daq_type
Expand Down Expand Up @@ -358,15 +354,14 @@ def analyze_gaze(
means, covs = model.evaluate(test_re)

# Visualize the results:
figure_handles = visualize_gaze_inquiries(
visualize_gaze_inquiries(
le, re,
means, covs,
save_path=save_figures,
img_path=f'{data_folder}/{MATRIX_IMAGE_FILENAME}',
show=show_figures,
raw_plot=plot_points,
)
figures.extend(figure_handles)
left_eye_all.append(le)
right_eye_all.append(re)
means_all.append(means)
Expand Down Expand Up @@ -410,22 +405,20 @@ def analyze_gaze(
print(f"Overall accuracy: {accuracy:.2f}")

# Plot all accuracies as bar plot:
figure_handles = visualize_gaze_accuracies(acc_all_symbols, accuracy, save_path=None, show=True)
figures.extend(figure_handles)
visualize_gaze_accuracies(acc_all_symbols, accuracy, save_path=None, show=True)

if model_type == "Centralized":
cent_left = np.concatenate(np.array(centralized_data_left, dtype=object))
cent_right = np.concatenate(np.array(centralized_data_right, dtype=object))

# Visualize the results:
figure_handles = visualize_centralized_data(
visualize_centralized_data(
cent_left, cent_right,
save_path=save_figures,
img_path=f'{data_folder}/{MATRIX_IMAGE_FILENAME}',
show=show_figures,
raw_plot=plot_points,
)
figures.extend(figure_handles)

# Fit the model:
model.fit(cent_left)
Expand All @@ -438,37 +431,36 @@ def analyze_gaze(
le = preprocessed_data[sym][0]
re = preprocessed_data[sym][1]
# Visualize the results:
figure_handles = visualize_gaze_inquiries(
visualize_gaze_inquiries(
le, re,
means, covs,
save_path=save_figures,
img_path=f'{data_folder}/{MATRIX_IMAGE_FILENAME}',
show=show_figures,
raw_plot=plot_points,
)
figures.extend(figure_handles)
left_eye_all.append(le)
right_eye_all.append(re)
means_all.append(means)
covs_all.append(covs)

fig_handles = visualize_results_all_symbols(
# TODO: add visualizations to subprocess
visualize_results_all_symbols(
left_eye_all, right_eye_all,
means_all, covs_all,
img_path=f'{data_folder}/{MATRIX_IMAGE_FILENAME}',
save_path=save_figures,
show=show_figures,
raw_plot=plot_points,
)
figures.extend(fig_handles)

model.metadata = SignalModelMetadata(device_spec=device_spec,
transform=None)
log.info("Training complete for Eyetracker model. Saving data...")
save_model(
model,
Path(data_folder, f"model_{device_spec.content_type}_{model_type}.pkl"))
return model, figures
return model


@report_execution_time
Expand All @@ -479,7 +471,7 @@ def offline_analysis(
estimate_balanced_acc: bool = False,
show_figures: bool = False,
save_figures: bool = False,
) -> Tuple[SignalModel, Figure]:
) -> Tuple[SignalModel]:
"""Gets calibration data and trains the model in an offline fashion.
pickle dumps the model into a .pkl folder
Expand Down Expand Up @@ -510,7 +502,6 @@ def offline_analysis(
Returns:
--------
model (SignalModel): trained model
figure_handles (Figure): handles to the ERP figures
"""
assert parameters, "Parameters are required for offline analysis."
if not data_folder:
Expand All @@ -533,30 +524,27 @@ def offline_analysis(
assert len(data_file_paths) > 0, "No data files found for offline analysis."

models = []
figure_handles = []
log.info(f"Starting offline analysis for {data_file_paths}")
for raw_data_path in data_file_paths:
raw_data = load_raw_data(raw_data_path)
device_spec = devices_by_name.get(raw_data.daq_type)
# extract relevant information from raw data object eeg
if device_spec.content_type == "EEG":
erp_model, erp_figure_handles = analyze_erp(
erp_model = analyze_erp(
raw_data, parameters, device_spec, data_folder, estimate_balanced_acc, save_figures, show_figures)
models.append(erp_model)
figure_handles.extend(erp_figure_handles)

if device_spec.content_type == "Eyetracker":
et_model, et_figure_handles = analyze_gaze(
et_model = analyze_gaze(
raw_data, parameters, device_spec, data_folder, save_figures, show_figures, model_type="Individual")
models.append(et_model)
figure_handles.extend(et_figure_handles)

if alert_finished:
log.info("Alerting Offline Analysis Complete")
results = [f"{model.name}: {model.auc}" for model in models]
confirm(f"Offline analysis complete! \n Results={results}")
log.info("Offline analysis complete")
return models, figure_handles
return models


def main():
Expand Down
6 changes: 3 additions & 3 deletions bcipy/task/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def execute(self) -> TaskData:
"""
logger.info("Running offline analysis action")
try:
cmd = f"bcipy-train --parameters {self.parameters_path} -s"
cmd = f"bcipy-train --parameters {self.parameters_path} -s -v"
if self.alert_finished:
cmd += " --alert"
response = subprocess.run(
Expand Down Expand Up @@ -157,7 +157,7 @@ def execute(self) -> TaskData:
)

def alert(self):
...
pass


class ExperimentFieldCollectionAction(Task):
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
self.protocol_path = protocol_path or ''
self.last_task_dir = last_task_dir
self.default_transform = None
self.trial_window = trial_window or (0, 1.0)
self.trial_window = trial_window or (0, 1.0) #TODO ask about this
self.static_offset = self.parameters.get("static_offset", 0)
self.report = Report(self.protocol_path)
self.report_sections: List[ReportSection] = []
Expand Down
Loading

0 comments on commit 1106618

Please sign in to comment.