diff --git a/src/osekit/core_api/base_dataset.py b/src/osekit/core_api/base_dataset.py index ffd250ef..098fd468 100644 --- a/src/osekit/core_api/base_dataset.py +++ b/src/osekit/core_api/base_dataset.py @@ -148,7 +148,10 @@ def move_files(self, folder: Path) -> None: Destination folder in which the dataset files will be moved. """ - for file in tqdm(self.files, disable=os.environ.get("DISABLE_TQDM", "")): + for file in tqdm( + self.files, + disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"), + ): file.move(folder) self._folder = folder @@ -186,7 +189,7 @@ def write( last = len(self.data) if last is None else last for data in tqdm( self.data[first:last], - disable=os.environ.get("DISABLE_TQDM", ""), + disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"), ): data.write(folder=folder, link=link) @@ -348,7 +351,7 @@ def _get_base_data_from_files_timedelta_total( for data_begin in tqdm( date_range(begin, end, freq=freq, inclusive="left"), - disable=os.environ.get("DISABLE_TQDM", ""), + disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"), ): data_end = Timestamp(data_begin + data_duration) while ( @@ -395,7 +398,7 @@ def _get_base_data_from_files_timedelta_file( files_chunk = [] for idx, file in tqdm( enumerate(files[first:last]), - disable=os.environ.get("DISABLE_TQDM", ""), + disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"), ): if file in files_chunk: continue @@ -492,7 +495,10 @@ def from_folder( # noqa: PLR0913 supported_file_extensions = [] valid_files = [] rejected_files = [] - for file in tqdm(folder.iterdir(), disable=os.environ.get("DISABLE_TQDM", "")): + for file in tqdm( + folder.iterdir(), + disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"), + ): if file.suffix.lower() not in supported_file_extensions: continue try: diff --git a/src/osekit/public_api/export_analysis.py b/src/osekit/public_api/export_analysis.py index 258c93b2..de63ebd8 100644 --- a/src/osekit/public_api/export_analysis.py +++ b/src/osekit/public_api/export_analysis.py @@ -239,24 +239,24 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument( "--tqdm-disable", required=False, - type=str, - default="true", + action=argparse.BooleanOptionalAction, + default=True, help="Disable TQDM progress bars.", ) parser.add_argument( "--multiprocessing", required=False, - type=str, - default="false", + action=argparse.BooleanOptionalAction, + default=False, help="Turn multiprocessing on or off.", ) parser.add_argument( "--use-logging-setup", required=False, - type=str, - default="false", + action=argparse.BooleanOptionalAction, + default=False, help="Call osekit.setup_logging() before running the analysis.", ) @@ -284,12 +284,12 @@ def main() -> None: """Export an analysis.""" args = create_parser().parse_args() - os.environ["DISABLE_TQDM"] = "" if not args.tqdm_disable else str(args.tqdm_disable) + os.environ["DISABLE_TQDM"] = str(args.tqdm_disable) - if args.use_logging_setup.lower() == "true": + if args.use_logging_setup: setup_logging() - config.multiprocessing["is_active"] = args.multiprocessing.lower() == "true" + config.multiprocessing["is_active"] = args.multiprocessing if (nb_processes := args.nb_processes) is not None: config.multiprocessing["nb_processes"] = ( None if nb_processes.lower() == "none" else int(nb_processes) diff --git a/src/osekit/utils/job.py b/src/osekit/utils/job.py index 738675ab..9a22c68e 100644 --- a/src/osekit/utils/job.py +++ b/src/osekit/utils/job.py @@ -256,6 +256,16 @@ def progress(self) -> None: return self._status = JobStatus(self._status.value + 1) + def _build_arg_string(self) -> str: + """Build a string representation of the job's arguments.""" + arg_list = [] + for key, value in self.script_args.items(): + if isinstance(value, bool): + arg_list.append(f"--{'no-' if not value else ''}{key}") + else: + arg_list.append(f"--{key} {value}") + return " ".join(arg_list) + def write_pbs(self, path: Path) -> None: """Write a PBS file matching the job. @@ -287,7 +297,8 @@ def write_pbs(self, path: Path) -> None: for key, value in request.items() if value ) - script = f"python {self.script_path} {' '.join(f'--{key} {value}' for key, value in self.script_args.items())}" + + script = f"python {self.script_path} {self._build_arg_string()}" pbs = f"{preamble}\n{request_str}\n{self.venv_activate_script}\n{script}" with path.open("w") as file: diff --git a/src/osekit/utils/multiprocess_utils.py b/src/osekit/utils/multiprocess_utils.py index e20d1043..cb5eebb3 100644 --- a/src/osekit/utils/multiprocess_utils.py +++ b/src/osekit/utils/multiprocess_utils.py @@ -43,7 +43,11 @@ def multiprocess( if bypass_multiprocessing or not config.multiprocessing["is_active"]: return list( func(element, *args, **kwargs) - for element in tqdm(enumerable, disable=os.environ.get("DISABLE_TQDM", "")) + for element in tqdm( + enumerable, + disable=os.getenv("DISABLE_TQDM", "False").lower() + in ("true", "1", "t"), + ) ) partial_func = partial(func, *args, **kwargs) @@ -53,6 +57,7 @@ def multiprocess( tqdm( pool.imap(partial_func, enumerable), total=len(list(enumerable)), - disable=os.environ.get("DISABLE_TQDM", ""), + disable=os.getenv("DISABLE_TQDM", "False").lower() + in ("true", "1", "t"), ), ) diff --git a/tests/test_export_analysis.py b/tests/test_export_analysis.py index 35afb3d3..7b606ee2 100644 --- a/tests/test_export_analysis.py +++ b/tests/test_export_analysis.py @@ -1,6 +1,7 @@ import argparse import logging import os +import shlex from pathlib import Path import pytest @@ -10,6 +11,7 @@ from osekit.core_api.spectro_dataset import SpectroDataset from osekit.public_api import export_analysis from osekit.public_api.export_analysis import create_parser +from osekit.utils.job import Job def test_parser_factory() -> None: @@ -64,9 +66,9 @@ def test_argument_defaults() -> None: assert args.downsampling_quality is None assert args.upsampling_quality is None assert args.umask == 0o002 # noqa: PLR2004 - assert args.tqdm_disable == "true" - assert args.multiprocessing == "false" - assert args.use_logging_setup == "false" + assert args.tqdm_disable + assert not args.multiprocessing + assert not args.use_logging_setup assert args.nb_processes is None assert args.dataset_json_path is None @@ -74,71 +76,71 @@ def test_argument_defaults() -> None: @pytest.fixture def script_arguments() -> dict: return { - "--analysis": 2, - "--ads-json": r"path/to/ads.json", - "--sds-json": r"path/to/ads.json", - "--subtype": "FLOAT", - "--matrix-folder-path": r"out/matrix", - "--spectrogram-folder-path": r"out/spectro", - "--welch-folder-path": r"out/welch", - "--first": 10, - "--last": 12, - "--downsampling-quality": "HQ", - "--upsampling-quality": "VHQ", - "--umask": 0o022, - "--tqdm-disable": "False", - "--multiprocessing": "True", - "--nb-processes": "3", # String because it might be "None" - "--use-logging-setup": "True", - "--dataset-json-path": r"path/to/dataset.json", + "analysis": 2, + "ads-json": r"path/to/ads.json", + "sds-json": r"path/to/ads.json", + "subtype": "FLOAT", + "matrix-folder-path": r"out/matrix", + "spectrogram-folder-path": r"out/spectro", + "welch-folder-path": r"out/welch", + "first": 10, + "last": 12, + "downsampling-quality": "HQ", + "upsampling-quality": "VHQ", + "umask": 0o022, + "tqdm-disable": False, + "multiprocessing": True, + "nb-processes": "3", # String because it might be "None" + "use-logging-setup": True, + "dataset-json-path": r"path/to/dataset.json", } def test_specified_arguments(script_arguments: dict) -> None: parser = create_parser() - args = parser.parse_args( - [str(arg_part) for arg in script_arguments.items() for arg_part in arg], - ) - - assert args.analysis == script_arguments["--analysis"] - assert args.ads_json == script_arguments["--ads-json"] - assert args.sds_json == script_arguments["--sds-json"] - assert args.subtype == script_arguments["--subtype"] - assert args.matrix_folder_path == script_arguments["--matrix-folder-path"] - assert args.spectrogram_folder_path == script_arguments["--spectrogram-folder-path"] - assert args.welch_folder_path == script_arguments["--welch-folder-path"] - assert args.first == script_arguments["--first"] - assert args.last == script_arguments["--last"] - assert args.downsampling_quality == script_arguments["--downsampling-quality"] - assert args.upsampling_quality == script_arguments["--upsampling-quality"] - assert args.umask == script_arguments["--umask"] - assert args.tqdm_disable == script_arguments["--tqdm-disable"] - assert args.multiprocessing == script_arguments["--multiprocessing"] - assert args.use_logging_setup == script_arguments["--use-logging-setup"] - assert args.nb_processes == script_arguments["--nb-processes"] - assert args.dataset_json_path == script_arguments["--dataset-json-path"] + parsed_str = Job(Path(), script_arguments)._build_arg_string() + + args = parser.parse_args(shlex.split(parsed_str)) + + assert args.analysis == script_arguments["analysis"] + assert args.ads_json == script_arguments["ads-json"] + assert args.sds_json == script_arguments["sds-json"] + assert args.subtype == script_arguments["subtype"] + assert args.matrix_folder_path == script_arguments["matrix-folder-path"] + assert args.spectrogram_folder_path == script_arguments["spectrogram-folder-path"] + assert args.welch_folder_path == script_arguments["welch-folder-path"] + assert args.first == script_arguments["first"] + assert args.last == script_arguments["last"] + assert args.downsampling_quality == script_arguments["downsampling-quality"] + assert args.upsampling_quality == script_arguments["upsampling-quality"] + assert args.umask == script_arguments["umask"] + assert args.tqdm_disable == script_arguments["tqdm-disable"] + assert args.multiprocessing == script_arguments["multiprocessing"] + assert args.use_logging_setup == script_arguments["use-logging-setup"] + assert args.nb_processes == script_arguments["nb-processes"] + assert args.dataset_json_path == script_arguments["dataset-json-path"] def test_main_script(monkeypatch: pytest.MonkeyPatch, script_arguments: dict) -> None: class MockedArgs: def __init__(self, *args: list, **kwargs: dict) -> None: - self.analysis = script_arguments["--analysis"] - self.ads_json = script_arguments["--ads-json"] - self.sds_json = script_arguments["--sds-json"] - self.subtype = script_arguments["--subtype"] - self.matrix_folder_path = script_arguments["--matrix-folder-path"] - self.spectrogram_folder_path = script_arguments["--spectrogram-folder-path"] - self.welch_folder_path = script_arguments["--welch-folder-path"] - self.first = script_arguments["--first"] - self.last = script_arguments["--last"] - self.downsampling_quality = script_arguments["--downsampling-quality"] - self.upsampling_quality = script_arguments["--upsampling-quality"] - self.umask = script_arguments["--umask"] - self.tqdm_disable = script_arguments["--tqdm-disable"] - self.multiprocessing = script_arguments["--multiprocessing"] - self.use_logging_setup = script_arguments["--use-logging-setup"] - self.nb_processes = script_arguments["--nb-processes"] + self.analysis = script_arguments["analysis"] + self.ads_json = script_arguments["ads-json"] + self.sds_json = script_arguments["sds-json"] + self.subtype = script_arguments["subtype"] + self.matrix_folder_path = script_arguments["matrix-folder-path"] + self.spectrogram_folder_path = script_arguments["spectrogram-folder-path"] + self.welch_folder_path = script_arguments["welch-folder-path"] + self.first = script_arguments["first"] + self.last = script_arguments["last"] + self.downsampling_quality = script_arguments["downsampling-quality"] + self.upsampling_quality = script_arguments["upsampling-quality"] + self.umask = script_arguments["umask"] + self.tqdm_disable = script_arguments["tqdm-disable"] + self.multiprocessing = script_arguments["multiprocessing"] + self.use_logging_setup = script_arguments["use-logging-setup"] + self.nb_processes = script_arguments["nb-processes"] self.dataset_json_path = "none" def return_mocked_attr(*args: list, **kwargs: dict) -> MockedArgs: @@ -172,31 +174,33 @@ def mock_write_analysis(*args: list, **kwargs: dict) -> None: export_analysis.main() - assert os.environ["DISABLE_TQDM"] == str(args.tqdm_disable) - assert config.multiprocessing["is_active"] is True - assert config.multiprocessing["nb_processes"] == 3 + assert ( + os.environ["DISABLE_TQDM"].lower() in ("true", "1", "t") + ) == args.tqdm_disable + assert config.multiprocessing["is_active"] + assert config.multiprocessing["nb_processes"] == 3 # noqa: PLR2004 assert ( config.resample_quality_settings["downsample"] - == script_arguments["--downsampling-quality"] + == script_arguments["downsampling-quality"] ) assert ( config.resample_quality_settings["upsample"] - == script_arguments["--upsampling-quality"] + == script_arguments["upsampling-quality"] ) - assert calls["ads_json"] == Path(script_arguments["--ads-json"]) - assert calls["sds_json"] == Path(script_arguments["--sds-json"]) + assert calls["ads_json"] == Path(script_arguments["ads-json"]) + assert calls["sds_json"] == Path(script_arguments["sds-json"]) # write_analysis - assert calls["analysis_type"].value == script_arguments["--analysis"] - assert calls["ads"] == Path(script_arguments["--ads-json"]) - assert calls["sds"] == Path(script_arguments["--sds-json"]) - assert calls["subtype"] == script_arguments["--subtype"] - assert calls["matrix_folder_path"] == Path(script_arguments["--matrix-folder-path"]) + assert calls["analysis_type"].value == script_arguments["analysis"] + assert calls["ads"] == Path(script_arguments["ads-json"]) + assert calls["sds"] == Path(script_arguments["sds-json"]) + assert calls["subtype"] == script_arguments["subtype"] + assert calls["matrix_folder_path"] == Path(script_arguments["matrix-folder-path"]) assert calls["spectrogram_folder_path"] == Path( - script_arguments["--spectrogram-folder-path"], + script_arguments["spectrogram-folder-path"], ) - assert calls["welch_folder_path"] == Path(script_arguments["--welch-folder-path"]) - assert calls["first"] == script_arguments["--first"] - assert calls["last"] == script_arguments["--last"] + assert calls["welch_folder_path"] == Path(script_arguments["welch-folder-path"]) + assert calls["first"] == script_arguments["first"] + assert calls["last"] == script_arguments["last"] assert calls["link"] is True assert calls["logger"] == logging.getLogger() diff --git a/tests/test_job.py b/tests/test_job.py index e247ef27..3b3052b5 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -355,14 +355,18 @@ def write_pbs(self, path: Path) -> None: job_builder.create_job( script_path=script, - script_args={"les": "fantômes", "de": "baleines"}, + script_args={"les": "fantômes", "de": "baleines", "bool": False}, name="idylle_des_abysses", output_folder=output_dir, ) keywords = called["init_job"] assert keywords["script_path"] == script - assert keywords["script_args"] == {"les": "fantômes", "de": "baleines"} + assert keywords["script_args"] == { + "les": "fantômes", + "de": "baleines", + "bool": False, + } assert keywords["name"] == "idylle_des_abysses" assert keywords["output_folder"] == output_dir @@ -373,6 +377,33 @@ def write_pbs(self, path: Path) -> None: assert job_builder.jobs[0].status == JobStatus.PREPARED +def test_build_arg_string_booleans(tmp_path: Path): + job_builder = JobBuilder() + assert job_builder.jobs == [] + + output_dir = tmp_path / "output" + output_dir.mkdir() + script = tmp_path / "script.py" + script.write_text("") + + job_builder.create_job( + script_path=script, + script_args={ + "danser": False, + "avec": True, + "le": 0.3, + "vent": "test" + }, + name="danser_avec_le_vent", + output_folder=output_dir, + ) + + job = next(iter(job_builder.jobs)) + arg_str = job._build_arg_string() + + assert arg_str == "--no-danser --avec --le 0.3 --vent test" + + def test_job_builder_submit(monkeypatch: pytest.MonkeyPatch) -> None: submitted_jobs = []