-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #315 from CAMBI-tech/s_simple_sampler_simulator
RSVP Copy Phrase Simulator
- Loading branch information
Showing
20 changed files
with
1,653 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
## RSVP Simulator | ||
|
||
### Overview | ||
|
||
This Simulator module aims to automate experimentation by sampling EEG data from prior sessions and running given models in a task loop, thus simulating a live session. | ||
|
||
### Run steps | ||
|
||
`main.py` is the entry point for program. After following `BciPy` readme steps for setup, run the module from terminal: | ||
|
||
``` | ||
(venv) $ python bcipy/simulator -h | ||
usage: simulator [-h] -d DATA_FOLDER [-g GLOB_PATTERN] -m MODEL_PATH -p PARAMETERS [-n N] | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
-d DATA_FOLDER, --data_folder DATA_FOLDER | ||
Raw data folders to be processed. | ||
-g GLOB_PATTERN, --glob_pattern GLOB_PATTERN | ||
glob pattern to select a subset of data folders Ex. "*RSVP_Copy_Phrase*" | ||
-m MODEL_PATH, --model_path MODEL_PATH | ||
Signal models to be used | ||
-p PARAMETERS, --parameters PARAMETERS | ||
Parameter File to be used | ||
-n N Number of times to run the simulation | ||
``` | ||
|
||
For example, | ||
`$ python bcipy/simulator -d my_data_folder/ -p my_parameters.json -m my_models/ -n 5` | ||
|
||
#### Program Args | ||
|
||
- `d` : the data wrapper folder argument is necessary. This folder is expected to contain 1 or more session folders. Each session folder should contain | ||
_raw_data.csv_, _triggers.txt_, _parameters.json_. These files will be used to construct a data pool from which simulator will sample EEG and other device responses. The parameters file in each data folder will be used to check compatibility with the simulation/model parameters. | ||
- `g` : optional glob filter that can be used to select a subset of data within the wrapper directory. | ||
- Ex. `"*Matrix_Copy*Jan_2024*"` will select all data for all Matrix Copy Phrase sessions recorded in January of 2024 (assuming the BciPy folder naming convention). | ||
- Glob patterns can also include nested directories (ex. `"*/*Matrix_Copy*"`). | ||
- `p` : path to the parameters.json file used to run the simulation. These parameters will be applied to | ||
all raw_data files when loading. This file can specify various aspects of the simulation, including the language model to be used, the text to be spelled, etc. Timing-related parameters should generally match the parameters file used for training the signal model(s). | ||
- `m`: all pickle (.pkl) files in this directory will be loaded as signal models. | ||
|
||
#### Sim Output Details | ||
|
||
Output folders are generally located in the `simulator/generated` directory. Each simulation will create a new directory. The directory name will be prefixed with `SIM` and will include the current date and time. | ||
|
||
- `parameters.json` captures params used for the simulation. | ||
- `sim.log` is a log file for the simulation | ||
|
||
A directory is created for each simulation run. The directory contents are similar to the session output in a normal bcipy task. Each run directory contains: | ||
|
||
- `run_{n}.log` log file specific to the run, where n is the run number. | ||
- `session.json` session data output for the task, including evidence generated for each inquiry and overall metrics. | ||
- `session.xlsx` session data summarized in an excel spreadsheet with charts for easier visualization. | ||
|
||
## Main Components | ||
|
||
* Task - a simulation task to be run (ex. RSVP Copy Phrase) | ||
* TaskRunner - runs one or more iterations of a simulation | ||
* TaskFactory - constructs the hierarchy of objects needed for the simulation. | ||
* DataEngine - loads data to be used in a simulation and provides an API to query for data. | ||
* DataProcessor - used by the DataEngine to pre-process data. Pre-processed data can be classified by a signal model. | ||
* Sampler - strategy for sampling data from the data pool stored in the DataEngine. | ||
|
||
## Device Support | ||
|
||
The simulator is structured to support evidence from multiple devices (multimodal). However, it currently only includes processing for EEG device data. To provide support for models trained on data from other devices (ex. Gaze), a `RawDataProcessor` must be added for that device. The Processor pre-processes data collected from that device and prepares it for sampling. A `RawDataProcessor` is matched up to a given signal model using that model's metadata (metadata.device_spec.content_type). See the `data_process` module for more details. | ||
|
||
## Current Limitations | ||
|
||
* Only provides EEG support | ||
* Only one sampler maybe provided for all devices. Ideally we should support a different sampling strategy for each device. | ||
* Only Copy Phrase is currently supported. | ||
* Metrics are collected per run, but not summarized across all runs. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
""" Simulator package. View README """ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" Entry point to run Simulator """ | ||
|
||
from bcipy.simulator.task import task_runner | ||
|
||
if __name__ == '__main__': | ||
task_runner.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
"""Classes and functions related to loading and querying data to be used in a simulation.""" | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
from typing import Any, List, NamedTuple, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from bcipy.helpers.exceptions import TaskConfigurationException | ||
from bcipy.helpers.parameters import Parameters | ||
from bcipy.simulator.data import data_process | ||
from bcipy.simulator.data.data_process import (ExtractedExperimentData, | ||
RawDataProcessor) | ||
from bcipy.simulator.util.artifact import TOP_LEVEL_LOGGER_NAME | ||
|
||
log = logging.getLogger(TOP_LEVEL_LOGGER_NAME) | ||
|
||
|
||
class Trial(NamedTuple): | ||
"""Data for a given trial (a symbol within an Inquiry). | ||
Attrs | ||
----- | ||
source - directory of the data source | ||
inquiry_n - starts at 0; does not reset at each series | ||
inquiry_pos - starts at 1; position in which the symbol was presented | ||
symbol - alphabet symbol that was presented | ||
target - 1 or 0 indicating a boolean of whether this was a target symbol | ||
eeg - EEG data associated with this trial | ||
""" | ||
source: str | ||
inquiry_n: int | ||
inquiry_pos: int | ||
symbol: str | ||
target: int | ||
eeg: np.ndarray # Channels by Samples ; ndarray.shape = (channel_n, sample_n) | ||
|
||
def __str__(self): | ||
fields = [ | ||
f"source='{self.source}'", f"inquiry_n={self.inquiry_n}", | ||
f"inquiry_pos={self.inquiry_pos}", f"symbol='{self.symbol}'", | ||
f"target={self.target}", f"eeg={self.eeg.shape}" | ||
] | ||
return f"Trial({', '.join(fields)})" | ||
|
||
def __repr__(self): | ||
return str(self) | ||
|
||
|
||
class QueryFilter(NamedTuple): | ||
"""Provides an API used to query a data engine for data.""" | ||
field: str | ||
operator: str | ||
value: Any | ||
|
||
def is_valid(self) -> bool: | ||
"""Check if the filter is valid.""" | ||
# pylint: disable=no-member | ||
return self.field in Trial._fields and self.operator in self.valid_operators and isinstance( | ||
self.value, Trial.__annotations__[self.field]) | ||
|
||
@property | ||
def valid_operators(self) -> List[str]: | ||
"""List of supported query operators""" | ||
return ["<", "<=", ">", ">=", "==", "!="] | ||
|
||
|
||
class DataEngine(ABC): | ||
"""Abstract class for an object that loads data from one or more sources, | ||
processes the data using a provided processor, and provides an interface | ||
for querying the processed data.""" | ||
|
||
def load(self): | ||
"""Load data from sources.""" | ||
|
||
@property | ||
def trials_df(self) -> pd.DataFrame: | ||
"""Returns a dataframe of Trial data.""" | ||
|
||
@abstractmethod | ||
def query(self, | ||
filters: List[QueryFilter], | ||
samples: int = 1) -> List[Trial]: | ||
"""Query the data.""" | ||
|
||
|
||
def convert_trials(data_source: ExtractedExperimentData) -> List[Trial]: | ||
"""Convert extracted data from a single data source to a list of Trials.""" | ||
trials = [] | ||
symbols_by_inquiry = data_source.symbols_by_inquiry | ||
labels_by_inquiry = data_source.labels_by_inquiry | ||
|
||
for i, inquiry_eeg in enumerate(data_source.trials_by_inquiry): | ||
# iterate through each inquiry | ||
inquiry_symbols = symbols_by_inquiry[i] | ||
inquiry_labels = labels_by_inquiry[i] | ||
|
||
for sym_i, symbol in enumerate(inquiry_symbols): | ||
# iterate through each symbol in the inquiry | ||
eeg_samples = [channel[sym_i] | ||
for channel in inquiry_eeg] # (channel_n, sample_n) | ||
trials.append( | ||
Trial(source=data_source.source_dir, | ||
inquiry_n=i, | ||
inquiry_pos=sym_i + 1, | ||
symbol=symbol, | ||
target=inquiry_labels[sym_i], | ||
eeg=np.array(eeg_samples))) | ||
return trials | ||
|
||
|
||
class RawDataEngine(DataEngine): | ||
""" | ||
Object that loads in list of session data folders and transforms data into | ||
a queryable data structure. | ||
""" | ||
|
||
def __init__(self, source_dirs: List[str], parameters: Parameters, | ||
data_processor: RawDataProcessor): | ||
self.source_dirs: List[str] = source_dirs | ||
self.parameters: Parameters = parameters | ||
|
||
self.data_processor = data_processor | ||
self.data: List[Union[ExtractedExperimentData, | ||
data_process.ExtractedExperimentData]] = [] | ||
self._trials_df = pd.DataFrame() | ||
|
||
self.load() | ||
|
||
def load(self) -> DataEngine: | ||
""" | ||
Processes raw data from data folders using provided parameter files. | ||
- Extracts and stores trial data, stimuli, and stimuli_labels by inquiries | ||
Returns: | ||
self for chaining | ||
""" | ||
if not self.data: | ||
log.debug( | ||
f"Loading data from {len(self.source_dirs)} source directories:" | ||
) | ||
rows = [] | ||
for i, source_dir in enumerate(self.source_dirs): | ||
log.debug(f"{i+1}. {Path(source_dir).name}") | ||
extracted_data = self.data_processor.process( | ||
source_dir, self.parameters) | ||
self.data.append(extracted_data) | ||
rows.extend(convert_trials(extracted_data)) | ||
|
||
self._trials_df = pd.DataFrame(rows) | ||
log.debug("Finished loading all data") | ||
return self | ||
|
||
@property | ||
def trials_df(self) -> pd.DataFrame: | ||
"""Dataframe of Trial data.""" | ||
if not self.data_loaded: | ||
self.load() | ||
return self._trials_df.copy() | ||
|
||
@property | ||
def data_loaded(self) -> bool: | ||
"""Check if the data has been loaded""" | ||
return bool(self.data) | ||
|
||
def query(self, | ||
filters: List[QueryFilter], | ||
samples: int = 1) -> List[Trial]: | ||
"""Query the engine for data using one or more filters. | ||
Parameters | ||
---------- | ||
filters - list of query filters | ||
samples - number of results to return. | ||
check_insufficient_results - if True, raises an exception when | ||
there are an insufficient number of samples. | ||
Returns a list of Trials. | ||
""" | ||
assert self.data_loaded, "Data must be loaded before querying." | ||
assert all(filt.is_valid() | ||
for filt in filters), "Filters must all be valid" | ||
assert samples >= 1, "Insufficient number of samples requested" | ||
|
||
expr = 'and '.join([self.query_condition(filt) for filt in filters]) | ||
filtered_data = self._trials_df.query(expr) | ||
if filtered_data is None or len(filtered_data) < samples: | ||
raise TaskConfigurationException( | ||
message="Not enough samples found") | ||
|
||
rows = filtered_data.sample(samples) | ||
return [Trial(*row) for row in rows.itertuples(index=False, name=None)] | ||
|
||
def query_condition(self, query_filter: QueryFilter) -> str: | ||
"""Returns the string representation of of the given query condition.""" | ||
value = query_filter.value | ||
if (isinstance(value, str)): | ||
value = f"'{value}'" | ||
return f"{query_filter.field} {query_filter.operator} {value}" |
Oops, something went wrong.