From cd081ab83608f39d034339b141cc3bedc7f8e539 Mon Sep 17 00:00:00 2001 From: Juehang Qin Date: Tue, 9 Dec 2025 21:48:08 -0600 Subject: [PATCH 01/10] Implement feature X to enhance user experience and optimize performance --- old_code/prototype_train_clean.ipynb | 1268 -------------------------- 1 file changed, 1268 deletions(-) delete mode 100644 old_code/prototype_train_clean.ipynb diff --git a/old_code/prototype_train_clean.ipynb b/old_code/prototype_train_clean.ipynb deleted file mode 100644 index 7503d5f..0000000 --- a/old_code/prototype_train_clean.ipynb +++ /dev/null @@ -1,1268 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "196a065c-5047-4095-8135-da88a976b28a", - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "import os\n", - "import pickle\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" #select GPU, -1 means use CPU\n", - "\n", - "import equinox as eqx\n", - "from flowjax.distributions import StandardNormal, Logistic, StudentT, AbstractDistribution, Normal\n", - "from flowjax.flows import masked_autoregressive_flow, block_neural_autoregressive_flow, triangular_spline_flow, coupling_flow\n", - "import h5py\n", - "import jax\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib import colormaps\n", - "from matplotlib.colors import LogNorm\n", - "import numpy as np\n", - "import pandas as pd\n", - "from scipy import stats\n", - "\n", - "from tqdm import trange, tqdm\n", - "import tqdm.utils as tutils\n", - "def ssl(x):\n", - " return 100, 200\n", - "tutils._screen_shape_linux = ssl\n", - "\n", - "from neural_net_defs import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1748bc84-858f-4bf3-9314-fd7865959fe9", - "metadata": {}, - "outputs": [], - "source": [ - "import gzip\n", - "import json\n", - "import optax\n", - "import diffrax\n", - "from jaxtyping import PyTree, Array\n", - "from copy import deepcopy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cbf130a0-a2d3-40e7-95ca-9e180d78d47a", - "metadata": {}, - "outputs": [], - "source": [ - "jax.devices()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8da2f577-ee56-48d9-9d46-84e0ce22357a", - "metadata": {}, - "outputs": [], - "source": [ - "tpc_r = 66.4" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "308d52f0-bad0-48f9-881a-49b4cc2e85f5", - "metadata": {}, - "outputs": [], - "source": [ - "data_obj = np.load('kr83_sr1_50runs_FDCtrain.npz')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "03b26d51-c1da-408a-bf79-762471504a20", - "metadata": {}, - "outputs": [], - "source": [ - "z = data_obj['z_corr']\n", - "conditions = data_obj['condition']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dcb6d249-bcf2-4bed-8e63-85fba118495a", - "metadata": {}, - "outputs": [], - "source": [ - "key = jax.random.PRNGKey(42)\n", - "key, flow_key = jax.random.split(key, 2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7b4759e-2069-4760-ae6e-193effaea4bc", - "metadata": {}, - "outputs": [], - "source": [ - "conditions.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18340d28-2e58-4824-84a6-690d5a628df4", - "metadata": {}, - "outputs": [], - "source": [ - "flow = coupling_flow(\n", - " flow_key, base_dist=StandardNormal((2,)), invert=False, flow_layers=flow_layers, nn_width=NN_width, nn_depth=NN_depth, nn_activation=activation, cond_dim=conditions.shape[1], transformer=bijection\n", - ")\n", - "flow = eqx.tree_deserialise_leaves(\"../flow_posrec/posrec_flow_uniform_100e_2to5e_PMTs_turned_off.eqx\", flow)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c679ddc9-2777-4cdf-9090-65651277ee7a", - "metadata": {}, - "outputs": [], - "source": [ - "compiled_sample = eqx.filter_jit(flow.sample)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6a83192a-19e1-40fd-b91c-fb4e584adecc", - "metadata": {}, - "outputs": [], - "source": [ - "def generate_samples(key, conditions, N_samples):\n", - " output = compiled_sample(key, (N_samples,), condition=conditions)\n", - " return data_inv_transformation(jnp.reshape(output, (-1,2)))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36d923c0-6aef-4df8-8a0f-38b2453ca936", - "metadata": {}, - "outputs": [], - "source": [ - "generate_samples(key, conditions[0:4], 2).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "75825c60-ef39-40f0-b96b-b276def9aafe", - "metadata": {}, - "outputs": [], - "source": [ - "class Func(eqx.Module):\n", - " layers: list[eqx.nn.Linear]\n", - "\n", - " def __init__(self, *, data_size, width_size, depth, key, **kwargs):\n", - " super().__init__(**kwargs)\n", - " keys = jax.random.split(key, depth + 1)\n", - " layers = []\n", - " if depth == 0:\n", - " layers.append(\n", - " ConcatSquash(in_size=data_size, out_size=data_size, key=keys[0])\n", - " )\n", - " else:\n", - " layers.append(\n", - " ConcatSquash(in_size=data_size, out_size=width_size, key=keys[0])\n", - " )\n", - " for i in range(depth - 1):\n", - " layers.append(\n", - " ConcatSquash(\n", - " in_size=width_size, out_size=width_size, key=keys[i + 1]\n", - " )\n", - " )\n", - " layers.append(\n", - " ConcatSquash(in_size=width_size, out_size=data_size, key=keys[-1])\n", - " )\n", - " self.layers = layers\n", - "\n", - " def __call__(self, t, y, args):\n", - " t = jnp.asarray(t)[None]\n", - " for layer in self.layers[:-1]:\n", - " y = layer(t, y)\n", - " y = jax.nn.tanh(y)\n", - " y = self.layers[-1](t, y)\n", - " return y\n", - "\n", - "\n", - "# Credit: this layer, and some of the default hyperparameters below, are taken from the\n", - "# FFJORD repo.\n", - "class ConcatSquash(eqx.Module):\n", - " lin1: eqx.nn.Linear\n", - " lin2: eqx.nn.Linear\n", - " lin3: eqx.nn.Linear\n", - "\n", - " def __init__(self, *, in_size, out_size, key, **kwargs):\n", - " super().__init__(**kwargs)\n", - " key1, key2, key3 = jax.random.split(key, 3)\n", - " self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1)\n", - " self.lin2 = eqx.nn.Linear(1, out_size, key=key2)\n", - " self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3)\n", - "\n", - " def __call__(self, t, y):\n", - " return self.lin1(y) * jax.nn.sigmoid(self.lin2(t)) + self.lin3(t)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5bd1fd91-379f-4882-bc58-15c1514fca3e", - "metadata": {}, - "outputs": [], - "source": [ - "class MLPFunc(eqx.Module):\n", - " layers: list[eqx.nn.Linear]\n", - " # layers_t: list[eqx.nn.Linear]\n", - "\n", - " def __init__(self, *, data_size, width_size, depth, key, **kwargs):\n", - " super().__init__(**kwargs)\n", - " keys = jax.random.split(key, depth + 1)\n", - " layers = []\n", - " # layers_t = []\n", - " if depth == 0:\n", - " layers.append(\n", - " eqx.nn.Linear(data_size+1, data_size, key=keys[0])\n", - " )\n", - " else:\n", - " layers.append(\n", - " eqx.nn.Linear(data_size+1, width_size, key=keys[0])\n", - " )\n", - " for i in range(depth - 1):\n", - " layers.append(\n", - " eqx.nn.Linear(\n", - " width_size, width_size, key=keys[i + 1]\n", - " )\n", - " )\n", - " layers.append(\n", - " eqx.nn.Linear(width_size, data_size, key=keys[-1])\n", - " )\n", - "\n", - " # if depth == 0:\n", - " # layers_t.append(\n", - " # eqx.nn.Linear(1, 1, key=keys[0])\n", - " # )\n", - " # else:\n", - " # layers_t.append(\n", - " # eqx.nn.Linear(1, width_size, key=keys[0])\n", - " # )\n", - " # for i in range(depth - 1):\n", - " # layers_t.append(\n", - " # eqx.nn.Linear(\n", - " # width_size, width_size, key=keys[i + 1]\n", - " # )\n", - " # )\n", - " # layers_t.append(\n", - " # eqx.nn.Linear(width_size, 1, key=keys[-1])\n", - " # )\n", - " self.layers = layers\n", - " # self.layers_t = layers_t\n", - "\n", - " def __call__(self, t, y, args):\n", - " t = jnp.asarray(t)[None]\n", - " # y_init = y\n", - " y = jnp.concatenate((y,t), axis=-1)\n", - "\n", - " for layer in self.layers[:-1]:\n", - " y = layer(y)\n", - " y = jax.nn.silu(y)\n", - " y = self.layers[-1](y)\n", - " # for layer in self.layers_t[:-1]:\n", - " # t = layer(t)\n", - " # t = jax.nn.silu(t)\n", - " # t = self.layers[-1](t)\n", - " return y" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "53334027-fbb2-4292-ade6-b0efcd4f3770", - "metadata": {}, - "outputs": [], - "source": [ - "def approx_logp_wrapper(t, y, args):\n", - " y, _ = y\n", - " *args, eps, func = args\n", - " fn = lambda y: func(t, y, args)\n", - " f, vjp_fn = jax.vjp(fn, y)\n", - " (eps_dfdy,) = vjp_fn(eps)\n", - " logp = jnp.sum(eps_dfdy * eps)\n", - " return f, logp\n", - "\n", - "\n", - "def exact_logp_wrapper(t, y, args):\n", - " y, _ = y\n", - " *args, _, func = args\n", - " fn = lambda y: func(t, y, args)\n", - " f, vjp_fn = jax.vjp(fn, y)\n", - " (size,) = y.shape # this implementation only works for 1D input\n", - " eye = jnp.eye(size)\n", - " (dfdy,) = jax.vmap(vjp_fn)(eye)\n", - " logp = jnp.trace(dfdy)\n", - " return f, logp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dc74fc5b-27d1-434b-b87b-4f110a289cf2", - "metadata": {}, - "outputs": [], - "source": [ - "class CNF(eqx.Module):\n", - " func_drift: eqx.Module\n", - " func_extract: eqx.Module\n", - " data_size: int\n", - " exact_logp: bool\n", - " t0: float\n", - " extract_t1: float\n", - " dt0: float\n", - " stepsizecontroller: diffrax.AbstractStepSizeController\n", - " \n", - " def __init__(\n", - " self,\n", - " *,\n", - " data_size,\n", - " exact_logp,\n", - " width_size,\n", - " depth,\n", - " key,\n", - " stepsizecontroller=diffrax.ConstantStepSize(),\n", - " func=Func,\n", - " **kwargs,\n", - " ):\n", - " keys = jax.random.split(key, 2)\n", - " super().__init__(**kwargs)\n", - " self.func_drift = (\n", - " func(\n", - " data_size=data_size,\n", - " width_size=width_size,\n", - " depth=depth,\n", - " key=keys[0],\n", - " )\n", - " )\n", - " self.func_extract = (\n", - " func(\n", - " data_size=data_size,\n", - " width_size=width_size,\n", - " depth=depth,\n", - " key=keys[1],\n", - " )\n", - " )\n", - " self.data_size = data_size\n", - " self.exact_logp = exact_logp\n", - " self.t0 = 0\n", - " self.extract_t1 = 10\n", - " self.dt0 = 1\n", - " self.stepsizecontroller=stepsizecontroller\n", - "\n", - " def transform(self, *, y, t1):\n", - " term = diffrax.ODETerm(self.func_extract)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.t0, self.extract_t1, self.dt0, y, stepsize_controller=self.stepsizecontroller)\n", - " (y,) = sol.ys\n", - " \n", - " term = diffrax.ODETerm(self.func_drift)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.t0, t1, self.dt0, y, stepsize_controller=self.stepsizecontroller)\n", - " (y,) = sol.ys\n", - " return y\n", - "\n", - " def transform_and_log_det(self, *, y, t1):\n", - " if self.exact_logp:\n", - " term = diffrax.ODETerm(exact_logp_wrapper)\n", - " else:\n", - " term = diffrax.ODETerm(approx_logp_wrapper)\n", - " eps = jax.random.normal(key, y.shape)\n", - " delta_log_likelihood = 0.0\n", - " \n", - " y = (y, delta_log_likelihood)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.t0, self.extract_t1, self.dt0, y, (eps, self.func_extract), stepsize_controller=self.stepsizecontroller)\n", - " (y,), (delta_log_likelihood,) = sol.ys\n", - "\n", - " y = (y, delta_log_likelihood)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.t0, t1, self.dt0, y, (eps, self.func_drift), stepsize_controller=self.stepsizecontroller)\n", - " (y,), (delta_log_likelihood,) = sol.ys\n", - " return y, delta_log_likelihood\n", - "\n", - " def inverse_and_log_det(self, *, y, t1):\n", - " if self.exact_logp:\n", - " term = diffrax.ODETerm(exact_logp_wrapper)\n", - " else:\n", - " term = diffrax.ODETerm(approx_logp_wrapper)\n", - " eps = jax.random.normal(key, y.shape)\n", - " delta_log_likelihood = 0.0\n", - "\n", - " y = (y, delta_log_likelihood)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, t1, self.t0, -self.dt0, y, (eps, self.func_drift), stepsize_controller=self.stepsizecontroller)\n", - " (y,), (delta_log_likelihood,) = sol.ys\n", - " \n", - " y = (y, delta_log_likelihood)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.extract_t1, self.t0, -self.dt0, y, (eps, self.func_extract), stepsize_controller=self.stepsizecontroller)\n", - " (y,), (delta_log_likelihood,) = sol.ys\n", - " return y, delta_log_likelihood" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13043c03-9e57-46d8-ad76-ddd5a34d3777", - "metadata": {}, - "outputs": [], - "source": [ - "from jax.scipy.interpolate import RegularGridInterpolator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "85e2ebe5-932b-494a-82ea-de56125f3941", - "metadata": {}, - "outputs": [], - "source": [ - "def load_civ():\n", - " civ_file_name = \"field_dependent_radius_depth_maps_B2d75n_C2d75n_G0d3p_A4d9p_T0d9n_PMTs1d3n_FSR0d65p_QPTFE_0d5n_0d4p.json.gz\"\n", - " \n", - " with gzip.open(civ_file_name, \"rb\") as f:\n", - " file = json.load(f)\n", - " civ_map = RegularGridInterpolator(\n", - " tuple([np.linspace(*ax[1]) for ax in file['coordinate_system']]),\n", - " np.array(file['survival_probability_map']).reshape([ax[1][-1] for ax in file['coordinate_system']]),\n", - " bounds_error=False,\n", - " fill_value=0,\n", - " )\n", - " \n", - " return civ_map\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ad7b7361-487c-4062-a6ab-b6b8965355d5", - "metadata": {}, - "outputs": [], - "source": [ - "civ_map = load_civ()\n", - "vec_civ_map = jax.vmap(civ_map)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d89c4762-251f-44e9-be76-e70bd50fc63a", - "metadata": {}, - "outputs": [], - "source": [ - "tpc_height = 148.6515\n", - "z_scale = 5\n", - "data_bool = (z>-tpc_height) & (z<0)\n", - "z_sel = z[data_bool]\n", - "z_sel_scaled = -z_sel/z_scale\n", - "cond_sel = conditions[data_bool]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a17f4f67-dc5c-4716-9a3b-3aebdc723711", - "metadata": {}, - "outputs": [], - "source": [ - "len(z_sel_scaled)/len(z)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60d27398-22cd-4b52-9c66-972a48a69cde", - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def compute_r(xy_arr):\n", - " return jnp.sqrt(xy_arr[:,0]**2 + xy_arr[:,1]**2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1af3831-5b8b-45c5-b511-014b6c90a8d0", - "metadata": {}, - "outputs": [], - "source": [ - "key, model_key = jax.random.split(key, 2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c25ae05-4b92-47ca-b9d7-01712bbffbc6", - "metadata": {}, - "outputs": [], - "source": [ - "model = CNF(\n", - " data_size=2,\n", - " exact_logp=True,\n", - " width_size=48,\n", - " depth=3,\n", - " key=model_key,\n", - " stepsizecontroller=diffrax.PIDController(rtol=1e-3, atol=1e-6, dtmax=5),\n", - " func=MLPFunc\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61f70527-822e-4c6e-ba24-8373ca65aa03", - "metadata": {}, - "outputs": [], - "source": [ - "def rolloff_func(x, rolloff=1e-2):\n", - " return x+rolloff*jnp.exp(-x/rolloff)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8da4c8b8-3acd-4aa5-9892-29e3ac28b779", - "metadata": {}, - "outputs": [], - "source": [ - "def curl_loss(key, model, z, x, extract_max_z=10.):\n", - " rand_z = jax.random.uniform(key, 1, minval=0.0, maxval=extract_max_z)\n", - " jac_drift = jax.jacfwd(lambda a:model.func_drift(z, a, 0.))(x)\n", - " jac_ex = jax.jacfwd(lambda a:model.func_extract(rand_z[0], a, 0.))(x)\n", - " return (jac_drift[1,0] - jac_drift[0,1])**2 + (jac_ex[1,0] - jac_ex[0,1])**2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a33d468a-4a68-4999-9d78-687b6298777b", - "metadata": {}, - "outputs": [], - "source": [ - "def single_likelihood_loss(key, model, condition, t1, z, min_p=1e-3, N_samples=4, tpc_r=66.4, curl_loss_multiplier=1000.):\n", - " keys = jax.random.split(key,2)\n", - " samples = generate_samples(keys[0], condition[np.newaxis,...], N_samples)\n", - " transformed_samples, logdet = eqx.filter_vmap(lambda y: model.transform_and_log_det(y=y, t1=t1))(samples)\n", - " sample_r = compute_r(transformed_samples)\n", - " p_surv = vec_civ_map(jnp.vstack((sample_r, np.repeat(z, N_samples))).T)\n", - " # p_surv = jnp.where(p_surv Date: Tue, 9 Dec 2025 22:06:29 -0600 Subject: [PATCH 02/10] Update docstring to clarify package purpose and focus on electric field modeling --- src/fieldflow/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/fieldflow/__init__.py b/src/fieldflow/__init__.py index 8e883cc..3b8d603 100644 --- a/src/fieldflow/__init__.py +++ b/src/fieldflow/__init__.py @@ -1,8 +1,6 @@ """JAX-based normalizing flows for physical field modeling. -This package provides tools for modeling physical fields using continuous -normalizing flows, with a focus on electric field modeling for particle -detectors. +This package provides tools for modeling electric fields in TPCs using physics-informed neural ODEs. """ from fieldflow.config import Config, load_config From 4adc53f99958614b376d0153a80bf3a3e5d5ce0e Mon Sep 17 00:00:00 2001 From: Juehang Qin Date: Tue, 9 Dec 2025 22:36:47 -0600 Subject: [PATCH 03/10] Update configuration and documentation for FieldFlow - Adjust TPC radius in sample_config.toml from 66.4 cm to 129.96 cm. - Enhance docstrings across multiple modules to clarify functionality and parameters, focusing on electric field modeling and position reconstruction. - Improve descriptions in model.py, train.py, and utils.py to better reflect the purpose and usage of functions and classes. --- sample_config.toml | 2 +- src/fieldflow/__init__.py | 15 ++- src/fieldflow/__main__.py | 23 ++-- src/fieldflow/config.py | 111 +++++++++++-------- src/fieldflow/dataloader.py | 66 ++++++++---- src/fieldflow/model.py | 173 ++++++++++++++++++++++------- src/fieldflow/posrec.py | 100 +++++++++-------- src/fieldflow/train.py | 210 +++++++++++++++++++----------------- src/fieldflow/utils.py | 14 +-- 9 files changed, 455 insertions(+), 259 deletions(-) diff --git a/sample_config.toml b/sample_config.toml index 3bfee82..0597ea2 100644 --- a/sample_config.toml +++ b/sample_config.toml @@ -64,7 +64,7 @@ multisteps_every_k = 4 # Gradient accumulation steps for MultiSteps optimizer [experiment] # Physical experimental setup parameters tpc_height = 148.6515 # Height of the TPC in cm (for filtering z coordinates) -tpc_r = 66.4 # Radius of the TPC in cm (for boundary constraints) +tpc_r = 129.96 # Radius of the TPC in cm (for boundary constraints) [posrec] # Position reconstruction flow model parameters diff --git a/src/fieldflow/__init__.py b/src/fieldflow/__init__.py index 3b8d603..3876f7b 100644 --- a/src/fieldflow/__init__.py +++ b/src/fieldflow/__init__.py @@ -1,6 +1,17 @@ -"""JAX-based normalizing flows for physical field modeling. +"""JAX-based continuous normalizing flows for electric field modeling. -This package provides tools for modeling electric fields in TPCs using physics-informed neural ODEs. +FieldFlow provides tools for modeling electric fields in dual-phase Time +Projection Chambers (TPCs) using continuous normalizing flows (CNFs). The +architecture mirrors the physical structure of dual-phase TPCs, with separate +neural networks for the extraction field (z-independent distortions) and +drift field (z-dependent distortions). + +The library supports two approaches for enforcing Maxwell's equations: + +- **Scalar potential method**: Models the field as the negative gradient of a + learned scalar potential, which is curl-free by construction. +- **Vector field with curl loss**: Directly learns the vector field while + penalizing non-zero curl during training. """ from fieldflow.config import Config, load_config diff --git a/src/fieldflow/__main__.py b/src/fieldflow/__main__.py index 7ff199e..d2b06d0 100644 --- a/src/fieldflow/__main__.py +++ b/src/fieldflow/__main__.py @@ -1,7 +1,12 @@ -"""Entry point for FieldFlow training. +"""Command-line interface for FieldFlow training. -This module provides a simple command-line interface for training FieldFlow -models from configuration, with optional fine-tuning of pre-trained models. +This module provides the CLI entry point for training CNF models to learn +electric field distortions in dual-phase TPCs. Supports training from +scratch or fine-tuning pretrained models. + +Usage: + python -m fieldflow config.toml + python -m fieldflow config.toml --pretrained model.eqx """ import argparse @@ -23,14 +28,18 @@ def create_model_from_config(config, key): - """Create a CNF model from configuration. + """Create a CNF model from configuration parameters. + + Initializes a ContinuousNormalizingFlow with architecture and ODE solver + settings from the config. Uses DriftFromPotential (scalar method) or + MLPFunc (vector method) based on config.model.scalar. Args: - config: Configuration object containing model parameters - key: JAX PRNG key for model initialization + config: Config object with model parameters. + key: JAX random key for parameter initialization. Returns: - Initialized ContinuousNormalizingFlow model + Initialized ContinuousNormalizingFlow model. """ # Create step size controller based on config if config.model.use_pid_controller: diff --git a/src/fieldflow/config.py b/src/fieldflow/config.py index fdc609a..8b5480a 100644 --- a/src/fieldflow/config.py +++ b/src/fieldflow/config.py @@ -18,28 +18,31 @@ @dataclass class ModelConfig: - """Configuration for continuous normalizing flow model architecture and - behavior. + """Configuration for continuous normalizing flow model architecture. - This class encapsulates the parameters used to define a Continuous - Normalizing Flow model, including neural network architecture, ODE - solver settings, and model-specific hyperparameters. + This class defines parameters for the CNF model that learns electric field + distortions in dual-phase TPCs. The model uses separate neural networks for + extraction (z-independent) and drift (z-dependent) field components. Attributes: - data_size: Dimensionality of the input data. - exact_logp: Whether to use exact log probability calculation. - width_size: Width of neural network layers. - depth: Depth of neural network. - scalar: Whether to use scalar field instead of vector field. - use_pid_controller: Whether to use PIDController instead of - ConstantStepSize. - rtol: Relative tolerance for PIDController. - atol: Absolute tolerance for PIDController. - dtmax: Maximum step size for PIDController. - dtmin: Minimum step size for PIDController. - t0: Starting time for ODE. - extract_t1: End time for extract phase. - dt0: Initial time step. + data_size: Dimensionality of the input data (default 2 for x,y). + exact_logp: If True, compute exact log probability using full Jacobian + trace. If False, use Hutchinson trace estimator (faster but + approximate). + width_size: Width of hidden layers in the neural networks. + depth: Number of hidden layers in the neural networks. + scalar: If True, use scalar potential method (curl-free by + construction). If False, use direct vector field with curl + penalty loss. + use_pid_controller: If True, use adaptive PID step size controller. + If False, use constant step size. + rtol: Relative tolerance for adaptive ODE solver. + atol: Absolute tolerance for adaptive ODE solver. + dtmax: Maximum step size for adaptive ODE solver. + dtmin: Minimum step size for adaptive ODE solver. + t0: Starting time for ODE integration. + extract_t1: End time for extraction phase ODE integration. + dt0: Initial time step for ODE solver. """ data_size: int = 2 @@ -63,7 +66,23 @@ class ModelConfig: @dataclass class PosRecFlowConfig: - """Configuration for Position Reconstruction Flow model.""" + """Configuration for the position reconstruction normalizing flow. + + The position reconstruction flow is a pretrained conditional normalizing + flow that maps detector hit patterns to (x, y) position distributions. + This flow provides the prior distribution for CNF training. + + Attributes: + flow_layers: Number of coupling layers in the normalizing flow. + nn_width: Width of hidden layers in coupling layer neural networks. + nn_depth: Number of hidden layers in coupling layer neural networks. + invert_bool: Whether to invert the flow direction. + cond_dim: Dimension of the conditioning vector (hit pattern size). + spline_knots: Number of knots for rational quadratic spline bijections. + spline_interval: Interval parameter for spline transformations. + radius_buffer: Buffer beyond TPC radius for coordinate transformation, + allowing predictions slightly outside the physical boundary. + """ # Neural network architecture flow_layers: int = 5 @@ -86,27 +105,31 @@ class PosRecFlowConfig: class TrainingConfig: """Configuration for training CNF models. - This class encapsulates the parameters used during the training process, - including optimization settings, data handling, and training strategies. + This class defines parameters for the training process including + optimization, batching, loss computation, and multi-GPU parallelization. Attributes: - seed: Random seed for training reproducibility. - learning_rate: Initial learning rate for the optimizer. - weight_decay: L2 regularization parameter. + seed: Random seed for reproducibility. + learning_rate: Initial learning rate for AdamW optimizer. + weight_decay: L2 regularization coefficient. epochs: Number of training epochs. - batch_size: Batch size for training. - enable_scheduler: Whether to enable learning rate scheduling. - epoch_start: First epoch to begin training at. - n_samples: Number of samples per instance for likelihood estimation. - n_train: Size of training set. - n_test: Size of test/validation set. - use_best: Whether to use the best model based on validation. - curl_loss_multiplier: Coefficient for curl loss component. - z_scale: Scaling factor for z dimension. - multisteps_every_k: Steps for MultiSteps optimizer. - num_devices: Number of devices to use for data parallelization. - save_iter: Every n number of epochs to save the model - save_file_name: Path and file name start (/path/to/model_name) + batch_size: Number of samples per training batch. + enable_scheduler: If True, use learning rate schedule that reduces + LR at epochs 20 and 70. If False, use constant learning rate. + epoch_start: Starting epoch number (useful for resuming training). + n_samples: Number of Monte Carlo samples per event for likelihood + estimation from the position reconstruction flow. + n_train: Number of training samples to use from the dataset. + n_test: Number of samples for validation/test set. + use_best: If True, return the model with lowest validation loss. + curl_loss_multiplier: Weight for curl penalty term (only used when + scalar=False in ModelConfig). + z_scale: Scaling factor to normalize z coordinates. + multisteps_every_k: Gradient accumulation steps before optimizer + update. + num_devices: Number of GPUs for data parallelization. + save_iter: Save model checkpoint every N epochs. + save_file_name: Base filename for saved model checkpoints. """ # Training process parameters @@ -138,14 +161,16 @@ class TrainingConfig: @dataclass class ExperimentConfig: - """Configuration for experimental parameters. + """Configuration for the physical TPC geometry. - This class encapsulates parameters that describe the physical - experimental setup and constraints. + This class defines physical parameters of the dual-phase Time Projection + Chamber that constrain the model. Attributes: - tpc_height: Height of the TPC for filtering z coordinates. - tpc_r: Radius of the TPC for boundary constraints. + tpc_height: Height of the TPC drift region in cm. Used to filter + events by z coordinate (keeping -tpc_height < z < 0). + tpc_r: Radius of the cylindrical TPC in cm. Used for boundary + constraints in the loss function. """ tpc_height: float = 259.92 diff --git a/src/fieldflow/dataloader.py b/src/fieldflow/dataloader.py index 79e0973..323d77c 100644 --- a/src/fieldflow/dataloader.py +++ b/src/fieldflow/dataloader.py @@ -1,7 +1,8 @@ """Data loading utilities for FieldFlow. -This module provides functions for loading hitpattern data and CIV maps, -extracted directly from the working notebook code. +This module provides functions for loading detector data required for +CNF training, including hit patterns and charge-insensitive-volume (CIV) survival +probability maps. """ from pathlib import Path @@ -16,15 +17,22 @@ def load_civ_map(file_path: str | Path) -> RegularGridInterpolator: - """Load CIV map from npz file. + """Load charge-insensitive-volume (CIV) survival probability map. - Direct adaptation of load_civ() function from notebook. + The CIV map provides the probability that charge generated at a given + (r, z) position survives to be detected. This is used in the likelihood + loss to account for position-dependent detection efficiency. Args: - file_path: Path to the .npz file containing CIV map data + file_path: Path to the .npz file containing CIV map data with keys + 'R' (radial coordinates), 'Z' (z coordinates), and 'vals' + (survival probabilities). Returns: - RegularGridInterpolator function for CIV map + RegularGridInterpolator that evaluates CIV probability at (r, z). + + Raises: + FileNotFoundError: If the CIV map file does not exist. """ file_path = Path(file_path) if not file_path.exists(): @@ -50,17 +58,27 @@ def load_hitpatterns( tpc_height: float, z_scale: float, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Load and process hitpattern data. + """Load and preprocess hit pattern data for CNF training. - Direct extraction from notebook data loading and processing code. + Loads detector hit patterns and their associated z coordinates, filters + to valid TPC volume, and applies z-coordinate scaling. Args: - file_path: Path to the .npz file containing hitpattern data - tpc_height: Height of the TPC for filtering z coordinates - z_scale: Scaling factor for z coordinates + file_path: Path to .npz file with 'z_corr' (corrected z coordinates) + and 'condition' (hit pattern arrays). + tpc_height: Height of TPC drift region in cm. Events with + z < -tpc_height or z > 0 are filtered out. + z_scale: Scaling factor for z coordinates. The scaled z is computed + as -z / z_scale. Returns: - Tuple of (z_sel, z_sel_scaled, cond_sel) + Tuple of (z_sel, z_sel_scaled, cond_sel) where: + - z_sel: Filtered z coordinates in cm + - z_sel_scaled: Scaled z coordinates (used as ODE integration time) + - cond_sel: Corresponding hit pattern conditioning vectors + + Raises: + FileNotFoundError: If the hitpattern file does not exist. """ file_path = Path(file_path) if not file_path.exists(): @@ -85,15 +103,20 @@ def load_data_from_config( hitpattern_path: str | Path, civ_map_path: str | Path, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, RegularGridInterpolator]: - """Load all data using configuration parameters. + """Load all training data using configuration parameters. + + Convenience function that loads both hit patterns and CIV map using + parameters from the configuration object. Args: - config: Configuration object containing training parameters - hitpattern_path: Path to hitpattern .npz file - civ_map_path: Path to CIV map JSON file + config: Configuration object with training.z_scale and + experiment.tpc_height parameters. + hitpattern_path: Path to hit pattern .npz file. + civ_map_path: Path to CIV map .npz file. Returns: - Tuple of (z_sel, z_sel_scaled, cond_sel, civ_map) + Tuple of (z_sel, z_sel_scaled, cond_sel, civ_map) containing + the loaded and preprocessed data. """ # Extract parameters from config z_scale = config.training.z_scale @@ -113,14 +136,13 @@ def load_data_from_config( def create_vectorized_civ_map(civ_map: RegularGridInterpolator): - """Create vectorized version of CIV map interpolator. - - Direct from notebook: vec_civ_map = jax.vmap(civ_map) + """Create a vectorized (batched) version of the CIV map interpolator. Args: - civ_map: RegularGridInterpolator for CIV map + civ_map: RegularGridInterpolator for single-point CIV evaluation. Returns: - Vectorized CIV map function + Vectorized function that can evaluate CIV at multiple points + simultaneously. """ return jax.vmap(civ_map) diff --git a/src/fieldflow/model.py b/src/fieldflow/model.py index 8845355..adc0ccd 100644 --- a/src/fieldflow/model.py +++ b/src/fieldflow/model.py @@ -1,5 +1,15 @@ -""" -Model definitions for continuous normalizing flow for electric field modeling. +"""Neural network models for continuous normalizing flows. + +This module defines the neural network architectures used to model electric +field distortions in dual-phase TPCs. The CNF uses separate networks for +the extraction field (z-independent) and drift field (z-dependent). + +Two approaches are provided for enforcing Maxwell's equations: + +- **MLPFunc**: Direct vector field parameterization. Requires explicit curl + penalty during training to enforce curl-free constraint. +- **DriftFromPotential**: Scalar potential parameterization where the drift + is the negative gradient of the potential. Curl-free by construction. """ import diffrax @@ -70,8 +80,14 @@ def fn(y): class MLPFunc(eqx.Module): - """Multilayer perceptron that models the drift function in - a continuous normalizing flow. + """Multilayer perceptron that directly models the drift vector field. + + This network takes (x, y, t) as input and outputs a 2D drift vector. + When used for electric field modeling, training should include a curl + penalty to encourage physically valid (curl-free) fields. + + Attributes: + layers: List of linear layers forming the MLP. """ layers: list[eqx.nn.Linear] @@ -124,8 +140,14 @@ def __call__(self, t, y, args): # noqa: ARG002 class ScalarMLPFunc(eqx.Module): - """Multilayer perceptron that models the scalar potential in - a continuous normalizing flow. + """Multilayer perceptron that models a scalar potential field. + + This network takes (x, y, t) as input and outputs a scalar value + representing the potential at that point. Used by DriftFromPotential + to derive curl-free drift fields via automatic differentiation. + + Attributes: + layers: List of linear layers forming the MLP. """ layers: list[eqx.nn.Linear] @@ -180,11 +202,28 @@ def __call__(self, t, y, args): # noqa: ARG002 return jnp.squeeze(y, axis=-1) class DriftFromPotential(eqx.Module): - model: ScalarMLPFunc - """ - Drift function derived from a scalar potential model. + """Drift function derived from a scalar potential. + + This class wraps a ScalarMLPFunc and computes the drift as the negative + gradient of the learned potential. This guarantees curl-free fields by + construction, satisfying Maxwell's equations without explicit penalties. + + Attributes: + model: The underlying scalar potential network. """ + + model: ScalarMLPFunc + def __init__(self, *, data_size, width_size, depth, key, **kwargs): + """Initialize the drift-from-potential model. + + Args: + data_size: Dimensionality of the spatial coordinates. + width_size: Width of hidden layers in the potential network. + depth: Number of hidden layers in the potential network. + key: JAX random key for parameter initialization. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(**kwargs) # Initialize the scalar potential model self.model = ScalarMLPFunc( @@ -195,27 +234,54 @@ def __init__(self, *, data_size, width_size, depth, key, **kwargs): ) def __call__(self, t, y, args): + """Compute drift as negative gradient of the scalar potential. + + Args: + t: Current time (z coordinate in physical terms). + y: Current spatial position (x, y). + args: Additional arguments (unused, for interface compatibility). + + Returns: + Drift vector pointing in direction of steepest potential descent. + """ gradient = jax.grad(lambda y: self.scalar_pot(t, y, args))(y) return -gradient def scalar_pot(self, t, y, args): + """Evaluate the scalar potential at a point. + + Args: + t: Current time. + y: Spatial position. + args: Additional arguments passed to the model. + + Returns: + Scalar potential value (summed over batch if applicable). + """ return self.model(t, y, args).sum() class ContinuousNormalizingFlow(eqx.Module): - """Continuous normalizing flow using neural ODEs. + """Continuous normalizing flow for dual-phase TPC field modeling. + + This model uses neural ODEs to learn electric field distortions. The + architecture mirrors dual-phase TPC physics with two sequential flows: + + 1. **Extraction phase** (func_extract): Models z-independent field + distortions that affect the (x, y) distribution uniformly regardless + of drift distance. + 2. **Drift phase** (func_drift): Models z-dependent field distortions + where the effect accumulates with drift distance. Attributes: - func_drift (eqx.Module): Neural network modeling the drift function. - data_size (int): Dimensionality of the data. - exact_logp (bool): Whether to use exact log probability computation. - t0 (float): Initial time for ODE integration. - dt0 (float): Initial time step for ODE integration. - stepsizecontroller (diffrax.AbstractStepSizeController): Controls - adaptive stepping. - - func_extract (eqx.Module): Neural network - modeling the extraction function. - extract_t1 (float): final extraction time + func_drift: Neural network for the z-dependent drift field. + func_extract: Neural network for the z-independent extraction field. + data_size: Dimensionality of the spatial data (typically 2 for x, y). + exact_logp: If True, compute exact Jacobian trace. If False, use + Hutchinson estimator. + t0: Initial time for ODE integration. + dt0: Initial step size for ODE solver. + extract_t1: Integration end time for extraction phase. + stepsizecontroller: Adaptive step size controller for ODE solver. """ func_drift: eqx.Module @@ -243,6 +309,27 @@ def __init__( extract_t1 = 10, **kwargs, ): + """Initialize the continuous normalizing flow. + + Creates two neural networks with identical architecture: one for the + extraction phase and one for the drift phase. + + Args: + data_size: Dimensionality of spatial coordinates. + exact_logp: Whether to use exact Jacobian trace computation. + width_size: Width of hidden layers in drift/extraction networks. + depth: Number of hidden layers in drift/extraction networks. + key: JAX random key for parameter initialization. + stepsizecontroller: ODE step size controller. Defaults to + ConstantStepSize if not provided. + func: Neural network class for drift/extraction functions. + Use MLPFunc for vector field or DriftFromPotential for + scalar potential approach. + t0: Initial time for ODE integration. + dt0: Initial step size for ODE solver. + extract_t1: End time for extraction phase integration. + **kwargs: Additional arguments passed to parent class. + """ if stepsizecontroller is None: stepsizecontroller = diffrax.ConstantStepSize() keys = jax.random.split(key, 2) @@ -269,14 +356,18 @@ def __init__( self.extract_t1 = extract_t1 def transform(self, *, y, t1): - """Transform data through the flow without computing log determinants. + """Transform coordinates through extraction and drift phases. + + Applies the full field distortion model without tracking probability + changes. First applies the extraction field (z-independent), then + the drift field (z-dependent, integrated to time t1). Args: - y (jnp.ndarray): Input data points to transform. - t1 (float): Target time for the transformation. + y: Input spatial coordinates of shape (data_size,). + t1: Target time for drift phase (corresponds to scaled z depth). Returns: - jnp.ndarray: Transformed data points. + Transformed coordinates after both flow phases. """ term = diffrax.ODETerm(self.func_extract) solver = diffrax.Euler() @@ -306,17 +397,21 @@ def transform(self, *, y, t1): return y def transform_and_log_det(self, *, y, t1, key): - """Transform data and compute log determinant of the transformation. + """Transform coordinates and compute the log determinant Jacobian. + + Applies extraction then drift phases while accumulating the log + determinant of the transformation Jacobian, needed for density + estimation via the change of variables formula. Args: - y (jnp.ndarray): Input data points to transform. - t1 (float): Target time for the transformation. - key (jax.random.key): Random key for stochastic operations. + y: Input spatial coordinates of shape (data_size,). + t1: Target time for drift phase (scaled z coordinate). + key: JAX random key (used for Hutchinson estimator if exact_logp + is False). Returns: - tuple: (transformed_y, log_determinant) where transformed_y is the - transformed data and log_determinant is the change in log - probability. + Tuple of (transformed_y, log_det) where log_det is the accumulated + log determinant from both phases. """ if self.exact_logp: term = diffrax.ODETerm(exact_logp_wrapper) @@ -359,15 +454,17 @@ def transform_and_log_det(self, *, y, t1, key): def inverse_and_log_det(self, *, y, t1, key): """Apply inverse transformation and compute the log determinant. + Reverses the flow transformation: first inverts the drift phase + (from t1 back to t0), then inverts the extraction phase. + Args: - y (jnp.ndarray): Input data points to inverse transform. - t1 (float): Starting time for the inverse transformation. - key (jax.random.key): Random key for stochastic operations. + y: Transformed coordinates to invert. + t1: Time to invert from for drift phase (scaled z coordinate). + key: JAX random key for Hutchinson estimator if needed. Returns: - tuple: (inverse_y, log_determinant) where inverse_y is the inverse - transformed data and log_determinant is the change in log - probability. + Tuple of (original_y, log_det) where original_y is the recovered + input coordinates and log_det is the accumulated log determinant. """ if self.exact_logp: term = diffrax.ODETerm(exact_logp_wrapper) diff --git a/src/fieldflow/posrec.py b/src/fieldflow/posrec.py index 687b50c..bbe516d 100644 --- a/src/fieldflow/posrec.py +++ b/src/fieldflow/posrec.py @@ -1,5 +1,12 @@ -"""Pretrained conditional normalizing flow for position reconstruction from -detector hit patterns, including coordinate transformations. +"""Position reconstruction flow and coordinate transformations. + +This module provides the pretrained position reconstruction normalizing flow +that maps detector hit patterns to (x, y) position distributions. It also +includes coordinate transformations between the normalized flow space and +physical detector coordinates. + +The position reconstruction flow serves as a prior for CNF training, +providing samples of likely (x, y) positions given observed hit patterns. """ import copy @@ -287,15 +294,18 @@ def inverse_and_log_det( def get_unconstrain_transform(): - """Get the unconstrain transform. + """Create the transformation from bounded to unbounded coordinates. + + This transformation maps normalized coordinates in [-1, 1] to unbounded + space suitable for the normalizing flow's standard normal base + distribution. The transformation chain is: - Maps from [-1,1] coordinates to unbounded space via: - 1. [-1,1] -> [-1+eps, 1-eps] (avoids saturation) - 2. unbounding transform (arctanh or StandardNormalToUnitBall.inverse) - 3. final scaling + 1. Scale by (1 - eps) to avoid boundary saturation + 2. Apply inverse of StandardNormalToUnitBall (unit ball -> normal) + 3. Apply final identity scaling Returns: - Transformation object + Chain bijection implementing the unconstrain transformation. """ # Common first step: avoid saturation affine_eps = Affine( @@ -315,6 +325,8 @@ def get_unconstrain_transform(): return Chain([affine_eps, unbounding_transform, affine_scale]) +#: Vectorized constraint transformation from unbounded to [-1, 1] space. +#: Applies the inverse of get_unconstrain_transform() to batches of points. constrain_vec = jax.vmap(get_unconstrain_transform().inverse) @@ -322,18 +334,22 @@ def get_unconstrain_transform(): def data_inv_transformation( data: Array, tpc_r: float, radius_buffer: float ) -> Array: - """Transform flow coordinates back to physical (x,y) coordinates. + """Transform from flow space to physical (x, y) coordinates. - This function applies the full coordinate transformation chain to convert - from normalized flow space back to physical detector coordinates. + Converts samples from the normalizing flow's output space back to + physical detector coordinates in centimeters. + + The transformation: + 1. Applies constrain_vec to map from unbounded to [-1, 1] + 2. Scales by (tpc_r + radius_buffer) to get physical coordinates Args: - data: Array of shape (N, 2) in flow coordinate space - tpc_r: TPC radius in cm - radius_buffer: Buffer for predictions beyond TPC radius + data: Array of shape (N, 2) in flow coordinate space. + tpc_r: TPC radius in cm. + radius_buffer: Additional buffer beyond TPC radius in cm. Returns: - Array of shape (N, 2) in physical coordinates (cm) + Array of shape (N, 2) with physical (x, y) coordinates in cm. """ @@ -355,22 +371,21 @@ def generate_samples_for_cnf( tpc_r: float = 129.96, # Default matches experiment.tpc_r radius_buffer: float = 0.0, # Default matches posrec.radius_buffer ) -> Array: - """Generate samples from position reconstruction flow for CNF training. + """Generate position samples from the reconstruction flow for CNF training. - This function provides a clean interface for CNF training to sample - from the position reconstruction flow and get properly transformed - physical coordinates. + Samples (x, y) positions from the pretrained position reconstruction flow + conditioned on hit patterns, then transforms to physical coordinates. Args: - key: Random key for sampling - conditions: Conditioning information (hit patterns) - n_samples: Number of samples to generate - posrec_model: Pretrained position reconstruction flow model - tpc_r: TPC radius in cm (default: 66.4) - radius_buffer: Buffer for predictions beyond TPC radius (default: 20.0) + key: JAX random key for sampling. + conditions: Hit pattern conditioning array of shape (1, cond_dim). + n_samples: Number of position samples to generate. + posrec_model: Pretrained position reconstruction flow model. + tpc_r: TPC radius in cm (default: 129.96). + radius_buffer: Buffer beyond TPC radius in cm (default: 0.0). Returns: - Array of shape (n_samples, 2) in physical coordinates + Array of shape (n_samples, 2) with sampled (x, y) coordinates in cm. """ # Sample from the position reconstruction flow output = posrec_model.sample(key, (n_samples,), condition=conditions) @@ -382,24 +397,21 @@ def generate_samples_for_cnf( def posrec_flow(pretrained_posrec_flow_path, config: "Config"): - """ - Load a pretrained position reconstruction flow model, which is a coupling - flow model with rational quadratic spline bijections. The model uses a - standard normal base distribution. - - Parameters - ---------- - pretrained_posrec_flow_path : str or Path - Path to the pretrained model weights file. Should be compatible with - equinox's tree serialization format. - config : Config - Configuration object containing position reconstruction flow - parameters. - - Returns - ------- - eqx.Module - A pretrained coupling flow model with loaded weights. + """Load a pretrained position reconstruction flow model. + + Creates a coupling flow architecture matching the pretrained model and + loads saved weights. The flow uses rational quadratic spline bijections + with a standard normal base distribution. + + Args: + pretrained_posrec_flow_path: Path to the saved model weights file + (equinox serialization format). + config: Configuration object with posrec parameters (flow_layers, + nn_width, nn_depth, spline_knots, spline_interval, cond_dim, + invert_bool). + + Returns: + Loaded coupling flow model ready for inference. """ bijection = RationalQuadraticSpline( knots=config.posrec.spline_knots, diff --git a/src/fieldflow/train.py b/src/fieldflow/train.py index 2eee305..3404437 100644 --- a/src/fieldflow/train.py +++ b/src/fieldflow/train.py @@ -1,7 +1,9 @@ -"""Training infrastructure for FieldFlow continuous normalizing flow models. +"""Training infrastructure for FieldFlow CNF models. -This module provides loss functions, training loops, and utilities for training -CNF models to learn drift fields from position reconstruction data. +This module provides the loss functions and training loop for learning +electric field distortions in dual-phase TPCs. Training uses position +samples from a pretrained reconstruction flow weighted by charge-in-volume +survival probabilities. """ import json @@ -26,17 +28,17 @@ def rolloff_func(x: Array, rolloff: float = 1e-2) -> Array: - """Apply rolloff regularization to prevent numerical issues. + """Apply soft lower bound to prevent log(0) numerical issues. - This function ensures that probabilities don't get too close to zero, - which can cause numerical instability in log computations. + Smoothly regularizes small values using: x + rolloff * exp(-x/rolloff). + This approaches x for large values and rolloff for x near 0. Args: - x: Input array to regularize - rolloff: Minimum value parameter + x: Input array (typically survival probabilities). + rolloff: Soft minimum value parameter. Returns: - Regularized array + Regularized array with values bounded away from zero. """ return x + rolloff * jnp.exp(-x / rolloff) @@ -48,20 +50,22 @@ def curl_loss( x: Array, extract_max_z: float = 10.0, # noqa: ARG001 ) -> float: - """Compute curl penalty for vector field to encourage curl-free flow. + """Compute curl penalty for enforcing Maxwell's equations. - This loss encourages the learned drift field to have minimal curl, - which is a physical constraint for certain types of fields. + For electrostatic fields, curl(E) = 0. This loss penalizes non-zero curl + in the learned drift field when using the vector field approach (MLPFunc). + Not needed when using scalar potential (DriftFromPotential) since gradient + fields are curl-free by construction. Args: - key: Random key for sampling (unused but kept for interface) - model: CNF model with func_drift method - z: Current z coordinate (time parameter) - x: Spatial coordinates [x, y] - extract_max_z: Maximum z value for random sampling (unused) + key: Unused (kept for interface compatibility). + model: CNF model containing func_drift. + z: Current z coordinate (ODE time parameter). + x: Spatial coordinates (x, y) at which to evaluate curl. + extract_max_z: Unused (kept for interface compatibility). Returns: - Curl penalty loss value + Squared curl value (∂v_y/∂x - ∂v_x/∂y)². """ jac_drift = jax.jacfwd(lambda a: model.func_drift(z, a, 0.0))(x) @@ -86,29 +90,31 @@ def single_likelihood_loss( curl_loss_multiplier: float = 1000.0, scalar: bool = False, ) -> float: - """Compute likelihood loss for a single data point. + """Compute negative log-likelihood loss for a single event. - This function computes the negative log-likelihood for a single event, - incorporating survival probability from CIV maps and curl penalty. + The loss combines: + 1. Monte Carlo estimation of the likelihood using samples from the + position reconstruction flow, transformed through the CNF + 2. CIV survival probability weighting + 3. Optional curl penalty for vector field approach Args: - key: Random key for sampling - model: CNF model to train - condition: Hit pattern conditioning information - t1: Target time for transformation (scaled z coordinate) - z: Physical z coordinate - posrec_model: Pretrained position reconstruction model - civ_map: Charge-in-volume survival probability map - tpc_r: TPC radius for boundary constraints - radius_buffer: Buffer for predictions beyond TPC radius - min_p: Minimum survival probability (for numerical stability) - n_samples: Number of samples for Monte Carlo estimation - curl_loss_multiplier: Weight for curl penalty term - scalar: False if using vector method, True if using scalar pot method - + key: JAX random key for sampling. + model: CNF model being trained. + condition: Hit pattern conditioning vector. + t1: ODE integration time (scaled z coordinate). + z: Physical z coordinate in cm (for CIV lookup). + posrec_model: Pretrained position reconstruction flow. + civ_map: Charge-in-volume survival probability interpolator. + tpc_r: TPC radius in cm for boundary constraints. + radius_buffer: Buffer beyond TPC radius for position sampling. + min_p: Minimum survival probability (numerical stability). + n_samples: Number of Monte Carlo samples. + curl_loss_multiplier: Weight for curl penalty term. + scalar: If True, skip curl penalty (using scalar potential method). Returns: - Negative log-likelihood loss value + Combined negative log-likelihood and curl penalty loss. """ keys = jax.random.split(key, 2 + n_samples) @@ -173,23 +179,26 @@ def likelihood_loss( scalar = False, **kwargs, ) -> float: - """Compute vectorized likelihood loss over a batch of data. + """Compute batched likelihood loss over multiple events. + + Vectorizes single_likelihood_loss over a batch of events and returns + the mean loss. Args: - model: CNF model to train - key: Random key for sampling - conditions: Batch of hit pattern conditions - t1s: Batch of scaled z coordinates - zs: Batch of physical z coordinates - posrec_model: Pretrained position reconstruction model - civ_map: Charge-in-volume survival probability map - tpc_r: TPC radius for boundary constraints - n_samples: Number of samples per data point - scalar: If True, omit curl loss component - **kwargs: Additional arguments passed to single_likelihood_loss + model: CNF model being trained. + key: JAX random key for sampling. + conditions: Batch of hit patterns, shape (batch_size, cond_dim). + t1s: Batch of ODE times (scaled z), shape (batch_size,). + zs: Batch of physical z coordinates, shape (batch_size,). + posrec_model: Pretrained position reconstruction flow. + civ_map: CIV survival probability interpolator. + tpc_r: TPC radius in cm. + n_samples: Monte Carlo samples per event. + scalar: If True, skip curl penalty. + **kwargs: Additional arguments for single_likelihood_loss. Returns: - Mean loss over the batch + Mean loss over the batch. """ keys = jax.random.split(key, len(zs)) vec_loss = eqx.filter_vmap( @@ -212,13 +221,21 @@ def likelihood_loss( def create_optimizer(config: "Config") -> optax.GradientTransformation: - """Create optimizer from configuration. + """Create AdamW optimizer with optional learning rate schedule. + + When enable_scheduler is True, uses a piecewise constant schedule: + - Epochs 0-19: learning_rate + - Epochs 20-69: learning_rate * 0.5 + - Epochs 70+: learning_rate * 0.1 Args: - config: Configuration object containing training parameters + config: Configuration with training.learning_rate, + training.weight_decay, training.enable_scheduler, and + training.multisteps_every_k. Returns: - Configured optax optimizer + Configured optax optimizer with gradient accumulation and + finite gradient checking. """ # Create learning rate schedule if config.training.enable_scheduler: @@ -254,11 +271,11 @@ def create_optimizer(config: "Config") -> optax.GradientTransformation: def save_model(model, path): - """Save model to disk. + """Save model weights to disk using equinox serialization. Args: - model: Trained model to save - path: Output path for the saved model + model: Equinox model to save. + path: File path for saved model (typically .eqx extension). """ eqx.tree_serialise_leaves(path, model) print(f"Model saved to {path}") @@ -289,39 +306,40 @@ def train( scalar: bool = False, epoch_start: int = 0, ) -> tuple[eqx.Module, list, list]: - """Train a continuous normalizing flow model. + """Train a CNF model with multi-GPU support. - This function implements the main training loop with epoch-level data - sharding for optimal multi-GPU performance. Data is resharded once per - epoch with batch-first distribution across devices. + Implements the main training loop with automatic data sharding across + multiple GPUs. Data is resharded once per epoch for optimal performance. + Saves periodic checkpoints and tracks train/validation loss history. Args: - key: Random key for training - model: CNF model to train - optim: Optimizer (from create_optimizer) - epochs: Number of training epochs - conditions: All hit pattern conditions - t1s: All scaled z coordinates - zs: All physical z coordinates - posrec_model: Pretrained position reconstruction model - civ_map: Charge-in-volume survival probability map - n_train: Number of training samples - n_batch: Batch size - n_samples: Samples per likelihood evaluation - n_test: Number of test samples - tpc_r: TPC radius for boundary constraints - radius_buffer: Buffer for predictions beyond TPC radius - use_best: Whether to return best model based on validation loss - save_iter: Every n number of epochs to save the model - save_file_name: Name of the file to save the model, default is "model" - output_path: Path to directory for saving, default current directory - loss_fn: Loss function to use - num_devices: Number of devices to use for data parallelization - scalar: If True, omit curl loss component - epoch_start: First epoch to begin training at + key: JAX random key for training. + model: CNF model to train. + optim: Optax optimizer (from create_optimizer). + epochs: Number of training epochs. + conditions: Full dataset of hit patterns. + t1s: Full dataset of scaled z coordinates (ODE times). + zs: Full dataset of physical z coordinates. + posrec_model: Frozen pretrained position reconstruction model. + civ_map: CIV survival probability interpolator. + n_train: Number of samples for training split. + n_batch: Batch size per training step. + n_samples: Monte Carlo samples per event in loss computation. + n_test: Number of samples for validation split. + tpc_r: TPC radius in cm. + radius_buffer: Buffer beyond TPC radius for position sampling. + use_best: If True, return model with lowest validation loss. + save_iter: Save checkpoint every N epochs. + save_file_name: Base filename for checkpoints. + output_path: Directory for saving checkpoints and loss logs. + loss_fn: Loss function (default: likelihood_loss). + num_devices: Number of GPUs for data parallelization. + scalar: If True, use scalar potential (no curl penalty). + epoch_start: Starting epoch number (for resuming training). Returns: - Tuple of (trained_model, train_loss_history, test_loss_history) + Tuple of (model, train_losses, val_losses, best_epoch) where + model is trained (or best if use_best=True). """ opt_state = optim.init(eqx.filter(model, eqx.is_array)) @@ -558,24 +576,24 @@ def train_model_from_config( config: "Config", output_path: str = "", ) -> tuple[eqx.Module, list, list]: - """Train model using configuration parameters. + """Train a CNF model using parameters from a Config object. - Convenience function that creates optimizer and calls train() with - parameters from the configuration object. + Convenience wrapper that extracts training parameters from the config + and calls train(). Creates the optimizer internally. Args: - key: Random key for training - model: CNF model to train - conditions: Hit pattern conditions - t1s: Scaled z coordinates - zs: Physical z coordinates - posrec_model: Pretrained position reconstruction model - civ_map: Charge-in-volume survival probability map - config: Configuration object - output_path: Path to directory for saving, default is current directory + key: JAX random key for training. + model: CNF model to train. + conditions: Full dataset of hit patterns. + t1s: Full dataset of scaled z coordinates. + zs: Full dataset of physical z coordinates. + posrec_model: Pretrained position reconstruction model. + civ_map: CIV survival probability interpolator. + config: Configuration object with all training parameters. + output_path: Directory for saving checkpoints. Returns: - Tuple of (trained_model, train_loss_history, test_loss_history) + Tuple of (model, train_losses, val_losses, best_epoch). """ optimizer = create_optimizer(config) diff --git a/src/fieldflow/utils.py b/src/fieldflow/utils.py index 11f452a..86dab79 100644 --- a/src/fieldflow/utils.py +++ b/src/fieldflow/utils.py @@ -1,4 +1,8 @@ -"""Utility functions""" +"""Utility functions for FieldFlow. + +This module provides helper functions for coordinate computations and +model manipulation. +""" import equinox as eqx import jax @@ -7,15 +11,13 @@ @jax.jit def compute_r(xy_arr): - """Compute radii from (x,y) coordinates. - - Direct extraction from notebook code. + """Compute radial distances from (x, y) coordinates. Args: - xy_arr: Array of shape (N, 2) containing x,y coordinates + xy_arr: Array of shape (N, 2) containing x, y coordinates. Returns: - Array of shape (N,) containing computed radii + Array of shape (N,) containing radial distances sqrt(x² + y²). """ return jnp.sqrt(xy_arr[:, 0] ** 2 + xy_arr[:, 1] ** 2) From 051e46ca94944329eedb15c4c2248f666e220455 Mon Sep 17 00:00:00 2001 From: Juehang Qin Date: Tue, 9 Dec 2025 22:48:59 -0600 Subject: [PATCH 04/10] linting change --- src/fieldflow/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fieldflow/dataloader.py b/src/fieldflow/dataloader.py index 323d77c..9e51559 100644 --- a/src/fieldflow/dataloader.py +++ b/src/fieldflow/dataloader.py @@ -1,8 +1,8 @@ """Data loading utilities for FieldFlow. This module provides functions for loading detector data required for -CNF training, including hit patterns and charge-insensitive-volume (CIV) survival -probability maps. +CNF training, including hit patterns and charge-insensitive-volume (CIV) +survival probability maps. """ from pathlib import Path From 6d4f7913bb7fbafda2b9864a5682633083c5fe02 Mon Sep 17 00:00:00 2001 From: Juehang Qin <39200111+juehang@users.noreply.github.com> Date: Wed, 10 Dec 2025 07:27:27 -0600 Subject: [PATCH 05/10] Update src/fieldflow/train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/fieldflow/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fieldflow/train.py b/src/fieldflow/train.py index 3404437..f0e9039 100644 --- a/src/fieldflow/train.py +++ b/src/fieldflow/train.py @@ -2,7 +2,7 @@ This module provides the loss functions and training loop for learning electric field distortions in dual-phase TPCs. Training uses position -samples from a pretrained reconstruction flow weighted by charge-in-volume +samples from a pretrained reconstruction flow weighted by charge-insensitive-volume survival probabilities. """ From e238e855c1d479544271c9946d1a5554c250dc29 Mon Sep 17 00:00:00 2001 From: Juehang Qin <39200111+juehang@users.noreply.github.com> Date: Wed, 10 Dec 2025 07:27:52 -0600 Subject: [PATCH 06/10] Update src/fieldflow/train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/fieldflow/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fieldflow/train.py b/src/fieldflow/train.py index f0e9039..47058e6 100644 --- a/src/fieldflow/train.py +++ b/src/fieldflow/train.py @@ -105,7 +105,7 @@ def single_likelihood_loss( t1: ODE integration time (scaled z coordinate). z: Physical z coordinate in cm (for CIV lookup). posrec_model: Pretrained position reconstruction flow. - civ_map: Charge-in-volume survival probability interpolator. + civ_map: Charge-insensitive-volume survival probability interpolator. tpc_r: TPC radius in cm for boundary constraints. radius_buffer: Buffer beyond TPC radius for position sampling. min_p: Minimum survival probability (numerical stability). From 13f8b88e74196a0dba90a130ae3c019546222800 Mon Sep 17 00:00:00 2001 From: Juehang Qin <39200111+juehang@users.noreply.github.com> Date: Wed, 10 Dec 2025 07:30:28 -0600 Subject: [PATCH 07/10] Update src/fieldflow/train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/fieldflow/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fieldflow/train.py b/src/fieldflow/train.py index 47058e6..b670011 100644 --- a/src/fieldflow/train.py +++ b/src/fieldflow/train.py @@ -575,7 +575,7 @@ def train_model_from_config( civ_map: RegularGridInterpolator, config: "Config", output_path: str = "", -) -> tuple[eqx.Module, list, list]: +) -> tuple[eqx.Module, list, list, int]: """Train a CNF model using parameters from a Config object. Convenience wrapper that extracts training parameters from the config From a3be2906bd8ede60e5c6c1039aa4e3562b6d48b8 Mon Sep 17 00:00:00 2001 From: Juehang Qin <39200111+juehang@users.noreply.github.com> Date: Wed, 10 Dec 2025 07:30:40 -0600 Subject: [PATCH 08/10] Update src/fieldflow/train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/fieldflow/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fieldflow/train.py b/src/fieldflow/train.py index b670011..5077e4c 100644 --- a/src/fieldflow/train.py +++ b/src/fieldflow/train.py @@ -305,7 +305,7 @@ def train( num_devices: int = 1, scalar: bool = False, epoch_start: int = 0, -) -> tuple[eqx.Module, list, list]: +) -> tuple[eqx.Module, list, list, int]: """Train a CNF model with multi-GPU support. Implements the main training loop with automatic data sharding across From f7a3a414e8b844a2324ca5621c0974597384e6de Mon Sep 17 00:00:00 2001 From: Juehang Qin Date: Wed, 10 Dec 2025 09:32:22 -0600 Subject: [PATCH 09/10] lint fix --- src/fieldflow/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fieldflow/train.py b/src/fieldflow/train.py index 5077e4c..ebff24c 100644 --- a/src/fieldflow/train.py +++ b/src/fieldflow/train.py @@ -2,8 +2,8 @@ This module provides the loss functions and training loop for learning electric field distortions in dual-phase TPCs. Training uses position -samples from a pretrained reconstruction flow weighted by charge-insensitive-volume -survival probabilities. +samples from a pretrained reconstruction flow weighted by +charge-insensitive-volume survival probabilities. """ import json From 5df323c85a90f5753ddd8cd985c0f31527197a19 Mon Sep 17 00:00:00 2001 From: napoliion Date: Thu, 11 Dec 2025 01:48:40 -0600 Subject: [PATCH 10/10] Update sample config file parameters --- sample_config.toml | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/sample_config.toml b/sample_config.toml index 0597ea2..b01e922 100644 --- a/sample_config.toml +++ b/sample_config.toml @@ -11,20 +11,21 @@ description = "Sample configuration for FieldFlow training" # Model architecture parameters data_size = 2 # Dimensionality of the input data (2D for x,y coordinates) exact_logp = true # Use exact log probability computation (more accurate but slower) -width_size = 48 # Width of neural network hidden layers -depth = 3 # Number of hidden layers in the neural network +width_size = 256 # Width of neural network hidden layers +depth = 16 # Number of hidden layers in the neural network +scalar = true # Use scalar potential model # ODE solver settings - these control the accuracy and efficiency of the flow -use_pid_controller = true # Use adaptive PIDController (recommended) vs constant step size +use_pid_controller = false # Use adaptive PIDController (recommended) vs constant step size rtol = 1e-3 # Relative tolerance for PIDController (smaller = more accurate, slower) atol = 1e-6 # Absolute tolerance for PIDController (smaller = more accurate, slower) -dtmax = 5.0 # Maximum step size for PIDController +dtmax = 2.0 # Maximum step size for PIDController dtmin = 0.05 # Minimum step size for PIDController # Time integration parameters t0 = 0.0 # Starting time for ODE integration extract_t1 = 10.0 # End time for extract phase -dt0 = 1.0 # Initial time step size +dt0 = 0.5 # Initial time step size [training] # Multi-GPU Training Support: @@ -39,45 +40,45 @@ dt0 = 1.0 # Initial time step size # Training process parameters seed = 42 # Random seed for reproducibility -learning_rate = 2e-3 # Initial learning rate (will be scheduled during training) +learning_rate = 1e-5 # Initial learning rate (will be scheduled during training) weight_decay = 1e-4 # L2 regularization parameter -epochs = 100 # Number of training epochs +epochs = 300 # Number of training epochs enable_scheduler = true # Enable learning rate scheduling for training from scratch # When false, uses constant LR = learning_rate * 0.01 # Set to false when loading pretrained models to continue training # Data and batching parameters -batch_size = 2048 # Training batch size (adjust based on GPU memory) +batch_size = 1024 # Training batch size (adjust based on GPU memory) num_devices = 1 # Number of GPUs for data parallelization # Examples: 1 (single GPU), 2 (dual GPU), 4 (quad GPU), 8 (octa GPU) # Note: batch_size should be divisible by num_devices for optimal performance n_samples = 16 # Number of samples per instance for likelihood estimation -n_train = 200000 # Size of training set -n_test = 20000 # Size of test/validation set +n_train = 65536 # Size of training set +n_test = 4096 # Size of test/validation set # Training strategy parameters use_best = true # Use the best model based on validation loss (recommended) curl_loss_multiplier = 1000.0 # Weight for curl penalty (encourages curl-free fields) z_scale = 5.0 # Scaling factor for z dimension coordinates -multisteps_every_k = 4 # Gradient accumulation steps for MultiSteps optimizer +multisteps_every_k = 2 # Gradient accumulation steps for MultiSteps optimizer [experiment] # Physical experimental setup parameters -tpc_height = 148.6515 # Height of the TPC in cm (for filtering z coordinates) +tpc_height = 259.92 # Height of the TPC in cm (for filtering z coordinates) tpc_r = 129.96 # Radius of the TPC in cm (for boundary constraints) [posrec] # Position reconstruction flow model parameters # These should match the pretrained position reconstruction model -flow_layers = 5 # Number of coupling layers in the flow +flow_layers = 6 # Number of coupling layers in the flow nn_width = 128 # Width of neural networks in coupling layers -nn_depth = 3 # Depth of neural networks in coupling layers +nn_depth = 6 # Depth of neural networks in coupling layers invert_bool = false # Whether to invert the flow (should match pretrained model) -cond_dim = 494 # Conditioning dimension (should match hit pattern size) +cond_dim = 860 # Conditioning dimension (should match hit pattern size) # Spline transformation parameters spline_knots = 5 # Number of knots for rational quadratic splines spline_interval = 5.0 # Interval for spline transformations # Coordinate transformation parameters -radius_buffer = 20.0 # Buffer for predictions beyond TPC radius (in cm) \ No newline at end of file +radius_buffer = 0.0 # Buffer for predictions beyond TPC radius (in cm) \ No newline at end of file