From 644b7bb545ef15b43508ba63c5473083bfe06165 Mon Sep 17 00:00:00 2001 From: MatthieuCMira <109624972+MatthieuCMira@users.noreply.github.com> Date: Thu, 28 Aug 2025 11:41:27 -0400 Subject: [PATCH 1/2] update get_logger to accept set_level as a string. Improve set logger tests. --- geoapps_utils/utils/logger.py | 59 ++++++++++++++++++++++++++- tests/driver_test.py | 16 +------- tests/logger_test.py | 75 +++++++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 16 deletions(-) create mode 100644 tests/logger_test.py diff --git a/geoapps_utils/utils/logger.py b/geoapps_utils/utils/logger.py index 7cfc933e..184cac59 100644 --- a/geoapps_utils/utils/logger.py +++ b/geoapps_utils/utils/logger.py @@ -10,14 +10,66 @@ from __future__ import annotations import logging +from enum import Enum + + +class LoggerLevel(str, Enum): + """ + The different possible log levels. + """ + + WARNING = "warning" + INFO = "info" + DEBUG = "debug" + ERROR = "error" + CRITICAL = "critical" + + @property + def level(self) -> int: + """ + Get the current state of the logger. + """ + if self == LoggerLevel.WARNING: + return logging.WARNING + if self == LoggerLevel.INFO: + return logging.INFO + if self == LoggerLevel.DEBUG: + return logging.DEBUG + if self == LoggerLevel.ERROR: + return logging.ERROR + if self == LoggerLevel.CRITICAL: + return logging.CRITICAL + return logging.NOTSET + + @classmethod + def get_logger(cls, level: str | LoggerLevel) -> int: + """ + Get the logger level from a string or LoggerLevel. + + :param level: The log level as a string or LoggerLevel. + + :return: The corresponding logging level. + """ + if isinstance(level, str): + try: + level = cls(level.lower()) + except ValueError as error: + raise KeyError( + f"Invalid log level: '{level}'. Choose from {list(cls)}" + ) from error + if not isinstance(level, cls): + raise TypeError(f"Level must be a string or LoggerLevel, got {type(level)}") + return level.level def get_logger( name: str | None = None, + *, timestamp: bool = False, level_name: bool = True, propagate: bool = True, add_name: bool = True, + level: str | LoggerLevel | None = None, ) -> logging.Logger: """ Get a logger with a timestamped stream and specified log level. @@ -27,6 +79,7 @@ def get_logger( :param level_name: Whether to include the log level name in the log format. :param propagate: Whether to propagate log messages to the parent logger. :param add_name: Whether to include the logger name in the log format. + :param level: Logging level to use. :return: Configured logger instance. """ @@ -51,6 +104,10 @@ def get_logger( formatter = logging.Formatter(formatting + "%(message)s") stream_handler.setFormatter(formatter) log.addHandler(stream_handler) - log.propagate = propagate + + if level: + log.setLevel(LoggerLevel.get_logger(level)) + elif propagate: + log.propagate = propagate return log diff --git a/tests/driver_test.py b/tests/driver_test.py index e597740c..464fd7d6 100644 --- a/tests/driver_test.py +++ b/tests/driver_test.py @@ -24,7 +24,7 @@ from pydantic import BaseModel, ConfigDict from geoapps_utils import assets_path -from geoapps_utils.base import Options, get_logger +from geoapps_utils.base import Options from geoapps_utils.driver.data import BaseData from geoapps_utils.driver.driver import BaseDriver, Driver from geoapps_utils.driver.params import BaseParams @@ -188,17 +188,3 @@ def test_fetch_driver(tmp_path): dict_params["run_command"] = "geoapps_utils.utils.plotting" with pytest.raises(SystemExit, match="1"): fetch_driver_class(dict_params) - - -def test_logger(caplog): - """ - Test that the logger is set up correctly. - """ - logger = get_logger("my-app") - with caplog.at_level("INFO"): - logger.info("Test log message") - - assert "Test log message" in caplog.text - assert "my-app" in caplog.text - assert caplog.records[0].levelname == "INFO" - assert caplog.records[0].name == "my-app" diff --git a/tests/logger_test.py b/tests/logger_test.py new file mode 100644 index 00000000..bf16c77f --- /dev/null +++ b/tests/logger_test.py @@ -0,0 +1,75 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2025 Mira Geoscience Ltd. ' +# ' +# This file is part of geoapps-utils package. ' +# ' +# geoapps-utils is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +from __future__ import annotations + +import logging + +import pytest + +from geoapps_utils.utils.logger import get_logger + + +def test_logger_warning(caplog): + """ + Test that the logger is set up correctly. + """ + # test with everything + logger = get_logger( + "my-app", + timestamp=True, + level_name=True, + propagate=True, # will be set to false because level + add_name=True, + level="warning", + ) + + with caplog.at_level(logging.WARNING): + logger.warning("Test log message") + + assert "Test log message" in caplog.text + assert "my-app" in caplog.text + assert "WARNING" in caplog.text + + +def test_logger_info(caplog): + # test with nothing (expect propagate) + logger_2 = get_logger( + timestamp=False, + level_name=False, + propagate=True, + add_name=False, + ) + + with caplog.at_level(logging.INFO): + logger_2.info("Test log message") + + assert "Test log message" in caplog.text + assert caplog.records[0].levelname == "INFO" + assert caplog.records[0].name == "root" + + +def test_logger_no_propagate(caplog): + # test with propagate false + logger_3 = get_logger( + "my-app", timestamp=False, level_name=False, propagate=False, add_name=False + ) + + with caplog.at_level(logging.INFO): + logger_3.info("Test log message") + + assert caplog.text == "" + + +def test_logger_level_errors(): + with pytest.raises(KeyError, match="Invalid log level: 'bidon'. Choose from"): + get_logger(level="bidon") + + with pytest.raises(TypeError, match="Level must be a string or LoggerLevel"): + get_logger(level=5) # type: ignore From 287c50d831befc5c3b874d80517a01f727494331 Mon Sep 17 00:00:00 2001 From: MatthieuCMira <109624972+MatthieuCMira@users.noreply.github.com> Date: Fri, 29 Aug 2025 09:17:38 -0400 Subject: [PATCH 2/2] remove catching logger error, will raises a clear enough error anyway --- geoapps_utils/utils/logger.py | 7 +------ tests/logger_test.py | 3 --- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/geoapps_utils/utils/logger.py b/geoapps_utils/utils/logger.py index 184cac59..e8a65856 100644 --- a/geoapps_utils/utils/logger.py +++ b/geoapps_utils/utils/logger.py @@ -51,12 +51,7 @@ def get_logger(cls, level: str | LoggerLevel) -> int: :return: The corresponding logging level. """ if isinstance(level, str): - try: - level = cls(level.lower()) - except ValueError as error: - raise KeyError( - f"Invalid log level: '{level}'. Choose from {list(cls)}" - ) from error + level = cls(level.lower()) if not isinstance(level, cls): raise TypeError(f"Level must be a string or LoggerLevel, got {type(level)}") return level.level diff --git a/tests/logger_test.py b/tests/logger_test.py index bf16c77f..80e5bdb1 100644 --- a/tests/logger_test.py +++ b/tests/logger_test.py @@ -68,8 +68,5 @@ def test_logger_no_propagate(caplog): def test_logger_level_errors(): - with pytest.raises(KeyError, match="Invalid log level: 'bidon'. Choose from"): - get_logger(level="bidon") - with pytest.raises(TypeError, match="Level must be a string or LoggerLevel"): get_logger(level=5) # type: ignore