Skip to content

Commit 06b9891

Browse files
committed
bugs appear when packages update
1 parent f430310 commit 06b9891

2 files changed

Lines changed: 9 additions & 4 deletions

File tree

lecilab_behavior_analysis/df_transforms.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def add_port_where_animal_comes_from(df_in: pd.DataFrame) -> pd.DataFrame:
121121
# equal to the original dataframe
122122
df.loc[df_mouse_session.index, 'roa_choice'] = series_to_append
123123
# add also the port where the animal comes from
124-
df.loc[df_mouse_session.index, 'previous_port_before_stimulus'] = last_choice
124+
last_choice_series = pd.Series(last_choice, index=df_mouse_session.index)
125+
df.loc[df_mouse_session.index, 'previous_port_before_stimulus'] = last_choice_series.astype("string")
126+
125127

126128
return df
127129

@@ -218,10 +220,10 @@ def get_performance_by_difficulty_ratio(df: pd.DataFrame) -> pd.DataFrame:
218220

219221
def get_performance_by_difficulty_diff(df: pd.DataFrame) -> pd.DataFrame:
220222
df_copy = df.copy(deep=True)
221-
if df_copy["current_training_stage"].str.contains("visual").any():
223+
if df_copy["stimulus_modality"].str.contains("visual").any():
222224
stim_col = "visual_stimulus"
223225
diff_col = "visual_stimulus_diff"
224-
elif df_copy["current_training_stage"].str.contains("auditory").any():
226+
elif df_copy["stimulus_modality"].str.contains("auditory").any():
225227
stim_col = "auditory_stimulus"
226228
diff_col = "auditory_stimulus_diff"
227229
else:

lecilab_behavior_analysis/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,10 @@ def get_previous_row_index(df: pd.DataFrame, current_index) -> int:
11091109
raise IndexError("No previous row exists for the given index.")
11101110

11111111

1112-
def transform_side_choice_to_numeric(side: str) -> Union[int, float]:
1112+
def transform_side_choice_to_numeric(side) -> Union[int, float]:
1113+
if pd.isna(side):
1114+
return np.nan
1115+
11131116
match side:
11141117
case "left":
11151118
return 1

0 commit comments

Comments
 (0)