|
parameters = HMCParameters( |
|
jnp.ones(initial_state.position.shape[0], dtype=jnp.int32) |
|
* num_integration_steps, |
|
step_size, |
|
inverse_mass_matrix, |
|
) |
|
|
|
return last_chain_state, parameters, warmup_chain |
We currently pass the parameters directly to the runtime; While the values of the parameters are passed in the Trace object, it would be convenient to update the Kernel's parameter values.
mcx/mcx/inference/hmc.py
Lines 198 to 205 in 2a2b948
We currently pass the parameters directly to the runtime; While the values of the parameters are passed in the
Traceobject, it would be convenient to update the Kernel's parameter values.