|
275 | 275 | }, |
276 | 276 | { |
277 | 277 | "cell_type": "code", |
278 | | - "execution_count": 8, |
| 278 | + "execution_count": null, |
279 | 279 | "metadata": {}, |
280 | 280 | "outputs": [], |
281 | 281 | "source": [ |
282 | | - "# step size for ode solver\n", |
283 | | - "step_size = 0.05\n", |
284 | | - "\n", |
285 | | - "norm = cm.colors.Normalize(vmax=50, vmin=0)\n", |
286 | | - "\n", |
287 | 282 | "batch_size = 50000 # batch size\n", |
288 | | - "eps_time = 1e-2\n", |
289 | | - "T = torch.linspace(0,1,10) # sample times\n", |
| 283 | + "T = torch.linspace(0,1,10) # sample times -> step size 0.1\n", |
290 | 284 | "T = T.to(device=device)\n", |
291 | 285 | "\n", |
292 | 286 | "x_init = torch.randn((batch_size, 2), dtype=torch.float32, device=device)\n", |
293 | 287 | "solver = ODESolver(velocity_model=wrapped_vf) # create an ODESolver class\n", |
294 | | - "sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True) # sample from the model" |
| 288 | + "sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=None, return_intermediates=True) # sample from the model" |
295 | 289 | ] |
296 | 290 | }, |
297 | 291 | { |
|
375 | 369 | }, |
376 | 370 | { |
377 | 371 | "cell_type": "code", |
378 | | - "execution_count": 12, |
| 372 | + "execution_count": null, |
379 | 373 | "metadata": {}, |
380 | 374 | "outputs": [], |
381 | 375 | "source": [ |
|
385 | 379 | "# compute log likelihood with unbiased hutchinson estimator, average over num_acc\n", |
386 | 380 | "num_acc = 10\n", |
387 | 381 | "log_p_acc = 0\n", |
| 382 | + "step_size = 0.5\n", |
388 | 383 | "\n", |
389 | 384 | "for i in range(num_acc):\n", |
390 | 385 | " _, log_p = solver.compute_likelihood(x_1=x_1, method='midpoint', step_size=step_size, exact_divergence=False, log_p0=gaussian_log_density)\n", |
|
0 commit comments