Skip to content

Commit d897db3

Browse files
committed
Normalize config defaults in build_dag
1 parent 7840ff3 commit d897db3

1 file changed

Lines changed: 32 additions & 22 deletions

File tree

src/_pytask/dag_command.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import click
1212
import networkx as nx
13+
from click._utils import Sentinel
1314
from rich.text import Text
1415

1516
from _pytask.click import ColoredCommand
@@ -156,28 +157,8 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
156157

157158
raw_config = {**DEFAULTS_FROM_CLI, **raw_config} # ty: ignore[invalid-assignment]
158159

159-
paths_value = raw_config["paths"]
160-
# Convert tuple to list since parse_paths expects Path | list[Path]
161-
if isinstance(paths_value, tuple):
162-
paths_value = list(paths_value)
163-
if not isinstance(paths_value, (Path, list)):
164-
msg = f"paths must be Path or list, got {type(paths_value)}"
165-
raise TypeError(msg) # noqa: TRY301
166-
# Cast is justified - we validated at runtime
167-
raw_config["paths"] = parse_paths(cast("Path | list[Path]", paths_value))
168-
169-
if raw_config["config"] is not None:
170-
config_value = raw_config["config"]
171-
if not isinstance(config_value, (str, Path)):
172-
msg = f"config must be str or Path, got {type(config_value)}"
173-
raise TypeError(msg) # noqa: TRY301
174-
raw_config["config"] = Path(config_value).resolve()
175-
raw_config["root"] = raw_config["config"].parent
176-
else:
177-
(
178-
raw_config["root"],
179-
raw_config["config"],
180-
) = find_project_root_and_config(raw_config["paths"])
160+
raw_config["paths"] = _normalize_paths_value(raw_config["paths"])
161+
_normalize_config_value(raw_config)
181162

182163
if raw_config["config"] is not None:
183164
config_from_file = read_config(raw_config["config"])
@@ -215,6 +196,35 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
215196
return _refine_dag(session)
216197

217198

199+
def _normalize_paths_value(paths_value: Any) -> list[Path]:
200+
"""Normalize paths from the programmatic interface."""
201+
if isinstance(paths_value, tuple):
202+
paths_value = list(paths_value)
203+
if not isinstance(paths_value, (Path, list)):
204+
msg = f"paths must be Path or list, got {type(paths_value)}"
205+
raise TypeError(msg)
206+
# Cast is justified - we validated at runtime
207+
return parse_paths(cast("Path | list[Path]", paths_value))
208+
209+
210+
def _normalize_config_value(raw_config: dict[str, Any]) -> None:
211+
"""Normalize config value from the programmatic interface."""
212+
config_value = raw_config["config"]
213+
if isinstance(config_value, Sentinel):
214+
config_value = None
215+
raw_config["config"] = None
216+
if config_value is not None:
217+
if not isinstance(config_value, (str, Path)):
218+
msg = f"config must be str or Path, got {type(config_value)}"
219+
raise TypeError(msg)
220+
raw_config["config"] = Path(config_value).resolve()
221+
raw_config["root"] = raw_config["config"].parent
222+
else:
223+
raw_config["root"], raw_config["config"] = find_project_root_and_config(
224+
raw_config["paths"]
225+
)
226+
227+
218228
def _refine_dag(session: Session) -> nx.DiGraph:
219229
"""Refine the dag for plotting."""
220230
dag = _shorten_node_labels(session.dag, session.config["paths"])

0 commit comments

Comments
 (0)