diff --git a/geoapps_utils/utils/logger.py b/geoapps_utils/utils/logger.py index 7cfc933e..7eba746e 100644 --- a/geoapps_utils/utils/logger.py +++ b/geoapps_utils/utils/logger.py @@ -10,14 +10,61 @@ from __future__ import annotations import logging +from enum import Enum + + +class LoggerLevel(str, Enum): + """ + The different possible log levels. + """ + + WARNING = "warning" + INFO = "info" + DEBUG = "debug" + ERROR = "error" + CRITICAL = "critical" + + @property + def level(self) -> int: + """ + Get the current state of the logger. + """ + if self == LoggerLevel.WARNING: + return logging.WARNING + if self == LoggerLevel.INFO: + return logging.INFO + if self == LoggerLevel.DEBUG: + return logging.DEBUG + if self == LoggerLevel.ERROR: + return logging.ERROR + if self == LoggerLevel.CRITICAL: + return logging.CRITICAL + return logging.NOTSET + + @classmethod + def get_logger(cls, level: str | LoggerLevel) -> int: + """ + Get the logger level from a string or LoggerLevel. + + :param level: The log level as a string or LoggerLevel. + + :return: The corresponding logging level. + """ + if isinstance(level, str): + level = cls(level.lower()) + if not isinstance(level, cls): + raise TypeError(f"Level must be a string or LoggerLevel, got {type(level)}") + return level.level def get_logger( name: str | None = None, + *, timestamp: bool = False, level_name: bool = True, - propagate: bool = True, + propagate: bool | None = None, add_name: bool = True, + level: str | LoggerLevel | None = None, ) -> logging.Logger: """ Get a logger with a timestamped stream and specified log level. @@ -27,6 +74,7 @@ def get_logger( :param level_name: Whether to include the log level name in the log format. :param propagate: Whether to propagate log messages to the parent logger. :param add_name: Whether to include the logger name in the log format. + :param level: Logging level to use. :return: Configured logger instance. """ @@ -51,6 +99,11 @@ def get_logger( formatter = logging.Formatter(formatting + "%(message)s") stream_handler.setFormatter(formatter) log.addHandler(stream_handler) - log.propagate = propagate + + if level: + log.setLevel(LoggerLevel.get_logger(level)) + log.propagate = False + elif propagate is not None: + log.propagate = propagate return log diff --git a/tests/driver_test.py b/tests/driver_test.py index e597740c..464fd7d6 100644 --- a/tests/driver_test.py +++ b/tests/driver_test.py @@ -24,7 +24,7 @@ from pydantic import BaseModel, ConfigDict from geoapps_utils import assets_path -from geoapps_utils.base import Options, get_logger +from geoapps_utils.base import Options from geoapps_utils.driver.data import BaseData from geoapps_utils.driver.driver import BaseDriver, Driver from geoapps_utils.driver.params import BaseParams @@ -188,17 +188,3 @@ def test_fetch_driver(tmp_path): dict_params["run_command"] = "geoapps_utils.utils.plotting" with pytest.raises(SystemExit, match="1"): fetch_driver_class(dict_params) - - -def test_logger(caplog): - """ - Test that the logger is set up correctly. - """ - logger = get_logger("my-app") - with caplog.at_level("INFO"): - logger.info("Test log message") - - assert "Test log message" in caplog.text - assert "my-app" in caplog.text - assert caplog.records[0].levelname == "INFO" - assert caplog.records[0].name == "my-app" diff --git a/tests/logger_test.py b/tests/logger_test.py new file mode 100644 index 00000000..80e5bdb1 --- /dev/null +++ b/tests/logger_test.py @@ -0,0 +1,72 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2025 Mira Geoscience Ltd. ' +# ' +# This file is part of geoapps-utils package. ' +# ' +# geoapps-utils is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +from __future__ import annotations + +import logging + +import pytest + +from geoapps_utils.utils.logger import get_logger + + +def test_logger_warning(caplog): + """ + Test that the logger is set up correctly. + """ + # test with everything + logger = get_logger( + "my-app", + timestamp=True, + level_name=True, + propagate=True, # will be set to false because level + add_name=True, + level="warning", + ) + + with caplog.at_level(logging.WARNING): + logger.warning("Test log message") + + assert "Test log message" in caplog.text + assert "my-app" in caplog.text + assert "WARNING" in caplog.text + + +def test_logger_info(caplog): + # test with nothing (expect propagate) + logger_2 = get_logger( + timestamp=False, + level_name=False, + propagate=True, + add_name=False, + ) + + with caplog.at_level(logging.INFO): + logger_2.info("Test log message") + + assert "Test log message" in caplog.text + assert caplog.records[0].levelname == "INFO" + assert caplog.records[0].name == "root" + + +def test_logger_no_propagate(caplog): + # test with propagate false + logger_3 = get_logger( + "my-app", timestamp=False, level_name=False, propagate=False, add_name=False + ) + + with caplog.at_level(logging.INFO): + logger_3.info("Test log message") + + assert caplog.text == "" + + +def test_logger_level_errors(): + with pytest.raises(TypeError, match="Level must be a string or LoggerLevel"): + get_logger(level=5) # type: ignore