@@ -84,9 +84,7 @@ def __exit__(self, *exc):
8484 self .close ()
8585
8686 @staticmethod
87- def create (
88- root_folder : str | Path , cache_data : bool = False , ** kwargs
89- ) -> "Interpolator" :
87+ def create (root_folder : str , cache_data : bool = False , ** kwargs ) -> Interpolator :
9088 """Factory method to create the appropriate interpolator for a modality.
9189
9290 Reads the ``meta.yml`` file in the folder to determine the modality type
@@ -111,24 +109,35 @@ def create(
111109 ValueError
112110 If the modality type is not supported.
113111 """
114- root_folder = str (root_folder )
115112 with open (Path (root_folder ) / "meta.yml" ) as file :
116113 meta_data = yaml .safe_load (file )
117114 modality = meta_data .get ("modality" )
118115
119116 if modality == "sequence" :
120117 if meta_data .get ("phase_shift_per_signal" , False ):
121118 return PhaseShiftedSequenceInterpolator (
122- root_folder , cache_data , ** kwargs
119+ root_folder , cache_data = cache_data , ** kwargs
123120 )
124121 else :
125- return SequenceInterpolator (root_folder , cache_data , ** kwargs )
122+ return SequenceInterpolator (
123+ root_folder , cache_data = cache_data , ** kwargs
124+ )
126125 elif modality == "screen" :
127- return ScreenInterpolator (root_folder , cache_data , ** kwargs )
126+ use_stimuli_names = kwargs .pop (
127+ "use_stimuli_names" , meta_data .get ("use_stimuli_names" , False )
128+ )
129+ return ScreenInterpolator (
130+ root_folder ,
131+ cache_data = cache_data ,
132+ use_stimuli_names = use_stimuli_names ,
133+ ** kwargs ,
134+ )
128135 elif modality == "time_interval" :
129- return TimeIntervalInterpolator (root_folder , cache_data , ** kwargs )
136+ return TimeIntervalInterpolator (
137+ root_folder , cache_data = cache_data , ** kwargs
138+ )
130139 elif modality == "spikes" :
131- return SpikeInterpolator (root_folder , cache_data , ** kwargs )
140+ return SpikeInterpolator (root_folder , cache_data = cache_data , ** kwargs )
132141 else :
133142 raise ValueError (
134143 f"There is no interpolator for { modality } . Please use 'sequence', 'screen', 'time_interval' as modality or provide a custom interpolator."
@@ -488,6 +497,8 @@ class ScreenInterpolator(Interpolator):
488497 native image size from metadata.
489498 normalize : bool, default=False
490499 If True, normalizes frames using stored mean/std statistics.
500+ use_stimuli_names : bool, default=False
501+ If True, uses ``stimulus_name`` from metadata to locate data files instead of trial keys.
491502 **kwargs
492503 Additional keyword arguments (ignored).
493504
@@ -508,10 +519,11 @@ class ScreenInterpolator(Interpolator):
508519 def __init__ (
509520 self ,
510521 root_folder : str ,
511- cache_data : bool = False , # New parameter
522+ cache_data : bool = False ,
512523 rescale : bool = False ,
513524 rescale_size : tuple [int , int ] | None = None ,
514525 normalize : bool = False ,
526+ use_stimuli_names : bool = False ,
515527 ** kwargs ,
516528 ) -> None :
517529 super ().__init__ (root_folder )
@@ -521,6 +533,7 @@ def __init__(
521533 self .valid_interval = TimeInterval (self .start_time , self .end_time )
522534 self .rescale = rescale
523535 self .cache_trials = cache_data # Store the cache preference
536+ self .use_stimuli_names = use_stimuli_names
524537 self ._parse_trials ()
525538
526539 # create mapping from image index to file index
@@ -605,8 +618,14 @@ def _parse_trials(self) -> None:
605618 metadatas , keys = self .read_combined_meta ()
606619
607620 for key , metadata in zip (keys , metadatas , strict = True ):
608- data_file_name = self .root_folder / "data" / f"{ key } .npy"
609- # Pass the cache_trials parameter when creating trials
621+ if self .use_stimuli_names :
622+ stimulus_name = metadata .get ("stimulus_name" )
623+ assert (
624+ stimulus_name is not None
625+ ), f"stimulus_name is required in metadata when use_stimuli_names is True, but not found for key: { key } "
626+ data_file_name = self .root_folder / "data" / f"{ stimulus_name } .npy"
627+ else :
628+ data_file_name = self .root_folder / "data" / f"{ key } .npy"
610629 self .trials .append (
611630 ScreenTrial .create (
612631 data_file_name , metadata , cache_data = self .cache_trials
0 commit comments