From 3fdad88156f9f1c0dc517212c95d9c8715ce35e5 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 12:07:18 -0500 Subject: [PATCH 01/16] restructure for serialization --- nt2/__init__.py | 2 +- nt2/containers/container.py | 14 +- nt2/containers/data.py | 168 +++++++++------- nt2/containers/fields.py | 63 +++--- nt2/containers/particles.py | 388 +++++++++++++++++++----------------- nt2/containers/spectra.py | 5 +- nt2/plotters/export.py | 92 ++++++--- pyproject.toml | 1 + requirements.txt | 15 +- 9 files changed, 413 insertions(+), 335 deletions(-) diff --git a/nt2/__init__.py b/nt2/__init__.py index 8679bb6..684950b 100644 --- a/nt2/__init__.py +++ b/nt2/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.3.0" +__version__ = "1.4.0" import nt2.containers.data as nt2_data diff --git a/nt2/containers/container.py b/nt2/containers/container.py index da0da3b..cafb5e1 100644 --- a/nt2/containers/container.py +++ b/nt2/containers/container.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Optional from nt2.readers.base import BaseReader @@ -6,11 +6,15 @@ class BaseContainer: """Parent container class for holding any category data.""" + __path: str + __reader: BaseReader + __remap: Optional[dict[str, Callable[[str], str]]] + def __init__( self, path: str, reader: BaseReader, - remap: dict[str, Callable[[str], str]] | None = None, + remap: Optional[dict[str, Callable[[str], str]]] = None, ): """Initializer for the BaseContainer class. @@ -20,7 +24,7 @@ def __init__( The path to the data. reader : BaseReader The reader to be used for reading the data. - remap : dict[str, Callable[[str], str]] | None + remap : Optional[dict[str, Callable[[str], str]]] Remap dictionary to use to remap the data names (coords, fields, etc.). """ @@ -40,6 +44,6 @@ def reader(self) -> BaseReader: return self.__reader @property - def remap(self) -> dict[str, Callable[[str], str]] | None: - """dict[str, Callable[[str], str]]: The coordinate/field remap dictionary.""" + def remap(self) -> Optional[dict[str, Callable[[str], str]]]: + """{ str: (str) -> str } : The coordinate/field remap dictionary.""" return self.__remap diff --git a/nt2/containers/data.py b/nt2/containers/data.py index a4f6e00..328e545 100644 --- a/nt2/containers/data.py +++ b/nt2/containers/data.py @@ -65,6 +65,84 @@ class MoviePlotAccessor(acc_movie.accessor): pass +# Cartesian remapping functions +def remap_fields_cart(name: str) -> str: + name = name[1:] + fieldname = name.split("_")[0] + fieldname = fieldname.replace("0", "t") + fieldname = fieldname.replace("1", "x") + fieldname = fieldname.replace("2", "y") + fieldname = fieldname.replace("3", "z") + suffix = "_".join(name.split("_")[1:]) + return f"{fieldname}{'_' + suffix if suffix != '' else ''}" + + +def remap_coords_cart(name: str) -> str: + return { + "X1": "x", + "X2": "y", + "X3": "z", + }.get(name, name) + + +def remap_prtl_quantities_cart(name: str) -> str: + shortname = name[1:] + return { + "X1": "x", + "X2": "y", + "X3": "z", + "U1": "ux", + "U2": "uy", + "U3": "uz", + "W": "w", + }.get(shortname, shortname) + + +# Spherical remapping functions +def remap_fields_sph(name: str) -> str: + name = name[1:] + fieldname = name.split("_")[0] + fieldname = fieldname.replace("0", "t") + fieldname = fieldname.replace("1", "r") + fieldname = fieldname.replace("2", "th") + fieldname = fieldname.replace("3", "ph") + suffix = "_".join(name.split("_")[1:]) + return f"{fieldname}{'_' + suffix if suffix != '' else ''}" + + +def remap_coords_sph(name: str) -> str: + return { + "X1": "r", + "X2": "th", + "X3": "ph", + }.get(name, name) + + +def remap_prtl_quantities_sph(name: str) -> str: + shortname = name[1:] + return { + "X1": "r", + "X2": "th", + "X3": "ph", + "U1": "ur", + "U2": "uth", + "U3": "uph", + "W": "w", + }.get(shortname, shortname) + + +def compactify(lst: list[Any] | KeysView[Any]) -> str: + c = "" + cntr = 0 + for l_ in lst: + if cntr > 5: + c += "\n| " + cntr = 0 + c += f"{l_}, " + cntr += 1 + return c[:-2] + + class Data(Fields, Particles, Spectra): """Main class to manage all the data containers. @@ -135,69 +213,8 @@ def __init__( ) else: if attrs["Coordinates"] in [b"cart", "cart"]: - - def remap_fields(name: str) -> str: - name = name[1:] - fieldname = name.split("_")[0] - fieldname = fieldname.replace("0", "t") - fieldname = fieldname.replace("1", "x") - fieldname = fieldname.replace("2", "y") - fieldname = fieldname.replace("3", "z") - suffix = "_".join(name.split("_")[1:]) - return f"{fieldname}{'_' + suffix if suffix != '' else ''}" - - def remap_coords(name: str) -> str: - return { - "X1": "x", - "X2": "y", - "X3": "z", - }.get(name, name) - - def remap_prtl_quantities(name: str) -> str: - shortname = name[1:] - return { - "X1": "x", - "X2": "y", - "X3": "z", - "U1": "ux", - "U2": "uy", - "U3": "uz", - "W": "w", - }.get(shortname, shortname) - coord_system = CoordinateSystem.XYZ - elif attrs["Coordinates"] in [b"sph", "sph", b"qsph", "qsph"]: - - def remap_fields(name: str) -> str: - name = name[1:] - fieldname = name.split("_")[0] - fieldname = fieldname.replace("0", "t") - fieldname = fieldname.replace("1", "r") - fieldname = fieldname.replace("2", "th") - fieldname = fieldname.replace("3", "ph") - suffix = "_".join(name.split("_")[1:]) - return f"{fieldname}{'_' + suffix if suffix != '' else ''}" - - def remap_coords(name: str) -> str: - return { - "X1": "r", - "X2": "th", - "X3": "ph", - }.get(name, name) - - def remap_prtl_quantities(name: str) -> str: - shortname = name[1:] - return { - "X1": "r", - "X2": "th", - "X3": "ph", - "U1": "ur", - "U2": "uth", - "U3": "uph", - "W": "w", - }.get(shortname, shortname) - coord_system = CoordinateSystem.SPH else: @@ -206,9 +223,21 @@ def remap_prtl_quantities(name: str) -> str: ) if remap is None: remap = { - "coords": remap_coords, - "fields": remap_fields, - "particles": remap_prtl_quantities, + "coords": ( + remap_coords_cart + if coord_system == CoordinateSystem.XYZ + else remap_coords_sph + ), + "fields": ( + remap_fields_cart + if coord_system == CoordinateSystem.XYZ + else remap_fields_sph + ), + "particles": ( + remap_prtl_quantities_cart + if coord_system == CoordinateSystem.XYZ + else remap_prtl_quantities_sph + ), } break @@ -279,17 +308,6 @@ def attrs(self) -> dict[str, Any]: def to_str(self) -> str: """str: String representation of the all the enclosed dataframes.""" - def compactify(lst: list[Any] | KeysView[Any]) -> str: - c = "" - cntr = 0 - for l_ in lst: - if cntr > 5: - c += "\n| " - cntr = 0 - c += f"{l_}, " - cntr += 1 - return c[:-2] - string = "" if self.fields_defined: string += "FieldsDataset:\n" diff --git a/nt2/containers/fields.py b/nt2/containers/fields.py index ebc9914..5f323a3 100644 --- a/nt2/containers/fields.py +++ b/nt2/containers/fields.py @@ -5,26 +5,21 @@ import xarray as xr from nt2.containers.container import BaseContainer -from nt2.readers.base import BaseReader from nt2.utils import Layout class Fields(BaseContainer): """Parent class to manage the fields dataframe.""" - @staticmethod - @dask.delayed - def __read_field(path: str, reader: BaseReader, field: str, step: int) -> Any: + def _read_field(self, layout: Layout, field: str, step: int) -> Any: """Reads a field from the data. This is a dask-delayed function used further to build the dataset. Parameters ---------- - path : str - Main path to the data. - reader : BaseReader - Reader to use to read the data. + layout : Layout + Layout of the field. field : str Field to read. step : int @@ -36,7 +31,10 @@ def __read_field(path: str, reader: BaseReader, field: str, step: int) -> Any: Field data. """ - return reader.ReadArrayAtTimestep(path, "fields", field, step) + if layout == Layout.L: + return self.reader.ReadArrayAtTimestep(self.path, "fields", field, step) + else: + return self.reader.ReadArrayAtTimestep(self.path, "fields", field, step).T def __init__( self, @@ -53,7 +51,7 @@ def __init__( super(Fields, self).__init__(**kwargs) if self.reader.DefinesCategory(self.path, "fields"): self.__fields_defined = True - self.__fields = self.__read_fields() + self.__fields = self._read_fields() else: self.__fields_defined = False self.__fields = xr.Dataset() @@ -68,7 +66,7 @@ def fields(self) -> xr.Dataset: """xr.Dataset: The fields dataframe.""" return self.__fields - def __read_fields(self) -> xr.Dataset: + def _read_fields(self) -> xr.Dataset: """Helper function to read the fields dataframe.""" self.reader.VerifySameCategoryNames(self.path, "fields", "f") self.reader.VerifySameFieldShapes(self.path) @@ -98,17 +96,13 @@ def __read_fields(self) -> xr.Dataset: steps = self.reader.ReadPerTimestepVariable(self.path, "fields", "Step", "s") edge_coords = self.reader.ReadEdgeCoordsAtTimestep(self.path, first_step) - if self.remap is None or "coords" not in self.remap: - - def remap(x: str) -> str: - return x - - coord_remap = remap - else: - coord_remap = self.remap["coords"] new_edge_coords = {} for coord in edge_coords.keys(): - assoc_x = coord_remap(coord[:-1]) + assoc_x = ( + coord[:-1] + if (self.remap is None or "coords" not in self.remap) + else self.remap["coords"](coord[:-1]) + ) new_edge_coords[assoc_x + "_min"] = (assoc_x, edge_coords[coord][:-1]) new_edge_coords[assoc_x + "_max"] = (assoc_x, edge_coords[coord][1:]) edge_coords = new_edge_coords @@ -116,30 +110,19 @@ def remap(x: str) -> str: all_dims = {**times, **coords}.keys() all_coords = {**times, **coords, "s": ("t", steps["s"]), **edge_coords} - def remap_name(name: str) -> str: - """ - Remaps the field name if remap is provided - """ - if self.remap is not None and "fields" in self.remap: - return self.remap["fields"](name) - return name - - def get_field(name: str, step: int) -> Any: - """ - Reads a field from the data - """ - if layout == Layout.L: - return Fields.__read_field(self.path, self.reader, name, step) - else: - return Fields.__read_field(self.path, self.reader, name, step).T - return xr.Dataset( { - remap_name(name): xr.DataArray( + ( + remapped_name := ( + self.remap["fields"](name) + if (self.remap is not None and "fields" in self.remap) + else name + ) + ): xr.DataArray( da.stack( [ da.from_delayed( - get_field(name, step), + dask.delayed(self._read_field)(layout, name, step), shape=shape[:: -1 if layout == Layout.R else 1], dtype="float", ) @@ -147,7 +130,7 @@ def get_field(name: str, step: int) -> Any: ], axis=0, ), - name=remap_name(name), + name=remapped_name, dims=all_dims, coords=all_coords, ) diff --git a/nt2/containers/particles.py b/nt2/containers/particles.py index e7d4f0e..782c8e8 100644 --- a/nt2/containers/particles.py +++ b/nt2/containers/particles.py @@ -1,8 +1,9 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Literal +import numpy.typing as npt from copy import copy +import dask import dask.dataframe as dd -from dask.delayed import delayed import pandas as pd import numpy as np @@ -149,23 +150,50 @@ def _coerce_selector_to_mask( return np.abs(series - s) == np.abs(series - s).min(), ("value", s) +def _attach_columns( + part: pd.DataFrame, + cols_tuple, + read_column, + metadtypes, +) -> pd.DataFrame: + if len(part) == 0: + for c in cols_tuple: + part[c] = np.array([], dtype=metadtypes[c]) + return part + st_val = int(part["st"].iloc[0]) + + arrays = {c: read_column(st_val, c) for c in cols_tuple} + + sel = part["row"].to_numpy() + for c in cols_tuple: + part[c] = np.asarray(arrays[c])[sel] + return part + + class ParticleDataset: + steps: npt.NDArray[np.int64] + times: npt.NDArray[np.float64] + colnames: List[str] + def __init__( self, species: List[int], - read_steps: Callable[[], np.ndarray], - read_times: Callable[[], np.ndarray], - read_column: Callable[[int, str], np.ndarray], - read_colnames: Callable[[int], List[str]], + steps: npt.NDArray[np.int64], + times: npt.NDArray[np.float64], + colnames: List[str], + read_column: Callable[ + [int, str], npt.NDArray[np.float64 | np.int64 | np.float32 | np.int32] + ], fprec: Optional[type] = np.float32, selection: Optional[dict[str, Selection]] = None, - ddf_index: dd.DataFrame | None = None, + ddf_index: Optional[dd.DataFrame] = None, ): self.species = species - self.read_steps = read_steps - self.read_times = read_times + self.steps = steps + self.times = times + self.colnames = colnames + self.read_column = read_column - self.read_colnames = read_colnames self.fprec = fprec self.index_cols = ("id", "sp") self._all_columns_cache: Optional[List[str]] = None @@ -180,9 +208,6 @@ def __init__( "id": Selection("range"), } - self.steps = read_steps() - self.times = read_times() - self._dtypes = { "id": np.int64, "sp": np.int32, @@ -195,6 +220,12 @@ def __init__( "ux": fprec, "uy": fprec, "uz": fprec, + "r": fprec, + "th": fprec, + "ph": fprec, + "ur": fprec, + "uth": fprec, + "uph": fprec, } if ddf_index is not None: @@ -213,7 +244,7 @@ def nbytes(self) -> int: @property def columns(self) -> List[str]: if self._all_columns_cache is None: - self._all_columns_cache = self.read_colnames(self.steps[0]) + self._all_columns_cache = self.colnames return self._all_columns_cache def sel( @@ -261,10 +292,10 @@ def sel( return ParticleDataset( species=self.species, - read_steps=self.read_steps, - read_times=self.read_times, + steps=self.steps, + times=self.times, + colnames=self.colnames, read_column=self.read_column, - read_colnames=self.read_colnames, fprec=self.fprec, selection=new_selection, ddf_index=ddf, @@ -311,27 +342,27 @@ def isel( ) return ParticleDataset( species=self.species, - read_steps=self.read_steps, - read_times=self.read_times, + steps=self.steps, + times=self.times, + colnames=self.colnames, read_column=self.read_column, - read_colnames=self.read_colnames, fprec=self.fprec, selection=new_selection, ddf_index=ddf, ) - def _build_index_ddf(self) -> dd.DataFrame: - def _load_index_partition(st: int, t: float, index_cols: Tuple[str, ...]): - cols = {c: self.read_column(st, c) for c in index_cols} - n = len(next(iter(cols.values()))) - df = pd.DataFrame(cols) - df["st"] = np.asarray(st, dtype=np.int64) - df["t"] = np.asarray(t, dtype=float) - df["row"] = np.arange(n, dtype=np.int64) - return df + def _load_index_partition(self, st: int, t: float, index_cols: Tuple[str, ...]): + cols = {c: self.read_column(st, c) for c in index_cols} + n = len(next(iter(cols.values()))) + df = pd.DataFrame(cols) + df["st"] = np.asarray(st, dtype=np.int64) + df["t"] = np.asarray(t, dtype=float) + df["row"] = np.arange(n, dtype=np.int64) + return df + def _build_index_ddf(self) -> dd.DataFrame: delayed_parts = [ - delayed(_load_index_partition)(st, t, self.index_cols) + dask.delayed(self._load_index_partition)(st, t, self.index_cols) for st, t in zip(self.steps, self.times) ] @@ -361,25 +392,16 @@ def load(self, cols: Sequence[str] | None = None) -> pd.DataFrame: } meta = self._ddf_index._meta.assign(**meta_dict) - read_column = self.read_column cols_tuple = tuple(cols) - def _attach_columns(part: pd.DataFrame) -> pd.DataFrame: - if len(part) == 0: - for c in cols_tuple: - part[c] = np.array([], dtype=meta.dtypes[c]) - return part - st_val = int(part["st"].iloc[0]) - - arrays = {c: read_column(st_val, c) for c in cols_tuple} - - sel = part["row"].to_numpy() - for c in cols_tuple: - part[c] = np.asarray(arrays[c])[sel] - return part - return ( - self._ddf_index.map_partitions(_attach_columns, meta=meta) + self._ddf_index.map_partitions( + _attach_columns, + cols_tuple=cols_tuple, + read_column=self.read_column, + metadtypes=meta.dtypes, + meta=meta, + ) .compute() .drop(columns=["row"]) ) @@ -515,6 +537,12 @@ def phase_plot( class Particles(BaseContainer): """Parent class to manage the particles dataframe.""" + __particles_defined: bool + __particles: Optional[ParticleDataset] + quantities: List[str] + sp_with_idx: List[int] + sp_without_idx: List[int] + def __init__(self, **kwargs: Any) -> None: """Initializer for the Particles class. @@ -530,7 +558,55 @@ def __init__(self, **kwargs: Any) -> None: and self.particles_present ): self.__particles_defined = True - self.__particles = self.__read_particles() + + valid_steps = self.nonempty_steps + quantities_ = [ + self.reader.ReadCategoryNamesAtTimestep( + self.path, "particles", "p", step + ) + for step in valid_steps + ] + self.quantities = sorted( + np.unique([q for qtys in quantities_ for q in qtys]) + ) + + unique_quantities = sorted( + list( + set( + str(q).split("_")[0] + for q in self.quantities + if not q.startswith("pIDX") and not q.startswith("pRNK") + ) + ) + ) + all_species = sorted( + list(set([int(str(q).split("_")[1]) for q in self.quantities])) + ) + + self.sp_with_idx = sorted( + [int(q.split("_")[1]) for q in self.quantities if q.startswith("pIDX")] + ) + self.sp_without_idx = sorted( + [sp for sp in all_species if sp not in self.sp_with_idx] + ) + + self.__particles = ParticleDataset( + species=all_species, + steps=np.array(self.reader.GetValidSteps(self.path, "particles")), + times=self.reader.ReadPerTimestepVariable( + self.path, "particles", "Time", "t" + )["t"], + colnames=[ + ( + self.remap["particles"](q) + if (self.remap is not None and "particles" in self.remap) + else q + ) + for q in unique_quantities + ] + + ["id", "sp"], + read_column=self._read_column, + ) else: self.__particles_defined = False self.__particles = None @@ -576,156 +652,110 @@ def particles(self) -> ParticleDataset | None: """ return self.__particles - def __read_particles(self) -> ParticleDataset: - """Helper function to read all particles data.""" - valid_steps = self.nonempty_steps - - quantities_ = [ - self.reader.ReadCategoryNamesAtTimestep(self.path, "particles", "p", step) - for step in valid_steps - ] - quantities = sorted(np.unique([q for qtys in quantities_ for q in qtys])) + def help_particles(self, prepend: str = "") -> str: + return self.particles.help(prepend) if self.particles is not None else "" - unique_quantities = sorted( - list( - set( - str(q).split("_")[0] - for q in quantities - if not q.startswith("pIDX") and not q.startswith("pRNK") - ) + def _get_count(self, step: int, sp: int) -> np.int64: + try: + return np.int64( + self.reader.ReadArrayShapeAtTimestep( + self.path, "particles", f"pX1_{sp}", step + )[0] ) - ) - all_species = sorted(list(set([int(str(q).split("_")[1]) for q in quantities]))) + except: + return np.int64(0) - sp_with_idx = sorted( - [int(q.split("_")[1]) for q in quantities if q.startswith("pIDX")] + def _species_has_quantity(self, read_colname: str, step: int, sp: int) -> bool: + return f"{read_colname}_{sp}" in self.reader.ReadCategoryNamesAtTimestep( + self.path, "particles", "p", step ) - sp_without_idx = sorted([sp for sp in all_species if sp not in sp_with_idx]) - - def remap_quantity(name: str) -> str: - """ - Remaps the particle quantity name if remap is provided - """ - if self.remap is not None and "particles" in self.remap: - return self.remap["particles"](name) - return name - - def GetCount(step: int, sp: int) -> np.int64: - try: - return np.int64( - self.reader.ReadArrayShapeAtTimestep( - self.path, "particles", f"pX1_{sp}", step - )[0] - ) - except: - return np.int64(0) - - def ReadSteps() -> np.ndarray: - return np.array(self.reader.GetValidSteps(self.path, "particles")) - - def ReadTimes() -> np.ndarray: - return self.reader.ReadPerTimestepVariable( - self.path, "particles", "Time", "t" - )["t"] - def ReadColnames(step: int) -> list[str]: - return [remap_quantity(q) for q in unique_quantities] + ["id", "sp"] - - def ReadColumn(step: int, colname: str) -> np.ndarray: - read_colname = None - if colname == "id": - idx = np.concat( + def _get_quantity_for_species( + self, + read_colname: str, + step: int, + sp: int, + ) -> npt.NDArray[np.float64 | np.int64]: + if f"{read_colname}_{sp}" in self.quantities: + return self.reader.ReadArrayAtTimestep( + self.path, "particles", f"{read_colname}_{sp}", step + ) + else: + return np.zeros(self._get_count(step, sp)) * np.nan + + def _read_column( + self, step: int, colname: str + ) -> npt.NDArray[np.float64 | np.int64 | np.float32 | np.int32]: + read_colname = None + if colname == "id": + idx = np.concat( + [ + self.reader.ReadArrayAtTimestep( + self.path, "particles", f"pIDX_{sp}", step + ).astype(np.int64) + for sp in self.sp_with_idx + ] + + [ + np.zeros(self._get_count(step, sp), dtype=np.int64) - 100 + for sp in self.sp_without_idx + ] + ) + if ( + len(self.sp_with_idx) > 0 + and f"pRNK_{self.sp_with_idx[0]}" in self.quantities + ): + rnk = np.concat( [ self.reader.ReadArrayAtTimestep( - self.path, "particles", f"pIDX_{sp}", step + self.path, "particles", f"pRNK_{sp}", step ).astype(np.int64) - for sp in sp_with_idx - ] - + [ - np.zeros(GetCount(step, sp), dtype=np.int64) - 100 - for sp in sp_without_idx - ] - ) - if len(sp_with_idx) > 0 and f"pRNK_{sp_with_idx[0]}" in quantities: - rnk = np.concat( - [ - self.reader.ReadArrayAtTimestep( - self.path, "particles", f"pRNK_{sp}", step - ).astype(np.int64) - for sp in sp_with_idx - ] - + [ - np.zeros(GetCount(step, sp), dtype=np.int64) - 100 - for sp in sp_without_idx - ] - ) - return (idx + rnk) * (idx + rnk + 1) // 2 + rnk - else: - return idx - elif colname == "x" or colname == "r": - read_colname = "pX1" - elif colname == "y" or colname == "th": - read_colname = "pX2" - elif colname == "z" or colname == "ph": - read_colname = "pX3" - elif colname == "ux" or colname == "ur": - read_colname = "pU1" - elif colname == "uy" or colname == "uth": - read_colname = "pU2" - elif colname == "uz" or colname == "uph": - read_colname = "pU3" - elif colname == "w": - read_colname = "pW" - elif colname == "sp": - return np.concat( - [ - np.zeros(GetCount(step, sp), dtype=np.int32) + sp - for sp in sp_with_idx + for sp in self.sp_with_idx ] + [ - np.zeros(GetCount(step, sp), dtype=np.int32) + sp - for sp in sp_without_idx + np.zeros(self._get_count(step, sp), dtype=np.int64) - 100 + for sp in self.sp_without_idx ] ) + return (idx + rnk) * (idx + rnk + 1) // 2 + rnk else: - read_colname = f"p{colname}" - - def species_has_quantity(sp: int) -> bool: - return ( - f"{read_colname}_{sp}" - in self.reader.ReadCategoryNamesAtTimestep( - self.path, "particles", "p", step - ) - ) - - def get_quantity_for_species(sp: int) -> np.ndarray: - if f"{read_colname}_{sp}" in quantities: - return self.reader.ReadArrayAtTimestep( - self.path, "particles", f"{read_colname}_{sp}", step - ) - else: - return np.zeros(GetCount(step, sp)) * np.nan - + return idx + elif colname == "x" or colname == "r": + read_colname = "pX1" + elif colname == "y" or colname == "th": + read_colname = "pX2" + elif colname == "z" or colname == "ph": + read_colname = "pX3" + elif colname == "ux" or colname == "ur": + read_colname = "pU1" + elif colname == "uy" or colname == "uth": + read_colname = "pU2" + elif colname == "uz" or colname == "uph": + read_colname = "pU3" + elif colname == "w": + read_colname = "pW" + elif colname == "sp": return np.concat( [ - get_quantity_for_species(sp) - for sp in sp_with_idx - if species_has_quantity(sp) + np.zeros(self._get_count(step, sp), dtype=np.int32) + sp + for sp in self.sp_with_idx ] + [ - get_quantity_for_species(sp) - for sp in sp_without_idx - if species_has_quantity(sp) + np.zeros(self._get_count(step, sp), dtype=np.int32) + sp + for sp in self.sp_without_idx ] ) - - return ParticleDataset( - species=all_species, - read_steps=ReadSteps, - read_times=ReadTimes, - read_colnames=ReadColnames, - read_column=ReadColumn, + else: + read_colname = f"p{colname}" + + return np.concat( + [ + self._get_quantity_for_species(read_colname, step, sp) + for sp in self.sp_with_idx + if self._species_has_quantity(read_colname, step, sp) + ] + + [ + self._get_quantity_for_species(read_colname, step, sp) + for sp in self.sp_without_idx + if self._species_has_quantity(read_colname, step, sp) + ] ) - - def help_particles(self, prepend="") -> str: - return self.particles.help(prepend) if self.particles is not None else "" diff --git a/nt2/containers/spectra.py b/nt2/containers/spectra.py index 7288b42..044656d 100644 --- a/nt2/containers/spectra.py +++ b/nt2/containers/spectra.py @@ -13,8 +13,7 @@ class Spectra(BaseContainer): """Parent class to manager the spectra dataframe.""" @staticmethod - @dask.delayed - def __read_spectrum(path: str, reader: BaseReader, spectrum: str, step: int) -> Any: + def read_spectrum(path: str, reader: BaseReader, spectrum: str, step: int) -> Any: """Reads a spectrum from the data. This is a dask-delayed function used further to build the dataset. @@ -95,7 +94,7 @@ def remap_name(name: str) -> str: da.stack( [ da.from_delayed( - self.__read_spectrum( + dask.delayed(self.read_spectrum)( path=self.path, reader=self.reader, spectrum=spectrum, diff --git a/nt2/plotters/export.py b/nt2/plotters/export.py index 939639f..f7fb5da 100644 --- a/nt2/plotters/export.py +++ b/nt2/plotters/export.py @@ -1,4 +1,5 @@ from typing import Any, Callable +import matplotlib.pyplot as plt def makeFramesAndMovie( @@ -95,6 +96,20 @@ def makeMovie(**ffmpeg_kwargs: str | int | float) -> bool: return False +def _plot_and_save(ti: int, t: float, fpath: str, plot: Callable, data: Any) -> bool: + try: + if data is None: + plot(t) + else: + plot(t, data) + plt.savefig(f"{fpath}/{ti:05d}.png") + plt.close() + return True + except Exception as e: + print(f"Error: {e}") + return False + + def makeFrames( plot: Callable, times: list[float], @@ -145,36 +160,57 @@ def makeFrames( >>> makeFrames(plot_func, range(100), 'output/', num_cpus=16) """ - - from tqdm import tqdm - import multiprocessing as mp - import matplotlib.pyplot as plt + from loky import get_reusable_executor import os - global plotAndSave - - def plotAndSave(ti: int, t: float, fpath: str) -> bool: - try: - if data is None: - plot(t) - else: - plot(t, data) - plt.savefig(f"{fpath}/{ti:05d}.png") - plt.close() - return True - except Exception as e: - print(f"Error: {e}") - return False - - if num_cpus is None: - num_cpus = mp.cpu_count() + ex = get_reusable_executor(max_workers=num_cpus or (os.cpu_count() or 1)) + return [ + f.result() + for f in [ + ex.submit(_plot_and_save, ti, t, fpath, plot, data) + for ti, t in enumerate(times) + ] + ] - pool = mp.Pool(num_cpus) + # from tqdm import tqdm + # import multiprocessing as mp + # import os + # + # ctx = mp.get_context() + # if num_cpus is None: + # num_cpus = os.cpu_count() or 1 + # + # tasks = [(ti, t, fpath, plot, data) for ti, t in enumerate(times)] + # + # pool = mp.Pool(num_cpus) + # + # with ctx.Pool(processes=num_cpus) as pool: + # results = pool.starmap_async(_plot_and_save, tasks) + # out = results.get() + # + # return list(tqdm(out)) + + # global plotAndSave + # + # def plotAndSave(ti: int, t: float) -> bool: + # import matplotlib.pyplot as plt + # + # try: + # if data is None: + # plot(t) + # else: + # plot(t, data) + # plt.savefig(f"{fpath}/{ti:05d}.png") + # plt.close() + # return True + # except Exception as e: + # print(f"Error: {e}") + # return False # if fpath doesn't exist, create it - if not os.path.exists(fpath): - os.makedirs(fpath) - - tasks = [[ti, t, fpath] for ti, t in enumerate(times)] - results = [pool.apply_async(plotAndSave, t) for t in tasks] - return [result.get() for result in tqdm(results)] + # if not os.path.exists(fpath): + # os.makedirs(fpath) + # + # tasks = [(ti, t) for ti, t in enumerate(times)] + # results = [pool.apply_async(plotAndSave, t) for t in tasks] + # return [result.get() for result in tqdm(results)] diff --git a/pyproject.toml b/pyproject.toml index 2ea2644..f47ebf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "tqdm", "contourpy", "typer", + "loky", ] requires-python = ">=3.8" authors = [{ name = "Hayk", email = "haykh.astro@gmail.com" }] diff --git a/requirements.txt b/requirements.txt index d8c67e5..23f0614 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ cachetools==6.1.0 certifi==2025.6.15 cffi==1.17.1 charset-normalizer==3.4.2 -click==8.1.8 +click==8.2.1 cloudpickle==3.1.1 colorcet==3.1.0 comm==0.2.2 @@ -51,9 +51,13 @@ json5==0.12.0 jsonpointer==3.0.0 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 +jupyter_client==8.8.0 +jupyter_core==5.9.1 +jupyterlab_pygments==0.3.0 kiwisolver==1.4.8 linkify-it-py==2.0.3 locket==1.0.0 +loky==3.5.6 lz4==4.4.4 Markdown==3.8.2 markdown-it-py==3.0.0 @@ -70,9 +74,10 @@ nbclient==0.10.2 nbconvert==7.16.6 nbformat==5.10.4 nest-asyncio==1.6.0 +-e git+ssh://git@github.com/entity-toolkit/nt2py.git@5021e028931ac5d0b0a417fc52a5c241e73c87e1#egg=nt2py numpy==2.3.1 overrides==7.7.0 -packaging==24.2 +packaging==25.0 pandas==2.3.0 pandocfilters==1.5.1 panel==1.7.2 @@ -82,7 +87,7 @@ partd==1.4.2 pathspec==0.12.1 pexpect==4.9.0 pillow==11.2.1 -platformdirs==4.3.7 +platformdirs==4.5.0 pluggy==1.6.0 prometheus_client==0.22.1 prompt_toolkit==3.0.51 @@ -106,10 +111,12 @@ referencing==0.36.2 requests==2.32.4 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 +rich==14.3.2 rpds-py==0.25.1 scipy==1.16.0 Send2Trash==1.8.3 setuptools==80.9.0 +shellingham==1.5.4 six==1.17.0 sniffio==1.3.1 sortedcontainers==2.4.0 @@ -123,10 +130,10 @@ tornado==6.5.1 tqdm==4.67.1 traitlets==5.14.3 trove-classifiers==2025.5.9.12 +typer==0.16.0 types-python-dateutil==2.9.0.20250516 types-setuptools==80.9.0.20250529 typing_extensions==4.14.0 -typer==0.16.0 tzdata==2025.2 uc-micro-py==1.0.3 uri-template==1.3.0 From 589126c3d604bf26a8b92f9259da1c65140aa386 Mon Sep 17 00:00:00 2001 From: hayk Date: Tue, 10 Feb 2026 12:19:57 -0500 Subject: [PATCH 02/16] Restore parallel movie frame generation with deterministic container token --- nt2/containers/container.py | 8 +++++++ nt2/plotters/export.py | 2 ++ nt2/tests/test_export.py | 43 +++++++++++++++++++++++++++++++++++++ nt2/tests/test_tokenize.py | 20 +++++++++++++++++ 4 files changed, 73 insertions(+) create mode 100644 nt2/tests/test_export.py create mode 100644 nt2/tests/test_tokenize.py diff --git a/nt2/containers/container.py b/nt2/containers/container.py index cafb5e1..aeb9772 100644 --- a/nt2/containers/container.py +++ b/nt2/containers/container.py @@ -47,3 +47,11 @@ def reader(self) -> BaseReader: def remap(self) -> Optional[dict[str, Callable[[str], str]]]: """{ str: (str) -> str } : The coordinate/field remap dictionary.""" return self.__remap + + def __dask_tokenize__(self) -> tuple[str, str, str]: + """Provide a deterministic Dask token for container instances.""" + return ( + self.__class__.__name__, + self.__path, + self.__reader.format.value, + ) diff --git a/nt2/plotters/export.py b/nt2/plotters/export.py index f7fb5da..bff28ac 100644 --- a/nt2/plotters/export.py +++ b/nt2/plotters/export.py @@ -163,6 +163,8 @@ def makeFrames( from loky import get_reusable_executor import os + os.makedirs(fpath, exist_ok=True) + ex = get_reusable_executor(max_workers=num_cpus or (os.cpu_count() or 1)) return [ f.result() diff --git a/nt2/tests/test_export.py b/nt2/tests/test_export.py new file mode 100644 index 0000000..1dc97da --- /dev/null +++ b/nt2/tests/test_export.py @@ -0,0 +1,43 @@ +from typing import Any + +from nt2.plotters.export import makeFrames + + +class _FakeFuture: + def __init__(self, value: bool): + self._value = value + + def result(self) -> bool: + return self._value + + +class _FakeExecutor: + def __init__(self): + self.calls: list[tuple[int, float, str, Any, Any]] = [] + + def submit(self, func, ti, t, fpath, plot, data): + self.calls.append((ti, t, fpath, plot, data)) + return _FakeFuture(func(ti, t, fpath, plot, data)) + + +def test_make_frames_uses_executor_with_data(tmp_path, monkeypatch): + ex = _FakeExecutor() + + monkeypatch.setattr( + "loky.get_reusable_executor", + lambda max_workers=None: ex, + ) + + called: list[float] = [] + + def plot_frame(t, d): + called.append(t) + + times = [0.0, 1.0, 2.0] + result = makeFrames(plot=plot_frame, times=times, fpath=str(tmp_path), data={"ok": 1}) + + assert result == [True, True, True] + assert len(ex.calls) == len(times) + assert called == times + for i in range(len(times)): + assert (tmp_path / f"{i:05d}.png").exists() diff --git a/nt2/tests/test_tokenize.py b/nt2/tests/test_tokenize.py new file mode 100644 index 0000000..7e49f74 --- /dev/null +++ b/nt2/tests/test_tokenize.py @@ -0,0 +1,20 @@ +from dask.base import tokenize + +from nt2.containers.container import BaseContainer +from nt2.readers.base import BaseReader +from nt2.utils import Format + + +class _Reader(BaseReader): + @property + def format(self) -> Format: + return Format.HDF5 + + +def test_base_container_has_deterministic_dask_token(): + container = BaseContainer(path="/tmp/sim", reader=_Reader(), remap=None) + + token1 = tokenize(container) + token2 = tokenize(container) + + assert token1 == token2 From 9b54d4d8c9b7b4f0b570e1292c585afb76800074 Mon Sep 17 00:00:00 2001 From: hayk Date: Tue, 10 Feb 2026 13:16:26 -0500 Subject: [PATCH 03/16] Add tqdm progress bar for frame rendering --- nt2/plotters/export.py | 15 +++++++++++---- nt2/tests/test_export.py | 9 +++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/nt2/plotters/export.py b/nt2/plotters/export.py index bff28ac..af4a217 100644 --- a/nt2/plotters/export.py +++ b/nt2/plotters/export.py @@ -161,17 +161,24 @@ def makeFrames( """ from loky import get_reusable_executor + from tqdm.auto import tqdm import os os.makedirs(fpath, exist_ok=True) ex = get_reusable_executor(max_workers=num_cpus or (os.cpu_count() or 1)) + futures = [ + ex.submit(_plot_and_save, ti, t, fpath, plot, data) + for ti, t in enumerate(times) + ] return [ f.result() - for f in [ - ex.submit(_plot_and_save, ti, t, fpath, plot, data) - for ti, t in enumerate(times) - ] + for f in tqdm( + futures, + total=len(futures), + desc="Rendering frames", + unit="frame", + ) ] # from tqdm import tqdm diff --git a/nt2/tests/test_export.py b/nt2/tests/test_export.py index 1dc97da..f402c53 100644 --- a/nt2/tests/test_export.py +++ b/nt2/tests/test_export.py @@ -28,6 +28,14 @@ def test_make_frames_uses_executor_with_data(tmp_path, monkeypatch): lambda max_workers=None: ex, ) + progress: list[tuple[int, int, str, str]] = [] + + def fake_tqdm(iterable, total, desc, unit): + progress.append((len(list(iterable)), total, desc, unit)) + return iterable + + monkeypatch.setattr("tqdm.auto.tqdm", fake_tqdm) + called: list[float] = [] def plot_frame(t, d): @@ -39,5 +47,6 @@ def plot_frame(t, d): assert result == [True, True, True] assert len(ex.calls) == len(times) assert called == times + assert progress == [(len(times), len(times), "Rendering frames", "frame")] for i in range(len(times)): assert (tmp_path / f"{i:05d}.png").exists() From 286e9cfc0b0ad06d57f82bf133aa25e745399ec0 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 13:55:28 -0500 Subject: [PATCH 04/16] readme + tqdm --- README.md | 102 ++++++++++++++++++++++++++++++++--------- nt2/plotters/export.py | 47 +------------------ 2 files changed, 82 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 3bbfa6a..bbfbc3f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ## nt2.py -Python package for visualization and post-processing of the [`Entity`](https://github.com/entity-toolkit/entity) simulation data. For usage, please refer to the [documentation](https://entity-toolkit.github.io/wiki/getting-started/vis/#nt2py). The package is distributed via [`PyPI`](https://pypi.org/project/nt2py/): +Python package for visualization and post-processing of the [`Entity`](https://github.com/entity-toolkit/entity) simulation data. For usage, please refer to the [documentation](https://entity-toolkit.github.io/wiki/content/2-howto/2-vis/#nt2py). The package is distributed via [`PyPI`](https://pypi.org/project/nt2py/): ```sh pip install nt2py @@ -14,15 +14,13 @@ Simply pass the location to the data when initializing the main `Data` object: import nt2 data = nt2.Data("path/to/data") -# example: -# data = nt2.Data("path/to/shock") ``` The data is stored in specialized containers which can be accessed via corresponding attributes: ```python data.fields # < xr.Dataset -data.particles # < dict[int : xr.Dataset] +data.particles # < special object which returns a pd.DataFrame when .load() is called data.spectra # < xr.Dataset ``` @@ -30,9 +28,42 @@ data.spectra # < xr.Dataset > Note, that by default, the `hdf5` support is disabled in `nt2py` (i.e., only `ADIOS2` format is supported). To enable it, install the package as `pip install "nt2py[hdf5]"` instead of simply `pip install nt2py`. -#### Examples +#### Accessing the data -Plot a field (in cartesian space) at a specific time (or output step): +Fields and spectra are stored as lazily loaded `xarray` datasets (a collection of equal-sized arrays with shared axis coordinates). You may access the coordinates in each dimension using `.coords`: + +```python +data.fields.coords +data.spectra.coords +``` + +Individual arrays can be requested by simply using, e.g., `data.fields.Ex` etc. One can also use slicing/selecting via the coordinates, i.e., + +```python +data.fields.sel(t=5, method="nearest") +``` + +accesses all the fields at time `t=5` (using `method="nearest"` means it will take the closest time to value `5`). You may also access by index in each coordinate: + +```python +data.fields.isel(x=-1) +``` + +accesses all the fields in the last position along the `x` coordinate. + +Note that all these operations do not load the actual data into memory; instead, the data is only loaded when explicitly requested (i.e., when plotting or explicitly calling `.values` or `.load()`. + +Particles are stored in a special lazy container which acts very similar to `xarray`; you can still make selections using specific queries. For instance, + +```python +data.particles.sel(sp=[1, 2, 4]).isel(t=-1) +``` + +selects all the particles of species 1, 2, and 4 on the last timestep. The loading of the data itself is done by calling: `.load()` method, which returns a simple `pandas` dataframe. + +#### Plotting + +Plot a field (in Cartesian coordinates) at a specific time (or output step): ```python data.fields.Ex.sel(t=10.0, method="nearest").plot() # time ~ 10 @@ -78,29 +109,35 @@ data.fields\ You can also create a movie of a single field quantity (can be custom): ```python -(data.fields.Ex * data.fields.Bx).sel(x=slice(None, 0.2)).movie.plot(name="ExBx", vmin=-0.01, vmax=0.01, cmap="BrBG") +(data.fields.Ex * data.fields.Bx).sel(x=slice(None, 0.2)).movie.plot(name="ExBx") ``` For particles, one can also make 2D phase-space plots: ```python -data.particles[1].sel(t=1.0, method="nearest").particles.phaseplot(x="x", y="uy", xnbins=100, ynbins=200, xlims=(0, 100), cmap="inferno") +data.particles.sel(sp=1).sel(t=1.0, method="nearest").phase_plot( + x_quantity=lambda f: f.x, + y_quantity=lambda f: f.ux, + xy_bins=(np.linspace(0, 60, 100), np.linspace(-2, 2, 100)), +) +``` + +or a spectrum plot: + +```python +data.particles.sel(sp=[1, 2]).sel(t=1.0, method="nearest").spectrum_plot() ``` You may also combine different quantities and plots (e.g., fields & particles) to produce a more customized movie: ```python def plot(t, data): - fig, ax = mpl.pyplot.subplots() + fig, ax = plt.subplots() data.fields.Ex.sel(t=t, method="nearest").sel(x=slice(None, 0.2)).plot( ax=ax, vmin=-0.001, vmax=0.001, cmap="BrBG" ) - for sp in range(1, 3): - ax.scatter( - data.particles[sp].sel(t=t, method="nearest").x, - data.particles[sp].sel(t=t, method="nearest").y, - c="r" if sp == 1 else "b", - ) + prtls = data.particles.sel(t=t, method="nearest").load() + ax.scatter(prtls.x, prtls.y, c="r" if prtls.sp == 1 else "b") ax.set_aspect(1) data.makeMovie(plot) ``` @@ -108,7 +145,7 @@ data.makeMovie(plot) You may also access the movie-making functionality directly in case you want to use it for other things: ```python -import nt2.export as nt2e +import nt2.plotters.export as nt2e def plot(t): ... @@ -127,16 +164,35 @@ nt2e.makeFramesAndMovie( ) ``` -### Dashboard +#### Raw readers -Support for the dask dashboard is still in beta, but you can access it by first launching the dashboard client: +In case you want to access the raw data without using `nt2py`'s `xarray`/`dask` lazy-loading, you may do so by using the readers. For example, for `ADIOS2` output data format: ```python -import nt2 -nt2.Dashboard() +import nt2.readers.adios2 as nt2a + +# define a reader +reader = nt2a.Reader() + +# get all the valid steps for particles +valid_steps = reader.GetValidSteps("path/to/sim", "particles") + +# get all variable names which have prefix "p" at the first valid step +variable_names = reader.ReadCategoryNamesAtTimestep( + "path/to/sim", "particles", "p", valid_steps[0] +) + +# convert the variable set into a list and take the first element +variable = list(variable_names)[0] + +# read the actual array from the file +reader.ReadArrayAtTimestep( + "path/to/sim", "particles", variable, valid_steps[0] +) ``` -This will output the port where the dashboard server is running, e.g., `Dashboard: http://127.0.0.1:8787/status`. Click on it (or enter in your browser) to open the dashboard. +There are many more functions available within the reader. For `hdf5`, you can simply change the import to `nt2.readers.hdf5`, and the rest should remain the same. + ### CLI @@ -170,7 +226,7 @@ nt2 plot myrun/mysimulation --fields "E.*;B.*" --sel "x=slice(-5, None); z=0.5" 1. Lazy loading and parallel processing of the simulation data with [`dask`](https://dask.org/). 2. Context-aware data manipulation with [`xarray`](http://xarray.pydata.org/en/stable/). -3. Parallel plotting and movie generation with [`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html) and [`ffmpeg`](https://ffmpeg.org/). +3. Parallel plotting and movie generation with [`loky`](https://pypi.org/project/loky/) and [`ffmpeg`](https://ffmpeg.org/). 4. Command-line interface, the `nt2` command, for quick plotting (both movies and snapshots). ### Testing @@ -188,3 +244,5 @@ There are unit tests included with the code which also require downloading test - [x] Ghost cells support - [x] Usage examples - [ ] Parse the log file with timings +- [x] Raw reader +- [x] 3.14-compatible parallel output diff --git a/nt2/plotters/export.py b/nt2/plotters/export.py index af4a217..3ebb0da 100644 --- a/nt2/plotters/export.py +++ b/nt2/plotters/export.py @@ -161,7 +161,7 @@ def makeFrames( """ from loky import get_reusable_executor - from tqdm.auto import tqdm + from tqdm import tqdm import os os.makedirs(fpath, exist_ok=True) @@ -176,50 +176,7 @@ def makeFrames( for f in tqdm( futures, total=len(futures), - desc="Rendering frames", + desc=f"rendering frames to {fpath}", unit="frame", ) ] - - # from tqdm import tqdm - # import multiprocessing as mp - # import os - # - # ctx = mp.get_context() - # if num_cpus is None: - # num_cpus = os.cpu_count() or 1 - # - # tasks = [(ti, t, fpath, plot, data) for ti, t in enumerate(times)] - # - # pool = mp.Pool(num_cpus) - # - # with ctx.Pool(processes=num_cpus) as pool: - # results = pool.starmap_async(_plot_and_save, tasks) - # out = results.get() - # - # return list(tqdm(out)) - - # global plotAndSave - # - # def plotAndSave(ti: int, t: float) -> bool: - # import matplotlib.pyplot as plt - # - # try: - # if data is None: - # plot(t) - # else: - # plot(t, data) - # plt.savefig(f"{fpath}/{ti:05d}.png") - # plt.close() - # return True - # except Exception as e: - # print(f"Error: {e}") - # return False - - # if fpath doesn't exist, create it - # if not os.path.exists(fpath): - # os.makedirs(fpath) - # - # tasks = [(ti, t) for ti, t in enumerate(times)] - # results = [pool.apply_async(plotAndSave, t) for t in tasks] - # return [result.get() for result in tqdm(results)] From 2e5608175758f71ebc1460f663ce83f922cfc42c Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 13:57:55 -0500 Subject: [PATCH 05/16] nt2py rm from req --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 23f0614..f2f1c65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -74,7 +74,6 @@ nbclient==0.10.2 nbconvert==7.16.6 nbformat==5.10.4 nest-asyncio==1.6.0 --e git+ssh://git@github.com/entity-toolkit/nt2py.git@5021e028931ac5d0b0a417fc52a5c241e73c87e1#egg=nt2py numpy==2.3.1 overrides==7.7.0 packaging==25.0 From 3328803450b33edbae598ce03862df2384c991f8 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 14:02:30 -0500 Subject: [PATCH 06/16] unit tests on all py versions --- .../workflows/{publish.yml => unittests.yml} | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) rename .github/workflows/{publish.yml => unittests.yml} (65%) diff --git a/.github/workflows/publish.yml b/.github/workflows/unittests.yml similarity index 65% rename from .github/workflows/publish.yml rename to .github/workflows/unittests.yml index a6fbaa9..5af8ad5 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/unittests.yml @@ -1,29 +1,37 @@ -name: Python package +name: Unit tests on: [push] jobs: build: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v3 with: lfs: true - - name: Set up Python 3.12 + + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: - python-version: "3.12" + python-version: ${{ matrix.python-version }} + - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest + pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Test with `pytest` run: | pytest + - name: Publish package - if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') && matrix.python-version == '3.14' uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} From 4971f49064e697e270f85480a59970883cabf544 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 14:14:55 -0500 Subject: [PATCH 07/16] tests fixed --- .github/workflows/unittests.yml | 3 +- nt2/tests/test_export.py | 13 +-- pyproject.toml | 1 + requirements.txt | 159 +++----------------------------- shell.nix | 2 +- 5 files changed, 18 insertions(+), 160 deletions(-) diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 5af8ad5..9e43391 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -23,8 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install .[dev] - name: Test with `pytest` run: | diff --git a/nt2/tests/test_export.py b/nt2/tests/test_export.py index f402c53..244f080 100644 --- a/nt2/tests/test_export.py +++ b/nt2/tests/test_export.py @@ -28,25 +28,18 @@ def test_make_frames_uses_executor_with_data(tmp_path, monkeypatch): lambda max_workers=None: ex, ) - progress: list[tuple[int, int, str, str]] = [] - - def fake_tqdm(iterable, total, desc, unit): - progress.append((len(list(iterable)), total, desc, unit)) - return iterable - - monkeypatch.setattr("tqdm.auto.tqdm", fake_tqdm) - called: list[float] = [] def plot_frame(t, d): called.append(t) times = [0.0, 1.0, 2.0] - result = makeFrames(plot=plot_frame, times=times, fpath=str(tmp_path), data={"ok": 1}) + result = makeFrames( + plot=plot_frame, times=times, fpath=str(tmp_path), data={"ok": 1} + ) assert result == [True, True, True] assert len(ex.calls) == len(times) assert called == times - assert progress == [(len(times), len(times), "Rendering frames", "frame")] for i in range(len(times)): assert (tmp_path / f"{i:05d}.png").exists() diff --git a/pyproject.toml b/pyproject.toml index f47ebf2..a4a3c04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ classifiers = [ [project.optional-dependencies] hdf5 = ["h5py"] +dev = ["black", "pytest"] [project.urls] Repository = "https://github.com/entity-toolkit/nt2py" diff --git a/requirements.txt b/requirements.txt index f2f1c65..d0c4da6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,147 +1,12 @@ -adios2==2.10.1.100042 -anyio==4.9.0 -argon2-cffi==25.1.0 -argon2-cffi-bindings==21.2.0 -arrow==1.3.0 -asttokens==3.0.0 -async-lru==2.0.5 -attrs==25.3.0 -babel==2.17.0 -beautifulsoup4==4.13.4 -black==25.1.0 -bleach==6.2.0 -bokeh==3.7.3 -build==1.2.2.post1 -cachetools==6.1.0 -certifi==2025.6.15 -cffi==1.17.1 -charset-normalizer==3.4.2 -click==8.2.1 -cloudpickle==3.1.1 -colorcet==3.1.0 -comm==0.2.2 -contourpy==1.3.2 -cycler==0.12.1 -dask==2025.5.1 -debugpy==1.8.14 -decorator==5.2.1 -defusedxml==0.7.1 -distributed==2025.5.1 -executing==2.2.0 -fastjsonschema==2.21.1 -flit_core==3.12.0 -fonttools==4.58.4 -fqdn==1.5.1 -fsspec==2025.5.1 -h11==0.16.0 -h5pickle==0.4.2 -h5py==3.14.0 -hatchling==1.27.0 -holoviews==1.21.0 -httpcore==1.0.9 -httpx==0.28.1 -hvplot==0.11.3 -idna==3.10 -importlib_metadata==8.7.0 -iniconfig==2.1.0 -isoduration==20.11.0 -jedi==0.19.2 -Jinja2==3.1.6 -json5==0.12.0 -jsonpointer==3.0.0 -jsonschema==4.24.0 -jsonschema-specifications==2025.4.1 -jupyter_client==8.8.0 -jupyter_core==5.9.1 -jupyterlab_pygments==0.3.0 -kiwisolver==1.4.8 -linkify-it-py==2.0.3 -locket==1.0.0 -loky==3.5.6 -lz4==4.4.4 -Markdown==3.8.2 -markdown-it-py==3.0.0 -MarkupSafe==3.0.2 -matplotlib==3.10.3 -matplotlib-inline==0.1.7 -mdit-py-plugins==0.4.2 -mdurl==0.1.2 -mistune==3.1.3 -msgpack==1.1.1 -mypy_extensions==1.1.0 -narwhals==1.44.0 -nbclient==0.10.2 -nbconvert==7.16.6 -nbformat==5.10.4 -nest-asyncio==1.6.0 -numpy==2.3.1 -overrides==7.7.0 -packaging==25.0 -pandas==2.3.0 -pandocfilters==1.5.1 -panel==1.7.2 -param==2.2.1 -parso==0.8.4 -partd==1.4.2 -pathspec==0.12.1 -pexpect==4.9.0 -pillow==11.2.1 -platformdirs==4.5.0 -pluggy==1.6.0 -prometheus_client==0.22.1 -prompt_toolkit==3.0.51 -psutil==7.0.0 -ptyprocess==0.7.0 -pure_eval==0.2.3 -pyarrow==20.0.0 -pycparser==2.22 -pyct==0.5.0 -Pygments==2.19.2 -pyparsing==3.2.3 -pyproject_hooks==1.2.0 -pytest==8.4.1 -python-dateutil==2.9.0.post0 -python-json-logger==3.3.0 -pytz==2025.2 -pyviz_comms==3.0.6 -PyYAML==6.0.2 -pyzmq==27.0.0 -referencing==0.36.2 -requests==2.32.4 -rfc3339-validator==0.1.4 -rfc3986-validator==0.1.1 -rich==14.3.2 -rpds-py==0.25.1 -scipy==1.16.0 -Send2Trash==1.8.3 -setuptools==80.9.0 -shellingham==1.5.4 -six==1.17.0 -sniffio==1.3.1 -sortedcontainers==2.4.0 -soupsieve==2.7 -stack-data==0.6.3 -tblib==3.1.0 -terminado==0.18.1 -tinycss2==1.4.0 -toolz==1.0.0 -tornado==6.5.1 -tqdm==4.67.1 -traitlets==5.14.3 -trove-classifiers==2025.5.9.12 -typer==0.16.0 -types-python-dateutil==2.9.0.20250516 -types-setuptools==80.9.0.20250529 -typing_extensions==4.14.0 -tzdata==2025.2 -uc-micro-py==1.0.3 -uri-template==1.3.0 -urllib3==2.5.0 -wcwidth==0.2.13 -webcolors==24.11.1 -webencodings==0.5.1 -websocket-client==1.8.0 -xarray==2025.6.1 -xyzservices==2025.4.0 -zict==3.0.0 -zipp==3.23.0 +types-setuptools +dask +adios2 +bokeh +xarray +numpy +scipy +matplotlib +tqdm +contourpy +typer +loky diff --git a/shell.nix b/shell.nix index 77eed82..146d016 100644 --- a/shell.nix +++ b/shell.nix @@ -23,7 +23,7 @@ pkgs.mkShell { if [ ! -d ".venv" ]; then python3 -m venv .venv source .venv/bin/activate - pip3 install -r requirements.txt + pip3 install ipykernel jupyterlab pip3 install pytest pip3 install -e . else From 9321fd3c55fc8d0acf1ea6c5007a53ed90728de4 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 14:31:07 -0500 Subject: [PATCH 08/16] types fixed --- .github/workflows/unittests.yml | 2 +- nt2/cli/main.py | 9 +++++---- nt2/containers/data.py | 24 +++++++++++++--------- nt2/containers/particles.py | 36 ++++++++++++++++----------------- nt2/plotters/export.py | 6 +++--- nt2/plotters/inspect.py | 36 ++++++++++++++++----------------- nt2/plotters/movie.py | 4 ++-- nt2/plotters/particles.py | 9 +++++---- nt2/tests/test_containers.py | 3 ++- nt2/utils.py | 3 ++- 10 files changed, 70 insertions(+), 62 deletions(-) diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 9e43391..005af95 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install .[dev] + pip install .[hdf5,dev] - name: Test with `pytest` run: | diff --git a/nt2/cli/main.py b/nt2/cli/main.py index d84c141..f3e6f7a 100644 --- a/nt2/cli/main.py +++ b/nt2/cli/main.py @@ -1,3 +1,4 @@ +from typing import Union import typer, nt2, os from typing_extensions import Annotated import matplotlib.pyplot as plt @@ -22,11 +23,11 @@ def check_path(path: str) -> str: return path -def check_sel(sel: str) -> dict[str, int | float | slice]: +def check_sel(sel: str) -> dict[str, Union[int, float, slice]]: if sel == "": return {} sel_list = sel.strip().split(";") - sel_dict: dict[str, int | float | slice] = {} + sel_dict: dict[str, Union[int, float, slice]] = {} for _, s in enumerate(sel_list): coord, arg = s.strip().split("=", 1) coord = coord.strip() @@ -130,10 +131,10 @@ def plot( if sel != {}: slices = {} sels = {} - slices: dict[str, slice | float | int] = { + slices: dict[str, Union[slice, float, int]] = { k: v for k, v in sel.items() if isinstance(v, slice) } - sels: dict[str, slice | float | int] = { + sels: dict[str, Union[slice, float, int]] = { k: v for k, v in sel.items() if not isinstance(v, slice) } d = d.sel(**sels, method="nearest") diff --git a/nt2/containers/data.py b/nt2/containers/data.py index 328e545..ca4abb4 100644 --- a/nt2/containers/data.py +++ b/nt2/containers/data.py @@ -1,4 +1,4 @@ -from typing import Callable, Any +from typing import Callable, Any, Union, Optional import sys @@ -131,7 +131,7 @@ def remap_prtl_quantities_sph(name: str) -> str: }.get(shortname, shortname) -def compactify(lst: list[Any] | KeysView[Any]) -> str: +def compactify(lst: Union[list[Any], KeysView[Any]]) -> str: c = "" cntr = 0 for l_ in lst: @@ -153,9 +153,9 @@ class Data(Fields, Particles, Spectra): def __init__( self, path: str, - reader: BaseReader | None = None, - remap: dict[str, Callable[[str], str]] | None = None, - coord_system: CoordinateSystem | None = None, + reader: Optional[BaseReader] = None, + remap: Optional[dict[str, Callable[[str], str]]] = None, + coord_system: Optional[CoordinateSystem] = None, ): """Initializer for the Data class. @@ -163,12 +163,12 @@ def __init__( ---------- path : str Main path to the data - reader : BaseReader | None + reader : BaseReader, optional Reader to use to read the data. If None, it will be determined based on the file format. - remap : dict[str, Callable[[str], str]] | None + remap : dict[str, Callable[[str], str]], optional Remap dictionary to use to remap the data names (coords, fields, etc.). - coord_system : CoordinateSystem | None + coord_system : CoordinateSystem, optional Coordinate system of the data. If None, it will be determined based on the data attrs (if remap is also None). @@ -251,8 +251,8 @@ def __init__( def makeMovie( self, plot: Callable, - time: list[float] | None = None, - num_cpus: int | None = None, + time: Optional[list[float]] = None, + num_cpus: Optional[int] = None, **movie_kwargs: Any, ) -> bool: f"""Create animation with provided plot function. @@ -263,6 +263,10 @@ def makeMovie( A function that takes a single argument (time in physical units) and produces a plot. time : array_like, optional An array of time values to use for the animation. If not provided, the entire time range will be used. + num_cpus : int, optional + The number of CPUs to use for parallel processing. If None, it will use all available CPUs. + **movie_kwargs : dict + Additional keyword arguments to pass to the movie creation function. Returns ------- diff --git a/nt2/containers/particles.py b/nt2/containers/particles.py index 782c8e8..5822641 100644 --- a/nt2/containers/particles.py +++ b/nt2/containers/particles.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Sequence, Tuple, Literal +from typing import Any, Callable, List, Optional, Sequence, Tuple, Literal, Union import numpy.typing as npt from copy import copy @@ -13,15 +13,15 @@ from nt2.containers.container import BaseContainer -IntSelector = int | Sequence[int] | slice | Tuple[int, int] -FloatSelector = float | slice | Sequence[float] | Tuple[float, float] +IntSelector = Union[int, Sequence[int], slice, Tuple[int, int]] +FloatSelector = Union[float, slice, Sequence[float], Tuple[float, float]] class Selection: def __init__( self, type: Literal["value", "range", "list"], - value: Optional[int | float | list | tuple] = None, + value: Optional[Union[int, float, list, tuple]] = None, ): self.type = type self.value = value @@ -113,7 +113,7 @@ def __str__(self) -> str: def _coerce_selector_to_mask( - s: IntSelector | FloatSelector, + s: Union[IntSelector, FloatSelector], series: Any, inclusive_tuple: bool = True, method="exact", @@ -182,7 +182,7 @@ def __init__( times: npt.NDArray[np.float64], colnames: List[str], read_column: Callable[ - [int, str], npt.NDArray[np.float64 | np.int64 | np.float32 | np.int32] + [int, str], npt.NDArray[Union[np.float64, np.int64, np.float32, np.int32]] ], fprec: Optional[type] = np.float32, selection: Optional[dict[str, Selection]] = None, @@ -249,7 +249,7 @@ def columns(self) -> List[str]: def sel( self, - t: Optional[IntSelector | FloatSelector] = None, + t: Optional[Union[IntSelector, FloatSelector]] = None, st: Optional[IntSelector] = None, sp: Optional[IntSelector] = None, id: Optional[IntSelector] = None, @@ -381,7 +381,7 @@ def _build_index_ddf(self) -> dd.DataFrame: ddf = dd.from_delayed(delayed_parts, meta=meta) return ddf - def load(self, cols: Sequence[str] | None = None) -> pd.DataFrame: + def load(self, cols: Optional[Sequence[str]] = None) -> pd.DataFrame: if cols is None: cols = self.columns @@ -445,9 +445,9 @@ def __str__(self) -> str: def spectrum_plot( self, - ax: maxes.Axes | None = None, - bins: np.ndarray | None = None, - quantity: Callable[[pd.DataFrame], np.ndarray] | None = None, + ax: Optional[maxes.Axes] = None, + bins: Optional[npt.NDArray] = None, + quantity: Optional[Callable[[pd.DataFrame], npt.NDArray]] = None, ): if ax is None: ax = plt.gca() @@ -480,10 +480,10 @@ def spectrum_plot( def phase_plot( self, - ax: maxes.Axes | None = None, - x_quantity: Callable[[pd.DataFrame], np.ndarray] | None = None, - y_quantity: Callable[[pd.DataFrame], np.ndarray] | None = None, - xy_bins: Tuple[np.ndarray, np.ndarray] | None = None, + ax: Optional[maxes.Axes] = None, + x_quantity: Optional[Callable[[pd.DataFrame], np.ndarray]] = None, + y_quantity: Optional[Callable[[pd.DataFrame], np.ndarray]] = None, + xy_bins: Optional[Tuple[npt.NDArray, npt.NDArray]] = None, **kwargs: Any, ): if ax is None: @@ -641,7 +641,7 @@ def particles_defined(self) -> bool: return self.__particles_defined @property - def particles(self) -> ParticleDataset | None: + def particles(self) -> Optional[ParticleDataset]: """Returns the particles data. Returns @@ -675,7 +675,7 @@ def _get_quantity_for_species( read_colname: str, step: int, sp: int, - ) -> npt.NDArray[np.float64 | np.int64]: + ) -> npt.NDArray[Union[np.float64, np.int64]]: if f"{read_colname}_{sp}" in self.quantities: return self.reader.ReadArrayAtTimestep( self.path, "particles", f"{read_colname}_{sp}", step @@ -685,7 +685,7 @@ def _get_quantity_for_species( def _read_column( self, step: int, colname: str - ) -> npt.NDArray[np.float64 | np.int64 | np.float32 | np.int32]: + ) -> npt.NDArray[Union[np.float64, np.int64, np.float32, np.int32]]: read_colname = None if colname == "id": idx = np.concat( diff --git a/nt2/plotters/export.py b/nt2/plotters/export.py index 3ebb0da..53c2427 100644 --- a/nt2/plotters/export.py +++ b/nt2/plotters/export.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, Union, Optional import matplotlib.pyplot as plt @@ -36,7 +36,7 @@ def makeFramesAndMovie( raise ValueError("Failed to make frames") -def makeMovie(**ffmpeg_kwargs: str | int | float) -> bool: +def makeMovie(**ffmpeg_kwargs: Union[str, int, float]) -> bool: """ Create a movie from frames using the `ffmpeg` command-line tool. @@ -115,7 +115,7 @@ def makeFrames( times: list[float], fpath: str, data: Any = None, - num_cpus: int | None = None, + num_cpus: Optional[int] = None, ) -> list[bool]: """ Create plot frames from a set of timesteps of the same dataset. diff --git a/nt2/plotters/inspect.py b/nt2/plotters/inspect.py index 4effbd3..457c81e 100644 --- a/nt2/plotters/inspect.py +++ b/nt2/plotters/inspect.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt import matplotlib.figure as mfigure import xarray as xr @@ -114,24 +114,24 @@ def _fixed_axes_grid_with_cbars( def plot( self, - fig: mfigure.Figure | None = None, - name: str | None = None, - skip_fields: list[str] | None = None, - only_fields: list[str] | None = None, - fig_kwargs: dict[str, Any] | None = None, - plot_kwargs: dict[str, Any] | None = None, - movie_kwargs: dict[str, Any] | None = None, - set_aspect: str | None = "equal", - ) -> mfigure.Figure | bool: + fig: Optional[mfigure.Figure] = None, + name: Optional[str] = None, + skip_fields: Optional[list[str]] = None, + only_fields: Optional[list[str]] = None, + fig_kwargs: Optional[dict[str, Any]] = None, + plot_kwargs: Optional[dict[str, Any]] = None, + movie_kwargs: Optional[dict[str, Any]] = None, + set_aspect: Optional[str] = "equal", + ) -> Union[mfigure.Figure, bool]: """ Plots the overview plot for fields at a given time or step (or as a movie). Kwargs ------ - fig : matplotlib.figure.Figure | None, optional + fig : matplotlib.figure.Figure, optional The figure to plot the data (if None, a new figure is created). Default is None. - name : string | None, optional + name : string, optional Used when saving the frames and the movie. Default is None. skip_fields : list, optional @@ -152,7 +152,7 @@ def plot( movie_kwargs : dict, optional Additional keyword arguments for makeMovie. Default is {}. - set_aspect : str | None, optional + set_aspect : str, optional If None, the aspect ratio will not be enforced. Otherwise, this value is passed to `set_aspect` method of the axes. Default is 'equal'. Returns @@ -263,8 +263,8 @@ def _get_fields_to_plot( @staticmethod def _get_fields_minmax( data: xr.Dataset, fields: list[str] - ) -> dict[str, None | tuple[float, float]]: - minmax: dict[str, None | tuple[float, float]] = { + ) -> dict[str, Optional[tuple[float, float]]]: + minmax: dict[str, Optional[tuple[float, float]]] = { "E": None, "B": None, "J": None, @@ -300,7 +300,7 @@ def _get_fields_minmax( def plot_frame_1d( self, data: xr.Dataset, - fig: mfigure.Figure | None, + fig: Optional[mfigure.Figure], skip_fields: list[str], only_fields: list[str], fig_kwargs: dict[str, Any], @@ -375,12 +375,12 @@ def make_plot(ax: plt.Axes, fld: str): def plot_frame_2d( self, data: xr.Dataset, - fig: mfigure.Figure | None, + fig: Optional[mfigure.Figure], skip_fields: list[str], only_fields: list[str], fig_kwargs: dict[str, Any], plot_kwargs: dict[str, Any], - set_aspect: str | None, + set_aspect: Optional[str], ) -> mfigure.Figure: if len(data.dims) != 2: raise ValueError("Pass 2D data; use .sel or .isel to reduce dimension.") diff --git a/nt2/plotters/movie.py b/nt2/plotters/movie.py index fd64e05..a7093d1 100644 --- a/nt2/plotters/movie.py +++ b/nt2/plotters/movie.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from nt2.plotters.export import ( makeFramesAndMovie, ) @@ -75,7 +75,7 @@ def plot_func(ti: int, _: Any) -> None: plt.gca().set_aspect("equal") plt.tight_layout() - num_cpus: int | None = movie_kwargs.pop("num_cpus", None) + num_cpus: Optional[int] = movie_kwargs.pop("num_cpus", None) return makeFramesAndMovie( name=name, data=self._obj, diff --git a/nt2/plotters/particles.py b/nt2/plotters/particles.py index b4bbe1b..285de78 100644 --- a/nt2/plotters/particles.py +++ b/nt2/plotters/particles.py @@ -1,5 +1,6 @@ import xarray as xr import numpy as np +from typing import Optional class ds_accessor: @@ -10,10 +11,10 @@ def phaseplot( self, x: str = "x", y: str = "ux", - xbins: None | np.ndarray = None, - ybins: None | np.ndarray = None, - xlims: None | tuple[float] = None, - ylims: None | tuple[float] = None, + xbins: Optional[np.ndarray] = None, + ybins: Optional[np.ndarray] = None, + xlims: Optional[tuple[float]] = None, + ylims: Optional[tuple[float]] = None, xnbins: int = 100, ynbins: int = 100, **kwargs, diff --git a/nt2/tests/test_containers.py b/nt2/tests/test_containers.py index 8ad2e92..45d0a83 100644 --- a/nt2/tests/test_containers.py +++ b/nt2/tests/test_containers.py @@ -1,4 +1,5 @@ import pytest +from typing import Union from nt2.readers.base import BaseReader from nt2.containers.fields import Fields @@ -17,7 +18,7 @@ def check_shape(shape1, shape2): @pytest.mark.parametrize( "test,field_container", [[test, fc] for test in TESTS for fc in [Data, Fields]] ) -def test_fields(test, field_container: type[Data] | type[Fields]): +def test_fields(test, field_container: Union[type[Data], type[Fields]]): reader: BaseReader = test["reader"]() PATH = test["path"] if test["fields"] == {}: diff --git a/nt2/utils.py b/nt2/utils.py index 700c738..12ea889 100644 --- a/nt2/utils.py +++ b/nt2/utils.py @@ -1,3 +1,4 @@ +from typing import Union from enum import Enum import os import re @@ -63,7 +64,7 @@ def DetermineDataFormat(path: str) -> Format: raise ValueError("Could not determine file format.") -def ToHumanReadable(num: float | int, suffix: str = "B") -> str: +def ToHumanReadable(num: Union[float, int], suffix: str = "B") -> str: """Convert a number to a human-readable format with SI prefixes. Parameters From 32c8e0cf8cbc20258ba283f4f68d3a2cf831bfef Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 14:48:38 -0500 Subject: [PATCH 09/16] typing for 3.8,3.9 fixed --- nt2/cli/main.py | 10 +++---- nt2/containers/container.py | 10 +++---- nt2/containers/data.py | 12 ++++----- nt2/plotters/export.py | 8 +++--- nt2/plotters/inspect.py | 52 ++++++++++++++++++------------------ nt2/plotters/particles.py | 11 ++++---- nt2/plotters/polar.py | 4 +-- nt2/readers/adios2.py | 32 +++++++++++----------- nt2/readers/base.py | 26 +++++++++--------- nt2/readers/hdf5.py | 20 +++++++------- nt2/tests/test_containers.py | 11 +++----- nt2/tests/test_export.py | 6 ++--- 12 files changed, 100 insertions(+), 102 deletions(-) diff --git a/nt2/cli/main.py b/nt2/cli/main.py index f3e6f7a..a7ca8cf 100644 --- a/nt2/cli/main.py +++ b/nt2/cli/main.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Dict import typer, nt2, os from typing_extensions import Annotated import matplotlib.pyplot as plt @@ -23,11 +23,11 @@ def check_path(path: str) -> str: return path -def check_sel(sel: str) -> dict[str, Union[int, float, slice]]: +def check_sel(sel: str) -> Dict[str, Union[int, float, slice]]: if sel == "": return {} sel_list = sel.strip().split(";") - sel_dict: dict[str, Union[int, float, slice]] = {} + sel_dict: Dict[str, Union[int, float, slice]] = {} for _, s in enumerate(sel_list): coord, arg = s.strip().split("=", 1) coord = coord.strip() @@ -131,10 +131,10 @@ def plot( if sel != {}: slices = {} sels = {} - slices: dict[str, Union[slice, float, int]] = { + slices: Dict[str, Union[slice, float, int]] = { k: v for k, v in sel.items() if isinstance(v, slice) } - sels: dict[str, Union[slice, float, int]] = { + sels: Dict[str, Union[slice, float, int]] = { k: v for k, v in sel.items() if not isinstance(v, slice) } d = d.sel(**sels, method="nearest") diff --git a/nt2/containers/container.py b/nt2/containers/container.py index aeb9772..430051b 100644 --- a/nt2/containers/container.py +++ b/nt2/containers/container.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Dict, Tuple from nt2.readers.base import BaseReader @@ -8,13 +8,13 @@ class BaseContainer: __path: str __reader: BaseReader - __remap: Optional[dict[str, Callable[[str], str]]] + __remap: Optional[Dict[str, Callable[[str], str]]] def __init__( self, path: str, reader: BaseReader, - remap: Optional[dict[str, Callable[[str], str]]] = None, + remap: Optional[Dict[str, Callable[[str], str]]] = None, ): """Initializer for the BaseContainer class. @@ -44,11 +44,11 @@ def reader(self) -> BaseReader: return self.__reader @property - def remap(self) -> Optional[dict[str, Callable[[str], str]]]: + def remap(self) -> Optional[Dict[str, Callable[[str], str]]]: """{ str: (str) -> str } : The coordinate/field remap dictionary.""" return self.__remap - def __dask_tokenize__(self) -> tuple[str, str, str]: + def __dask_tokenize__(self) -> Tuple[str, str, str]: """Provide a deterministic Dask token for container instances.""" return ( self.__class__.__name__, diff --git a/nt2/containers/data.py b/nt2/containers/data.py index ca4abb4..03e4471 100644 --- a/nt2/containers/data.py +++ b/nt2/containers/data.py @@ -1,4 +1,4 @@ -from typing import Callable, Any, Union, Optional +from typing import Callable, Any, Union, Optional, List, Dict import sys @@ -131,7 +131,7 @@ def remap_prtl_quantities_sph(name: str) -> str: }.get(shortname, shortname) -def compactify(lst: Union[list[Any], KeysView[Any]]) -> str: +def compactify(lst: Union[List[Any], KeysView[Any]]) -> str: c = "" cntr = 0 for l_ in lst: @@ -154,7 +154,7 @@ def __init__( self, path: str, reader: Optional[BaseReader] = None, - remap: Optional[dict[str, Callable[[str], str]]] = None, + remap: Optional[Dict[str, Callable[[str], str]]] = None, coord_system: Optional[CoordinateSystem] = None, ): """Initializer for the Data class. @@ -198,7 +198,7 @@ def __init__( self.__reader = reader # determine the coordinate system and remapping - self.__attrs: dict[str, Any] = {} + self.__attrs: Dict[str, Any] = {} for category in ["fields", "particles", "spectra"]: if self.__reader.DefinesCategory(path, category): valid_steps = self.__reader.GetValidSteps(path, category) @@ -251,7 +251,7 @@ def __init__( def makeMovie( self, plot: Callable, - time: Optional[list[float]] = None, + time: Optional[List[float]] = None, num_cpus: Optional[int] = None, **movie_kwargs: Any, ) -> bool: @@ -305,7 +305,7 @@ def coordinate_system(self) -> CoordinateSystem: return self.__coordinate_system @property - def attrs(self) -> dict[str, Any]: + def attrs(self) -> Dict[str, Any]: """dict[str, Any]: The attributes of the data.""" return self.__attrs diff --git a/nt2/plotters/export.py b/nt2/plotters/export.py index 53c2427..eb7dafd 100644 --- a/nt2/plotters/export.py +++ b/nt2/plotters/export.py @@ -1,11 +1,11 @@ -from typing import Any, Callable, Union, Optional +from typing import Any, Callable, Union, Optional, List import matplotlib.pyplot as plt def makeFramesAndMovie( name: str, plot: Callable, - times: list[float], + times: List[float], data: Any = None, **kwargs: Any, ) -> bool: @@ -112,11 +112,11 @@ def _plot_and_save(ti: int, t: float, fpath: str, plot: Callable, data: Any) -> def makeFrames( plot: Callable, - times: list[float], + times: List[float], fpath: str, data: Any = None, num_cpus: Optional[int] = None, -) -> list[bool]: +) -> List[bool]: """ Create plot frames from a set of timesteps of the same dataset. diff --git a/nt2/plotters/inspect.py b/nt2/plotters/inspect.py index 457c81e..7505de1 100644 --- a/nt2/plotters/inspect.py +++ b/nt2/plotters/inspect.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, List, Dict, Tuple import matplotlib.pyplot as plt import matplotlib.figure as mfigure import xarray as xr @@ -12,7 +12,7 @@ def __init__(self, xarray_obj: xr.Dataset): def __axes_grid( self, - grouped_fields: dict[str, list[str]], + grouped_fields: Dict[str, List[str]], makeplot: Callable, nrows: int, ncols: int, @@ -21,7 +21,7 @@ def __axes_grid( aspect: float, pad: float, **fig_kwargs: Any, - ) -> tuple[mfigure.Figure, list[plt.Axes]]: + ) -> Tuple[mfigure.Figure, List[plt.Axes]]: if aspect > 1: axw = size / aspect axh = size @@ -50,7 +50,7 @@ def __axes_grid( @staticmethod def _fixed_axes_grid_with_cbars( - fields: list[str], + fields: List[str], makeplot: Callable, makecbar: Callable, nrows: int, @@ -61,7 +61,7 @@ def _fixed_axes_grid_with_cbars( pad: float, cbar_w: float, **fig_kwargs: Any, - ) -> tuple[mfigure.Figure, list[plt.Axes]]: + ) -> Tuple[mfigure.Figure, List[plt.Axes]]: from mpl_toolkits.axes_grid1 import Divider, Size if aspect > 1: @@ -86,7 +86,7 @@ def _fixed_axes_grid_with_cbars( v += [Size.Fixed(pad)] divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False) - axes: list[plt.Axes] = [] + axes: List[plt.Axes] = [] cntr = 0 for i in range(nrows): @@ -116,11 +116,11 @@ def plot( self, fig: Optional[mfigure.Figure] = None, name: Optional[str] = None, - skip_fields: Optional[list[str]] = None, - only_fields: Optional[list[str]] = None, - fig_kwargs: Optional[dict[str, Any]] = None, - plot_kwargs: Optional[dict[str, Any]] = None, - movie_kwargs: Optional[dict[str, Any]] = None, + skip_fields: Optional[List[str]] = None, + only_fields: Optional[List[str]] = None, + fig_kwargs: Optional[Dict[str, Any]] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, + movie_kwargs: Optional[Dict[str, Any]] = None, set_aspect: Optional[str] = "equal", ) -> Union[mfigure.Figure, bool]: """ @@ -236,13 +236,13 @@ def plot_func(ti: int, _): @staticmethod def _get_fields_to_plot( - data: xr.Dataset, skip_fields: list[str], only_fields: list[str] - ) -> list[str]: + data: xr.Dataset, skip_fields: List[str], only_fields: List[str] + ) -> List[str]: import re nfields = len(data.data_vars) if nfields > 0: - keys: list[str] = [str(k) for k in data.keys()] + keys: List[str] = [str(k) for k in data.keys()] if len(only_fields) == 0: fields_to_plot = [ f for f in keys if not any([re.match(sf, f) for sf in skip_fields]) @@ -262,9 +262,9 @@ def _get_fields_to_plot( @staticmethod def _get_fields_minmax( - data: xr.Dataset, fields: list[str] - ) -> dict[str, Optional[tuple[float, float]]]: - minmax: dict[str, Optional[tuple[float, float]]] = { + data: xr.Dataset, fields: List[str] + ) -> Dict[str, Optional[Tuple[float, float]]]: + minmax: Dict[str, Optional[Tuple[float, float]]] = { "E": None, "B": None, "J": None, @@ -301,10 +301,10 @@ def plot_frame_1d( self, data: xr.Dataset, fig: Optional[mfigure.Figure], - skip_fields: list[str], - only_fields: list[str], - fig_kwargs: dict[str, Any], - plot_kwargs: dict[str, Any], + skip_fields: List[str], + only_fields: List[str], + fig_kwargs: Dict[str, Any], + plot_kwargs: Dict[str, Any], ) -> mfigure.Figure: if len(data.dims) != 1: raise ValueError("Pass 1D data; use .sel or .isel to reduce dimension.") @@ -315,7 +315,7 @@ def plot_frame_1d( fields_to_plot = self._get_fields_to_plot(data, skip_fields, only_fields) # group fields by their first letter - grouped_fields: dict[str, list[str]] = {} + grouped_fields: Dict[str, List[str]] = {} for f in fields_to_plot: key = f[0] if key not in grouped_fields: @@ -376,10 +376,10 @@ def plot_frame_2d( self, data: xr.Dataset, fig: Optional[mfigure.Figure], - skip_fields: list[str], - only_fields: list[str], - fig_kwargs: dict[str, Any], - plot_kwargs: dict[str, Any], + skip_fields: List[str], + only_fields: List[str], + fig_kwargs: Dict[str, Any], + plot_kwargs: Dict[str, Any], set_aspect: Optional[str], ) -> mfigure.Figure: if len(data.dims) != 2: diff --git a/nt2/plotters/particles.py b/nt2/plotters/particles.py index 285de78..a001e9a 100644 --- a/nt2/plotters/particles.py +++ b/nt2/plotters/particles.py @@ -1,6 +1,7 @@ import xarray as xr import numpy as np -from typing import Optional +import numpy.typing as npt +from typing import Optional, Tuple class ds_accessor: @@ -11,10 +12,10 @@ def phaseplot( self, x: str = "x", y: str = "ux", - xbins: Optional[np.ndarray] = None, - ybins: Optional[np.ndarray] = None, - xlims: Optional[tuple[float]] = None, - ylims: Optional[tuple[float]] = None, + xbins: Optional[npt.NDArray] = None, + ybins: Optional[npt.NDArray] = None, + xlims: Optional[Tuple[float, float]] = None, + ylims: Optional[Tuple[float, float]] = None, xnbins: int = 100, ynbins: int = 100, **kwargs, diff --git a/nt2/plotters/polar.py b/nt2/plotters/polar.py index f25ca95..a9dfbeb 100644 --- a/nt2/plotters/polar.py +++ b/nt2/plotters/polar.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Any +from typing import Any, Dict from nt2.utils import DataIs2DPolar @@ -176,7 +176,7 @@ def fieldlines(self, fr, fth, start_points, **kwargs): fxs = self._obj[fr] * np.sin(ths) + self._obj[fth] * np.cos(ths) fys = self._obj[fr] * np.cos(ths) - self._obj[fth] * np.sin(ths) - props: dict[str, Any] = { + props: Dict[str, Any] = { "method": "nearest", "bounds_error": False, "fill_value": 0, diff --git a/nt2/readers/adios2.py b/nt2/readers/adios2.py index e01d4a0..fcfe17a 100644 --- a/nt2/readers/adios2.py +++ b/nt2/readers/adios2.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List, Dict, Tuple import sys @@ -41,15 +41,15 @@ def ReadPerTimestepVariable( category: str, varname: str, newname: str, - ) -> dict[str, npt.NDArray[Any]]: - variables: list[float] = [] + ) -> Dict[str, npt.NDArray[Any]]: + variables: List[float] = [] for filename in self.GetValidFiles( path=path, category=category, ): with bp.FileReader(os.path.join(path, category, filename)) as f: - avail: dict[str, Any] = f.available_variables() - vars: list[str] = list(avail.keys()) + avail: Dict[str, Any] = f.available_variables() + vars: List[str] = list(avail.keys()) if varname in vars: var = f.inquire_variable(varname) if var is not None: @@ -67,11 +67,11 @@ def ReadEdgeCoordsAtTimestep( self, path: str, step: int, - ) -> dict[str, Any]: - dct: dict[str, npt.NDArray[Any]] = {} + ) -> Dict[str, Any]: + dct: Dict[str, npt.NDArray[Any]] = {} with bp.FileReader(self.FullPath(path, "fields", step)) as f: - avail: dict[str, Any] = f.available_variables() - vars: list[str] = list(avail.keys()) + avail: Dict[str, Any] = f.available_variables() + vars: List[str] = list(avail.keys()) for var in vars: if var.startswith("X") and var.endswith("e"): var_obj = f.inquire_variable(var) @@ -85,7 +85,7 @@ def ReadAttrsAtTimestep( path: str, category: str, step: int, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: with bp.FileReader(self.FullPath(path, category, step)) as f: return {k: f.read_attribute(k) for k in f.available_attributes()} @@ -116,7 +116,7 @@ def ReadCategoryNamesAtTimestep( step: int, ) -> set[str]: with bp.FileReader(self.FullPath(path, category, step)) as f: - keys: list[str] = f.available_variables() + keys: List[str] = f.available_variables() return set( filter( lambda c: c.startswith(prefix), @@ -127,7 +127,7 @@ def ReadCategoryNamesAtTimestep( @override def ReadArrayShapeAtTimestep( self, path: str, category: str, quantity: str, step: int - ) -> tuple[int]: + ) -> Tuple[int]: with bp.FileReader(filename := self.FullPath(path, category, step)) as f: if quantity in f.available_variables(): var = f.inquire_variable(quantity) @@ -145,7 +145,7 @@ def ReadArrayShapeAtTimestep( @override def ReadArrayShapeExplicitlyAtTimestep( self, path: str, category: str, quantity: str, step: int - ) -> tuple[int]: + ) -> Tuple[int]: with bp.FileReader(filename := self.FullPath(path, category, step)) as f: if quantity in f.available_variables(): var = f.inquire_variable(quantity) @@ -163,7 +163,7 @@ def ReadArrayShapeExplicitlyAtTimestep( @override def ReadFieldCoordsAtTimestep( self, path: str, step: int - ) -> dict[str, npt.NDArray[Any]]: + ) -> Dict[str, npt.NDArray[Any]]: with bp.FileReader(filename := self.FullPath(path, "fields", step)) as f: def get_coord(c: str) -> npt.NDArray[Any]: @@ -173,13 +173,13 @@ def get_coord(c: str) -> npt.NDArray[Any]: else: raise ValueError(f"Field {c} is not a group in the {filename}") - keys: list[str] = list(f.available_variables()) + keys: List[str] = list(f.available_variables()) return {c: get_coord(c) for c in keys if re.match(r"^X[1|2|3]$", c)} @override def ReadFieldLayoutAtTimestep(self, path: str, step: int) -> Layout: with bp.FileReader(filename := self.FullPath(path, "fields", step)) as f: - attrs: dict[str, Any] = f.available_attributes() + attrs: Dict[str, Any] = f.available_attributes() keys = list(attrs.keys()) if "LayoutRight" not in keys: raise ValueError(f"LayoutRight attribute not found in the {filename}") diff --git a/nt2/readers/base.py b/nt2/readers/base.py index 1b21a06..40fb3de 100644 --- a/nt2/readers/base.py +++ b/nt2/readers/base.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List, Tuple, Dict import numpy.typing as npt import os, re, logging @@ -12,7 +12,7 @@ class BaseReader: """ - skipped_files: list[str] + skipped_files: List[str] def __init__(self) -> None: """Initializer for the BaseReader class.""" @@ -52,7 +52,7 @@ def ReadPerTimestepVariable( category: str, varname: str, newname: str, - ) -> dict[str, npt.NDArray[Any]]: + ) -> Dict[str, npt.NDArray[Any]]: """Read a variable at each timestep and return a dictionary with the new name. Parameters @@ -79,7 +79,7 @@ def ReadAttrsAtTimestep( path: str, category: str, step: int, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Read the attributes of a given timestep. Parameters @@ -103,7 +103,7 @@ def ReadEdgeCoordsAtTimestep( self, path: str, step: int, - ) -> dict[str, npt.NDArray[Any]]: + ) -> Dict[str, npt.NDArray[Any]]: """Read the coordinates of cell edges at a given timestep. Parameters @@ -204,7 +204,7 @@ def ReadArrayShapeAtTimestep( category: str, quantity: str, step: int, - ) -> tuple[int]: + ) -> Tuple[int]: """Read the shape of an array at a given timestep. Parameters @@ -232,7 +232,7 @@ def ReadArrayShapeExplicitlyAtTimestep( category: str, quantity: str, step: int, - ) -> tuple[int]: + ) -> Tuple[int]: """Read the shape of an array at a given timestep, without relying on metadata. Parameters @@ -260,7 +260,7 @@ def ReadFieldCoordsAtTimestep( self, path: str, step: int, - ) -> dict[str, npt.NDArray[Any]]: + ) -> Dict[str, npt.NDArray[Any]]: """Read the coordinates of the fields at a given timestep. Parameters @@ -301,7 +301,7 @@ def ReadFieldLayoutAtTimestep(self, path: str, step: int) -> Layout: # # # # # # # # # # # # # # # # # # # # # # # # @staticmethod - def CategoryFiles(path: str, category: str, format: str) -> list[str]: + def CategoryFiles(path: str, category: str, format: str) -> List[str]: """Get the list of files in a given category and format. Parameters @@ -360,7 +360,7 @@ def GetValidSteps( self, path: str, category: str, - ) -> list[int]: + ) -> List[int]: """Get valid timesteps (sorted) in a given path and category. Parameters @@ -376,7 +376,7 @@ def GetValidSteps( A list of valid timesteps in the given path and category. """ - steps: list[int] = [] + steps: List[int] = [] for filename in BaseReader.CategoryFiles( path=path, category=category, @@ -399,7 +399,7 @@ def GetValidFiles( self, path: str, category: str, - ) -> list[str]: + ) -> List[str]: """Get valid files (sorted by timestep) in a given path and category. Parameters @@ -415,7 +415,7 @@ def GetValidFiles( A list of valid files in the given path and category. """ - files: list[str] = [] + files: List[str] = [] for filename in BaseReader.CategoryFiles( path=path, category=category, diff --git a/nt2/readers/hdf5.py b/nt2/readers/hdf5.py index 4217753..4e8b4a4 100644 --- a/nt2/readers/hdf5.py +++ b/nt2/readers/hdf5.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, List, Dict, Tuple import sys @@ -71,8 +71,8 @@ def ReadPerTimestepVariable( category: str, varname: str, newname: str, - ) -> dict[str, npt.NDArray[Any]]: - variables: list[Any] = [] + ) -> Dict[str, npt.NDArray[Any]]: + variables: List[Any] = [] h5 = _require_h5py() for filename in self.GetValidFiles( path=path, @@ -99,7 +99,7 @@ def ReadAttrsAtTimestep( path: str, category: str, step: int, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: h5 = _require_h5py() with h5.File(self.FullPath(path, category, step), "r") as f: return {k: v for k, v in f.attrs.items()} @@ -109,7 +109,7 @@ def ReadEdgeCoordsAtTimestep( self, path: str, step: int, - ) -> dict[str, npt.NDArray[Any]]: + ) -> Dict[str, npt.NDArray[Any]]: h5 = _require_h5py() with h5.File(self.FullPath(path, "fields", step), "r") as f: f0 = Reader.__extract_step0(f) @@ -146,13 +146,13 @@ def ReadCategoryNamesAtTimestep( h5 = _require_h5py() with h5.File(self.FullPath(path, category, step), "r") as f: f0 = Reader.__extract_step0(f) - keys: list[str] = list(f0.keys()) + keys: List[str] = list(f0.keys()) return set(c for c in keys if c.startswith(prefix)) @override def ReadArrayShapeAtTimestep( self, path: str, category: str, quantity: str, step: int - ) -> tuple[int]: + ) -> Tuple[int]: h5 = _require_h5py() with h5.File(filename := self.FullPath(path, category, step), "r") as f: f0 = Reader.__extract_step0(f) @@ -172,7 +172,7 @@ def ReadArrayShapeAtTimestep( @override def ReadArrayShapeExplicitlyAtTimestep( self, path: str, category: str, quantity: str, step: int - ) -> tuple[int]: + ) -> Tuple[int]: h5 = _require_h5py() with h5.File(self.FullPath(path, category, step), "r") as f: f0 = Reader.__extract_step0(f) @@ -192,7 +192,7 @@ def ReadArrayShapeExplicitlyAtTimestep( @override def ReadFieldCoordsAtTimestep( self, path: str, step: int - ) -> dict[str, npt.NDArray[Any]]: + ) -> Dict[str, npt.NDArray[Any]]: h5 = _require_h5py() with h5.File(filename := self.FullPath(path, "fields", step), "r") as f: f0 = Reader.__extract_step0(f) @@ -204,7 +204,7 @@ def get_coord(c: str) -> Any: else: raise ValueError(f"Field {c} is not a group in the {filename}") - keys: list[str] = list(f0.keys()) + keys: List[str] = list(f0.keys()) return {c: get_coord(c) for c in keys if re.match(r"^X[1|2|3]$", c)} @override diff --git a/nt2/tests/test_containers.py b/nt2/tests/test_containers.py index 45d0a83..fece1c5 100644 --- a/nt2/tests/test_containers.py +++ b/nt2/tests/test_containers.py @@ -1,5 +1,5 @@ import pytest -from typing import Union +from typing import Union, List from nt2.readers.base import BaseReader from nt2.containers.fields import Fields @@ -24,8 +24,8 @@ def test_fields(test, field_container: Union[type[Data], type[Fields]]): if test["fields"] == {}: return - coords: list[str] = ["x", "y", "z"] - flds: list[str] = ["Ex", "Ey", "Ez", "Bx", "By", "Bz"] + coords: List[str] = ["x", "y", "z"] + flds: List[str] = ["Ex", "Ey", "Ez", "Bx", "By", "Bz"] def coord_remap(Xold: str) -> str: return { @@ -110,14 +110,12 @@ def field_remap(Fold: str): "test,particle_container", [[test, fc] for test in TESTS for fc in [Data, Particles]], ) -def test_particles(test, particle_container: type[Data] | type[Particles]): +def test_particles(test, particle_container: Union[type[Data], type[Particles]]): reader: BaseReader = test["reader"]() PATH = test["path"] if test["particles"] == {}: return - prtl_coords: list[str] = ["x", "y", "z", "ux", "uy", "uz", "w"] - def prtl_remap(Xold: str) -> str: return { "pX1": "x", @@ -130,7 +128,6 @@ def prtl_remap(Xold: str) -> str: }.get(Xold, Xold) if test.get("coords", "cart") != "cart": - prtl_coords = ["r", "th", "ph", "ur", "uth", "uph", "w"] prtl_remap = lambda Xold: { "pX1": "r", "pX2": "th", diff --git a/nt2/tests/test_export.py b/nt2/tests/test_export.py index 244f080..4a69b79 100644 --- a/nt2/tests/test_export.py +++ b/nt2/tests/test_export.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List, Tuple from nt2.plotters.export import makeFrames @@ -13,7 +13,7 @@ def result(self) -> bool: class _FakeExecutor: def __init__(self): - self.calls: list[tuple[int, float, str, Any, Any]] = [] + self.calls: List[Tuple[int, float, str, Any, Any]] = [] def submit(self, func, ti, t, fpath, plot, data): self.calls.append((ti, t, fpath, plot, data)) @@ -28,7 +28,7 @@ def test_make_frames_uses_executor_with_data(tmp_path, monkeypatch): lambda max_workers=None: ex, ) - called: list[float] = [] + called: List[float] = [] def plot_frame(t, d): called.append(t) From e8690e4ab39d77b83921b086a5c793d4cb6d5bdd Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 14:54:30 -0500 Subject: [PATCH 10/16] added typing_extensions + added Set --- nt2/readers/adios2.py | 8 ++++---- nt2/readers/base.py | 10 +++++----- nt2/readers/hdf5.py | 8 ++++---- pyproject.toml | 1 + requirements.txt | 12 ------------ 5 files changed, 14 insertions(+), 25 deletions(-) delete mode 100644 requirements.txt diff --git a/nt2/readers/adios2.py b/nt2/readers/adios2.py index fcfe17a..1694081 100644 --- a/nt2/readers/adios2.py +++ b/nt2/readers/adios2.py @@ -1,4 +1,4 @@ -from typing import Any, List, Dict, Tuple +from typing import Any, List, Dict, Tuple, Set import sys @@ -114,7 +114,7 @@ def ReadCategoryNamesAtTimestep( category: str, prefix: str, step: int, - ) -> set[str]: + ) -> Set[str]: with bp.FileReader(self.FullPath(path, category, step)) as f: keys: List[str] = f.available_variables() return set( @@ -127,7 +127,7 @@ def ReadCategoryNamesAtTimestep( @override def ReadArrayShapeAtTimestep( self, path: str, category: str, quantity: str, step: int - ) -> Tuple[int]: + ) -> Tuple[int, ...]: with bp.FileReader(filename := self.FullPath(path, category, step)) as f: if quantity in f.available_variables(): var = f.inquire_variable(quantity) @@ -145,7 +145,7 @@ def ReadArrayShapeAtTimestep( @override def ReadArrayShapeExplicitlyAtTimestep( self, path: str, category: str, quantity: str, step: int - ) -> Tuple[int]: + ) -> Tuple[int, ...]: with bp.FileReader(filename := self.FullPath(path, category, step)) as f: if quantity in f.available_variables(): var = f.inquire_variable(quantity) diff --git a/nt2/readers/base.py b/nt2/readers/base.py index 40fb3de..d53280d 100644 --- a/nt2/readers/base.py +++ b/nt2/readers/base.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple, Dict +from typing import Any, List, Tuple, Dict, Set import numpy.typing as npt import os, re, logging @@ -155,7 +155,7 @@ def ReadCategoryNamesAtTimestep( category: str, prefix: str, step: int, - ) -> set[str]: + ) -> Set[str]: """Read the names of the variables in a given category and timestep. Parameters @@ -177,7 +177,7 @@ def ReadCategoryNamesAtTimestep( """ raise NotImplementedError("ReadCategoryNamesAtTimestep is not implemented") - def ReadParticleSpeciesAtTimestep(self, path: str, step: int) -> set[int]: + def ReadParticleSpeciesAtTimestep(self, path: str, step: int) -> Set[int]: """Read the particle species indices at a given timestep. Parameters @@ -204,7 +204,7 @@ def ReadArrayShapeAtTimestep( category: str, quantity: str, step: int, - ) -> Tuple[int]: + ) -> Tuple[int, ...]: """Read the shape of an array at a given timestep. Parameters @@ -232,7 +232,7 @@ def ReadArrayShapeExplicitlyAtTimestep( category: str, quantity: str, step: int, - ) -> Tuple[int]: + ) -> Tuple[int, ...]: """Read the shape of an array at a given timestep, without relying on metadata. Parameters diff --git a/nt2/readers/hdf5.py b/nt2/readers/hdf5.py index 4e8b4a4..0ba198b 100644 --- a/nt2/readers/hdf5.py +++ b/nt2/readers/hdf5.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, TYPE_CHECKING, List, Dict, Tuple +from typing import Any, TYPE_CHECKING, List, Dict, Tuple, Set import sys @@ -142,7 +142,7 @@ def ReadCategoryNamesAtTimestep( category: str, prefix: str, step: int, - ) -> set[str]: + ) -> Set[str]: h5 = _require_h5py() with h5.File(self.FullPath(path, category, step), "r") as f: f0 = Reader.__extract_step0(f) @@ -152,7 +152,7 @@ def ReadCategoryNamesAtTimestep( @override def ReadArrayShapeAtTimestep( self, path: str, category: str, quantity: str, step: int - ) -> Tuple[int]: + ) -> Tuple[int, ...]: h5 = _require_h5py() with h5.File(filename := self.FullPath(path, category, step), "r") as f: f0 = Reader.__extract_step0(f) @@ -172,7 +172,7 @@ def ReadArrayShapeAtTimestep( @override def ReadArrayShapeExplicitlyAtTimestep( self, path: str, category: str, quantity: str, step: int - ) -> Tuple[int]: + ) -> Tuple[int, ...]: h5 = _require_h5py() with h5.File(self.FullPath(path, category, step), "r") as f: f0 = Reader.__extract_step0(f) diff --git a/pyproject.toml b/pyproject.toml index a4a3c04..58c2852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ name = "nt2py" dynamic = ["version"] dependencies = [ "types-setuptools", + "typing_extensions", "dask[complete]", "adios2", "bokeh", diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index d0c4da6..0000000 --- a/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -types-setuptools -dask -adios2 -bokeh -xarray -numpy -scipy -matplotlib -tqdm -contourpy -typer -loky From 262422602ad486cb944465a000d5e958950dfae3 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 14:56:46 -0500 Subject: [PATCH 11/16] dict -> Dict for 3.8 support --- nt2/containers/particles.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/nt2/containers/particles.py b/nt2/containers/particles.py index 5822641..1966a52 100644 --- a/nt2/containers/particles.py +++ b/nt2/containers/particles.py @@ -1,4 +1,15 @@ -from typing import Any, Callable, List, Optional, Sequence, Tuple, Literal, Union +from typing import ( + Any, + Callable, + List, + Optional, + Sequence, + Tuple, + Literal, + Union, + Dict, + Type, +) import numpy.typing as npt from copy import copy @@ -184,8 +195,8 @@ def __init__( read_column: Callable[ [int, str], npt.NDArray[Union[np.float64, np.int64, np.float32, np.int32]] ], - fprec: Optional[type] = np.float32, - selection: Optional[dict[str, Selection]] = None, + fprec: Optional[Type] = np.float32, + selection: Optional[Dict[str, Selection]] = None, ddf_index: Optional[dd.DataFrame] = None, ): self.species = species From 62f6a344b102293d410430311431b5a0b140b895 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 14:58:19 -0500 Subject: [PATCH 12/16] list -> List for 3.8 support --- nt2/containers/particles.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nt2/containers/particles.py b/nt2/containers/particles.py index 1966a52..c7fba7c 100644 --- a/nt2/containers/particles.py +++ b/nt2/containers/particles.py @@ -628,7 +628,7 @@ def particles_present(self) -> bool: return len(self.nonempty_steps) > 0 @property - def nonempty_steps(self) -> list[int]: + def nonempty_steps(self) -> List[int]: """list[int]: List of timesteps that contain particles data.""" valid_steps = self.reader.GetValidSteps(self.path, "particles") return [ From 18a39fbcd57f87a435f11481c6c001cde4b97ed6 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 14:59:44 -0500 Subject: [PATCH 13/16] dict -> Dict for 3.8 support --- nt2/plotters/movie.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nt2/plotters/movie.py b/nt2/plotters/movie.py index a7093d1..9481c4e 100644 --- a/nt2/plotters/movie.py +++ b/nt2/plotters/movie.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Dict from nt2.plotters.export import ( makeFramesAndMovie, ) @@ -21,8 +21,8 @@ def __init__(self, xarray_obj: xr.DataArray) -> None: def plot( self, name: str, - movie_kwargs: dict[str, Any] = {}, - fig_kwargs: dict[str, Any] = {}, + movie_kwargs: Dict[str, Any] = {}, + fig_kwargs: Dict[str, Any] = {}, aspect_equal: bool = False, **kwargs: Any, ) -> bool: From 7f1bd6233ee0a129698609820e1905595740e18f Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 15:02:22 -0500 Subject: [PATCH 14/16] rm keysview --- nt2/containers/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nt2/containers/data.py b/nt2/containers/data.py index 03e4471..ad96400 100644 --- a/nt2/containers/data.py +++ b/nt2/containers/data.py @@ -10,7 +10,6 @@ def override(method): return method -from collections.abc import KeysView from nt2.utils import ToHumanReadable import xarray as xr @@ -131,7 +130,7 @@ def remap_prtl_quantities_sph(name: str) -> str: }.get(shortname, shortname) -def compactify(lst: Union[List[Any], KeysView[Any]]) -> str: +def compactify(lst: Union[List[Any], Any]) -> str: c = "" cntr = 0 for l_ in lst: From f580e2447b05cf04fa2e7ff8d710b146cb017e80 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 15:04:28 -0500 Subject: [PATCH 15/16] type -> Type --- nt2/tests/test_containers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nt2/tests/test_containers.py b/nt2/tests/test_containers.py index fece1c5..9e8f058 100644 --- a/nt2/tests/test_containers.py +++ b/nt2/tests/test_containers.py @@ -1,5 +1,5 @@ import pytest -from typing import Union, List +from typing import Union, List, Type from nt2.readers.base import BaseReader from nt2.containers.fields import Fields @@ -18,7 +18,7 @@ def check_shape(shape1, shape2): @pytest.mark.parametrize( "test,field_container", [[test, fc] for test in TESTS for fc in [Data, Fields]] ) -def test_fields(test, field_container: Union[type[Data], type[Fields]]): +def test_fields(test, field_container: Union[Type[Data], Type[Fields]]): reader: BaseReader = test["reader"]() PATH = test["path"] if test["fields"] == {}: @@ -110,7 +110,7 @@ def field_remap(Fold: str): "test,particle_container", [[test, fc] for test in TESTS for fc in [Data, Particles]], ) -def test_particles(test, particle_container: Union[type[Data], type[Particles]]): +def test_particles(test, particle_container: Union[Type[Data], Type[Particles]]): reader: BaseReader = test["reader"]() PATH = test["path"] if test["particles"] == {}: From 795578fdb4279ea3a3d6a224264e52fef7eea731 Mon Sep 17 00:00:00 2001 From: haykh Date: Tue, 10 Feb 2026 15:09:11 -0500 Subject: [PATCH 16/16] concat -> concatenate --- nt2/containers/particles.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nt2/containers/particles.py b/nt2/containers/particles.py index c7fba7c..a352e4e 100644 --- a/nt2/containers/particles.py +++ b/nt2/containers/particles.py @@ -699,7 +699,7 @@ def _read_column( ) -> npt.NDArray[Union[np.float64, np.int64, np.float32, np.int32]]: read_colname = None if colname == "id": - idx = np.concat( + idx = np.concatenate( [ self.reader.ReadArrayAtTimestep( self.path, "particles", f"pIDX_{sp}", step @@ -715,7 +715,7 @@ def _read_column( len(self.sp_with_idx) > 0 and f"pRNK_{self.sp_with_idx[0]}" in self.quantities ): - rnk = np.concat( + rnk = np.concatenate( [ self.reader.ReadArrayAtTimestep( self.path, "particles", f"pRNK_{sp}", step @@ -745,7 +745,7 @@ def _read_column( elif colname == "w": read_colname = "pW" elif colname == "sp": - return np.concat( + return np.concatenate( [ np.zeros(self._get_count(step, sp), dtype=np.int32) + sp for sp in self.sp_with_idx @@ -758,7 +758,7 @@ def _read_column( else: read_colname = f"p{colname}" - return np.concat( + return np.concatenate( [ self._get_quantity_for_species(read_colname, step, sp) for sp in self.sp_with_idx