@@ -84,7 +84,9 @@ def __exit__(self, *exc):
8484 self .close ()
8585
8686 @staticmethod
87- def create (root_folder : str , cache_data : bool = False , ** kwargs ) -> Interpolator :
87+ def create (
88+ root_folder : str | Path , cache_data : bool = False , ** kwargs
89+ ) -> "Interpolator" :
8890 """Factory method to create the appropriate interpolator for a modality.
8991
9092 Reads the ``meta.yml`` file in the folder to determine the modality type
@@ -109,35 +111,24 @@ def create(root_folder: str, cache_data: bool = False, **kwargs) -> Interpolator
109111 ValueError
110112 If the modality type is not supported.
111113 """
114+ root_folder = str (root_folder )
112115 with open (Path (root_folder ) / "meta.yml" ) as file :
113116 meta_data = yaml .safe_load (file )
114117 modality = meta_data .get ("modality" )
115118
116119 if modality == "sequence" :
117120 if meta_data .get ("phase_shift_per_signal" , False ):
118121 return PhaseShiftedSequenceInterpolator (
119- root_folder , cache_data = cache_data , ** kwargs
122+ root_folder , cache_data , ** kwargs
120123 )
121124 else :
122- return SequenceInterpolator (
123- root_folder , cache_data = cache_data , ** kwargs
124- )
125+ return SequenceInterpolator (root_folder , cache_data , ** kwargs )
125126 elif modality == "screen" :
126- use_stimuli_names = kwargs .pop (
127- "use_stimuli_names" , meta_data .get ("use_stimuli_names" , False )
128- )
129- return ScreenInterpolator (
130- root_folder ,
131- cache_data = cache_data ,
132- use_stimuli_names = use_stimuli_names ,
133- ** kwargs ,
134- )
127+ return ScreenInterpolator (root_folder , cache_data , ** kwargs )
135128 elif modality == "time_interval" :
136- return TimeIntervalInterpolator (
137- root_folder , cache_data = cache_data , ** kwargs
138- )
129+ return TimeIntervalInterpolator (root_folder , cache_data , ** kwargs )
139130 elif modality == "spikes" :
140- return SpikeInterpolator (root_folder , cache_data = cache_data , ** kwargs )
131+ return SpikeInterpolator (root_folder , cache_data , ** kwargs )
141132 else :
142133 raise ValueError (
143134 f"There is no interpolator for { modality } . Please use 'sequence', 'screen', 'time_interval' as modality or provide a custom interpolator."
@@ -497,8 +488,6 @@ class ScreenInterpolator(Interpolator):
497488 native image size from metadata.
498489 normalize : bool, default=False
499490 If True, normalizes frames using stored mean/std statistics.
500- use_stimuli_names : bool, default=False
501- If True, uses ``stimulus_name`` from metadata to locate data files instead of trial keys.
502491 **kwargs
503492 Additional keyword arguments (ignored).
504493
@@ -519,11 +508,10 @@ class ScreenInterpolator(Interpolator):
519508 def __init__ (
520509 self ,
521510 root_folder : str ,
522- cache_data : bool = False ,
511+ cache_data : bool = False , # New parameter
523512 rescale : bool = False ,
524513 rescale_size : tuple [int , int ] | None = None ,
525514 normalize : bool = False ,
526- use_stimuli_names : bool = False ,
527515 ** kwargs ,
528516 ) -> None :
529517 super ().__init__ (root_folder )
@@ -533,7 +521,6 @@ def __init__(
533521 self .valid_interval = TimeInterval (self .start_time , self .end_time )
534522 self .rescale = rescale
535523 self .cache_trials = cache_data # Store the cache preference
536- self .use_stimuli_names = use_stimuli_names
537524 self ._parse_trials ()
538525
539526 # create mapping from image index to file index
@@ -618,14 +605,8 @@ def _parse_trials(self) -> None:
618605 metadatas , keys = self .read_combined_meta ()
619606
620607 for key , metadata in zip (keys , metadatas , strict = True ):
621- if self .use_stimuli_names :
622- stimulus_name = metadata .get ("stimulus_name" )
623- assert (
624- stimulus_name is not None
625- ), f"stimulus_name is required in metadata when use_stimuli_names is True, but not found for key: { key } "
626- data_file_name = self .root_folder / "data" / f"{ stimulus_name } .npy"
627- else :
628- data_file_name = self .root_folder / "data" / f"{ key } .npy"
608+ data_file_name = self .root_folder / "data" / f"{ key } .npy"
609+ # Pass the cache_trials parameter when creating trials
629610 self .trials .append (
630611 ScreenTrial .create (
631612 data_file_name , metadata , cache_data = self .cache_trials
0 commit comments