Skip to content

Commit ca58d4e

Browse files
committed
Add a method to load a VmecWout object from a NetCDF file.
1 parent 59519b4 commit ca58d4e

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

src/vmecpp/__init__.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,31 @@ def _to_cpp_wout(self) -> _vmecpp.WOutFileContents:
830830

831831
return cpp_wout
832832

833-
# TODO(eguiraud): implement from_wout_file
833+
@staticmethod
834+
def from_wout_file(wout_filename: str | Path) -> VmecWOut:
835+
"""Load wout contents in NetCDF format.
836+
837+
This is the format used by Fortran VMEC implementations and the one expected by
838+
SIMSOPT.
839+
"""
840+
with netCDF4.Dataset(wout_filename, "r") as fnc:
841+
fnc.set_auto_mask(False)
842+
attrs = {}
843+
for key in fnc.variables:
844+
if key.endswith("__logical__"):
845+
attrs[key[:-11]] = fnc[key][()] != 0
846+
elif key == "volume_p":
847+
attrs["volume"] = fnc[key][()]
848+
elif key in ["xm", "xn", "xm_nyq", "xn_nyq"]:
849+
attrs[key] = np.array(fnc[key][()], dtype=int)
850+
elif key in ["pmass_type", "piota_type", "pcurr_type", "mgrid_file"]:
851+
attrs[key] = fnc[key][()].tobytes().decode("ascii")
852+
else:
853+
attrs[key] = fnc[key][()]
854+
if "lmns_full" not in fnc.variables:
855+
attrs["lmns_full"] = None
856+
return VmecWOut(**attrs)
857+
raise RuntimeError("Failed to load NetCDF wout file " + str(wout_filename))
834858

835859

836860
class Threed1Volumetrics(pydantic.BaseModel):

tests/test_init.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,16 @@ def cma_output() -> vmecpp.VmecOutput:
8888
return vmec_output
8989

9090

91-
def test_vmecwout_save(cma_output):
91+
def test_vmecwout_io(cma_output):
9292
with tempfile.NamedTemporaryFile() as tmp_file:
9393
cma_output.wout.save(tmp_file.name)
9494

9595
assert Path(tmp_file.name).exists()
9696

97+
# check that from_wout_file can load the file as well
98+
loaded_wout = vmecpp.VmecWOut.from_wout_file(tmp_file.name)
99+
assert loaded_wout is not None
100+
97101
test_dataset = netCDF4.Dataset(tmp_file.name, "r")
98102

99103
expected_dataset = netCDF4.Dataset(TEST_DATA_DIR / "wout_cma.nc", "r")

0 commit comments

Comments
 (0)