Skip to content

Commit dec141c

Browse files
committed
Make a movie of the convergence of a VMEC++ run for W7-X.
1 parent adef7d1 commit dec141c

3 files changed

Lines changed: 499 additions & 0 deletions

File tree

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-FileCopyrightText: 2024-present Proxima Fusion GmbH <info@proximafusion.com>
2+
#
3+
# SPDX-License-Identifier: MIT
4+
"""Run VMEC++ via the Python API and take snapshots along the run."""
5+
6+
from pathlib import Path
7+
8+
import numpy as np
9+
10+
import vmecpp
11+
12+
# output folder for intermediate state files of VMEC++
13+
cache_folder = Path("/home/jons/results/vmec_w7x/movie_cache")
14+
Path.mkdir(cache_folder, parents=True, exist_ok=True)
15+
16+
input_file = "examples/data/w7x_generic_initial_guess.json"
17+
input = vmecpp.VmecInput.from_file(input_file)
18+
19+
# adjust as needed - we don't vendor the mgrid file, since it is too large
20+
input.mgrid_file = "/home/jons/results/vmec_w7x/mgrid_w7x.nc"
21+
22+
# optional: higher-res for nicer plots
23+
# input.mgrid_file = "/home/jons/results/vmec_w7x/mgrid_w7x_nv72.nc"
24+
# input.ntheta = 100
25+
# input.nzeta = 72
26+
27+
input.return_outputs_even_if_not_converged = True
28+
29+
maximum_iterations = 20000
30+
31+
# number of iterations between saving
32+
# step = 100
33+
step = 10
34+
35+
verbose = False
36+
max_threads = 6
37+
38+
saved_steps = []
39+
40+
currently_allowed_num_iterations = 1
41+
while currently_allowed_num_iterations < maximum_iterations:
42+
# only run up to given limit of number of iterations
43+
input.niter_array[0] = currently_allowed_num_iterations
44+
45+
cpp_indata = input._to_cpp_vmecindatapywrapper()
46+
# start all over again, because flow control flags are not saved (yet) for restarting
47+
output = vmecpp._vmecpp.run(
48+
cpp_indata,
49+
max_threads=max_threads,
50+
verbose=verbose,
51+
)
52+
53+
# print convergence progress
54+
print(
55+
"% 5d | % .3e | % .3e | % .3e"
56+
% (
57+
currently_allowed_num_iterations,
58+
output.wout.fsqr,
59+
output.wout.fsqz,
60+
output.wout.fsql,
61+
)
62+
)
63+
64+
# save outputs for later plotting
65+
output.save(cache_folder / f"vmecpp_w7x_{currently_allowed_num_iterations:04d}.h5")
66+
saved_steps.append(currently_allowed_num_iterations)
67+
68+
# early exis this loop when VMEC is converged
69+
if (
70+
output is not None
71+
and output.wout.fsqr < input.ftol_array[0]
72+
and output.wout.fsqz < input.ftol_array[0]
73+
and output.wout.fsql < input.ftol_array[0]
74+
):
75+
print("converged after", output.wout.maximum_iterations, "iterations")
76+
break
77+
78+
currently_allowed_num_iterations += step
79+
80+
np.savetxt(cache_folder / "saved_steps.dat", saved_steps, fmt="%d")
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# SPDX-FileCopyrightText: 2024-present Proxima Fusion GmbH <info@proximafusion.com>
2+
#
3+
# SPDX-License-Identifier: MIT
4+
"""Plot snapshots of VMEC++ taken along the run.
5+
6+
Needs the outputs from `convergence_movie_make_runs.py`.
7+
Call me as follows (using GNU parallel):
8+
`parallel python examples/plot_torus_vtk.py {} ::: path/to/vmecpp_w7x_*.h5`
9+
"""
10+
11+
import sys
12+
from pathlib import Path
13+
14+
import matplotlib.pyplot as plt
15+
import numpy as np
16+
import vtk
17+
18+
import vmecpp
19+
20+
if len(sys.argv) < 2:
21+
print(f"usage: {sys.argv[0]} vmecpp_out.h5")
22+
23+
vmecpp_out_filename = Path(sys.argv[1])
24+
if not Path.exists(vmecpp_out_filename):
25+
raise RuntimeError(
26+
"VMEC++ output file "
27+
+ str(vmecpp_out_filename)
28+
+ " does not exist. Run convergence_movie_make_runs.py to generate it."
29+
)
30+
31+
oq = vmecpp._vmecpp.OutputQuantities.load(vmecpp_out_filename)
32+
33+
ns = oq.wout.ns
34+
nfp = oq.wout.nfp
35+
36+
ntheta1 = 2 * (oq.indata.ntheta // 2)
37+
ntheta3 = ntheta1 // 2 + 1
38+
nzeta = oq.indata.nzeta
39+
40+
jxb_gradp = np.reshape(oq.jxbout.jxb_gradp, [ns, nzeta, ntheta3])
41+
42+
# extend to full poloidal range
43+
jxb_gradp_full = np.zeros([ns, nfp * nzeta, ntheta1])
44+
jxb_gradp_full[:, :nzeta, :ntheta3] = jxb_gradp
45+
jxb_gradp_full[:, :nzeta, ntheta3:] = np.roll(
46+
jxb_gradp[:, :, 1:-1][:, ::-1, ::-1], shift=1, axis=1
47+
)
48+
49+
# extend to full toroidal range
50+
for kp in range(1, nfp):
51+
jxb_gradp_full[:, kp * nzeta : (kp + 1) * nzeta, :] = jxb_gradp_full[:, :nzeta, :]
52+
53+
54+
def create_vtk_lut_from_matplotlib(cmap_name="jet", num_colors=256):
55+
"""Create a vtkLookupTable by sampling a Matplotlib colormap.
56+
57+
:param cmap_name: Name of a Matplotlib colormap (e.g., 'jet', 'viridis', etc.).
58+
:param num_colors: Number of discrete samples in the lookup table.
59+
:return: A vtkLookupTable filled with RGBA entries from the chosen colormap.
60+
"""
61+
# Create an empty lookup table in VTK
62+
lut = vtk.vtkLookupTable()
63+
lut.SetNumberOfTableValues(num_colors)
64+
lut.Build()
65+
66+
# Get the specified colormap from Matplotlib
67+
cmap = plt.get_cmap(cmap_name, num_colors)
68+
69+
# Fill the VTK lookup table by sampling the Matplotlib colormap
70+
for i in range(num_colors):
71+
fraction = i / (num_colors - 1)
72+
r, g, b, a = cmap(fraction)
73+
lut.SetTableValue(i, r, g, b, a)
74+
75+
return lut
76+
77+
78+
# Create a jet-like lookup table from Matplotlib
79+
lut = create_vtk_lut_from_matplotlib("jet", num_colors=256)
80+
81+
# Create the main objects for VTK
82+
renderer = vtk.vtkRenderer()
83+
render_window = vtk.vtkRenderWindow()
84+
render_window.SetOffScreenRendering(True)
85+
render_window.AddRenderer(renderer)
86+
render_window.SetAlphaBitPlanes(True)
87+
88+
# how many toroidal grid indices to "pull back" the flux surfaces for each layer
89+
delta_k = 6
90+
91+
# flux surfaces to render; adjusted for ns=51
92+
all_js = [1, 2, 3, 4, 6, 8, 10, 12, 14, 17, 20, 23, 27, 31, 35, 39, 44, 49]
93+
94+
for i, js in enumerate(all_js):
95+
num_toroidal = nfp * nzeta - i * delta_k
96+
97+
# Arrays to hold coordinates and scalars
98+
points = vtk.vtkPoints()
99+
scalars = vtk.vtkFloatArray()
100+
101+
# Build the torus in a parametric grid:
102+
# theta in [0, 2 pi), phi in [0, 2 pi)
103+
# We'll use modulo for wrap-around.
104+
for idx_theta in range(ntheta1):
105+
theta = 2.0 * np.pi * idx_theta / ntheta1
106+
107+
for idx_phi in range(min(num_toroidal + 1, nfp * nzeta)):
108+
phi = 2.0 * np.pi * idx_phi / (nfp * nzeta)
109+
110+
kernel = oq.wout.xm * theta - oq.wout.xn * phi
111+
112+
r = np.dot(oq.wout.rmnc[js, :], np.cos(kernel))
113+
x = r * np.cos(phi)
114+
y = r * np.sin(phi)
115+
z = np.dot(oq.wout.zmns[js, :], np.sin(kernel))
116+
117+
# Insert the point
118+
points.InsertNextPoint(x, z, y)
119+
120+
# Define the scalar field: MHD force residual
121+
scalar_value = jxb_gradp_full[js, idx_phi, idx_theta]
122+
scalars.InsertNextValue(scalar_value)
123+
124+
# Create a vtkPolyData to store the geometry
125+
poly_data = vtk.vtkPolyData()
126+
poly_data.SetPoints(points)
127+
128+
# We need to define connectivity (which points form each polygon).
129+
# We'll create a mesh of quadrilaterals, each made of two triangles.
130+
# Let idx(i,j) = i*n + j in a 1D index. We'll wrap around with modulo.
131+
cells = vtk.vtkCellArray()
132+
133+
def idx(idx_theta, idx_phi, num_toroidal):
134+
return idx_theta * min(num_toroidal + 1, nfp * nzeta) + idx_phi
135+
136+
for idx_theta in range(ntheta1):
137+
idx_theta_1 = (idx_theta + 1) % ntheta1
138+
for idx_phi in range(num_toroidal):
139+
idx_phi_1 = (idx_phi + 1) % (nfp * nzeta)
140+
141+
# We can form two triangles, or a single quad cell.
142+
# Here we'll make one quad: (i,j), (i+1,j), (i+1,j+1), (i,j+1)
143+
quad = vtk.vtkQuad()
144+
quad.GetPointIds().SetId(0, idx(idx_theta, idx_phi, num_toroidal))
145+
quad.GetPointIds().SetId(1, idx(idx_theta_1, idx_phi, num_toroidal))
146+
quad.GetPointIds().SetId(2, idx(idx_theta_1, idx_phi_1, num_toroidal))
147+
quad.GetPointIds().SetId(3, idx(idx_theta, idx_phi_1, num_toroidal))
148+
cells.InsertNextCell(quad)
149+
150+
poly_data.SetPolys(cells)
151+
152+
# Attach the scalars to the polydata
153+
poly_data.GetPointData().SetScalars(scalars)
154+
155+
# Create a lookup table so we can color the range of scalars
156+
scalar_min = scalars.GetRange()[0]
157+
scalar_max = scalars.GetRange()[1]
158+
159+
# Symmetrize colorbar range
160+
val_max = max(abs(scalar_min), abs(scalar_max))
161+
scalar_min = -val_max
162+
scalar_max = val_max
163+
164+
# Create a mapper for the polydata
165+
mapper = vtk.vtkPolyDataMapper()
166+
mapper.SetInputData(poly_data)
167+
mapper.SetScalarModeToUsePointData()
168+
169+
# The range in the data we want to map. We'll just fake a scalar range here;
170+
# in a real case, you'd have scalar data from the geometry or from a separate array.
171+
# We'll forcibly set a "ScalarRange" to demonstrate usage.
172+
# If you actually have point or cell scalars, set them on the data
173+
# and let the mapper pick it up.
174+
mapper.SetLookupTable(lut)
175+
mapper.SetColorModeToMapScalars()
176+
mapper.SetScalarRange(scalar_min, scalar_max)
177+
178+
# Create an actor using this mapper
179+
actor = vtk.vtkActor()
180+
actor.SetMapper(mapper)
181+
182+
# Add the actor to the renderer
183+
renderer.AddActor(actor)
184+
185+
# Add a scalar bar to show the color mapping
186+
scalar_bar = vtk.vtkScalarBarActor()
187+
scalar_bar.SetTitle("MHD Force Residual")
188+
scalar_bar.SetLookupTable(lut)
189+
scalar_bar.SetNumberOfLabels(5)
190+
scalar_bar.SetVerticalTitleSeparation(20)
191+
192+
# Place the scalar bar as a 2D overlay (no widget/interactor needed)
193+
# By default, it should appear on the right side of the image
194+
renderer.AddActor2D(scalar_bar)
195+
196+
# ------------------------------------------------------------------------------
197+
# Customize Font, Size, and Text Color
198+
# ------------------------------------------------------------------------------
199+
200+
# Access the text properties for title and labels separately
201+
title_text_prop = scalar_bar.GetTitleTextProperty()
202+
label_text_prop = scalar_bar.GetLabelTextProperty()
203+
204+
# Change font family (options include SetFontFamilyToArial, SetFontFamilyToTimes, etc.)
205+
title_text_prop.SetFontFamilyToArial()
206+
label_text_prop.SetFontFamilyToArial()
207+
208+
# Change font sizes (the actual rendered size also depends on the overall image size)
209+
title_text_prop.SetFontSize(24)
210+
label_text_prop.SetFontSize(16)
211+
212+
# Change text color to black (0,0,0)
213+
title_text_prop.SetColor(0, 0, 0)
214+
label_text_prop.SetColor(0, 0, 0)
215+
216+
# If you want to make the title bold or italic, you can do:
217+
title_text_prop.SetBold(False)
218+
title_text_prop.SetItalic(False)
219+
220+
label_text_prop.SetBold(False)
221+
label_text_prop.SetItalic(False)
222+
223+
# Adjust the bar ratio, width, or position if you need to
224+
scalar_bar.SetBarRatio(0.2)
225+
scalar_bar.SetWidth(0.1)
226+
scalar_bar.SetHeight(0.6)
227+
scalar_bar.SetPosition(0.82, 0.2)
228+
229+
# Adjust background color: transparent white
230+
renderer.SetBackground(1.0, 1.0, 1.0) # white
231+
renderer.SetBackgroundAlpha(0.0) # transparent
232+
233+
# Make sure we see our torus nicely
234+
renderer.ResetCamera()
235+
236+
# -------------------------------------------------------------------
237+
# Adjust the camera's elevation and azimuth
238+
# -------------------------------------------------------------------
239+
240+
camera = renderer.GetActiveCamera()
241+
242+
# Elevation is the angle above/below the view plane
243+
camera.Elevation(20.0) # tilt up by 20 degrees
244+
245+
# Azimuth is rotation around the scene (viewing down from above)
246+
camera.Azimuth(130.0) # rotate camera by 130 degrees around the focal point
247+
248+
# Dolly the camera to zoom in (factor > 1 => zoom in; factor < 1 => zoom out)
249+
camera.Dolly(1.5) # e.g., 1.5 means 50% closer to the focal point
250+
251+
# Update the camera's clipping range (otherwise geometry can get clipped)
252+
renderer.ResetCameraClippingRange()
253+
254+
# ------------------------------------------------------------------------------
255+
# Remove default lights and add a new light at the camera position
256+
# ------------------------------------------------------------------------------
257+
258+
# By default, VTK automatically creates lights. We turn that off:
259+
renderer.AutomaticLightCreationOff()
260+
renderer.RemoveAllLights()
261+
262+
# Create a new light
263+
light = vtk.vtkLight()
264+
light.SetLightTypeToSceneLight()
265+
266+
# Position the light where the camera is
267+
light.SetPosition(camera.GetPosition())
268+
269+
# Orient the light toward the camera's focal point
270+
light.SetFocalPoint(camera.GetFocalPoint())
271+
272+
# Optionally adjust light properties: color, intensity, etc.
273+
light.SetColor(1.0, 1.0, 1.0) # white light
274+
light.SetIntensity(1.0) # brightness factor (1.0 = default)
275+
renderer.AddLight(light)
276+
277+
# ------------------------------------------------------------------------------
278+
# Render off-screen and save to an image file
279+
# ------------------------------------------------------------------------------
280+
281+
render_window.SetSize(1920, 1080)
282+
render_window.Render()
283+
284+
# Capture RGBA from the render window
285+
window_to_image = vtk.vtkWindowToImageFilter()
286+
window_to_image.SetInput(render_window)
287+
288+
# Request an RGBA buffer (with alpha)
289+
window_to_image.SetInputBufferTypeToRGBA()
290+
291+
# Make sure we read the back buffer (without any prior composited front buffer)
292+
window_to_image.ReadFrontBufferOff()
293+
window_to_image.Update()
294+
295+
# Write the image to a file
296+
writer = vtk.vtkPNGWriter()
297+
writer.SetFileName(vmecpp_out_filename.with_suffix(".png"))
298+
writer.SetInputConnection(window_to_image.GetOutputPort())
299+
300+
# This tells the PNG writer to preserve the alpha channel
301+
if hasattr(writer, "SetWriteAlphaChannel"):
302+
print("using alpha channel for transparency")
303+
writer.SetWriteAlphaChannel(True)
304+
305+
# Actually write the file.
306+
writer.Write()

0 commit comments

Comments
 (0)