diff --git a/PyPIC3D/__main__.py b/PyPIC3D/__main__.py index 775213c..e51bd44 100644 --- a/PyPIC3D/__main__.py +++ b/PyPIC3D/__main__.py @@ -9,24 +9,13 @@ import time import jax from jax import block_until_ready +from jax import lax import jax.numpy as jnp from tqdm import tqdm #from memory_profiler import profile # Importing relevant libraries -from PyPIC3D.diagnostics.plotting import ( - write_particles_phase_space, write_data -) - -from PyPIC3D.diagnostics.openPMD import ( - write_openpmd_particles, write_openpmd_fields -) - -from PyPIC3D.diagnostics.vtk import ( - plot_field_slice_vtk, plot_vectorfield_slice_vtk, plot_vtk_particles -) - from PyPIC3D.utils import ( dump_parameters_to_toml, load_config_file, compute_energy, setup_pmd_files @@ -36,13 +25,6 @@ initialize_simulation ) -from PyPIC3D.diagnostics.fluid_quantities import ( - compute_mass_density -) - -from PyPIC3D.rho import compute_rho - - # Importing functions from the PyPIC3D package ############################################################################################################ @@ -53,7 +35,9 @@ def run_PyPIC3D(config_file): loop, particles, fields, world, simulation_parameters, constants, plotting_parameters, plasma_parameters, solver, electrostatic, verbose, GPUs, Nt, curl_func, J_func, relativistic = initialize_simulation(config_file) # initialize the simulation - jit_loop = jax.jit(loop, static_argnames=('curl_func', 'J_func', 'solver', 'relativistic')) + # `loop` is already jitted in `PyPIC3D.evolve`; avoid double-jitting. + jit_loop = loop + step_impl = getattr(loop, "__wrapped__", None) dt = world['dt'] output_dir = simulation_parameters['output_dir'] @@ -69,109 +53,163 @@ def run_PyPIC3D(config_file): # Compute the energy of the system initial_energy = e_energy + b_energy + kinetic_energy - if plotting_parameters['plot_openpmd_fields']: setup_pmd_files( os.path.join(output_dir, "data"), "fields", ".h5") - if plotting_parameters['plot_openpmd_particles']: setup_pmd_files( os.path.join(output_dir, "data"), "particles", ".h5") + if plotting_parameters.get('plot_openpmd_fields', False) or plotting_parameters.get('plot_openpmd_particles', False): + setup_pmd_files(os.path.join(output_dir, "data"), "fields", ".h5") + setup_pmd_files(os.path.join(output_dir, "data"), "particles", ".h5") # setup the openPMD files if needed ############################################################################################################ ###################################################### SIMULATION LOOP ##################################### - - for t in tqdm(range(Nt)): - - # plot the data - if t % plotting_parameters['plotting_interval'] == 0: - - plot_num = t // plotting_parameters['plotting_interval'] - # determine the plot number - - E, B, J, rho, *rest = fields - # unpack the fields - - e_energy, b_energy, kinetic_energy = compute_energy(particles, E, B, world, constants) - # Compute the energy of the system - write_data(f"{output_dir}/data/total_energy.txt", t * dt, e_energy + b_energy + kinetic_energy) - write_data(f"{output_dir}/data/energy_error.txt", t * dt, abs( initial_energy - (e_energy + b_energy + kinetic_energy)) / max(initial_energy, 1e-10)) - write_data(f"{output_dir}/data/electric_field_energy.txt", t * dt, e_energy) - write_data(f"{output_dir}/data/magnetic_field_energy.txt", t * dt, b_energy) - write_data(f"{output_dir}/data/kinetic_energy.txt", t * dt, kinetic_energy) - # Write the total energy to a file - total_momentum = sum(particle_species.momentum() for particle_species in particles) - # Total momentum of the particles - write_data(f"{output_dir}/data/total_momentum.txt", t * dt, total_momentum) - # Write the total momentum to a file - - # for species in particles: - # write_data(f"{output_dir}/data/{species.name}_kinetic_energy.txt", t * dt, species.kinetic_energy()) - - - if plotting_parameters['plot_phasespace']: - write_particles_phase_space(particles, t, output_dir) - - - - if plotting_parameters['plot_vtk_scalars']: - rho = compute_rho(particles, rho, world, constants) - # calculate the charge density based on the particle positions - mass_density = compute_mass_density(particles, rho, world) - # calculate the mass density based on the particle positions - - fields_mag = [rho[:,world['Ny']//2,:], mass_density[:,world['Ny']//2,:]] - plot_field_slice_vtk(fields_mag, scalar_field_names, 1, vertex_grid, t, "scalar_field", output_dir, world) - # Plot the scalar fields in VTK format - - - if plotting_parameters['plot_vtk_vectors']: - vector_field_slices = [ [E[0][:,world['Ny']//2,:], E[1][:,world['Ny']//2,:], E[2][:,world['Ny']//2,:]], - [B[0][:,world['Ny']//2,:], B[1][:,world['Ny']//2,:], B[2][:,world['Ny']//2,:]], - [J[0][:,world['Ny']//2,:], J[1][:,world['Ny']//2,:], J[2][:,world['Ny']//2,:]]] - plot_vectorfield_slice_vtk(vector_field_slices, vector_field_names, 1, vertex_grid, t, 'vector_field', output_dir, world) - # Plot the vector fields in VTK format - - if plotting_parameters['plot_vtk_particles']: - plot_vtk_particles(particles, plot_num, output_dir) - # Plot the particles in VTK format - - if plotting_parameters['plot_openpmd_particles']: - write_openpmd_particles(particles, world, constants, os.path.join(output_dir, "data"), plot_num, t, "particles", ".h5") - # Write the particles in openPMD format - - if plotting_parameters['plot_openpmd_fields']: - write_openpmd_fields(fields, world, os.path.join(output_dir, "data"), plot_num, t, "fields", ".h5") - # Write the fields in openPMD format - - fields = (E, B, J, rho, *rest) - # repack the fields - - particles, fields = jit_loop( - particles, - fields, - world, - constants, - curl_func, - J_func, - solver, - relativistic=relativistic, - ) - # time loop to update the particles and fields + do_plotting = bool(plotting_parameters.get("plotting", True)) + plotting_interval = int(plotting_parameters.get("plotting_interval", 0) or 0) + + def do_diagnostics(t, particles, fields): + if not (do_plotting and plotting_interval > 0 and (t % plotting_interval == 0)): + return fields + + from PyPIC3D.diagnostics.plotting import write_data, write_particles_phase_space + + plot_num = t // plotting_interval + # determine the plot number + + E, B, J, rho, *rest = fields + # unpack the fields + + e_energy, b_energy, kinetic_energy = compute_energy(particles, E, B, world, constants) + # Compute the energy of the system + write_data(f"{output_dir}/data/total_energy.txt", t * dt, e_energy + b_energy + kinetic_energy) + write_data(f"{output_dir}/data/energy_error.txt", t * dt, abs( initial_energy - (e_energy + b_energy + kinetic_energy)) / max(initial_energy, 1e-10)) + write_data(f"{output_dir}/data/electric_field_energy.txt", t * dt, e_energy) + write_data(f"{output_dir}/data/magnetic_field_energy.txt", t * dt, b_energy) + write_data(f"{output_dir}/data/kinetic_energy.txt", t * dt, kinetic_energy) + # Write the total energy to a file + total_momentum = sum(particle_species.momentum() for particle_species in particles) + # Total momentum of the particles + write_data(f"{output_dir}/data/total_momentum.txt", t * dt, total_momentum) + # Write the total momentum to a file + + if plotting_parameters['plot_phasespace']: + write_particles_phase_space(particles, t, output_dir) + + if plotting_parameters['plot_vtk_scalars']: + from PyPIC3D.diagnostics.vtk import plot_field_slice_vtk + from PyPIC3D.diagnostics.fluid_quantities import compute_mass_density + from PyPIC3D.rho import compute_rho + + rho = compute_rho(particles, rho, world, constants) + # calculate the charge density based on the particle positions + mass_density = compute_mass_density(particles, rho, world) + # calculate the mass density based on the particle positions + + fields_mag = [rho[:,world['Ny']//2,:], mass_density[:,world['Ny']//2,:]] + plot_field_slice_vtk(fields_mag, scalar_field_names, 1, vertex_grid, t, "scalar_field", output_dir, world) + # Plot the scalar fields in VTK format + + if plotting_parameters['plot_vtk_vectors']: + from PyPIC3D.diagnostics.vtk import plot_vectorfield_slice_vtk + vector_field_slices = [ [E[0][:,world['Ny']//2,:], E[1][:,world['Ny']//2,:], E[2][:,world['Ny']//2,:]], + [B[0][:,world['Ny']//2,:], B[1][:,world['Ny']//2,:], B[2][:,world['Ny']//2,:]], + [J[0][:,world['Ny']//2,:], J[1][:,world['Ny']//2,:], J[2][:,world['Ny']//2,:]]] + plot_vectorfield_slice_vtk(vector_field_slices, vector_field_names, 1, vertex_grid, t, 'vector_field', output_dir, world) + # Plot the vector fields in VTK format + + if plotting_parameters['plot_vtk_particles']: + from PyPIC3D.diagnostics.vtk import plot_vtk_particles + plot_vtk_particles(particles, plot_num, output_dir) + # Plot the particles in VTK format + + if plotting_parameters['plot_openpmd_particles']: + from PyPIC3D.diagnostics.openPMD import write_openpmd_particles + write_openpmd_particles(particles, world, constants, os.path.join(output_dir, "data"), plot_num, t, "particles", ".h5") + # Write the particles in openPMD format + + if plotting_parameters['plot_openpmd_fields']: + from PyPIC3D.diagnostics.openPMD import write_openpmd_fields + write_openpmd_fields(fields, world, os.path.join(output_dir, "data"), plot_num, t, "fields", ".h5") + # Write the fields in openPMD format + + return (E, B, J, rho, *rest) + + use_scan = bool(simulation_parameters.get("use_scan", False)) + scan_chunk = int(simulation_parameters.get("scan_chunk", 256) or 256) + + if use_scan: + def _scan_chunk(particles, fields, *, n_steps: int): + def body(carry, _): + p, f = carry + p, f = jit_loop( + p, + f, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + return (p, f), None + + (p, f), _ = lax.scan(body, (particles, fields), xs=None, length=n_steps) + return p, f + + scan_chunk_jit = jax.jit(_scan_chunk, donate_argnums=(0, 1), static_argnames=("n_steps",)) + + chunk = scan_chunk + if do_plotting and plotting_interval > 0 and (plotting_interval % scan_chunk): + chunk = plotting_interval + + pbar = tqdm(total=Nt) + t = 0 + while t < Nt: + fields = do_diagnostics(t, particles, fields) + remaining = Nt - t + n_steps = chunk if remaining >= chunk else remaining + if do_plotting and plotting_interval > 0: + n_steps = min(n_steps, plotting_interval - (t % plotting_interval)) + + particles, fields = scan_chunk_jit(particles, fields, n_steps=n_steps) + t += n_steps + pbar.update(n_steps) + + pbar.close() + + else: + for t in tqdm(range(Nt)): + + fields = do_diagnostics(t, particles, fields) + + particles, fields = jit_loop( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + # time loop to update the particles and fields return Nt, plotting_parameters, simulation_parameters, plasma_parameters, constants, particles, fields, world def main(): ###################### JAX SETTINGS ######################################################################## - jax.config.update("jax_enable_x64", True) - # set Jax to use 64 bit precision + toml_file = load_config_file() + # load the configuration file + + enable_x64 = bool(toml_file.get("simulation_parameters", {}).get("enable_x64", True)) + jax.config.update("jax_enable_x64", enable_x64) + # set Jax to use 64 bit precision (configurable via `simulation_parameters.enable_x64`) # jax.config.update("jax_debug_nans", True) # debugging for nans - jax.config.update('jax_platform_name', 'cpu') - # set Jax to use CPUs + platform = toml_file.get("simulation_parameters", {}).get("platform_name", "cpu") + jax.config.update("jax_platform_name", platform) + # set Jax platform via config (default: cpu) #jax.config.update("jax_disable_jit", True) ############################################################################################################ - toml_file = load_config_file() - # load the configuration file - start = time.time() # start the timer diff --git a/PyPIC3D/evolve.py b/PyPIC3D/evolve.py index 5cdacff..e0525b2 100644 --- a/PyPIC3D/evolve.py +++ b/PyPIC3D/evolve.py @@ -22,7 +22,7 @@ E_from_A, B_from_A, update_vector_potential ) -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_electrostatic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advances the simulation by one time step for an electrostatic Particle-In-Cell (PIC) loop. @@ -74,7 +74,7 @@ def time_loop_electrostatic(particles, fields, world, constants, curl_func, J_fu return particles, fields -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advance an electrodynamic Particle-In-Cell (PIC) system by one time step. @@ -154,7 +154,7 @@ def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_f return particles, fields -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_vector_potential(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advance a PIC (Particle-In-Cell) simulation by one time step using a diff --git a/PyPIC3D/initialization.py b/PyPIC3D/initialization.py index 665a38b..195cb20 100644 --- a/PyPIC3D/initialization.py +++ b/PyPIC3D/initialization.py @@ -5,7 +5,6 @@ import functools from functools import partial import toml -import matplotlib.pyplot as plt import jax.numpy as jnp #from memory_profiler import profile @@ -34,11 +33,6 @@ plot_initial_histograms ) -from PyPIC3D.diagnostics.openPMD import ( - write_openpmd_initial_particles, write_openpmd_initial_fields -) - - from PyPIC3D.evolve import ( time_loop_electrodynamic, time_loop_electrostatic, time_loop_vector_potential ) @@ -121,6 +115,9 @@ def default_parameters(): "shape_factor" : 1, # shape factor for the simulation (1 for 1st order, 2 for 2nd order) "current_calculation": "j_from_rhov", # current calculation method: esirkepov, villasenor_buneman, j_from_rhov "filter_j": "bilinear", # filter for the current density: bilinear, digital, none + "use_scan": False, # batch timesteps with lax.scan to reduce dispatch overhead + "scan_chunk": 256, # number of timesteps per compiled scan chunk + "platform_name": "cpu", # cpu|gpu (if supported by your JAX install) } # dictionary for simulation parameters @@ -248,12 +245,6 @@ def initialize_simulation(toml_file): } # set the simulation world parameters - world = convert_to_jax_compatible(world) - constants = convert_to_jax_compatible(constants) - simulation_parameters = convert_to_jax_compatible(simulation_parameters) - plotting_parameters = convert_to_jax_compatible(plotting_parameters) - # convert the world parameters to jax compatible format - # if solver == "vector_potential": # B_grid, E_grid = build_collocated_grid(world) # # build the grid for the fields @@ -275,12 +266,24 @@ def initialize_simulation(toml_file): particles = load_particles_from_toml(toml_file, simulation_parameters, world, constants) # load the particles from the configuration file - for species in particles: - name = species.get_name() - name = name.replace(" ", "_") - # replace spaces with underscores in the name - plot_initial_histograms(species, world, path=f"{simulation_parameters['output_dir']}/data", name=name) - # plot the initial histograms of the particles + # Convert `world` and `constants` to JAX-compatible PyTrees after particle creation, + # so particle metadata (dx/dt/domain size) stays as plain Python scalars and doesn't + # induce per-step device work when PyTrees are reconstructed. + grids = world.pop("grids", None) + world = convert_to_jax_compatible(world) + if grids is not None: + world["grids"] = grids + constants = convert_to_jax_compatible(constants) + + if plotting_parameters.get("plotting", True): + for species in particles: + name = species.get_name().replace(" ", "_") + plot_initial_histograms( + species, + world, + path=f"{simulation_parameters['output_dir']}/data", + name=name, + ) print_stats(world) # print the statistics of the simulation @@ -295,7 +298,13 @@ def initialize_simulation(toml_file): particle_sanity_check(particles) # ensure the arrays for the particles are of the correct shape - if plotting_parameters['dump_particles']: + if plotting_parameters.get('dump_particles', False): + try: + from PyPIC3D.diagnostics.openPMD import write_openpmd_initial_particles + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "openpmd-api is required for `dump_particles=true` diagnostics." + ) from e write_openpmd_initial_particles(particles, world, constants, simulation_parameters['output_dir']) # write the initial particles to an openPMD file @@ -307,7 +316,7 @@ def initialize_simulation(toml_file): # convert the E, B, and J tuples into one big list fields = load_external_fields_from_toml(fields, toml_file) # add any external fields to the simulation - E, B, J = fields[:3], fields[3:6], fields[6:9] + E, B, J = tuple(fields[:3]), tuple(fields[3:6]), tuple(fields[6:9]) # convert the fields list back into tuples if solver == "spectral": @@ -362,7 +371,13 @@ def initialize_simulation(toml_file): fields = (E, B, J, rho, phi) # define the fields tuple for the electrodynamic and electrostatic solvers - if plotting_parameters['dump_fields']: + if plotting_parameters.get('dump_fields', False): + try: + from PyPIC3D.diagnostics.openPMD import write_openpmd_initial_fields + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "openpmd-api is required for `dump_fields=true` diagnostics." + ) from e write_openpmd_initial_fields(fields, world, simulation_parameters['output_dir'], filename="initial_fields.h5") # write the initial fields to an openPMD file diff --git a/PyPIC3D/particle.py b/PyPIC3D/particle.py index 4655277..3343ce4 100644 --- a/PyPIC3D/particle.py +++ b/PyPIC3D/particle.py @@ -535,24 +535,26 @@ def __init__(self, name, N_particles, charge, mass, T, v1, v2, v3, x1, x2, x3, \ xwind, ywind, zwind, dx, dy, dz, weight=1, x_bc="periodic", y_bc="periodic", \ z_bc="periodic", update_x=True, update_y=True, update_z=True, \ update_vx=True, update_vy=True, update_vz=True, update_pos=True, update_v=True, shape=1, dt = 0): + # Keep particle metadata as plain Python scalars so PyTree aux_data does not + # trigger device work or extra jitted scalar kernels when reconstructed. self.name = name - self.N_particles = N_particles - self.charge = charge - self.mass = mass - self.weight = weight - self.T = T + self.N_particles = int(N_particles) + self.charge = float(charge) + self.mass = float(mass) + self.weight = float(weight) + self.T = float(T) self.v1 = v1 self.v2 = v2 self.v3 = v3 - self.dx = dx - self.dy = dy - self.dz = dz - self.x_wind = xwind - self.y_wind = ywind - self.z_wind = zwind - self.half_x_wind = 0.5 * xwind - self.half_y_wind = 0.5 * ywind - self.half_z_wind = 0.5 * zwind + self.dx = float(dx) + self.dy = float(dy) + self.dz = float(dz) + self.x_wind = float(xwind) + self.y_wind = float(ywind) + self.z_wind = float(zwind) + self.half_x_wind = 0.5 * self.x_wind + self.half_y_wind = 0.5 * self.y_wind + self.half_z_wind = 0.5 * self.z_wind self.x_bc = x_bc self.y_bc = y_bc self.z_bc = z_bc @@ -572,7 +574,7 @@ def __init__(self, name, N_particles, charge, mass, T, v1, v2, v3, x1, x2, x3, \ self.update_pos = update_pos self.update_v = update_v self.shape = shape - self.dt = dt + self.dt = float(dt) self.x1 = x1 self.x2 = x2 @@ -710,38 +712,41 @@ def tree_unflatten(cls, aux_data, children): update_vx, update_vy, update_vz, shape, dt = aux_data - obj = cls( - name=name, - N_particles=N_particles, - charge=charge, - mass=mass, - T=T, - x1=x1, - x2=x2, - x3=x3, - v1=v1, - v2=v2, - v3=v3, - xwind=x_wind, - ywind=y_wind, - zwind=z_wind, - dx=dx, - dy=dy, - dz=dz, - weight=weight, - x_bc=x_bc, - y_bc=y_bc, - z_bc=z_bc, - update_x=update_x, - update_y=update_y, - update_z=update_z, - update_vx=update_vx, - update_vy=update_vy, - update_vz=update_vz, - update_pos=update_pos, - update_v=update_v, - shape=shape, - dt=dt - ) + obj = cls.__new__(cls) + + obj.name = name + obj.N_particles = int(N_particles) + obj.charge = float(charge) + obj.mass = float(mass) + obj.weight = float(weight) + obj.T = float(T) + + obj.v1, obj.v2, obj.v3 = v1, v2, v3 + obj.x1, obj.x2, obj.x3 = x1, x2, x3 + + obj.dx = float(dx) + obj.dy = float(dy) + obj.dz = float(dz) + obj.x_wind = float(x_wind) + obj.y_wind = float(y_wind) + obj.z_wind = float(z_wind) + obj.half_x_wind = 0.5 * obj.x_wind + obj.half_y_wind = 0.5 * obj.y_wind + obj.half_z_wind = 0.5 * obj.z_wind + + obj.x_bc, obj.y_bc, obj.z_bc = x_bc, y_bc, z_bc + obj.x_periodic = x_bc == 'periodic' + obj.x_reflecting = x_bc == 'reflecting' + obj.y_periodic = y_bc == 'periodic' + obj.y_reflecting = y_bc == 'reflecting' + obj.z_periodic = z_bc == 'periodic' + obj.z_reflecting = z_bc == 'reflecting' + + obj.update_x, obj.update_y, obj.update_z = update_x, update_y, update_z + obj.update_vx, obj.update_vy, obj.update_vz = update_vx, update_vy, update_vz + obj.update_pos, obj.update_v = update_pos, update_v + + obj.shape = shape + obj.dt = float(dt) return obj