diff --git a/old_code/prototype_train_clean.ipynb b/old_code/prototype_train_clean.ipynb deleted file mode 100644 index 7503d5f..0000000 --- a/old_code/prototype_train_clean.ipynb +++ /dev/null @@ -1,1268 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "196a065c-5047-4095-8135-da88a976b28a", - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "import os\n", - "import pickle\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" #select GPU, -1 means use CPU\n", - "\n", - "import equinox as eqx\n", - "from flowjax.distributions import StandardNormal, Logistic, StudentT, AbstractDistribution, Normal\n", - "from flowjax.flows import masked_autoregressive_flow, block_neural_autoregressive_flow, triangular_spline_flow, coupling_flow\n", - "import h5py\n", - "import jax\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib import colormaps\n", - "from matplotlib.colors import LogNorm\n", - "import numpy as np\n", - "import pandas as pd\n", - "from scipy import stats\n", - "\n", - "from tqdm import trange, tqdm\n", - "import tqdm.utils as tutils\n", - "def ssl(x):\n", - " return 100, 200\n", - "tutils._screen_shape_linux = ssl\n", - "\n", - "from neural_net_defs import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1748bc84-858f-4bf3-9314-fd7865959fe9", - "metadata": {}, - "outputs": [], - "source": [ - "import gzip\n", - "import json\n", - "import optax\n", - "import diffrax\n", - "from jaxtyping import PyTree, Array\n", - "from copy import deepcopy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cbf130a0-a2d3-40e7-95ca-9e180d78d47a", - "metadata": {}, - "outputs": [], - "source": [ - "jax.devices()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8da2f577-ee56-48d9-9d46-84e0ce22357a", - "metadata": {}, - "outputs": [], - "source": [ - "tpc_r = 66.4" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "308d52f0-bad0-48f9-881a-49b4cc2e85f5", - "metadata": {}, - "outputs": [], - "source": [ - "data_obj = np.load('kr83_sr1_50runs_FDCtrain.npz')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "03b26d51-c1da-408a-bf79-762471504a20", - "metadata": {}, - "outputs": [], - "source": [ - "z = data_obj['z_corr']\n", - "conditions = data_obj['condition']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dcb6d249-bcf2-4bed-8e63-85fba118495a", - "metadata": {}, - "outputs": [], - "source": [ - "key = jax.random.PRNGKey(42)\n", - "key, flow_key = jax.random.split(key, 2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7b4759e-2069-4760-ae6e-193effaea4bc", - "metadata": {}, - "outputs": [], - "source": [ - "conditions.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18340d28-2e58-4824-84a6-690d5a628df4", - "metadata": {}, - "outputs": [], - "source": [ - "flow = coupling_flow(\n", - " flow_key, base_dist=StandardNormal((2,)), invert=False, flow_layers=flow_layers, nn_width=NN_width, nn_depth=NN_depth, nn_activation=activation, cond_dim=conditions.shape[1], transformer=bijection\n", - ")\n", - "flow = eqx.tree_deserialise_leaves(\"../flow_posrec/posrec_flow_uniform_100e_2to5e_PMTs_turned_off.eqx\", flow)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c679ddc9-2777-4cdf-9090-65651277ee7a", - "metadata": {}, - "outputs": [], - "source": [ - "compiled_sample = eqx.filter_jit(flow.sample)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6a83192a-19e1-40fd-b91c-fb4e584adecc", - "metadata": {}, - "outputs": [], - "source": [ - "def generate_samples(key, conditions, N_samples):\n", - " output = compiled_sample(key, (N_samples,), condition=conditions)\n", - " return data_inv_transformation(jnp.reshape(output, (-1,2)))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36d923c0-6aef-4df8-8a0f-38b2453ca936", - "metadata": {}, - "outputs": [], - "source": [ - "generate_samples(key, conditions[0:4], 2).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "75825c60-ef39-40f0-b96b-b276def9aafe", - "metadata": {}, - "outputs": [], - "source": [ - "class Func(eqx.Module):\n", - " layers: list[eqx.nn.Linear]\n", - "\n", - " def __init__(self, *, data_size, width_size, depth, key, **kwargs):\n", - " super().__init__(**kwargs)\n", - " keys = jax.random.split(key, depth + 1)\n", - " layers = []\n", - " if depth == 0:\n", - " layers.append(\n", - " ConcatSquash(in_size=data_size, out_size=data_size, key=keys[0])\n", - " )\n", - " else:\n", - " layers.append(\n", - " ConcatSquash(in_size=data_size, out_size=width_size, key=keys[0])\n", - " )\n", - " for i in range(depth - 1):\n", - " layers.append(\n", - " ConcatSquash(\n", - " in_size=width_size, out_size=width_size, key=keys[i + 1]\n", - " )\n", - " )\n", - " layers.append(\n", - " ConcatSquash(in_size=width_size, out_size=data_size, key=keys[-1])\n", - " )\n", - " self.layers = layers\n", - "\n", - " def __call__(self, t, y, args):\n", - " t = jnp.asarray(t)[None]\n", - " for layer in self.layers[:-1]:\n", - " y = layer(t, y)\n", - " y = jax.nn.tanh(y)\n", - " y = self.layers[-1](t, y)\n", - " return y\n", - "\n", - "\n", - "# Credit: this layer, and some of the default hyperparameters below, are taken from the\n", - "# FFJORD repo.\n", - "class ConcatSquash(eqx.Module):\n", - " lin1: eqx.nn.Linear\n", - " lin2: eqx.nn.Linear\n", - " lin3: eqx.nn.Linear\n", - "\n", - " def __init__(self, *, in_size, out_size, key, **kwargs):\n", - " super().__init__(**kwargs)\n", - " key1, key2, key3 = jax.random.split(key, 3)\n", - " self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1)\n", - " self.lin2 = eqx.nn.Linear(1, out_size, key=key2)\n", - " self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3)\n", - "\n", - " def __call__(self, t, y):\n", - " return self.lin1(y) * jax.nn.sigmoid(self.lin2(t)) + self.lin3(t)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5bd1fd91-379f-4882-bc58-15c1514fca3e", - "metadata": {}, - "outputs": [], - "source": [ - "class MLPFunc(eqx.Module):\n", - " layers: list[eqx.nn.Linear]\n", - " # layers_t: list[eqx.nn.Linear]\n", - "\n", - " def __init__(self, *, data_size, width_size, depth, key, **kwargs):\n", - " super().__init__(**kwargs)\n", - " keys = jax.random.split(key, depth + 1)\n", - " layers = []\n", - " # layers_t = []\n", - " if depth == 0:\n", - " layers.append(\n", - " eqx.nn.Linear(data_size+1, data_size, key=keys[0])\n", - " )\n", - " else:\n", - " layers.append(\n", - " eqx.nn.Linear(data_size+1, width_size, key=keys[0])\n", - " )\n", - " for i in range(depth - 1):\n", - " layers.append(\n", - " eqx.nn.Linear(\n", - " width_size, width_size, key=keys[i + 1]\n", - " )\n", - " )\n", - " layers.append(\n", - " eqx.nn.Linear(width_size, data_size, key=keys[-1])\n", - " )\n", - "\n", - " # if depth == 0:\n", - " # layers_t.append(\n", - " # eqx.nn.Linear(1, 1, key=keys[0])\n", - " # )\n", - " # else:\n", - " # layers_t.append(\n", - " # eqx.nn.Linear(1, width_size, key=keys[0])\n", - " # )\n", - " # for i in range(depth - 1):\n", - " # layers_t.append(\n", - " # eqx.nn.Linear(\n", - " # width_size, width_size, key=keys[i + 1]\n", - " # )\n", - " # )\n", - " # layers_t.append(\n", - " # eqx.nn.Linear(width_size, 1, key=keys[-1])\n", - " # )\n", - " self.layers = layers\n", - " # self.layers_t = layers_t\n", - "\n", - " def __call__(self, t, y, args):\n", - " t = jnp.asarray(t)[None]\n", - " # y_init = y\n", - " y = jnp.concatenate((y,t), axis=-1)\n", - "\n", - " for layer in self.layers[:-1]:\n", - " y = layer(y)\n", - " y = jax.nn.silu(y)\n", - " y = self.layers[-1](y)\n", - " # for layer in self.layers_t[:-1]:\n", - " # t = layer(t)\n", - " # t = jax.nn.silu(t)\n", - " # t = self.layers[-1](t)\n", - " return y" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "53334027-fbb2-4292-ade6-b0efcd4f3770", - "metadata": {}, - "outputs": [], - "source": [ - "def approx_logp_wrapper(t, y, args):\n", - " y, _ = y\n", - " *args, eps, func = args\n", - " fn = lambda y: func(t, y, args)\n", - " f, vjp_fn = jax.vjp(fn, y)\n", - " (eps_dfdy,) = vjp_fn(eps)\n", - " logp = jnp.sum(eps_dfdy * eps)\n", - " return f, logp\n", - "\n", - "\n", - "def exact_logp_wrapper(t, y, args):\n", - " y, _ = y\n", - " *args, _, func = args\n", - " fn = lambda y: func(t, y, args)\n", - " f, vjp_fn = jax.vjp(fn, y)\n", - " (size,) = y.shape # this implementation only works for 1D input\n", - " eye = jnp.eye(size)\n", - " (dfdy,) = jax.vmap(vjp_fn)(eye)\n", - " logp = jnp.trace(dfdy)\n", - " return f, logp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dc74fc5b-27d1-434b-b87b-4f110a289cf2", - "metadata": {}, - "outputs": [], - "source": [ - "class CNF(eqx.Module):\n", - " func_drift: eqx.Module\n", - " func_extract: eqx.Module\n", - " data_size: int\n", - " exact_logp: bool\n", - " t0: float\n", - " extract_t1: float\n", - " dt0: float\n", - " stepsizecontroller: diffrax.AbstractStepSizeController\n", - " \n", - " def __init__(\n", - " self,\n", - " *,\n", - " data_size,\n", - " exact_logp,\n", - " width_size,\n", - " depth,\n", - " key,\n", - " stepsizecontroller=diffrax.ConstantStepSize(),\n", - " func=Func,\n", - " **kwargs,\n", - " ):\n", - " keys = jax.random.split(key, 2)\n", - " super().__init__(**kwargs)\n", - " self.func_drift = (\n", - " func(\n", - " data_size=data_size,\n", - " width_size=width_size,\n", - " depth=depth,\n", - " key=keys[0],\n", - " )\n", - " )\n", - " self.func_extract = (\n", - " func(\n", - " data_size=data_size,\n", - " width_size=width_size,\n", - " depth=depth,\n", - " key=keys[1],\n", - " )\n", - " )\n", - " self.data_size = data_size\n", - " self.exact_logp = exact_logp\n", - " self.t0 = 0\n", - " self.extract_t1 = 10\n", - " self.dt0 = 1\n", - " self.stepsizecontroller=stepsizecontroller\n", - "\n", - " def transform(self, *, y, t1):\n", - " term = diffrax.ODETerm(self.func_extract)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.t0, self.extract_t1, self.dt0, y, stepsize_controller=self.stepsizecontroller)\n", - " (y,) = sol.ys\n", - " \n", - " term = diffrax.ODETerm(self.func_drift)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.t0, t1, self.dt0, y, stepsize_controller=self.stepsizecontroller)\n", - " (y,) = sol.ys\n", - " return y\n", - "\n", - " def transform_and_log_det(self, *, y, t1):\n", - " if self.exact_logp:\n", - " term = diffrax.ODETerm(exact_logp_wrapper)\n", - " else:\n", - " term = diffrax.ODETerm(approx_logp_wrapper)\n", - " eps = jax.random.normal(key, y.shape)\n", - " delta_log_likelihood = 0.0\n", - " \n", - " y = (y, delta_log_likelihood)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.t0, self.extract_t1, self.dt0, y, (eps, self.func_extract), stepsize_controller=self.stepsizecontroller)\n", - " (y,), (delta_log_likelihood,) = sol.ys\n", - "\n", - " y = (y, delta_log_likelihood)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.t0, t1, self.dt0, y, (eps, self.func_drift), stepsize_controller=self.stepsizecontroller)\n", - " (y,), (delta_log_likelihood,) = sol.ys\n", - " return y, delta_log_likelihood\n", - "\n", - " def inverse_and_log_det(self, *, y, t1):\n", - " if self.exact_logp:\n", - " term = diffrax.ODETerm(exact_logp_wrapper)\n", - " else:\n", - " term = diffrax.ODETerm(approx_logp_wrapper)\n", - " eps = jax.random.normal(key, y.shape)\n", - " delta_log_likelihood = 0.0\n", - "\n", - " y = (y, delta_log_likelihood)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, t1, self.t0, -self.dt0, y, (eps, self.func_drift), stepsize_controller=self.stepsizecontroller)\n", - " (y,), (delta_log_likelihood,) = sol.ys\n", - " \n", - " y = (y, delta_log_likelihood)\n", - " solver = diffrax.Tsit5()\n", - " sol = diffrax.diffeqsolve(term, solver, self.extract_t1, self.t0, -self.dt0, y, (eps, self.func_extract), stepsize_controller=self.stepsizecontroller)\n", - " (y,), (delta_log_likelihood,) = sol.ys\n", - " return y, delta_log_likelihood" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13043c03-9e57-46d8-ad76-ddd5a34d3777", - "metadata": {}, - "outputs": [], - "source": [ - "from jax.scipy.interpolate import RegularGridInterpolator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "85e2ebe5-932b-494a-82ea-de56125f3941", - "metadata": {}, - "outputs": [], - "source": [ - "def load_civ():\n", - " civ_file_name = \"field_dependent_radius_depth_maps_B2d75n_C2d75n_G0d3p_A4d9p_T0d9n_PMTs1d3n_FSR0d65p_QPTFE_0d5n_0d4p.json.gz\"\n", - " \n", - " with gzip.open(civ_file_name, \"rb\") as f:\n", - " file = json.load(f)\n", - " civ_map = RegularGridInterpolator(\n", - " tuple([np.linspace(*ax[1]) for ax in file['coordinate_system']]),\n", - " np.array(file['survival_probability_map']).reshape([ax[1][-1] for ax in file['coordinate_system']]),\n", - " bounds_error=False,\n", - " fill_value=0,\n", - " )\n", - " \n", - " return civ_map\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ad7b7361-487c-4062-a6ab-b6b8965355d5", - "metadata": {}, - "outputs": [], - "source": [ - "civ_map = load_civ()\n", - "vec_civ_map = jax.vmap(civ_map)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d89c4762-251f-44e9-be76-e70bd50fc63a", - "metadata": {}, - "outputs": [], - "source": [ - "tpc_height = 148.6515\n", - "z_scale = 5\n", - "data_bool = (z>-tpc_height) & (z<0)\n", - "z_sel = z[data_bool]\n", - "z_sel_scaled = -z_sel/z_scale\n", - "cond_sel = conditions[data_bool]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a17f4f67-dc5c-4716-9a3b-3aebdc723711", - "metadata": {}, - "outputs": [], - "source": [ - "len(z_sel_scaled)/len(z)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60d27398-22cd-4b52-9c66-972a48a69cde", - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def compute_r(xy_arr):\n", - " return jnp.sqrt(xy_arr[:,0]**2 + xy_arr[:,1]**2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1af3831-5b8b-45c5-b511-014b6c90a8d0", - "metadata": {}, - "outputs": [], - "source": [ - "key, model_key = jax.random.split(key, 2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c25ae05-4b92-47ca-b9d7-01712bbffbc6", - "metadata": {}, - "outputs": [], - "source": [ - "model = CNF(\n", - " data_size=2,\n", - " exact_logp=True,\n", - " width_size=48,\n", - " depth=3,\n", - " key=model_key,\n", - " stepsizecontroller=diffrax.PIDController(rtol=1e-3, atol=1e-6, dtmax=5),\n", - " func=MLPFunc\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61f70527-822e-4c6e-ba24-8373ca65aa03", - "metadata": {}, - "outputs": [], - "source": [ - "def rolloff_func(x, rolloff=1e-2):\n", - " return x+rolloff*jnp.exp(-x/rolloff)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8da4c8b8-3acd-4aa5-9892-29e3ac28b779", - "metadata": {}, - "outputs": [], - "source": [ - "def curl_loss(key, model, z, x, extract_max_z=10.):\n", - " rand_z = jax.random.uniform(key, 1, minval=0.0, maxval=extract_max_z)\n", - " jac_drift = jax.jacfwd(lambda a:model.func_drift(z, a, 0.))(x)\n", - " jac_ex = jax.jacfwd(lambda a:model.func_extract(rand_z[0], a, 0.))(x)\n", - " return (jac_drift[1,0] - jac_drift[0,1])**2 + (jac_ex[1,0] - jac_ex[0,1])**2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a33d468a-4a68-4999-9d78-687b6298777b", - "metadata": {}, - "outputs": [], - "source": [ - "def single_likelihood_loss(key, model, condition, t1, z, min_p=1e-3, N_samples=4, tpc_r=66.4, curl_loss_multiplier=1000.):\n", - " keys = jax.random.split(key,2)\n", - " samples = generate_samples(keys[0], condition[np.newaxis,...], N_samples)\n", - " transformed_samples, logdet = eqx.filter_vmap(lambda y: model.transform_and_log_det(y=y, t1=t1))(samples)\n", - " sample_r = compute_r(transformed_samples)\n", - " p_surv = vec_civ_map(jnp.vstack((sample_r, np.repeat(z, N_samples))).T)\n", - " # p_surv = jnp.where(p_surv