diff --git a/src/visiomode/models.py b/src/visiomode/models.py index 7fe9c4c..d0dcbb4 100644 --- a/src/visiomode/models.py +++ b/src/visiomode/models.py @@ -83,13 +83,10 @@ class Trial(Base): response: Response timestamp: typing.Optional[str] = None correction: bool = False - response_time: int = 0 + response_time: float = 0.0 stimulus: dict = dataclasses.field(default_factory=dict) sdt_type: str = "NA" - def __post_init__(self): - self.timestamp = datetime.datetime.now().isoformat() - def __repr__(self): return f"" @@ -140,7 +137,7 @@ def __post_init__(self): self.device = socket.gethostname() if not self.device else self.device self.animal_meta = {} if not self.animal_meta else self.animal_meta self.experimenter_meta = {} if not self.experimenter_meta else self.experimenter_meta - self.spec = self.task.get_spec if self.task else {} + self.spec = self.task.get_spec() if self.task else {} def to_dict(self): """Get class instance attributes as a dictionary. diff --git a/src/visiomode/webpanel/export.py b/src/visiomode/webpanel/export.py index 37da211..e912cdd 100644 --- a/src/visiomode/webpanel/export.py +++ b/src/visiomode/webpanel/export.py @@ -53,19 +53,15 @@ def to_nwb(session_path): nwbfile.subject = pynwb.file.Subject(subject_id=session["animal_id"]) - nwbfile.add_trial_column(name="stimulus", description="the visual stimuli during the trial") - nwbfile.add_trial_column(name="cue_onset", description="when the stimulus came on") - nwbfile.add_trial_column(name="response", description="trial response type (left, right, lever)") - nwbfile.add_trial_column(name="response_time", description="response timestamp") - nwbfile.add_trial_column(name="pos_x", description="response position in x-axis") - nwbfile.add_trial_column(name="pos_y", description="response position in y-axis") - nwbfile.add_trial_column(name="dist_x", description="response displacement in x-axis") - nwbfile.add_trial_column(name="dist_y", description="response displacement in y-axis") - nwbfile.add_trial_column(name="outcome", description="trial outcome") - nwbfile.add_trial_column(name="correction", description="whether trial was a correction trial") - nwbfile.add_trial_column(name="sdt_type", description="signal detection theory classification") - - for trial in _flatten_trials(session): + trials = list(flatten_trials(session)) + trial_keys = set().union(*(trial.keys() for trial in trials)) + for key in trial_keys: + try: + nwbfile.add_trial_column(name=key, description="") + except ValueError: + continue + + for trial in trials: nwbfile.add_trial(**trial) nwbfile.create_device( @@ -88,7 +84,7 @@ def to_csv(session_path): with open(session_path) as f: session = json.load(f) - df = pd.DataFrame(_flatten_trials(session)) + df = pd.DataFrame(flatten_trials(session)) fname = session_path.split("/")[-1].replace(".json", ".csv") outpath = config.cache_dir + os.sep + fname @@ -97,19 +93,18 @@ def to_csv(session_path): return fname -def _flatten_trials(session): +def flatten_trials(session): session_start_time = datetime.fromisoformat(session["timestamp"]) for trial in session.get("trials"): start_time = (datetime.fromisoformat(trial["timestamp"]) - session_start_time).total_seconds() - stop_time = start_time + trial["iti"] + float(session["spec"].get("stimulus_duration", 10000)) / 1000 + stimulus_duration = float((session.get("spec").get("stimulus_duration") or -1) / 1000) + + stop_time = start_time + trial["iti"] + stimulus_duration if trial["response"].get("timestamp"): stop_time = (datetime.fromisoformat(trial["response"]["timestamp"]) - session_start_time).total_seconds() - stimulus = trial["stimulus"].get("common_name") if trial["stimulus"] != "None" else "None" - cue_onset = start_time + trial["iti"] - response = trial["response"].get("name") response_time = trial["response_time"] @@ -120,10 +115,26 @@ def _flatten_trials(session): sdt_type = trial.get("sdt_type", "unavailable") + stimulus = {} + if trial.get("stimulus"): + if trial.get("stimulus") == "None": + stimulus = {} + elif trial.get("stimulus").get("common_name"): + # handle single stimulus on screen tasks + stimulus = {f"stim_{key}": value for key, value in trial.get("stimulus").items()} + elif trial.get("stimulus").get("target"): + # 2AFC / nAFC + target_stim = {f"target_{key}": value for key, value in trial.get("stimulus").get("target").items()} + distractor_stim = { + f"distractor_{key}": value for key, value in trial.get("stimulus").get("distractor").items() + } + stimulus = {**target_stim, **distractor_stim} + + cue_onset = start_time + trial["iti"] if stimulus else "NA" + yield { "start_time": start_time, "stop_time": stop_time, - "stimulus": stimulus, "cue_onset": cue_onset, "response": response, "response_time": response_time, @@ -134,4 +145,8 @@ def _flatten_trials(session): "dist_x": dist_x, "dist_y": dist_y, "sdt_type": sdt_type, + **stimulus, } + + +def flatten_stimulus_details(stimulus): ... diff --git a/tests/conftest.py b/tests/conftest.py index 2672db1..10a99cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -85,6 +85,7 @@ def session(config): [ models.Trial( outcome="correct", + timestamp=str(datetime.now().isoformat()), iti=5.0, response=models.Response( timestamp=str(datetime.now().isoformat()), @@ -97,9 +98,11 @@ def session(config): response_time=1.0, sdt_type="hit", stimulus="None", + correction=False, ), models.Trial( outcome="incorrect", + timestamp=str(datetime.now().isoformat()), iti=5.0, response=models.Response( timestamp=str(datetime.now().isoformat()), @@ -112,6 +115,7 @@ def session(config): response_time=1.5, sdt_type="false_alarm", stimulus="None", + correction=False, ), ] )