2424# local libraries
2525from .intervals import TimeInterval
2626
27+ logger = logging .getLogger (__name__ )
28+
2729
2830def replace_nan_with_batch_mean (data : np .ndarray ) -> np .ndarray :
2931 row , col = np .where (np .isnan (data ))
@@ -274,9 +276,9 @@ def __init__(self, datasets, session_names=None):
274276 session_names = [f"session_{ i } " for i in range (len (datasets ))]
275277 self .session_names = session_names
276278
277- # Print dataset sizes for debugging
279+ # Log dataset sizes for debugging
278280 for i , (name , dataset ) in enumerate (zip (session_names , datasets )):
279- print ( f "Dataset { i } : { name } , length = { len (dataset )} " )
281+ logger . debug ( "Dataset %s: %s , length = %s" , i , name , len (dataset ))
280282
281283 # Compute cumulative sizes for efficient indexing
282284 self .cumulative_sizes = []
@@ -377,7 +379,7 @@ def __init__(self, dataset, batch_size, drop_last=False, shuffle=False, seed=Non
377379
378380 # Get sessions
379381 self .session_names = list (dataset .session_indices .keys ())
380- print ( f "Sessions: { self .session_names } " )
382+ logger . debug ( "Sessions: %s" , self .session_names )
381383
382384 self .consumed_sessions = []
383385
@@ -401,8 +403,8 @@ def __init__(self, dataset, batch_size, drop_last=False, shuffle=False, seed=Non
401403 self .batches_per_session [session_name ] = num_batches
402404 total_batches += num_batches
403405
404- print ( f "Batches per session: { self .batches_per_session } " )
405- print ( f "Total batches: { total_batches } " )
406+ logger . debug ( "Batches per session: %s" , self .batches_per_session )
407+ logger . debug ( "Total batches: %s" , total_batches )
406408
407409 def __len__ (self ):
408410 """Return the total number of batches across all sessions."""
@@ -556,8 +558,10 @@ def __init__(
556558 # Track active sessions
557559 self .active_sessions = set (self .session_names )
558560
559- print (
560- f"Created FastSessionDataLoader with { len (self .session_names )} sessions and { len (self )} total batches"
561+ logger .debug (
562+ "Created FastSessionDataLoader with %s sessions and %s total batches" ,
563+ len (self .session_names ),
564+ len (self ),
561565 )
562566
563567 def __len__ (self ):
@@ -641,8 +645,10 @@ def set_state(self, state):
641645 if sampler_state is not None and hasattr (sampler , "set_state" ):
642646 sampler .set_state (sampler_state )
643647
644- print (
645- f"Restored dataloader state to batch { self .current_batch } , epoch { self .epoch } "
648+ logger .info (
649+ "Restored dataloader state to batch %s, epoch %s" ,
650+ self .current_batch ,
651+ self .epoch ,
646652 )
647653
648654 def __iter__ (self ):
0 commit comments