Skip to content

Commit 0eb2343

Browse files
committed
exact profile evaluation
1 parent 64fa233 commit 0eb2343

5 files changed

Lines changed: 153 additions & 58 deletions

File tree

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,6 @@ VMEC++:
318318
* `cubic_spline_ip`
319319
* `pedestal`
320320
* `rational`
321-
* `line_segment`
322-
* `line_segment_i`
323-
* `line_segment_ip`
324321
* `nice_quadratic`
325322
* `sum_cossq_s`
326323
* `sum_cossq_sqrts`

src/vmecpp/__init__.py

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@
4747

4848
AuxFType = typing.Annotated[
4949
_ArrayType,
50-
pydantic.BeforeValidator(lambda x: _util.pad_to_target(x, ndfmax, 0.0)),
50+
pydantic.BeforeValidator(lambda x: _util.right_pad(x, ndfmax, 0.0)),
5151
]
5252
AuxSType = typing.Annotated[
5353
_ArrayType,
54-
pydantic.BeforeValidator(lambda x: _util.pad_to_target(x, ndfmax, -1.0)),
54+
pydantic.BeforeValidator(lambda x: _util.right_pad(x, ndfmax, -1.0)),
5555
]
5656

5757
MgridModeType: typing.TypeAlias = typing.Annotated[
@@ -770,7 +770,8 @@ class VmecWOut(BaseModelWithNumpy):
770770
"""Radial derivative of enclosed toroidal magnetic flux ``phi'`` on the full-
771771
grid."""
772772

773-
chipf: jt.Float[np.ndarray, "n_surfaces"]
773+
# Defaulted for backwards compatibility with old wout files
774+
chipf: jt.Float[np.ndarray, "n_surfaces"] = np.array([])
774775
"""Radial derivative of enclosed poloidal magnetic flux ``chi'`` on the full-
775776
grid."""
776777

@@ -1095,11 +1096,13 @@ class VmecWOut(BaseModelWithNumpy):
10951096
"""Volume-averaged magnetic field strength."""
10961097

10971098
# In the C++ WOutFileContents this is called safety_factor.
1098-
q_factor: jt.Float[np.ndarray, "n_surfaces"]
1099+
# Defaulted for backwards compatibility with old wout files.
1100+
q_factor: jt.Float[np.ndarray, "n_surfaces"] = np.array([])
10991101
r"""Safety factor :math:`q = 1/\iota` on the full-grid."""
11001102

11011103
# In the C++ WOutFileContents this is called poloidal_flux.
1102-
chi: jt.Float[np.ndarray, "n_surfaces"]
1104+
# Defaulted for backwards compatibility with old wout files.
1105+
chi: jt.Float[np.ndarray, "n_surfaces"] = np.array([])
11031106
r"""Enclosed poloidal magnetic flux :math:`\chi` on the full-grid."""
11041107

11051108
# In the C++ WOutFileContents this is called spectral_width.
@@ -1417,45 +1420,15 @@ def _from_cpp_wout(cpp_wout: _vmecpp.VmecppWOut) -> VmecWOut:
14171420
attrs["gmnc"] = _pad_and_transpose(cpp_wout.gmnc, attrs["mnmax_nyq"])
14181421

14191422
# These attributes have zero-padding at the end up to a fixed length
1420-
attrs["am"] = np.pad(cpp_wout.am, (0, preset - len(cpp_wout.am)))
1421-
attrs["ac"] = np.pad(cpp_wout.ac, (0, preset - len(cpp_wout.ac)))
1422-
attrs["ai"] = np.pad(cpp_wout.ai, (0, preset - len(cpp_wout.ai)))
1423-
attrs["am_aux_s"] = np.pad(
1424-
cpp_wout.am_aux_s,
1425-
(0, ndfmax - len(cpp_wout.am_aux_s)),
1426-
mode="constant",
1427-
constant_values=-1.0,
1428-
)
1429-
attrs["am_aux_f"] = np.pad(
1430-
cpp_wout.am_aux_f,
1431-
(0, ndfmax - len(cpp_wout.am_aux_f)),
1432-
mode="constant",
1433-
constant_values=0.0,
1434-
)
1435-
attrs["ac_aux_s"] = np.pad(
1436-
cpp_wout.ac_aux_s,
1437-
(0, ndfmax - len(cpp_wout.ac_aux_s)),
1438-
mode="constant",
1439-
constant_values=-1.0,
1440-
)
1441-
attrs["ac_aux_f"] = np.pad(
1442-
cpp_wout.ac_aux_f,
1443-
(0, ndfmax - len(cpp_wout.ac_aux_f)),
1444-
mode="constant",
1445-
constant_values=0.0,
1446-
)
1447-
attrs["ai_aux_s"] = np.pad(
1448-
cpp_wout.ai_aux_s,
1449-
(0, ndfmax - len(cpp_wout.ai_aux_s)),
1450-
mode="constant",
1451-
constant_values=-1.0,
1452-
)
1453-
attrs["ai_aux_f"] = np.pad(
1454-
cpp_wout.ai_aux_f,
1455-
(0, ndfmax - len(cpp_wout.ai_aux_f)),
1456-
mode="constant",
1457-
constant_values=0.0,
1458-
)
1423+
attrs["am"] = _util.right_pad(cpp_wout.am, preset)
1424+
attrs["ac"] = _util.right_pad(cpp_wout.ac, preset)
1425+
attrs["ai"] = _util.right_pad(cpp_wout.ai, preset)
1426+
attrs["am_aux_s"] = _util.right_pad(cpp_wout.am_aux_s, ndfmax, -1.0)
1427+
attrs["am_aux_f"] = _util.right_pad(cpp_wout.am_aux_f, ndfmax)
1428+
attrs["ac_aux_s"] = _util.right_pad(cpp_wout.ac_aux_s, ndfmax, -1.0)
1429+
attrs["ac_aux_f"] = _util.right_pad(cpp_wout.ac_aux_f, ndfmax)
1430+
attrs["ai_aux_s"] = _util.right_pad(cpp_wout.ai_aux_s, ndfmax, -1.0)
1431+
attrs["ai_aux_f"] = _util.right_pad(cpp_wout.ai_aux_f, ndfmax)
14591432

14601433
attrs["restart_reason_timetrace"] = cpp_wout.restart_reasons
14611434

@@ -2020,6 +1993,46 @@ def _pad_and_transpose(
20201993
return stacked
20211994

20221995

1996+
def populate_raw_profile(
1997+
vmec_input: VmecInput,
1998+
field: typing.Literal["pressure", "iota", "current"],
1999+
f: typing.Callable[[np.ndarray], np.ndarray],
2000+
) -> None:
2001+
"""Populate a line segment profile using callable ``f``.
2002+
2003+
The callable is evaluated on all unique ``s`` values required for the
2004+
multi-grid steps (full and half grids). The resulting knots and values are
2005+
stored in the auxiliary arrays for the chosen profile.
2006+
"""
2007+
s_values: set[float] = set()
2008+
for ns in vmec_input.ns_array:
2009+
full_grid = np.linspace(0.0, 1.0, ns)
2010+
half_grid = full_grid - 0.5 * (full_grid[1] - full_grid[0])
2011+
s_values.update(full_grid)
2012+
s_values.update(half_grid)
2013+
knots = np.array(np.sort(np.array(list(s_values))))
2014+
values = np.array(f(knots))
2015+
2016+
if field == "pressure":
2017+
vmec_input.pmass_type = "line_segment"
2018+
vmec_input.am_aux_s = knots
2019+
vmec_input.am_aux_f = values
2020+
vmec_input.am = np.array([])
2021+
elif field == "iota":
2022+
vmec_input.piota_type = "line_segment"
2023+
vmec_input.ai_aux_s = knots
2024+
vmec_input.ai_aux_f = values
2025+
vmec_input.ai = np.array([])
2026+
elif field == "current":
2027+
vmec_input.pcurr_type = "line_segment_i"
2028+
vmec_input.ac_aux_s = knots
2029+
vmec_input.ac_aux_f = values
2030+
vmec_input.ac = np.array([])
2031+
else:
2032+
msg = "field must be one of 'pressure', 'iota', 'current'"
2033+
raise ValueError(msg)
2034+
2035+
20232036
# Ordered this way to ensure run, VmecInput, and VmecOutput are the first three
20242037
# items in the generated documentation.
20252038
__all__ = [
@@ -2032,4 +2045,5 @@ def _pad_and_transpose(
20322045
"Threed1Volumetrics",
20332046
"MakegridParameters",
20342047
"MagneticFieldResponseTable",
2048+
"populate_raw_profile",
20352049
]

src/vmecpp/_util.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,9 @@ def sparse_to_dense_coefficients_implicit(
414414
return sparse_to_dense_coefficients(maybe_sparse_list, mpol, ntor)
415415

416416

417-
def pad_to_target(value, target_length: int, default_value: float):
417+
def pad_to_target(
418+
value: np.ndarray, target_length: int, default_value: float
419+
) -> np.ndarray:
418420
if len(value) <= target_length:
419421
return np.pad(
420422
value,
@@ -427,3 +429,13 @@ def pad_to_target(value, target_length: int, default_value: float):
427429
f"length {target_length} allowed for serialization"
428430
)
429431
raise ValueError(msg)
432+
433+
434+
def right_pad(
435+
arr: np.ndarray, target_length: int, default_value: float = 0.0
436+
) -> np.ndarray:
437+
"""Right-pad an array with zeros to a given length.
438+
439+
If the array is longer than the target length, leave it unchanged.
440+
"""
441+
return pad_to_target(arr, max(len(arr), target_length), default_value)

src/vmecpp/cpp/vmecpp/vmec/radial_profiles/radial_profiles.cc

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -653,24 +653,74 @@ double RadialProfiles::evalRational(const std::vector<double>& coeffs,
653653
return 0.0;
654654
}
655655

656+
// Linear interpolation between closest points and associated knots.
657+
// Clamp the profile if x is outside the range of knots.
656658
double RadialProfiles::evalLineSegment(const std::vector<double>& splineKnots,
657659
const std::vector<double>& splineValues,
658660
double x) {
659-
// TODO(jons): implement `line_segment`
660-
(void)splineKnots;
661-
(void)splineValues;
662-
(void)x;
663-
return 0.0;
661+
const int n = static_cast<int>(splineKnots.size());
662+
if (n < 2 || n != static_cast<int>(splineValues.size())) {
663+
return 0.0;
664+
}
665+
666+
auto it = std::lower_bound(splineKnots.begin(), splineKnots.end(), x);
667+
if (it == splineKnots.end()) {
668+
return splineKnots.back(); // x is out of bounds
669+
}
670+
if (it == splineKnots.begin()) {
671+
return splineValues[0]; // x is below the first knot
672+
}
673+
const double x0 = *it;
674+
const double x1 = *(it + 1);
675+
int ilow = static_cast<int>(std::distance(splineKnots.begin(), it));
676+
const double y0 = splineValues[ilow];
677+
const double y1 = splineValues[ilow + 1];
678+
const double t = (x - x0) / (x1 - x0);
679+
return (1.0 - t) * y0 + t * y1;
664680
}
665681

666682
double RadialProfiles::evalLineSegmentIntegrated(
667683
const std::vector<double>& splineKnots,
668684
const std::vector<double>& splineValues, double x) {
669-
// TODO(jons): implement `line_segment_i`
670-
(void)splineKnots;
671-
(void)splineValues;
672-
(void)x;
673-
return 0.0;
685+
const int n = static_cast<int>(splineKnots.size());
686+
if (n < 2 || n != static_cast<int>(splineValues.size())) {
687+
return 0.0;
688+
}
689+
690+
auto integrate_segment = [](double x0, double x1, double y0, double y1) {
691+
const double m = (y1 - y0) / (x1 - x0);
692+
const double b = y0 - m * x0;
693+
return m * 0.5 * (x1 * x1 - x0 * x0) + b * (x1 - x0);
694+
};
695+
696+
double xi = x;
697+
double result = 0.0;
698+
699+
if (xi <= splineKnots.front()) {
700+
result += integrate_segment(0.0, xi, splineValues[0],
701+
evalLineSegment(splineKnots, splineValues, xi));
702+
return result;
703+
}
704+
705+
int idx = 0;
706+
while (idx < n - 1 && xi > splineKnots[idx + 1]) {
707+
result += integrate_segment(splineKnots[idx], splineKnots[idx + 1],
708+
splineValues[idx], splineValues[idx + 1]);
709+
++idx;
710+
}
711+
712+
const double x0 = splineKnots[idx];
713+
const double x1 = std::min(xi, splineKnots[idx + 1]);
714+
const double y0 = splineValues[idx];
715+
const double y1 = splineValues[idx + 1];
716+
result += integrate_segment(x0, x1, y0, y1);
717+
718+
if (xi > splineKnots.back()) {
719+
result += integrate_segment(splineKnots.back(), xi, splineValues[n - 1],
720+
evalLineSegment(splineKnots, splineValues, xi));
721+
}
722+
723+
return result;
674724
}
675725

676726
double RadialProfiles::evalNiceQuadratic(const std::vector<double>& coeffs,

tests/test_init.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,28 @@ def assert_aux_defaults(wout: vmecpp.VmecWOut):
582582
np.testing.assert_almost_equal(wout.am_aux_f[:2], np.array([2.0, 3.0]))
583583

584584

585+
def test_populate_raw_profile_knots():
586+
vmec_input = vmecpp.VmecInput.default()
587+
vmec_input.ns_array = np.array([5, 9])
588+
589+
def f(s):
590+
return s**2
591+
592+
vmecpp.populate_raw_profile(vmec_input, "pressure", f)
593+
594+
s_values = set()
595+
for ns in vmec_input.ns_array:
596+
delta = 1.0 / float(ns - 1)
597+
s_values.update(i * delta for i in range(ns))
598+
s_values.update((i - 0.5) * delta for i in range(ns))
599+
expected_knots = np.array(sorted(s_values), dtype=float)
600+
601+
n = len(expected_knots)
602+
np.testing.assert_allclose(vmec_input.am_aux_s[:n], expected_knots)
603+
np.testing.assert_allclose(vmec_input.am_aux_f[:n], expected_knots**2)
604+
assert vmec_input.pmass_type == "line_segment"
605+
606+
585607
def test_default_preset():
586608
# Default construction doesn't throw an exception
587609
default_preset = vmecpp.VmecInput.default()

0 commit comments

Comments
 (0)