diff --git a/src/lib/cli.py b/src/lib/cli.py index 926fa93..799b608 100644 --- a/src/lib/cli.py +++ b/src/lib/cli.py @@ -68,4 +68,6 @@ def main(): format = plot.default_save_format() - plot.save(args.save, format=format) + path = args.save / f"{args.get_save_file_stem()}.{format}" + plot.save_to_path(path) + print(f"wrote to {path}") diff --git a/src/lib/data/adaptor.py b/src/lib/data/adaptor.py index 00a0e05..486dcaf 100644 --- a/src/lib/data/adaptor.py +++ b/src/lib/data/adaptor.py @@ -5,6 +5,7 @@ import xarray as xr from lib.data.data_with_attrs import DataWithAttrs, Field, List, Metadata +from lib.has_name_fragments import HasNameFragments def _fail_apply_field(adaptor_type: type[Adaptor]): @@ -17,7 +18,7 @@ def _fail_apply_list(adaptor_type: type[Adaptor]): raise RuntimeError(message) -class Adaptor: +class Adaptor(HasNameFragments): def apply(self, data: DataWithAttrs) -> DataWithAttrs: if isinstance(data, List): return self.apply_list(data) @@ -33,9 +34,6 @@ def apply_list(self, data: List) -> DataWithAttrs: def apply_field(self, data: Field) -> DataWithAttrs: _fail_apply_list(self.__class__) - def get_name_fragments(self) -> list[str]: - return [] - class MetadataAdaptor(Adaptor): """Wraps `apply` to perform standard metadata mutations.""" @@ -49,8 +47,6 @@ def get_modified_unit_latex(self, metadata: Metadata) -> str: def apply(self, data: DataWithAttrs) -> DataWithAttrs: data = super().apply(data) - name_fragments = data.metadata.name_fragments + self.get_name_fragments() - var_infos = data.metadata.var_infos if data.metadata.active_key is not None and data.metadata.active_key in var_infos: display_latex = self.get_modified_display_latex(data.metadata) @@ -59,10 +55,7 @@ def apply(self, data: DataWithAttrs) -> DataWithAttrs: new_dim = old_dim.assign(display=display_latex, unit=unit_latex) var_infos = {**var_infos, data.metadata.active_key: new_dim} - return data.assign_metadata( - name_fragments=name_fragments, - var_infos=var_infos, - ) + return data.assign_metadata(var_infos=var_infos) class BareAdaptor(MetadataAdaptor): diff --git a/src/lib/data/adaptors/display.py b/src/lib/data/adaptors/display.py index 92edab6..e757e08 100644 --- a/src/lib/data/adaptors/display.py +++ b/src/lib/data/adaptors/display.py @@ -12,7 +12,6 @@ def __init__(self, target: str | None, value: str): def apply(self, data: DataWithAttrs) -> DataWithAttrs: metadata = data.metadata - name_fragments = metadata.name_fragments + self.get_name_fragments() target = self.target or metadata.active_key if target is None: @@ -24,7 +23,7 @@ def apply(self, data: DataWithAttrs) -> DataWithAttrs: old_dim = metadata.var_infos[target] new_dim = old_dim.assign(display=self.value) new_var_infos = {**metadata.var_infos, target: new_dim} - return data.assign_metadata(name_fragments=name_fragments, var_infos=new_var_infos) + return data.assign_metadata(var_infos=new_var_infos) def get_name_fragments(self) -> list[str]: return [f"display_{self.target or 'active'}={self.value}"] diff --git a/src/lib/data/adaptors/unit.py b/src/lib/data/adaptors/unit.py index 0e94f4f..45d7e8f 100644 --- a/src/lib/data/adaptors/unit.py +++ b/src/lib/data/adaptors/unit.py @@ -12,7 +12,6 @@ def __init__(self, target: str | None, value: str): def apply(self, data: DataWithAttrs) -> DataWithAttrs: metadata = data.metadata - name_fragments = metadata.name_fragments + self.get_name_fragments() target = self.target or metadata.active_key if target is None: @@ -24,7 +23,7 @@ def apply(self, data: DataWithAttrs) -> DataWithAttrs: old_dim = metadata.var_infos[target] new_dim = old_dim.assign(unit=self.value) new_var_infos = {**metadata.var_infos, target: new_dim} - return data.assign_metadata(name_fragments=name_fragments, var_infos=new_var_infos) + return data.assign_metadata(var_infos=new_var_infos) def get_name_fragments(self) -> list[str]: return [f"unit_{self.target or 'active'}={self.value}"] diff --git a/src/lib/data/adaptors/versus.py b/src/lib/data/adaptors/versus.py index 9ef2c50..3c127a2 100644 --- a/src/lib/data/adaptors/versus.py +++ b/src/lib/data/adaptors/versus.py @@ -14,9 +14,6 @@ def __init__(self, spatial_dims: list[str], time_dim: str | None, color_dim: str self.all_dims = spatial_dims + ([time_dim] if time_dim else []) + ([color_dim] if color_dim else []) def apply_field(self, data: Field) -> Field: - # going to cut out name fragments of inner adaptors, since they can be inferred - initial_name_fragments = data.metadata.name_fragments - # 1. apply implicit coordinate transforms, as necessary for dim_name in self.all_dims: # 1a. already have the coordinate; do nothing @@ -46,13 +43,9 @@ def apply_field(self, data: Field) -> Field: spatial_dims=self.spatial_dims.copy(), time_dim=self.time_dim, color_dim=self.color_dim, - name_fragments=initial_name_fragments, ) def apply_list(self, data: List) -> List: - # going to cut out name fragments of inner adaptors, since they can be inferred - initial_name_fragments = data.metadata.name_fragments - # 1. coordinate transform # TODO @@ -69,7 +62,6 @@ def apply_list(self, data: List) -> List: spatial_dims=spatial_dims, time_dim=self.time_dim, color_dim=self.color_dim, - name_fragments=initial_name_fragments, ) def get_name_fragments(self) -> list[str]: diff --git a/src/lib/data/adaptors/with_.py b/src/lib/data/adaptors/with_.py index 1065859..c7f66a6 100644 --- a/src/lib/data/adaptors/with_.py +++ b/src/lib/data/adaptors/with_.py @@ -8,10 +8,7 @@ def __init__(self, key: str): self.key = key def apply(self, data: DataWithAttrs) -> DataWithAttrs: - return data.assign_metadata( - active_key=self.key, - name_fragments=data.metadata.name_fragments + self.get_name_fragments(), - ) + return data.assign_metadata(active_key=self.key) def get_name_fragments(self) -> list[str]: return [f"with_{self.key}"] diff --git a/src/lib/data/data_with_attrs.py b/src/lib/data/data_with_attrs.py index d3d84f4..03d4597 100644 --- a/src/lib/data/data_with_attrs.py +++ b/src/lib/data/data_with_attrs.py @@ -19,7 +19,6 @@ @dataclass(kw_only=True, frozen=True) class Metadata: active_key: str | None = None - name_fragments: list[str] = field(default_factory=list) spatial_dims: list[str] = field(default_factory=list) time_dim: str | None = None diff --git a/src/lib/data/loader.py b/src/lib/data/loader.py index e69b2fc..939ace1 100644 --- a/src/lib/data/loader.py +++ b/src/lib/data/loader.py @@ -4,9 +4,10 @@ from lib.data.data_source import DataSource from lib.file_util import get_available_steps +from lib.has_name_fragments import HasNameFragments -class Loader(DataSource): +class Loader(DataSource, HasNameFragments): @classmethod @abstractmethod def discover_prefixes(cls, data_dir: Path) -> list[str]: @@ -22,7 +23,7 @@ def __init__(self, prefix: str, active_key: str | None = None): self.steps = get_available_steps(prefix + ".", "." + self.suffix()) self.active_key = active_key - def _get_name_fragments(self) -> list[str]: + def get_name_fragments(self) -> list[str]: fragments = [self.prefix] if self.active_key is not None: fragments.append(self.active_key) diff --git a/src/lib/data/loaders/field_bp.py b/src/lib/data/loaders/field_bp.py index 2f1a383..30d4d74 100644 --- a/src/lib/data/loaders/field_bp.py +++ b/src/lib/data/loaders/field_bp.py @@ -46,7 +46,6 @@ def get_data(self) -> Field: var_info = {key: lookup(self.prefix, key) for key in ds.variables} metadata = FieldMetadata( active_key=self.active_key, - name_fragments=self._get_name_fragments(), prefix=self.prefix, var_infos=var_info, ) diff --git a/src/lib/data/loaders/particle_bp.py b/src/lib/data/loaders/particle_bp.py index c49d3c1..0ff72af 100644 --- a/src/lib/data/loaders/particle_bp.py +++ b/src/lib/data/loaders/particle_bp.py @@ -116,7 +116,6 @@ def get_data(self) -> LazyList: # species suffix when looking up per-column metadata. var_infos = {key: lookup("prt", key) for key in data.dims} return data.assign_metadata( - name_fragments=self._get_name_fragments(), active_key=self.active_key, var_infos=var_infos, ) diff --git a/src/lib/data/loaders/particle_h5.py b/src/lib/data/loaders/particle_h5.py index 2c97aec..77ffa33 100644 --- a/src/lib/data/loaders/particle_h5.py +++ b/src/lib/data/loaders/particle_h5.py @@ -215,7 +215,6 @@ def get_data(self) -> LazyList: var_infos = {key: lookup(self.prefix, key) for key in df_with_metadata.dims} return df_with_metadata.assign_metadata( - name_fragments=self._get_name_fragments(), active_key=self.active_key, var_infos=var_infos, subject=Latex(r"\text{Particles}"), diff --git a/src/lib/data/pipeline.py b/src/lib/data/pipeline.py index a8e9411..2543998 100644 --- a/src/lib/data/pipeline.py +++ b/src/lib/data/pipeline.py @@ -5,9 +5,6 @@ class Pipeline(Adaptor): def __init__(self, *adaptors: Adaptor): self.adaptors = list(adaptors) - def get_name_fragments(self) -> list[str]: - return [fragment for adaptor in self.adaptors for fragment in adaptor.get_name_fragments()] - def apply(self, data): for adaptor in self.adaptors: data = adaptor.apply(data) diff --git a/src/lib/has_name_fragments.py b/src/lib/has_name_fragments.py new file mode 100644 index 0000000..b987939 --- /dev/null +++ b/src/lib/has_name_fragments.py @@ -0,0 +1,3 @@ +class HasNameFragments: + def get_name_fragments(self) -> list[str]: + return [] diff --git a/src/lib/parsing/args.py b/src/lib/parsing/args.py index 2acda96..7d4dd3b 100644 --- a/src/lib/parsing/args.py +++ b/src/lib/parsing/args.py @@ -29,3 +29,8 @@ def get_animation(self) -> Plot: plot.add_hook(hook) return plot + + def get_save_file_stem(self) -> str: + sources = [self.loader, *self.adaptors, *self.hooks] + fragments = [frag for src in sources for frag in src.get_name_fragments()] + return "-".join(fragments) diff --git a/src/lib/plotting/animated_plot.py b/src/lib/plotting/animated_plot.py index f17d4eb..fe1a117 100644 --- a/src/lib/plotting/animated_plot.py +++ b/src/lib/plotting/animated_plot.py @@ -60,7 +60,7 @@ def show(self): def allowed_save_formats(self) -> list[SaveFormat]: return ["mp4", "gif"] - def _save_to_path(self, path: Path): + def save_to_path(self, path: Path): self._initialize() writer = PillowWriter() if path.suffix == ".gif" else FFMpegWriter() self.anim.save(path, writer=writer) diff --git a/src/lib/plotting/hook.py b/src/lib/plotting/hook.py index d807d6a..9cb4749 100644 --- a/src/lib/plotting/hook.py +++ b/src/lib/plotting/hook.py @@ -1,7 +1,9 @@ from typing import Any +from lib.has_name_fragments import HasNameFragments -class Hook: + +class Hook(HasNameFragments): def post_add_hook(self, add_data: Any): pass diff --git a/src/lib/plotting/hooks/scale.py b/src/lib/plotting/hooks/scale.py index e181f48..d2ea080 100644 --- a/src/lib/plotting/hooks/scale.py +++ b/src/lib/plotting/hooks/scale.py @@ -32,6 +32,9 @@ def to_axis_scale(self, data: DataWithAttrs) -> plt_util.AxisScaleArg: def to_color_norm(self, data: DataWithAttrs) -> plt_util.ColorNormArg: return self.scale_key + def to_name_fragment_part(self) -> str: + return str(self.scale_key) + @classmethod def to_argparse_format(cls) -> str: return cls.scale_key @@ -73,6 +76,11 @@ def to_color_norm(self, data: DataWithAttrs) -> plt_util.ColorNormArg: linthresh = self.linear_threshold or self._choose_linear_threshold(data) return SymLogNorm(linthresh) + def to_name_fragment_part(self) -> str: + if self.linear_threshold is None: + return self.scale_key + return f"{self.scale_key}{parse_util.SUBARG_DELIM}{self.linear_threshold}" + @classmethod def to_argparse_format(cls) -> str: return f"{cls.scale_key}[{parse_util.SUBARG_DELIM}{cls.LINEAR_THRESHOLD_ARG_FORMAT}]" @@ -122,6 +130,10 @@ def pre_init_fig(self, init_data): message = f"'{self.dim_name}' isn't a dimension" raise Exception(message) + def get_name_fragments(self) -> list[str]: + maybe_dim_name = f"{self.dim_name}=" if self.dim_name is not None else "" + return [f"scale_{maybe_dim_name}{self.scale.to_name_fragment_part()}"] + ANY_SCALE_ARGS_FORMAT = "{" + ",".join(scale_type.to_argparse_format() for scale_type in SCALE_TYPES) + "}" SCALE_FORMAT = f"[dim_name=]{ANY_SCALE_ARGS_FORMAT}" @@ -133,7 +145,7 @@ def pre_init_fig(self, init_data): help="set the axis/color scale of the dependent variable or specified dimension (default: linear)", dest="hooks", ) -def parse_vline(arg: str) -> Scale: +def parse_scale(arg: str) -> Scale: if "=" in arg: dim_name, scale_arg = parse_util.parse_assignment(arg, SCALE_FORMAT) parse_util.check_identifier(dim_name, "dim_name") diff --git a/src/lib/plotting/plot.py b/src/lib/plotting/plot.py index 99d7150..8e33d60 100644 --- a/src/lib/plotting/plot.py +++ b/src/lib/plotting/plot.py @@ -32,7 +32,7 @@ def add_hook(self, hook: Hook): def show(self): ... @abstractmethod - def _save_to_path(self, path: Path): ... + def save_to_path(self, path: Path): ... @abstractmethod def allowed_save_formats(self) -> list[SaveFormat]: ... @@ -40,13 +40,6 @@ def allowed_save_formats(self) -> list[SaveFormat]: ... def default_save_format(self) -> SaveFormat: return self.allowed_save_formats()[0] - def save(self, dir: Path, format: SaveFormat | None = None): - format = format or self.default_save_format() - name = "-".join(self.data.metadata.name_fragments) + "." + format - path = dir / name - self._save_to_path(path) - print(f"wrote to {path}") - def pre_init_fig(self, init_data: Any): for hook in self.hooks: hook.pre_init_fig(init_data) diff --git a/src/lib/plotting/static_plot.py b/src/lib/plotting/static_plot.py index 3ed2732..9f0e37e 100644 --- a/src/lib/plotting/static_plot.py +++ b/src/lib/plotting/static_plot.py @@ -32,6 +32,6 @@ def show(self): def allowed_save_formats(self) -> list[SaveFormat]: return ["png"] - def _save_to_path(self, path: Path): + def save_to_path(self, path: Path): self._initialize() self.fig.savefig(path, dpi=300.0) diff --git a/tests/conftest.py b/tests/conftest.py index 9d354ba..ea4b2a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,10 +45,9 @@ def make_save(args_list: list[str], save_dir: Path, format: SaveFormat, data_dir args = get_parsed_args(args_list) plot = args.get_animation() save_dir.mkdir(exist_ok=True) - plot.save(save_dir, format=format) - - name = "-".join(plot.data.metadata.name_fragments) + "." + format - return save_dir / name + path = save_dir / f"{args.get_save_file_stem()}.{format}" + plot.save_to_path(path) + return path finally: if data_dir is not None: CONFIG.data_dir = original_dir diff --git a/tests/test_save_filename.py b/tests/test_save_filename.py new file mode 100644 index 0000000..79b0b37 --- /dev/null +++ b/tests/test_save_filename.py @@ -0,0 +1,22 @@ +import pytest + +from lib.data.compile import compile_source +from lib.parsing.parse import get_parsed_args + + +def _stem(args_list: list[str]) -> str: + args = get_parsed_args(args_list) + compile_source(args.loader, args.adaptors) # mutates args.adaptors: appends default Versus + return args.get_save_file_stem() + + +@pytest.mark.parametrize( + "args_list, expected_stem", + [ + (["pfd", "hx_fc"], "pfd-hx_fc-vs_y,z;time=t"), + (["pfd", "hx_fc", "--nan0"], "pfd-hx_fc-nan0-vs_y,z;time=t"), + (["pfd", "hx_fc", "--scale", "log"], "pfd-hx_fc-vs_y,z;time=t-scale_log"), + ], +) +def test_save_file_stem(args_list, expected_stem): + assert _stem(args_list) == expected_stem