@@ -340,33 +340,37 @@ def forward(
340340 def encode (self , audio_items : List [AudioItem ], cpu_embed_cache_client : CpuEmbedCacheClient ):
341341 uuids = []
342342 items : List [AudioItem ] = []
343+ per_audio_features : List [torch .Tensor ] = []
343344 for i , item in enumerate (audio_items ):
344345 if isinstance (item , AudioItem ):
345346 uuids .append (item .uuid )
346347 items .append (item )
347348 audio_data = read_shm (get_shm_name_data (item .uuid ))
348349 audio = BytesIO (audio_data )
349- audio , _ = librosa .load (audio , sr = 16000 )
350+ audio , _ = librosa .load (audio , sr = self . processor . sampling_rate )
350351 else :
351352 raise ValueError (f"cannot read audio which type is { type (item )} !" )
352353
353- input_features , feature_attention_mask = self .processor ._preprocess (audio , return_attention_mask = True )
354- print (f"input_features is { input_features } , input_features.shape is { input_features .shape } " )
355- print (f"feature_attention_mask is { feature_attention_mask } , shape is { feature_attention_mask .shape } " )
356- if feature_attention_mask is not None :
357- audio_feature_lengths = torch .sum (feature_attention_mask , dim = 1 )
358- input_features = input_features .permute (0 , 2 , 1 )[feature_attention_mask .bool ()].permute (1 , 0 )
359- else :
360- audio_feature_lengths = None
361- print (f"input_features is { input_features } , input_features.shape is { input_features .shape } " )
362-
363- feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask .sum (- 1 )
364- print (f"feature_lens is { feature_lens } " )
365- audio_features = self .forward (
366- input_features ,
367- feature_lens = feature_lens ,
368- )
369- print (f"audio_features is { audio_features } , shape is { audio_features .shape } " )
354+ input_features , feature_attention_mask = self .processor ._preprocess (audio , return_attention_mask = True )
355+ print (f"input_features is { input_features } , input_features.shape is { input_features .shape } " )
356+ print (f"feature_attention_mask is { feature_attention_mask } , shape is { feature_attention_mask .shape } " )
357+ if feature_attention_mask is not None :
358+ audio_feature_lengths = torch .sum (feature_attention_mask , dim = 1 )
359+ input_features = input_features .permute (0 , 2 , 1 )[feature_attention_mask .bool ()].permute (1 , 0 )
360+ else :
361+ audio_feature_lengths = None
362+ print (f"input_features is { input_features } , input_features.shape is { input_features .shape } " )
363+
364+ feature_lens = (
365+ audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask .sum (- 1 )
366+ )
367+ print (f"feature_lens is { feature_lens } " )
368+ audio_features = self .forward (
369+ input_features ,
370+ feature_lens = feature_lens ,
371+ )
372+ per_audio_features .append (audio_features )
373+ print (f"audio_features is { audio_features } , shape is { audio_features .shape } " )
370374
371375 ready_audio = obtain (self .cache_client .root .get_items_embed (uuids ))
372376 ids_to_set = []
@@ -377,8 +381,9 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC
377381 uid = uuids [i ]
378382 item = items [i ]
379383
384+ cur_embed = per_audio_features [i ]
380385 cpu_embed_cache_client .copy_to_cache (
381- embed_tensor = audio_features , start_index_in_cache = item .start_index_in_embed_cache
386+ embed_tensor = cur_embed , start_index_in_cache = item .start_index_in_embed_cache
382387 )
383388 ids_to_set .append (uid )
384389
0 commit comments