Skip to content

Commit

Permalink
prettier logging
Browse files Browse the repository at this point in the history
  • Loading branch information
sreekaroo committed Nov 20, 2023
1 parent 8731cf5 commit f4faea1
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 16 deletions.
18 changes: 18 additions & 0 deletions bcipy/simulator/helpers/log_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List

import pandas as pd


def format_alp_likelihoods(likelihoods, alp):
rounded = [round(lik, 3) for lik in likelihoods]
formatted = [f"{a} : {l}" for a, l in zip(alp, rounded)]
return formatted


def format_sample_rows(sample_rows: List[pd.Series]):
formatted_rows = []
for row in sample_rows:
new_row = row.drop(columns=['eeg'], axis=1, inplace=False)
formatted_rows.append(new_row.to_string(index=False, header=True))

return ", ".join(formatted_rows)
3 changes: 2 additions & 1 deletion bcipy/simulator/helpers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bcipy.helpers.parameters import Parameters
from bcipy.helpers.symbols import alphabet
from bcipy.simulator.helpers.data_engine import RawDataEngine
from bcipy.simulator.helpers.log_utils import format_sample_rows
from bcipy.simulator.interfaces import SimState

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,7 +56,7 @@ def sample(self, state: SimState) -> np.ndarray:
row = filtered_data.sample(1)
sample_rows.append(row)

log.debug(f"EEG Samples: \n {', '.join([r.to_string() for r in sample_rows])}")
log.debug(f"EEG Samples: \n {format_sample_rows(sample_rows)}")
eeg_responses = [r['eeg'].to_numpy()[0] for r in sample_rows]
sample = self.model_input_reshaper(eeg_responses)

Expand Down
2 changes: 1 addition & 1 deletion bcipy/simulator/helpers/signal_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ def process_raw_data_for_model(data_folder, parameters, reshaper: InquiryReshape
# define the training classes using integers, where 0=nontargets/1=targets
# labels = inquiry_labels.flatten()

return ExtractedExperimentData(inquiries, trials, inquiry_labels, inquiry_timing, (trigger_targetness, trigger_timing, trigger_symbols))
return ExtractedExperimentData(inquiries, trials, inquiry_labels, inquiry_timing, (trigger_targetness, trigger_timing, trigger_symbols))
5 changes: 3 additions & 2 deletions bcipy/simulator/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from bcipy.simulator.helpers.sampler import Sampler
from bcipy.simulator.helpers.state_manager import StateManager, SimState
from bcipy.simulator.helpers.types import InquiryResult
from bcipy.simulator.helpers.log_utils import format_alp_likelihoods
from bcipy.simulator.interfaces import MetricReferee, ModelHandler
from bcipy.simulator.simulator_base import Simulator

Expand Down Expand Up @@ -72,11 +73,11 @@ def run(self):
evidence = self.model_handler.generate_evidence(curr_state,
sampled_data) # TODO make this evidence be a dict (mapping of evidence type to evidence)

log.debug(f"Evidence for stimuli {curr_state.display_alphabet} \n {evidence}")
log.debug(f"Evidence for stimuli {curr_state.display_alphabet} \n {format_alp_likelihoods(evidence, self.symbol_set)}")

inq_record: InquiryResult = self.state_manager.update(evidence)
updated_state = self.state_manager.get_state()
log.debug(f"Fused Likelihoods {[str(round(p, 3)) for p in inq_record.fused_likelihood]}")
log.debug(f"Fused Likelihoods {format_alp_likelihoods(inq_record.fused_likelihood, self.symbol_set)}")

if inq_record.decision:
log.info(f"Decided {inq_record.decision} for target {inq_record.target}")
Expand Down
18 changes: 6 additions & 12 deletions bcipy/simulator/tests/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ class DummyRef(MetricReferee):

if __name__ == "__main__":
args = dict()
args['data_folders'] = ["/Users/srikarananthoju/cambi/tab_test_dynamic/tab_test_dynamic_RSVP_Copy_Phrase_Thu_24_Aug_2023_18hr58min16sec_-0700",
# "/Users/srikarananthoju/cambi/tab_test_dynamic/tab_test_dynamic_RSVP_Copy_Phrase_Thu_24_Aug_2023_19hr07min50sec_-0700",
# "/Users/srikarananthoju/cambi/tab_test_dynamic/tab_test_dynamic_RSVP_Copy_Phrase_Thu_24_Aug_2023_19hr15min29sec_-0700"
args['data_folders'] = ["/Users/srikarananthoju/cambi/tab_test_dynamic/16sec_-0700",
"/Users/srikarananthoju/cambi/tab_test_dynamic/50sec_-0700",
# "/Users/srikarananthoju/cambi/tab_test_dynamic/29sec_-0700"
]
args['out_dir'] = Path(__file__).resolve().parent
model_file = Path(
"/Users/srikarananthoju/cambi/tab_test_dynamic/tab_test_dynamic_RSVP_Calibration_Thu_24_Aug_2023_18hr41min37sec_-0700/model_0.9524_200_800.pkl")
"/Users/srikarananthoju/cambi/tab_test_dynamic/calibr_37sec_-0700/model_0.9524_200_800.pkl"
# "/Users/srikarananthoju/cambi/tab_test_dynamic/calibr_37sec_-0700/model_0.9595.pkl"
)
sim_parameters = load_json_parameters("bcipy/simulator/sim_parameters.json", value_cast=True)

data_engine = RawDataEngine(args['data_folders'])
Expand All @@ -58,14 +60,6 @@ class DummyRef(MetricReferee):
sampler: Sampler = SimpleLetterSampler(data_engine)
sample: np.ndarray = sampler.sample(stateManager.get_state())

# model = PcaRdaKdeModel()
# model = model.load(model_file)
#
# eeg_evidence = model.predict(sample, stateManager.get_state().display_alphabet, alphabet())
#
# print(eeg_evidence.shape)
# print(eeg_evidence)

model_handler = DummyModelHandler(model_file)
sim = SimulatorCopyPhrase(data_engine, model_handler, sampler, stateManager, DummyRef())
sim.run()

0 comments on commit f4faea1

Please sign in to comment.