Skip to content

Commit 9e9abd2

Browse files
committed
Merge main and resolve logging conflicts
2 parents 730140e + 183bfb2 commit 9e9abd2

5 files changed

Lines changed: 40 additions & 22 deletions

File tree

experanto/datasets.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import importlib
55
import json
6+
import logging
67
import os
78
from collections.abc import Iterable
89
from pathlib import Path
@@ -30,6 +31,8 @@
3031
# see .configs.py for the definition of DEFAULT_MODALITY_CONFIG
3132
DEFAULT_MODALITY_CONFIG = dict()
3233

34+
logger = logging.getLogger(__name__)
35+
3336

3437
class SimpleChunkedDataset(Dataset):
3538
def __init__(
@@ -483,11 +486,11 @@ def get_valid_intervals_from_filters(
483486
filter_function = self._get_callable_filter(filter_config)
484487
valid_intervals_: List[TimeInterval] = filter_function(device_=device) # type: ignore[assignment]
485488
if visualize:
486-
print(f"modality: {modality}, filter: {filter_name}")
489+
logger.info("modality: %s, filter: %s", modality, filter_name)
487490
visualization_string = get_stats_for_valid_interval(
488491
valid_intervals_, self.start_time, self.end_time
489492
)
490-
print(visualization_string)
493+
logger.info("%s", visualization_string)
491494
if valid_intervals is None:
492495
valid_intervals = valid_intervals_
493496
else:
@@ -722,15 +725,16 @@ def get_data_key_from_root_folder(self, root_folder):
722725
dataset_name = root_folder.split("_gaze")[0].split("datasets/")[1]
723726
return dataset_name
724727
else:
725-
print(
726-
f"No 'data_key' found in {meta_file_path}, using folder name instead"
728+
logger.info(
729+
"No 'data_key' found in %s, using folder name instead",
730+
meta_file_path,
727731
)
728-
except json.JSONDecodeError:
729-
print(f"Error: {meta_file_path} is not a valid JSON file")
732+
except json.JSONDecodeError as e:
733+
logger.warning("Error loading %s: %s", meta_file_path, e)
730734
except Exception as e:
731-
print(f"Error loading {meta_file_path}: {str(e)}")
735+
logger.warning("Error loading %s: %s", meta_file_path, e)
732736
else:
733-
print(f"No metadata file found at {meta_file_path}")
737+
logger.warning("No metadata file found at %s", meta_file_path)
734738
return os.path.basename(root_folder)
735739

736740
def __len__(self):

experanto/experiment.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .configs import DEFAULT_MODALITY_CONFIG
1313
from .interpolators import Interpolator
1414

15-
log = logging.getLogger(__name__)
15+
logger = logging.getLogger(__name__)
1616

1717

1818
class Experiment:
@@ -80,10 +80,10 @@ def _load_devices(self) -> None:
8080

8181
for d in device_folders:
8282
if d.name not in self.modality_config:
83-
log.info(f"Skipping {d.name} data... ")
83+
logger.info("Skipping %s data", d.name)
8484
continue
85+
logger.info("Parsing %s data", d.name)
8586

86-
log.info(f"Parsing {d.name} data... ")
8787
interp_conf = self.modality_config[d.name]["interpolation"]
8888

8989
if (
@@ -109,11 +109,17 @@ def _load_devices(self) -> None:
109109
)
110110

111111
self.devices[d.name] = dev
112+
<<<<<<< HEAD
112113
if dev.start_time is not None:
113114
self.start_time = min(self.start_time, dev.start_time)
114115
if dev.end_time is not None:
115116
self.end_time = max(self.end_time, dev.end_time)
116117
log.info("Parsing finished")
118+
=======
119+
self.start_time = dev.start_time
120+
self.end_time = dev.end_time
121+
logger.info("Parsing finished")
122+
>>>>>>> main
117123

118124
@property
119125
def device_names(self):

experanto/interpolators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
import logging
45
import os
56
import re
67
import typing
@@ -18,6 +19,8 @@
1819

1920
from .intervals import TimeInterval
2021

22+
logger = logging.getLogger(__name__)
23+
2124

2225
class Interpolator:
2326
"""Abstract base class for time series interpolation.
@@ -574,7 +577,7 @@ def is_numbered_yml(file_name):
574577

575578
def read_combined_meta(self) -> tuple[list, list]:
576579
if not (self.root_folder / "combined_meta.json").exists():
577-
print("Combining metadatas...")
580+
logger.info("Combining metadata files...")
578581
self._combine_metadatas()
579582

580583
with open(self.root_folder / "combined_meta.json", "r") as file:

experanto/utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
# local libraries
2525
from .intervals import TimeInterval
2626

27+
logger = logging.getLogger(__name__)
28+
2729

2830
def replace_nan_with_batch_mean(data: np.ndarray) -> np.ndarray:
2931
row, col = np.where(np.isnan(data))
@@ -274,9 +276,9 @@ def __init__(self, datasets, session_names=None):
274276
session_names = [f"session_{i}" for i in range(len(datasets))]
275277
self.session_names = session_names
276278

277-
# Print dataset sizes for debugging
279+
# Log dataset sizes for debugging
278280
for i, (name, dataset) in enumerate(zip(session_names, datasets)):
279-
print(f"Dataset {i}: {name}, length = {len(dataset)}")
281+
logger.debug("Dataset %s: %s, length = %s", i, name, len(dataset))
280282

281283
# Compute cumulative sizes for efficient indexing
282284
self.cumulative_sizes = []
@@ -377,7 +379,7 @@ def __init__(self, dataset, batch_size, drop_last=False, shuffle=False, seed=Non
377379

378380
# Get sessions
379381
self.session_names = list(dataset.session_indices.keys())
380-
print(f"Sessions: {self.session_names}")
382+
logger.debug("Sessions: %s", self.session_names)
381383

382384
self.consumed_sessions = []
383385

@@ -401,8 +403,8 @@ def __init__(self, dataset, batch_size, drop_last=False, shuffle=False, seed=Non
401403
self.batches_per_session[session_name] = num_batches
402404
total_batches += num_batches
403405

404-
print(f"Batches per session: {self.batches_per_session}")
405-
print(f"Total batches: {total_batches}")
406+
logger.debug("Batches per session: %s", self.batches_per_session)
407+
logger.debug("Total batches: %s", total_batches)
406408

407409
def __len__(self):
408410
"""Return the total number of batches across all sessions."""
@@ -556,8 +558,10 @@ def __init__(
556558
# Track active sessions
557559
self.active_sessions = set(self.session_names)
558560

559-
print(
560-
f"Created FastSessionDataLoader with {len(self.session_names)} sessions and {len(self)} total batches"
561+
logger.debug(
562+
"Created FastSessionDataLoader with %s sessions and %s total batches",
563+
len(self.session_names),
564+
len(self),
561565
)
562566

563567
def __len__(self):
@@ -641,8 +645,10 @@ def set_state(self, state):
641645
if sampler_state is not None and hasattr(sampler, "set_state"):
642646
sampler.set_state(sampler_state)
643647

644-
print(
645-
f"Restored dataloader state to batch {self.current_batch}, epoch {self.epoch}"
648+
logger.info(
649+
"Restored dataloader state to batch %s, epoch %s",
650+
self.current_batch,
651+
self.epoch,
646652
)
647653

648654
def __iter__(self):

tests/test_sequence_interpolator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,5 +548,4 @@ def test_interpolation_mode_not_implemented():
548548

549549

550550
if __name__ == "__main__":
551-
print("Running tests")
552551
pytest.main([__file__])

0 commit comments

Comments
 (0)