From 82b3ab4088a82b3bc951175a0e9d5f9ee971a77a Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 6 Sep 2025 11:31:27 +0200 Subject: [PATCH 001/101] allow tensor in DiagonalNormal dimension --- bayesflow/distributions/diagonal_normal.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 6b64445c7..25a7797df 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -57,7 +57,7 @@ def __init__( self.trainable_parameters = trainable_parameters self.seed_generator = seed_generator or keras.random.SeedGenerator() - self.dim = None + self.dims = None self._mean = None self._std = None @@ -65,10 +65,10 @@ def build(self, input_shape: Shape) -> None: if self.built: return - self.dim = int(input_shape[-1]) + self.dims = input_shape[1:] - self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32") - self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32") + self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32") + self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32") if self.trainable_parameters: self._mean = self.add_weight( @@ -91,14 +91,16 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) + log_normalization_constant = -0.5 * ops.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( + ops.log(self._std) + ) result += log_normalization_constant return result @allow_batch_size def sample(self, batch_shape: Shape) -> Tensor: - return self._mean + self._std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator) + return self._mean + self._std * keras.random.normal(shape=batch_shape + self.dims, seed=self.seed_generator) def get_config(self): base_config = super().get_config() From 8fbf7374ca95826e6f7f954e237c0c8d1a955b95 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 15:04:54 +0200 Subject: [PATCH 002/101] fix sum dims --- bayesflow/distributions/diagonal_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 25a7797df..9cf068137 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -91,7 +91,7 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * ops.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( + log_normalization_constant = -0.5 * np.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( ops.log(self._std) ) result += log_normalization_constant From 5c27246b3aa48f9c8ba400596200c20ada7251ae Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 15:08:16 +0200 Subject: [PATCH 003/101] fix batch_shape for sample --- bayesflow/approximators/continuous_approximator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index fb2e95a56..f27a612f0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -535,7 +535,10 @@ def _sample( inference_conditions = keras.ops.broadcast_to( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - batch_shape = keras.ops.shape(inference_conditions)[:-1] + batch_shape = ( + batch_size, + num_samples, + ) else: batch_shape = (num_samples,) From c684bcace2add939945da94af16f3de8b0f9cdc9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 17:55:10 +0200 Subject: [PATCH 004/101] dims to tuple --- bayesflow/distributions/diagonal_normal.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 9cf068137..83b3e556b 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -65,7 +65,7 @@ def build(self, input_shape: Shape) -> None: if self.built: return - self.dims = input_shape[1:] + self.dims = tuple(input_shape[1:]) self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32") self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32") @@ -91,9 +91,7 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * np.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( - ops.log(self._std) - ) + log_normalization_constant = -0.5 * sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) result += log_normalization_constant return result From 06976344cb70d9e268392437d0240268cf3d4778 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 13:04:34 +0200 Subject: [PATCH 005/101] first draft compositional --- .../approximators/continuous_approximator.py | 164 +++++++++++ .../diffusion_model/diffusion_model.py | 261 ++++++++++++++++++ bayesflow/networks/inference_network.py | 17 ++ 3 files changed, 442 insertions(+) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index f27a612f0..5a183922f 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -638,3 +638,167 @@ def _batch_size_from_data(self, data: Mapping[str, any]) -> int: inference variables as present. """ return keras.ops.shape(data["inference_variables"])[0] + + def compositional_sample( + self, + *, + num_samples: int, + conditions: Mapping[str, np.ndarray], + split: bool = False, + **kwargs, + ) -> dict[str, np.ndarray]: + """ + Generates compositional samples from the approximator given input conditions. + The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...). + This method handles the extra compositional dimension appropriately. + + Parameters + ---------- + num_samples : int + Number of samples to generate. + conditions : dict[str, np.ndarray] + Dictionary of conditioning variables as NumPy arrays with shape + (n_datasets, n_compositional_conditions, ...). + split : bool, default=False + Whether to split the output arrays along the last axis and return one column vector per target variable + samples. + **kwargs : dict + Additional keyword arguments for the adapter and sampling process. + + Returns + ------- + dict[str, np.ndarray] + Dictionary containing generated samples with compositional structure preserved. + """ + original_shapes = {} + flattened_conditions = {} + for key, value in conditions.items(): # Flatten compositional dimensions + original_shapes[key] = value.shape + n_datasets, n_comp = value.shape[:2] + flattened_shape = (n_datasets * n_comp,) + value.shape[2:] + flattened_conditions[key] = value.reshape(flattened_shape) + n_datasets, n_comp = original_shapes[next(iter(original_shapes))][:2] + + # Prepare data using existing method (handles adaptation and standardization) + prepared_conditions = self._prepare_data(flattened_conditions, **kwargs) + + # Remove any superfluous keys, just retain actual conditions + prepared_conditions = {k: v for k, v in prepared_conditions.items() if k in self.CONDITION_KEYS} + + # Sample using compositional sampling + samples = self._compositional_sample( + num_samples=num_samples, n_datasets=n_datasets, n_compositional=n_comp, **prepared_conditions, **kwargs + ) + + if "inference_variables" in self.standardize: + samples = self.standardize_layers["inference_variables"](samples, forward=False) + + samples = {"inference_variables": samples} + samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples) + + # Back-transform quantities and samples + samples = self._back_transform_compositional(samples, original_shapes, **kwargs) + + if split: + samples = split_arrays(samples, axis=-1) + return samples + + def _compositional_sample( + self, + num_samples: int, + n_datasets: int, + n_compositional: int, + inference_conditions: Tensor = None, + summary_variables: Tensor = None, + **kwargs, + ) -> Tensor: + """ + Internal method for compositional sampling. + """ + if self.summary_network is None: + if summary_variables is not None: + raise ValueError("Cannot use summary variables without a summary network.") + else: + if summary_variables is None: + raise ValueError("Summary variables are required when a summary network is present.") + + if self.summary_network is not None: + summary_outputs = self.summary_network( + summary_variables, **filter_kwargs(kwargs, self.summary_network.call) + ) + inference_conditions = concatenate_valid([inference_conditions, summary_outputs], axis=-1) + + if inference_conditions is not None: + # Reshape conditions for compositional sampling + # From (n_datasets * n_comp, dims) to (n_datasets, n_comp, dims) + condition_dims = keras.ops.shape(inference_conditions)[-1] + inference_conditions = keras.ops.reshape( + inference_conditions, (n_datasets, n_compositional, condition_dims) + ) + + # Expand for num_samples: (n_datasets, n_comp, dims) -> (n_datasets, n_comp, num_samples, dims) + inference_conditions = keras.ops.expand_dims(inference_conditions, axis=2) + inference_conditions = keras.ops.broadcast_to( + inference_conditions, (n_datasets, n_compositional, num_samples, condition_dims) + ) + + batch_shape = (n_datasets, n_compositional, num_samples) + else: + raise ValueError("Cannot perform compositional sampling without inference conditions.") + + return self.inference_network.sample( + batch_shape, + conditions=inference_conditions, + compositional=True, + **filter_kwargs(kwargs, self.inference_network.sample), + ) + + def _back_transform_compositional( + self, samples: dict[str, np.ndarray], original_shapes: dict[str, tuple], **kwargs + ) -> dict[str, np.ndarray]: + """ + Back-transform compositional samples, handling the extra compositional dimension. + """ + # Get the sample shape to understand the compositional structure + inference_samples = samples["inference_variables"] + sample_shape = inference_samples.shape + + # Determine compositional dimensions from original shapes + # Assuming all condition keys have the same compositional structure + first_key = next(iter(original_shapes.keys())) + n_datasets, n_compositional = original_shapes[first_key][:2] + + # Reshape samples to match compositional structure if needed + if len(sample_shape) == 3: # (n_datasets * n_comp, num_samples, dims) + num_samples, dims = sample_shape[1], sample_shape[2] + inference_samples = inference_samples.reshape(n_datasets, n_compositional, num_samples, dims) + samples["inference_variables"] = inference_samples + + # For back-transformation, we might need to flatten again temporarily + # depending on how the adapter expects the data + flattened_samples = {} + for key, value in samples.items(): + if len(value.shape) == 4: # (n_datasets, n_comp, num_samples, dims) + n_d, n_c, n_s, dims = value.shape + flattened_samples[key] = value.reshape(n_d * n_c, n_s, dims) + else: + flattened_samples[key] = value + + # Apply inverse transformation + transformed = self.adapter(flattened_samples, inverse=True, strict=False, **kwargs) + + # Reshape back to compositional structure + final_samples = {} + for key, value in transformed.items(): + if key in original_shapes: + # Reshape to include compositional dimension + if len(value.shape) >= 2: + num_samples = value.shape[1] + remaining_dims = value.shape[2:] if len(value.shape) > 2 else () + final_samples[key] = value.reshape(n_datasets, n_compositional, num_samples, *remaining_dims) + else: + final_samples[key] = value + else: + final_samples[key] = value + + return final_samples diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index ca8a634e9..e815b89db 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -362,6 +362,7 @@ def _forward( conditions: Tensor = None, density: bool = False, training: bool = False, + compositional: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} @@ -412,6 +413,7 @@ def _inverse( conditions: Tensor = None, density: bool = False, training: bool = False, + compositional: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} @@ -541,3 +543,262 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss} + + def compositional_velocity( + self, + xz: Tensor, + time: float | Tensor, + stochastic_solver: bool, + conditions: Tensor, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional velocity for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + stochastic_solver : bool + Whether to use stochastic (SDE) or deterministic (ODE) formulation + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + if conditions is None: + raise ValueError("Conditions are required for compositional sampling") + + # Get shapes for compositional structure + n_datasets, n_compositional = ops.shape(xz)[0], ops.shape(xz)[1] + print(xz.shape, n_datasets, n_compositional) + + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + # Compute individual dataset scores + individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions, training) + + # Compute prior score component + prior_score = self.compute_prior_score(xz) + + # Combine scores using compositional formula + # s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + n = ops.cast(n_compositional, dtype=ops.dtype(time)) + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) + + # Sum individual scores across compositional dimension + summed_individual_scores = ops.sum(individual_scores, axis=1, keepdims=True) + + # Prior contribution: (1-n)(1-t) * prior_score + prior_weight = (1.0 - n) * (1.0 - time_tensor) + weighted_prior = prior_weight * prior_score + + # Combined score + compositional_score = weighted_prior + summed_individual_scores + + # Broadcast back to full compositional shape + compositional_score = ops.broadcast_to(compositional_score, ops.shape(xz)) + + # Compute velocity using standard drift-diffusion formulation + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) + + if stochastic_solver: + # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW + velocity = f - g_squared * compositional_score + else: + # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt + velocity = f - 0.5 * g_squared * compositional_score + + print(velocity.shape) + return velocity + + def _compute_individual_scores( + self, + xz: Tensor, + log_snr_t: Tensor, + alpha_t: Tensor, + sigma_t: Tensor, + conditions: Tensor, + training: bool, + ) -> Tensor: + """ + Compute individual dataset scores s_ψ(θ,t,yᵢ) for each compositional condition. + + Returns + ------- + Tensor + Individual scores with shape (n_datasets, n_compositional, ...) + """ + # Apply subnet to each compositional condition separately + transformed_log_snr = self._transform_log_snr(log_snr_t) + + # Reshape for processing: flatten compositional dimension temporarily + original_shape = ops.shape(xz) + n_datasets, n_comp = original_shape[0], original_shape[1] + remaining_dims = original_shape[2:] + + # Flatten for subnet application + xz_flat = ops.reshape(xz, (n_datasets * n_comp,) + remaining_dims) + log_snr_flat = ops.reshape(transformed_log_snr, (n_datasets * n_comp,) + ops.shape(transformed_log_snr)[2:]) + conditions_flat = ops.reshape(conditions, (n_datasets * n_comp,) + ops.shape(conditions)[2:]) + alpha_flat = ops.reshape(alpha_t, (n_datasets * n_comp,) + ops.shape(alpha_t)[2:]) + sigma_flat = ops.reshape(sigma_t, (n_datasets * n_comp,) + ops.shape(sigma_t)[2:]) + + # Apply subnet + subnet_out = self._apply_subnet(xz_flat, log_snr_flat, conditions=conditions_flat, training=training) + pred = self.output_projector(subnet_out, training=training) + + # Convert prediction to x + x_pred = self.convert_prediction_to_x( + pred=pred, z=xz_flat, alpha_t=alpha_flat, sigma_t=sigma_flat, log_snr_t=log_snr_flat + ) + + # Compute score: (α_t * x_pred - z) / σ_t² + score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat) + + # Reshape back to compositional structure + score = ops.reshape(score, original_shape) + + return score + + def _compositional_forward( + self, + x: Tensor, + conditions: Tensor = None, + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Forward pass for compositional diffusion. + """ + integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for forward integration.") + + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional + x = x / ops.sqrt(ops.cast(ops.shape(x)[1], dtype=ops.dtype(x))) + + if density: + + def deltas(time, xz): + v = self.compositional_velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + # For density, we need trace but compositional trace is complex + # Simplified version - could be extended + trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) + return {"xz": v, "trace": trace} + + state = { + "xz": x, + "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), + } + state = integrate(deltas, state, **integrate_kwargs) + + z = state["xz"] + # Simplified density computation + log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) + ops.squeeze(state["trace"], axis=-1) + return z, log_density + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + } + + state = {"xz": x} + state = integrate(deltas, state, **integrate_kwargs) + z = state["xz"] + return z + + def _compositional_inverse( + self, + z: Tensor, + conditions: Tensor = None, + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Inverse pass for compositional diffusion (sampling). + """ + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + + if density: + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for density computation.") + + def deltas(time, xz): + v = self.compositional_velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) + return {"xz": v, "trace": trace} + + state = { + "xz": z, + "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), + } + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) - ops.squeeze(state["trace"], axis=-1) + return x, log_density + + state = {"xz": z} + + if integrate_kwargs["method"] == "euler_maruyama": + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, time=time, stochastic_solver=True, conditions=conditions, training=training + ) + } + + def diffusion(time, xz): + return {"xz": self.diffusion_term(xz, time=time, training=training)} + + state = integrate_stochastic( + drift_fn=deltas, + diffusion_fn=diffusion, + state=state, + seed=self.seed_generator, + **integrate_kwargs, + ) + else: + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + } + + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + return x + + @staticmethod + def compute_prior_score(xz: Tensor) -> Tensor: + return ops.ones_like(xz) # todo: Placeholder implementation + # raise NotImplementedError('Please implement the prior score computation method.') diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index b092ce2cb..f2e5c512f 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -27,11 +27,18 @@ def call( conditions: Tensor = None, inverse: bool = False, density: bool = False, + compositional: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: if inverse: + if compositional: + return self._inverse_compositional( + xz, conditions=conditions, density=density, training=training, **kwargs + ) return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) + if compositional: + return self._forward_compositional(xz, conditions=conditions, density=density, training=training, **kwargs) return self._forward(xz, conditions=conditions, density=density, training=training, **kwargs) def _forward( @@ -44,6 +51,16 @@ def _inverse( ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError + def _forward_compositional( + self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + + def _inverse_compositional( + self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + @allow_batch_size def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor: samples = self.base_distribution.sample(batch_shape) From b8e849e4393f3a5199c39beb9db6a7ed710a4063 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 13:12:52 +0200 Subject: [PATCH 006/101] first draft compositional --- bayesflow/workflows/basic_workflow.py | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 34fa03794..f1941ac3a 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -286,6 +286,37 @@ def sample( """ return self.approximator.sample(num_samples=num_samples, conditions=conditions, **kwargs) + def compositional_sample( + self, + *, + num_samples: int, + conditions: Mapping[str, np.ndarray], + **kwargs, + ) -> dict[str, np.ndarray]: + """ + Draws `num_samples` samples from the approximator given specified composition conditions. + The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...). + + Parameters + ---------- + num_samples : int + The number of samples to generate. + conditions : dict[str, np.ndarray] + A dictionary where keys represent variable names and values are + NumPy arrays containing the adapted simulated variables. Keys used as summary or inference + conditions during training should be present. + Should have shape (n_datasets, n_compositional_conditions, ...). + **kwargs : dict, optional + Additional keyword arguments passed to the approximator's sampling function. + + Returns + ------- + dict[str, np.ndarray] + A dictionary where keys correspond to variable names and + values are arrays containing the generated samples. + """ + return self.approximator.compositional_sample(num_samples=num_samples, conditions=conditions, **kwargs) + def estimate( self, *, From a280af32c5dd6ea9c657079b8e08116f7cde374b Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 13:15:16 +0200 Subject: [PATCH 007/101] first draft compositional --- bayesflow/networks/diffusion_model/diffusion_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index e815b89db..fa32b2bc6 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -672,7 +672,7 @@ def _compute_individual_scores( return score - def _compositional_forward( + def _forward_compositional( self, x: Tensor, conditions: Tensor = None, @@ -727,7 +727,7 @@ def deltas(time, xz): z = state["xz"] return z - def _compositional_inverse( + def _inverse_compositional( self, z: Tensor, conditions: Tensor = None, From b9faf31104821473f76c9dfa773cb403f7f82190 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:03:01 +0200 Subject: [PATCH 008/101] first draft compositional --- .../approximators/continuous_approximator.py | 23 ++++++++++--------- .../diffusion_model/diffusion_model.py | 6 ++--- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 5a183922f..c28424d6f 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -730,19 +730,19 @@ def _compositional_sample( if inference_conditions is not None: # Reshape conditions for compositional sampling - # From (n_datasets * n_comp, dims) to (n_datasets, n_comp, dims) - condition_dims = keras.ops.shape(inference_conditions)[-1] + # From (n_datasets * n_comp, ...., dims) to (n_datasets, n_comp, ...., dims) + condition_dims = keras.ops.shape(inference_conditions)[2:] inference_conditions = keras.ops.reshape( - inference_conditions, (n_datasets, n_compositional, condition_dims) + inference_conditions, (n_datasets, n_compositional, *condition_dims) ) # Expand for num_samples: (n_datasets, n_comp, dims) -> (n_datasets, n_comp, num_samples, dims) inference_conditions = keras.ops.expand_dims(inference_conditions, axis=2) inference_conditions = keras.ops.broadcast_to( - inference_conditions, (n_datasets, n_compositional, num_samples, condition_dims) + inference_conditions, (n_datasets, n_compositional, num_samples, *condition_dims) ) - batch_shape = (n_datasets, n_compositional, num_samples) + batch_shape = (n_datasets, num_samples) else: raise ValueError("Cannot perform compositional sampling without inference conditions.") @@ -769,18 +769,19 @@ def _back_transform_compositional( n_datasets, n_compositional = original_shapes[first_key][:2] # Reshape samples to match compositional structure if needed - if len(sample_shape) == 3: # (n_datasets * n_comp, num_samples, dims) - num_samples, dims = sample_shape[1], sample_shape[2] - inference_samples = inference_samples.reshape(n_datasets, n_compositional, num_samples, dims) + if len(sample_shape) == 3: # (n_datasets * n_comp, num_samples, ..., dims) + num_samples, dims = sample_shape[1], sample_shape[2:] + inference_samples = inference_samples.reshape(n_datasets, n_compositional, num_samples, *dims) samples["inference_variables"] = inference_samples # For back-transformation, we might need to flatten again temporarily # depending on how the adapter expects the data flattened_samples = {} for key, value in samples.items(): - if len(value.shape) == 4: # (n_datasets, n_comp, num_samples, dims) - n_d, n_c, n_s, dims = value.shape - flattened_samples[key] = value.reshape(n_d * n_c, n_s, dims) + if len(value.shape) == 4: # (n_datasets, n_comp, num_samples, ..., dims) + n_d, n_c, n_s = value.shape[:3] + dims = value.shape[3:] + flattened_samples[key] = value.reshape(n_d * n_c, n_s, *dims) else: flattened_samples[key] = value diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index fa32b2bc6..157bc360c 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -578,8 +578,8 @@ def compositional_velocity( raise ValueError("Conditions are required for compositional sampling") # Get shapes for compositional structure - n_datasets, n_compositional = ops.shape(xz)[0], ops.shape(xz)[1] - print(xz.shape, n_datasets, n_compositional) + n_compositional = ops.shape(conditions)[1] + print(ops.shape(xz), ops.shape(conditions)) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) @@ -620,7 +620,7 @@ def compositional_velocity( # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt velocity = f - 0.5 * g_squared * compositional_score - print(velocity.shape) + print(velocity.shape, velocity) return velocity def _compute_individual_scores( From 9b7eb1696ee04bbff2b72e75950e51984de2e611 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:11:20 +0200 Subject: [PATCH 009/101] fix shapes --- bayesflow/approximators/continuous_approximator.py | 2 +- bayesflow/networks/diffusion_model/diffusion_model.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index c28424d6f..12a0878af 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -731,7 +731,7 @@ def _compositional_sample( if inference_conditions is not None: # Reshape conditions for compositional sampling # From (n_datasets * n_comp, ...., dims) to (n_datasets, n_comp, ...., dims) - condition_dims = keras.ops.shape(inference_conditions)[2:] + condition_dims = keras.ops.shape(inference_conditions)[1:] inference_conditions = keras.ops.reshape( inference_conditions, (n_datasets, n_compositional, *condition_dims) ) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 157bc360c..132cf2f80 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -579,7 +579,7 @@ def compositional_velocity( # Get shapes for compositional structure n_compositional = ops.shape(conditions)[1] - print(ops.shape(xz), ops.shape(conditions)) + print(ops.shape(xz), ops.shape(conditions)) # (1, 100, 2), (1, 2, 100, 2) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) @@ -644,12 +644,13 @@ def _compute_individual_scores( transformed_log_snr = self._transform_log_snr(log_snr_t) # Reshape for processing: flatten compositional dimension temporarily - original_shape = ops.shape(xz) + original_shape = ops.shape(conditions) n_datasets, n_comp = original_shape[0], original_shape[1] - remaining_dims = original_shape[2:] # Flatten for subnet application - xz_flat = ops.reshape(xz, (n_datasets * n_comp,) + remaining_dims) + xz_flat = ops.expand_dims(xz, axis=1) # (n_datasets, 1, ...) + xz_flat = ops.broadcast_to(xz_flat, (n_datasets, n_comp) + ops.shape(xz)[1:]) + xz_flat = ops.reshape(xz_flat, (n_datasets * n_comp,) + ops.shape(xz)[1:]) log_snr_flat = ops.reshape(transformed_log_snr, (n_datasets * n_comp,) + ops.shape(transformed_log_snr)[2:]) conditions_flat = ops.reshape(conditions, (n_datasets * n_comp,) + ops.shape(conditions)[2:]) alpha_flat = ops.reshape(alpha_t, (n_datasets * n_comp,) + ops.shape(alpha_t)[2:]) From e79aac11555d37b192d0cbf26da092f2814a4f12 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:12:51 +0200 Subject: [PATCH 010/101] fix shapes --- bayesflow/networks/diffusion_model/diffusion_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 132cf2f80..7233a5ce5 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -598,7 +598,7 @@ def compositional_velocity( time_tensor = ops.cast(time, dtype=ops.dtype(xz)) # Sum individual scores across compositional dimension - summed_individual_scores = ops.sum(individual_scores, axis=1, keepdims=True) + summed_individual_scores = ops.sum(individual_scores, axis=1) # Prior contribution: (1-n)(1-t) * prior_score prior_weight = (1.0 - n) * (1.0 - time_tensor) @@ -608,7 +608,7 @@ def compositional_velocity( compositional_score = weighted_prior + summed_individual_scores # Broadcast back to full compositional shape - compositional_score = ops.broadcast_to(compositional_score, ops.shape(xz)) + # compositional_score = ops.broadcast_to(compositional_score, ops.shape(xz)) # Compute velocity using standard drift-diffusion formulation f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) From 8a802409e649012d61f9571268c5fb291a138d86 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:29:51 +0200 Subject: [PATCH 011/101] fix shapes --- .../diffusion_model/diffusion_model.py | 45 +++++++++++++------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 7233a5ce5..666e1986b 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -583,10 +583,12 @@ def compositional_velocity( # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) - log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores + print(xz.shape, log_snr_t.shape, alpha_t.shape, sigma_t.shape, conditions.shape) + # (1, 100, 2) (1, 100, 1) (1, 100, 1) (1, 100, 1) (1, 2, 100, 2) individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions, training) # Compute prior score component @@ -643,18 +645,34 @@ def _compute_individual_scores( # Apply subnet to each compositional condition separately transformed_log_snr = self._transform_log_snr(log_snr_t) - # Reshape for processing: flatten compositional dimension temporarily - original_shape = ops.shape(conditions) - n_datasets, n_comp = original_shape[0], original_shape[1] + # Get shapes + xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) + conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) + n_datasets, n_compositional = conditions_shape[0], conditions_shape[1] + conditions_dims = tuple(conditions_shape[3:]) + num_samples = xz_shape[1] + dims = tuple(xz_shape[2:]) + + # Expand xz to match compositional structure + xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) + xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) + + # Expand noise schedule components to match compositional structure + log_snr_expanded = ops.expand_dims(transformed_log_snr, axis=1) + log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples) + dims) - # Flatten for subnet application - xz_flat = ops.expand_dims(xz, axis=1) # (n_datasets, 1, ...) - xz_flat = ops.broadcast_to(xz_flat, (n_datasets, n_comp) + ops.shape(xz)[1:]) - xz_flat = ops.reshape(xz_flat, (n_datasets * n_comp,) + ops.shape(xz)[1:]) - log_snr_flat = ops.reshape(transformed_log_snr, (n_datasets * n_comp,) + ops.shape(transformed_log_snr)[2:]) - conditions_flat = ops.reshape(conditions, (n_datasets * n_comp,) + ops.shape(conditions)[2:]) - alpha_flat = ops.reshape(alpha_t, (n_datasets * n_comp,) + ops.shape(alpha_t)[2:]) - sigma_flat = ops.reshape(sigma_t, (n_datasets * n_comp,) + ops.shape(sigma_t)[2:]) + alpha_expanded = ops.expand_dims(alpha_t, axis=1) + alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples) + dims) + + sigma_expanded = ops.expand_dims(sigma_t, axis=1) + sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples) + dims) + + # Flatten for subnet application: (n_datasets * n_compositional, num_samples, ..., dims) + xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) + log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples) + dims) + alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples) + dims) + sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples) + dims) + conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) # Apply subnet subnet_out = self._apply_subnet(xz_flat, log_snr_flat, conditions=conditions_flat, training=training) @@ -669,8 +687,7 @@ def _compute_individual_scores( score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat) # Reshape back to compositional structure - score = ops.reshape(score, original_shape) - + score = ops.reshape(score, (n_datasets, n_compositional, num_samples)) return score def _forward_compositional( From 00fbc619626db2e5b180a9e9c0fc0a2025600fe8 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:31:30 +0200 Subject: [PATCH 012/101] fix shapes --- bayesflow/networks/diffusion_model/diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 666e1986b..3943b8986 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -583,7 +583,7 @@ def compositional_velocity( # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) - log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:1] + (1,)) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores From e6158e7c76b418b5c37a35579b16ddc8495e8fab Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:33:37 +0200 Subject: [PATCH 013/101] fix shapes --- .../networks/diffusion_model/diffusion_model.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 3943b8986..94f278b7d 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -579,7 +579,6 @@ def compositional_velocity( # Get shapes for compositional structure n_compositional = ops.shape(conditions)[1] - print(ops.shape(xz), ops.shape(conditions)) # (1, 100, 2), (1, 2, 100, 2) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) @@ -587,8 +586,6 @@ def compositional_velocity( alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores - print(xz.shape, log_snr_t.shape, alpha_t.shape, sigma_t.shape, conditions.shape) - # (1, 100, 2) (1, 100, 1) (1, 100, 1) (1, 100, 1) (1, 2, 100, 2) individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions, training) # Compute prior score component @@ -659,19 +656,19 @@ def _compute_individual_scores( # Expand noise schedule components to match compositional structure log_snr_expanded = ops.expand_dims(transformed_log_snr, axis=1) - log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples) + dims) + log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) alpha_expanded = ops.expand_dims(alpha_t, axis=1) - alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples) + dims) + alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples, 1)) sigma_expanded = ops.expand_dims(sigma_t, axis=1) - sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples) + dims) + sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples, 1)) # Flatten for subnet application: (n_datasets * n_compositional, num_samples, ..., dims) xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) - log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples) + dims) - alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples) + dims) - sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples) + dims) + log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) + alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples, 1)) + sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples, 1)) conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) # Apply subnet From 1ac39b2521212647cdcc41de5010510a26615c8b Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:38:34 +0200 Subject: [PATCH 014/101] fix shapes --- .../approximators/continuous_approximator.py | 53 +------------------ .../diffusion_model/diffusion_model.py | 3 +- 2 files changed, 2 insertions(+), 54 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 12a0878af..c8bd77a57 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -697,7 +697,7 @@ def compositional_sample( samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples) # Back-transform quantities and samples - samples = self._back_transform_compositional(samples, original_shapes, **kwargs) + samples = self.adapter(samples, inverse=True, strict=False, **kwargs) if split: samples = split_arrays(samples, axis=-1) @@ -752,54 +752,3 @@ def _compositional_sample( compositional=True, **filter_kwargs(kwargs, self.inference_network.sample), ) - - def _back_transform_compositional( - self, samples: dict[str, np.ndarray], original_shapes: dict[str, tuple], **kwargs - ) -> dict[str, np.ndarray]: - """ - Back-transform compositional samples, handling the extra compositional dimension. - """ - # Get the sample shape to understand the compositional structure - inference_samples = samples["inference_variables"] - sample_shape = inference_samples.shape - - # Determine compositional dimensions from original shapes - # Assuming all condition keys have the same compositional structure - first_key = next(iter(original_shapes.keys())) - n_datasets, n_compositional = original_shapes[first_key][:2] - - # Reshape samples to match compositional structure if needed - if len(sample_shape) == 3: # (n_datasets * n_comp, num_samples, ..., dims) - num_samples, dims = sample_shape[1], sample_shape[2:] - inference_samples = inference_samples.reshape(n_datasets, n_compositional, num_samples, *dims) - samples["inference_variables"] = inference_samples - - # For back-transformation, we might need to flatten again temporarily - # depending on how the adapter expects the data - flattened_samples = {} - for key, value in samples.items(): - if len(value.shape) == 4: # (n_datasets, n_comp, num_samples, ..., dims) - n_d, n_c, n_s = value.shape[:3] - dims = value.shape[3:] - flattened_samples[key] = value.reshape(n_d * n_c, n_s, *dims) - else: - flattened_samples[key] = value - - # Apply inverse transformation - transformed = self.adapter(flattened_samples, inverse=True, strict=False, **kwargs) - - # Reshape back to compositional structure - final_samples = {} - for key, value in transformed.items(): - if key in original_shapes: - # Reshape to include compositional dimension - if len(value.shape) >= 2: - num_samples = value.shape[1] - remaining_dims = value.shape[2:] if len(value.shape) > 2 else () - final_samples[key] = value.reshape(n_datasets, n_compositional, num_samples, *remaining_dims) - else: - final_samples[key] = value - else: - final_samples[key] = value - - return final_samples diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 94f278b7d..69e6db619 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -619,7 +619,6 @@ def compositional_velocity( # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt velocity = f - 0.5 * g_squared * compositional_score - print(velocity.shape, velocity) return velocity def _compute_individual_scores( @@ -684,7 +683,7 @@ def _compute_individual_scores( score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat) # Reshape back to compositional structure - score = ops.reshape(score, (n_datasets, n_compositional, num_samples)) + score = ops.reshape(score, (n_datasets, n_compositional, num_samples) + dims) return score def _forward_compositional( From 9fd9cf887d09bc29c03b40f7a9f802a55fe2971e Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:52:17 +0200 Subject: [PATCH 015/101] add minibatch --- .../diffusion_model/diffusion_model.py | 70 +++++++++++++------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 69e6db619..1a302c5e4 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -550,6 +550,7 @@ def compositional_velocity( time: float | Tensor, stochastic_solver: bool, conditions: Tensor, + mini_batch_size: int | None, training: bool = False, ) -> Tensor: """ @@ -566,6 +567,8 @@ def compositional_velocity( Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + mini_batch_size : int or None + Size of mini-batches for processing compositional conditions to save memory. training : bool, optional Whether in training mode @@ -579,35 +582,35 @@ def compositional_velocity( # Get shapes for compositional structure n_compositional = ops.shape(conditions)[1] + n = ops.cast(n_compositional, dtype=ops.dtype(time)) + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + conditions_batch = conditions[:, idx[:mini_batch_size]] + else: + conditions_batch = conditions + # Compute individual dataset scores - individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions, training) + individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component prior_score = self.compute_prior_score(xz) - # Combine scores using compositional formula - # s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) - n = ops.cast(n_compositional, dtype=ops.dtype(time)) - time_tensor = ops.cast(time, dtype=ops.dtype(xz)) - - # Sum individual scores across compositional dimension - summed_individual_scores = ops.sum(individual_scores, axis=1) + # Combine scores using compositional formula, mean over individual scores and scale with n to get sum + summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) - # Prior contribution: (1-n)(1-t) * prior_score - prior_weight = (1.0 - n) * (1.0 - time_tensor) - weighted_prior = prior_weight * prior_score + # Prior contribution + weighted_prior_score = (1.0 - n) * (1.0 - time_tensor) * prior_score # Combined score - compositional_score = weighted_prior + summed_individual_scores - - # Broadcast back to full compositional shape - # compositional_score = ops.broadcast_to(compositional_score, ops.shape(xz)) + compositional_score = weighted_prior_score + summed_individual_scores # Compute velocity using standard drift-diffusion formulation f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -700,6 +703,7 @@ def _forward_compositional( integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs + mini_batch_size = integrate_kwargs.get("mini_batch_size", None) if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for forward integration.") @@ -711,7 +715,12 @@ def _forward_compositional( def deltas(time, xz): v = self.compositional_velocity( - xz, time=time, stochastic_solver=False, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) # For density, we need trace but compositional trace is complex # Simplified version - could be extended @@ -732,7 +741,12 @@ def deltas(time, xz): def deltas(time, xz): return { "xz": self.compositional_velocity( - xz, time=time, stochastic_solver=False, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) } @@ -755,6 +769,7 @@ def _inverse_compositional( integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs + mini_batch_size = integrate_kwargs.get("mini_batch_size", None) if density: if integrate_kwargs["method"] == "euler_maruyama": @@ -762,7 +777,12 @@ def _inverse_compositional( def deltas(time, xz): v = self.compositional_velocity( - xz, time=time, stochastic_solver=False, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) return {"xz": v, "trace": trace} @@ -784,7 +804,12 @@ def deltas(time, xz): def deltas(time, xz): return { "xz": self.compositional_velocity( - xz, time=time, stochastic_solver=True, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=True, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) } @@ -803,7 +828,12 @@ def diffusion(time, xz): def deltas(time, xz): return { "xz": self.compositional_velocity( - xz, time=time, stochastic_solver=False, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) } From 830e9295b9bfa03c6ed714564baba2fb3e7c7860 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 15:11:35 +0200 Subject: [PATCH 016/101] add compositional_bridge --- .../diffusion_model/diffusion_model.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 1a302c5e4..3ebb0dd43 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -544,6 +544,24 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss} + @staticmethod + def compositional_bridge(time: Tensor) -> Tensor: + """ + Bridge function for compositional diffusion. In the simplest case, this is just 1. + + Parameters + ---------- + time: Tensor + Time step for the diffusion process. + + Returns + ------- + Tensor + Bridge function value with same shape as time. + + """ + return ops.ones_like(time) + def compositional_velocity( self, xz: Tensor, @@ -610,7 +628,7 @@ def compositional_velocity( weighted_prior_score = (1.0 - n) * (1.0 - time_tensor) * prior_score # Combined score - compositional_score = weighted_prior_score + summed_individual_scores + compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) # Compute velocity using standard drift-diffusion formulation f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -703,13 +721,14 @@ def _forward_compositional( integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - mini_batch_size = integrate_kwargs.get("mini_batch_size", None) + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for forward integration.") # x is sampled from a normal distribution, must be scaled with var 1/n_compositional - x = x / ops.sqrt(ops.cast(ops.shape(x)[1], dtype=ops.dtype(x))) + scale_latent = ops.shape(conditions)[1] * self.compositional_bridge(ops.ones(1)) + x = x / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(x))) if density: @@ -769,7 +788,7 @@ def _inverse_compositional( integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - mini_batch_size = integrate_kwargs.get("mini_batch_size", None) + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) if density: if integrate_kwargs["method"] == "euler_maruyama": @@ -844,5 +863,4 @@ def deltas(time, xz): @staticmethod def compute_prior_score(xz: Tensor) -> Tensor: - return ops.ones_like(xz) # todo: Placeholder implementation - # raise NotImplementedError('Please implement the prior score computation method.') + raise NotImplementedError("Please implement the prior score computation method.") From f97594b522a598ebb2a353ba8a11aefa192006da Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 15:31:26 +0200 Subject: [PATCH 017/101] fix mini batch randomness --- .../diffusion_model/diffusion_model.py | 99 ++++--------------- 1 file changed, 20 insertions(+), 79 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 3ebb0dd43..9dd3dc899 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -568,7 +568,7 @@ def compositional_velocity( time: float | Tensor, stochastic_solver: bool, conditions: Tensor, - mini_batch_size: int | None, + mini_batch_idx: Sequence | None, training: bool = False, ) -> Tensor: """ @@ -585,8 +585,8 @@ def compositional_velocity( Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (n_datasets, n_compositional, ...) - mini_batch_size : int or None - Size of mini-batches for processing compositional conditions to save memory. + mini_batch_idx : Sequence + Indices for mini-batch selection along the compositional axis. training : bool, optional Whether in training mode @@ -608,14 +608,11 @@ def compositional_velocity( log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - conditions_batch = conditions[:, idx[:mini_batch_size]] + # Compute individual dataset scores + if mini_batch_idx is not None: + conditions_batch = conditions[:, mini_batch_idx] else: conditions_batch = conditions - - # Compute individual dataset scores individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component @@ -707,73 +704,6 @@ def _compute_individual_scores( score = ops.reshape(score, (n_datasets, n_compositional, num_samples) + dims) return score - def _forward_compositional( - self, - x: Tensor, - conditions: Tensor = None, - density: bool = False, - training: bool = False, - **kwargs, - ) -> Tensor | tuple[Tensor, Tensor]: - """ - Forward pass for compositional diffusion. - """ - integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} - integrate_kwargs = integrate_kwargs | self.integrate_kwargs - integrate_kwargs = integrate_kwargs | kwargs - mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) - - if integrate_kwargs["method"] == "euler_maruyama": - raise ValueError("Stochastic methods are not supported for forward integration.") - - # x is sampled from a normal distribution, must be scaled with var 1/n_compositional - scale_latent = ops.shape(conditions)[1] * self.compositional_bridge(ops.ones(1)) - x = x / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(x))) - - if density: - - def deltas(time, xz): - v = self.compositional_velocity( - xz, - time=time, - stochastic_solver=False, - conditions=conditions, - mini_batch_size=mini_batch_size, - training=training, - ) - # For density, we need trace but compositional trace is complex - # Simplified version - could be extended - trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) - return {"xz": v, "trace": trace} - - state = { - "xz": x, - "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), - } - state = integrate(deltas, state, **integrate_kwargs) - - z = state["xz"] - # Simplified density computation - log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) + ops.squeeze(state["trace"], axis=-1) - return z, log_density - - def deltas(time, xz): - return { - "xz": self.compositional_velocity( - xz, - time=time, - stochastic_solver=False, - conditions=conditions, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - state = {"xz": x} - state = integrate(deltas, state, **integrate_kwargs) - z = state["xz"] - return z - def _inverse_compositional( self, z: Tensor, @@ -790,6 +720,17 @@ def _inverse_compositional( integrate_kwargs = integrate_kwargs | kwargs mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional + n_compositional = ops.shape(conditions)[1] + scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) + z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) + + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + else: + mini_batch_idx = None + if density: if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for density computation.") @@ -800,7 +741,7 @@ def deltas(time, xz): time=time, stochastic_solver=False, conditions=conditions, - mini_batch_size=mini_batch_size, + mini_batch_idx=mini_batch_idx, training=training, ) trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) @@ -827,7 +768,7 @@ def deltas(time, xz): time=time, stochastic_solver=True, conditions=conditions, - mini_batch_size=mini_batch_size, + mini_batch_idx=mini_batch_idx, training=training, ) } @@ -851,7 +792,7 @@ def deltas(time, xz): time=time, stochastic_solver=False, conditions=conditions, - mini_batch_size=mini_batch_size, + mini_batch_idx=mini_batch_idx, training=training, ) } From 7219a71aac81c72bf686acd02e3a5d0ade316fe6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 15:37:34 +0200 Subject: [PATCH 018/101] fix mini batch randomness --- .../diffusion_model/diffusion_model.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 9dd3dc899..8cf58c4f8 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -725,17 +725,18 @@ def _inverse_compositional( scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - else: - mini_batch_idx = None - if density: if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + else: + mini_batch_idx = None + v = self.compositional_velocity( xz, time=time, @@ -762,6 +763,13 @@ def deltas(time, xz): if integrate_kwargs["method"] == "euler_maruyama": def deltas(time, xz): + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + else: + mini_batch_idx = None + return { "xz": self.compositional_velocity( xz, @@ -786,6 +794,13 @@ def diffusion(time, xz): else: def deltas(time, xz): + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + else: + mini_batch_idx = None + return { "xz": self.compositional_velocity( xz, From a10026a00e7fa04eae7bccffb062b77d9c8e1ce9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 15:45:41 +0200 Subject: [PATCH 019/101] fix mini batch randomness --- .../diffusion_model/diffusion_model.py | 49 ++++++++----------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 8cf58c4f8..37fc9033b 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -568,7 +568,7 @@ def compositional_velocity( time: float | Tensor, stochastic_solver: bool, conditions: Tensor, - mini_batch_idx: Sequence | None, + mini_batch_size: int | None = None, training: bool = False, ) -> Tensor: """ @@ -585,8 +585,8 @@ def compositional_velocity( Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (n_datasets, n_compositional, ...) - mini_batch_idx : Sequence - Indices for mini-batch selection along the compositional axis. + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. training : bool, optional Whether in training mode @@ -609,7 +609,10 @@ def compositional_velocity( alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores - if mini_batch_idx is not None: + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] conditions_batch = conditions[:, mini_batch_idx] else: conditions_batch = conditions @@ -720,6 +723,14 @@ def _inverse_compositional( integrate_kwargs = integrate_kwargs | kwargs mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + if mini_batch_size is not None: + # if backend is jax, mini batching does not work + if ops.__name__ == "jax": + raise ValueError( + "Mini batching is not supported with JAX backend. Set mini_batch_size to None " + "or use another backend." + ) + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional n_compositional = ops.shape(conditions)[1] scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) @@ -730,19 +741,12 @@ def _inverse_compositional( raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - mini_batch_idx = mini_batch_idx[:mini_batch_size] - else: - mini_batch_idx = None - v = self.compositional_velocity( xz, time=time, stochastic_solver=False, conditions=conditions, - mini_batch_idx=mini_batch_idx, + mini_batch_size=mini_batch_size, training=training, ) trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) @@ -763,20 +767,13 @@ def deltas(time, xz): if integrate_kwargs["method"] == "euler_maruyama": def deltas(time, xz): - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - mini_batch_idx = mini_batch_idx[:mini_batch_size] - else: - mini_batch_idx = None - return { "xz": self.compositional_velocity( xz, time=time, stochastic_solver=True, conditions=conditions, - mini_batch_idx=mini_batch_idx, + mini_batch_size=mini_batch_size, training=training, ) } @@ -794,20 +791,13 @@ def diffusion(time, xz): else: def deltas(time, xz): - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - mini_batch_idx = mini_batch_idx[:mini_batch_size] - else: - mini_batch_idx = None - return { "xz": self.compositional_velocity( xz, time=time, stochastic_solver=False, conditions=conditions, - mini_batch_idx=mini_batch_idx, + mini_batch_size=mini_batch_size, training=training, ) } @@ -819,4 +809,5 @@ def deltas(time, xz): @staticmethod def compute_prior_score(xz: Tensor) -> Tensor: - raise NotImplementedError("Please implement the prior score computation method.") + return ops.ones_like(xz) + # raise NotImplementedError("Please implement the prior score computation method.") From 457eb5d6a7b7d3ec82ab7811cb4e7dbcd7f976ff Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 16:12:32 +0200 Subject: [PATCH 020/101] add prior score --- .../approximators/continuous_approximator.py | 32 ++++++++++++++++++- .../diffusion_model/diffusion_model.py | 22 +++++++------ bayesflow/networks/inference_network.py | 17 ++++++++-- bayesflow/workflows/basic_workflow.py | 7 +++- 4 files changed, 64 insertions(+), 14 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index c8bd77a57..0a046d57e 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -644,6 +644,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], split: bool = False, **kwargs, ) -> dict[str, np.ndarray]: @@ -659,6 +660,8 @@ def compositional_sample( conditions : dict[str, np.ndarray] Dictionary of conditioning variables as NumPy arrays with shape (n_datasets, n_compositional_conditions, ...). + compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + A function that computes the log probability of samples under the prior distribution. split : bool, default=False Whether to split the output arrays along the last axis and return one column vector per target variable samples. @@ -685,9 +688,34 @@ def compositional_sample( # Remove any superfluous keys, just retain actual conditions prepared_conditions = {k: v for k, v in prepared_conditions.items() if k in self.CONDITION_KEYS} + # Prepare prior scores to handle adapter + def compute_prior_score_pre(_samples: Tensor) -> Tensor: + if "inference_variables" in self.standardize: + _samples, log_det_jac_standardize = self.standardize_layers["inference_variables"]( + _samples, forward=False, log_det_jac=True + ) + else: + log_det_jac_standardize = 0 + _samples = {"inference_variables": _samples} + _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, _samples) + adapted_samples, log_det_jac = self.adapter( + _samples, inverse=True, strict=False, log_det_jac=True, **kwargs + ) + prior_score = keras.ops.convert_to_tensor(compute_prior_score(adapted_samples)) + if log_det_jac is not None: + prior_score += keras.ops.convert_to_tensor(log_det_jac) + if log_det_jac_standardize is not None: + prior_score += keras.ops.convert_to_tensor(log_det_jac_standardize) + return prior_score + # Sample using compositional sampling samples = self._compositional_sample( - num_samples=num_samples, n_datasets=n_datasets, n_compositional=n_comp, **prepared_conditions, **kwargs + num_samples=num_samples, + n_datasets=n_datasets, + n_compositional=n_comp, + compute_prior_score=compute_prior_score_pre, + **prepared_conditions, + **kwargs, ) if "inference_variables" in self.standardize: @@ -708,6 +736,7 @@ def _compositional_sample( num_samples: int, n_datasets: int, n_compositional: int, + compute_prior_score: Callable[[Tensor], Tensor], inference_conditions: Tensor = None, summary_variables: Tensor = None, **kwargs, @@ -750,5 +779,6 @@ def _compositional_sample( batch_shape, conditions=inference_conditions, compositional=True, + compute_prior_score=compute_prior_score, **filter_kwargs(kwargs, self.inference_network.sample), ) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 37fc9033b..b65c61d01 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal +from typing import Literal, Callable import keras from keras import ops @@ -568,6 +568,7 @@ def compositional_velocity( time: float | Tensor, stochastic_solver: bool, conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], mini_batch_size: int | None = None, training: bool = False, ) -> Tensor: @@ -585,6 +586,8 @@ def compositional_velocity( Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). mini_batch_size : int or None Mini batch size for computing individual scores. If None, use all conditions. training : bool, optional @@ -619,7 +622,7 @@ def compositional_velocity( individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component - prior_score = self.compute_prior_score(xz) + prior_score = compute_prior_score(xz) # Combine scores using compositional formula, mean over individual scores and scale with n to get sum summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) @@ -710,13 +713,14 @@ def _compute_individual_scores( def _inverse_compositional( self, z: Tensor, - conditions: Tensor = None, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: """ - Inverse pass for compositional diffusion (sampling). + Inverse pass for compositional diffusion sampling. """ integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs @@ -725,7 +729,7 @@ def _inverse_compositional( if mini_batch_size is not None: # if backend is jax, mini batching does not work - if ops.__name__ == "jax": + if keras.backend.backend() == "jax": raise ValueError( "Mini batching is not supported with JAX backend. Set mini_batch_size to None " "or use another backend." @@ -746,6 +750,7 @@ def deltas(time, xz): time=time, stochastic_solver=False, conditions=conditions, + compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, ) @@ -773,6 +778,7 @@ def deltas(time, xz): time=time, stochastic_solver=True, conditions=conditions, + compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, ) @@ -797,6 +803,7 @@ def deltas(time, xz): time=time, stochastic_solver=False, conditions=conditions, + compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, ) @@ -806,8 +813,3 @@ def deltas(time, xz): x = state["xz"] return x - - @staticmethod - def compute_prior_score(xz: Tensor) -> Tensor: - return ops.ones_like(xz) - # raise NotImplementedError("Please implement the prior score computation method.") diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index f2e5c512f..250c93b22 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,3 +1,4 @@ +from typing import Callable import keras from bayesflow.types import Shape, Tensor @@ -52,12 +53,24 @@ def _inverse( raise NotImplementedError def _forward_compositional( - self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + self, + x: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError def _inverse_compositional( - self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + self, + z: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index f1941ac3a..2ef326dae 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -291,6 +291,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], + prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], **kwargs, ) -> dict[str, np.ndarray]: """ @@ -306,6 +307,8 @@ def compositional_sample( NumPy arrays containing the adapted simulated variables. Keys used as summary or inference conditions during training should be present. Should have shape (n_datasets, n_compositional_conditions, ...). + prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + A function that computes the log probability of samples under the prior distribution. **kwargs : dict, optional Additional keyword arguments passed to the approximator's sampling function. @@ -315,7 +318,9 @@ def compositional_sample( A dictionary where keys correspond to variable names and values are arrays containing the generated samples. """ - return self.approximator.compositional_sample(num_samples=num_samples, conditions=conditions, **kwargs) + return self.approximator.compositional_sample( + num_samples=num_samples, conditions=conditions, prior_score=prior_score, **kwargs + ) def estimate( self, From 7de473649f066fce40dd93c74512c6894ac1be74 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 16:14:15 +0200 Subject: [PATCH 021/101] add prior score --- bayesflow/workflows/basic_workflow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 2ef326dae..cfa63545b 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -291,7 +291,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], - prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], **kwargs, ) -> dict[str, np.ndarray]: """ @@ -307,7 +307,7 @@ def compositional_sample( NumPy arrays containing the adapted simulated variables. Keys used as summary or inference conditions during training should be present. Should have shape (n_datasets, n_compositional_conditions, ...). - prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] A function that computes the log probability of samples under the prior distribution. **kwargs : dict, optional Additional keyword arguments passed to the approximator's sampling function. @@ -319,7 +319,7 @@ def compositional_sample( values are arrays containing the generated samples. """ return self.approximator.compositional_sample( - num_samples=num_samples, conditions=conditions, prior_score=prior_score, **kwargs + num_samples=num_samples, conditions=conditions, compute_prior_score=compute_prior_score, **kwargs ) def estimate( From 1ee0e785086b4e28d92e037b3c84c62ddd44c1f4 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 16:40:20 +0200 Subject: [PATCH 022/101] add prior score draft --- .../approximators/continuous_approximator.py | 28 ++++++++++++++----- bayesflow/networks/inference_network.py | 24 ++++++++++++---- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 0a046d57e..cbc7d30c3 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -701,12 +701,27 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: adapted_samples, log_det_jac = self.adapter( _samples, inverse=True, strict=False, log_det_jac=True, **kwargs ) - prior_score = keras.ops.convert_to_tensor(compute_prior_score(adapted_samples)) - if log_det_jac is not None: - prior_score += keras.ops.convert_to_tensor(log_det_jac) - if log_det_jac_standardize is not None: - prior_score += keras.ops.convert_to_tensor(log_det_jac_standardize) - return prior_score + prior_score = compute_prior_score(adapted_samples) + prior_score_final = {} + for i, key in enumerate(adapted_samples): # todo: assumes same order, might incorrect + prior_score_final[key] = prior_score[key] + if len(log_det_jac_standardize) > 0: + prior_score_final[key] += log_det_jac_standardize[:, i] + if len(log_det_jac) > 0: + prior_score_final[key] += log_det_jac[:, i] + prior_score_final[key] = keras.ops.convert_to_tensor(prior_score_final[key]) + # make a tensor + out = keras.ops.concatenate(list(prior_score_final.values()), axis=-1) + return out + + # Test prior score function, useful for debugging + test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) + test = compute_prior_score_pre(test) + if test.shape[:2] != (n_datasets, num_samples): + raise ValueError( + "The provided compute_prior_score function does not return the correct shape. " + f"Expected ({n_datasets}, {num_samples}, ...), got {test.shape}." + ) # Sample using compositional sampling samples = self._compositional_sample( @@ -778,7 +793,6 @@ def _compositional_sample( return self.inference_network.sample( batch_shape, conditions=inference_conditions, - compositional=True, compute_prior_score=compute_prior_score, **filter_kwargs(kwargs, self.inference_network.sample), ) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 250c93b22..4fd22e468 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -28,18 +28,32 @@ def call( conditions: Tensor = None, inverse: bool = False, density: bool = False, - compositional: bool = False, + compute_prior_score: Callable[[Tensor], Tensor] = None, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: if inverse: - if compositional: + if compute_prior_score is not None: return self._inverse_compositional( xz, conditions=conditions, density=density, training=training, **kwargs ) - return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) - if compositional: - return self._forward_compositional(xz, conditions=conditions, density=density, training=training, **kwargs) + return self._inverse( + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, + ) + if compute_prior_score is not None: + return self._forward_compositional( + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, + ) return self._forward(xz, conditions=conditions, density=density, training=training, **kwargs) def _forward( From f71359bbdb0cd95b042876df38766bcd150ff4dd Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 16:41:46 +0200 Subject: [PATCH 023/101] add prior score draft --- bayesflow/networks/inference_network.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 4fd22e468..9488f644d 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -35,16 +35,14 @@ def call( if inverse: if compute_prior_score is not None: return self._inverse_compositional( - xz, conditions=conditions, density=density, training=training, **kwargs + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, ) - return self._inverse( - xz, - conditions=conditions, - compute_prior_score=compute_prior_score, - density=density, - training=training, - **kwargs, - ) + return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) if compute_prior_score is not None: return self._forward_compositional( xz, From 6210c07ade674e5cbf0ef5a2def2b21b1125d7e7 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 17:01:14 +0200 Subject: [PATCH 024/101] add prior score draft --- .../approximators/continuous_approximator.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index cbc7d30c3..b4e543495 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -14,6 +14,7 @@ squeeze_inner_estimates_dict, concatenate_valid, concatenate_valid_shapes, + expand_right_as, ) from bayesflow.utils.serialization import serialize, deserialize, serializable @@ -690,29 +691,28 @@ def compositional_sample( # Prepare prior scores to handle adapter def compute_prior_score_pre(_samples: Tensor) -> Tensor: + return keras.ops.zeros_like(_samples) if "inference_variables" in self.standardize: _samples, log_det_jac_standardize = self.standardize_layers["inference_variables"]( _samples, forward=False, log_det_jac=True ) else: log_det_jac_standardize = 0 - _samples = {"inference_variables": _samples} - _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, _samples) + _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) adapted_samples, log_det_jac = self.adapter( _samples, inverse=True, strict=False, log_det_jac=True, **kwargs ) prior_score = compute_prior_score(adapted_samples) - prior_score_final = {} - for i, key in enumerate(adapted_samples): # todo: assumes same order, might incorrect - prior_score_final[key] = prior_score[key] - if len(log_det_jac_standardize) > 0: - prior_score_final[key] += log_det_jac_standardize[:, i] + for key in adapted_samples: + prior_score[key] = prior_score[key] if len(log_det_jac) > 0: - prior_score_final[key] += log_det_jac[:, i] - prior_score_final[key] = keras.ops.convert_to_tensor(prior_score_final[key]) + prior_score[key] += log_det_jac[key] + prior_score[key] = keras.ops.convert_to_tensor(prior_score[key]) # make a tensor - out = keras.ops.concatenate(list(prior_score_final.values()), axis=-1) - return out + out = keras.ops.concatenate( + list(prior_score.values()), axis=-1 + ) # todo: assumes same order, might be incorrect + return out + expand_right_as(log_det_jac_standardize, out) # Test prior score function, useful for debugging test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) From bcb9f60a63ddd2ade564121a6f15e9933824a132 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 17:19:10 +0200 Subject: [PATCH 025/101] add prior score draft --- bayesflow/approximators/continuous_approximator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index b4e543495..5c727eac0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -691,7 +691,6 @@ def compositional_sample( # Prepare prior scores to handle adapter def compute_prior_score_pre(_samples: Tensor) -> Tensor: - return keras.ops.zeros_like(_samples) if "inference_variables" in self.standardize: _samples, log_det_jac_standardize = self.standardize_layers["inference_variables"]( _samples, forward=False, log_det_jac=True @@ -707,7 +706,8 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: prior_score[key] = prior_score[key] if len(log_det_jac) > 0: prior_score[key] += log_det_jac[key] - prior_score[key] = keras.ops.convert_to_tensor(prior_score[key]) + + prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) # make a tensor out = keras.ops.concatenate( list(prior_score.values()), axis=-1 From 455f03c2336549ea8c9cd8161af8b16f1af5150d Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 17:25:22 +0200 Subject: [PATCH 026/101] fix dtype --- bayesflow/networks/diffusion_model/diffusion_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index b65c61d01..3099788bd 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -603,8 +603,6 @@ def compositional_velocity( # Get shapes for compositional structure n_compositional = ops.shape(conditions)[1] - n = ops.cast(n_compositional, dtype=ops.dtype(time)) - time_tensor = ops.cast(time, dtype=ops.dtype(xz)) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) @@ -628,9 +626,10 @@ def compositional_velocity( summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) # Prior contribution - weighted_prior_score = (1.0 - n) * (1.0 - time_tensor) * prior_score + weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score # Combined score + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) # Compute velocity using standard drift-diffusion formulation From 89523a98494d40046d75581d146e466a6e295b49 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 9 Sep 2025 15:16:10 +0200 Subject: [PATCH 027/101] fix docstring --- bayesflow/approximators/continuous_approximator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 5c727eac0..2cd4225c6 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -662,7 +662,7 @@ def compositional_sample( Dictionary of conditioning variables as NumPy arrays with shape (n_datasets, n_compositional_conditions, ...). compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] - A function that computes the log probability of samples under the prior distribution. + A function that computes the score of the log prior distribution. split : bool, default=False Whether to split the output arrays along the last axis and return one column vector per target variable samples. From e55631dcec299cecc0e1ce27ebf0cd668ab960ba Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 9 Sep 2025 15:30:35 +0200 Subject: [PATCH 028/101] fix batch_shape in sample --- bayesflow/approximators/continuous_approximator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index f27a612f0..b60f4e4bd 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -535,10 +535,9 @@ def _sample( inference_conditions = keras.ops.broadcast_to( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - batch_shape = ( - batch_size, - num_samples, - ) + + target_dim = self.inference_network.base_distribution.dims + batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)] else: batch_shape = (num_samples,) From 3eaff24a1d314435040a688559e3c85d5235f9c6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 9 Sep 2025 15:35:56 +0200 Subject: [PATCH 029/101] fix batch_shape for point approximator --- bayesflow/approximators/continuous_approximator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index b60f4e4bd..13ba32cb9 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -536,8 +536,12 @@ def _sample( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - target_dim = self.inference_network.base_distribution.dims - batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)] + if hasattr(self.inference_network, "base_distribution"): + target_shape_len = len(self.inference_network.base_distribution.dims) + else: + # point approximator has no base_distribution + target_shape_len = 1 + batch_shape = keras.ops.shape(inference_conditions)[:-target_shape_len] else: batch_shape = (num_samples,) From e97e375f458eadd41ed19b0a574ad0fa0daf9b59 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 10 Sep 2025 18:11:31 +0200 Subject: [PATCH 030/101] fix docstring --- bayesflow/simulators/sequential_simulator.py | 4 ++-- bayesflow/simulators/simulator.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/bayesflow/simulators/sequential_simulator.py b/bayesflow/simulators/sequential_simulator.py index 21e1542e6..96ab0ead3 100644 --- a/bayesflow/simulators/sequential_simulator.py +++ b/bayesflow/simulators/sequential_simulator.py @@ -88,7 +88,7 @@ def _single_sample(self, batch_shape_ext, **kwargs) -> dict[str, np.ndarray]: return self.sample(batch_shape=(1, *tuple(batch_shape_ext)), **kwargs) def sample_parallel( - self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 0, **kwargs + self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 1, **kwargs ) -> dict[str, np.ndarray]: """ Sample in parallel from the sequential simulator. @@ -101,7 +101,7 @@ def sample_parallel( n_jobs : int, optional Number of parallel jobs. -1 uses all available cores. Default is -1. verbose : int, optional - Verbosity level for joblib. Default is 0 (no output). + Verbosity level for joblib. Default is 1 (minimal output). **kwargs Additional keyword arguments passed to each simulator. These may include previously sampled outputs used as inputs for subsequent simulators. diff --git a/bayesflow/simulators/simulator.py b/bayesflow/simulators/simulator.py index 00d3d84f3..53d54e455 100644 --- a/bayesflow/simulators/simulator.py +++ b/bayesflow/simulators/simulator.py @@ -95,3 +95,8 @@ def accept_all_predicate(x): return np.full((sample_size,), True) return self.rejection_sample(batch_shape, predicate=accept_all_predicate, sample_size=sample_size, **kwargs) + + def sample_parallel( + self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 1, **kwargs + ) -> dict[str, np.ndarray]: + raise NotImplementedError From caa2d67ec4f934984d9a932515472ec9e432b4e8 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 10 Sep 2025 18:54:15 +0200 Subject: [PATCH 031/101] fix float32 --- bayesflow/approximators/continuous_approximator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 546de64ed..7a0d757d9 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -706,7 +706,8 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: ) prior_score = compute_prior_score(adapted_samples) for key in adapted_samples: - prior_score[key] = prior_score[key] + if isinstance(prior_score[key], np.ndarray): + prior_score[key] = prior_score[key].astype("float32") if len(log_det_jac) > 0: prior_score[key] += log_det_jac[key] From 1ac9bff056cb539a45a76391e2733afff6fb99ca Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 12:34:57 +0200 Subject: [PATCH 032/101] reorganize --- .../diffusion_model/diffusion_model.py | 71 +++++++++++++++---- 1 file changed, 59 insertions(+), 12 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 3099788bd..332165a5e 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -593,6 +593,64 @@ def compositional_velocity( training : bool, optional Whether in training mode + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + compositional_score = self.compositional_score( + xz=xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + + # Compute velocity using standard drift-diffusion formulation + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) + + if stochastic_solver: + # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW + velocity = f - g_squared * compositional_score + else: + # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt + velocity = f - 0.5 * g_squared * compositional_score + + return velocity + + def compositional_score( + self, + xz: Tensor, + time: float | Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional score for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + Returns ------- Tensor @@ -631,18 +689,7 @@ def compositional_velocity( # Combined score time_tensor = ops.cast(time, dtype=ops.dtype(xz)) compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) - - # Compute velocity using standard drift-diffusion formulation - f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) - - if stochastic_solver: - # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW - velocity = f - g_squared * compositional_score - else: - # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt - velocity = f - 0.5 * g_squared * compositional_score - - return velocity + return compositional_score def _compute_individual_scores( self, From df23f892a5d9b7d23eec0c9ff2ffe61a266e8698 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 14:02:09 +0200 Subject: [PATCH 033/101] add annealed_langevin --- .../diffusion_model/diffusion_model.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 332165a5e..c3a625e06 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -16,12 +16,15 @@ integrate_stochastic, logging, tensor_utils, + filter_kwargs, ) from bayesflow.utils.serialization import serialize, deserialize, serializable from .schedules.noise_schedule import NoiseSchedule from .dispatch import find_noise_schedule +ArrayLike = int | float | Tensor + # disable module check, use potential module after moving from experimental @serializable("bayesflow.networks", disable_module_check=True) @@ -840,6 +843,26 @@ def diffusion(time, xz): seed=self.seed_generator, **integrate_kwargs, ) + elif integrate_kwargs["method"] == "langevin": + + def scores(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = annealed_langevin( + score_fn=scores, + state=state, + seed=self.seed_generator, + **filter_kwargs(integrate_kwargs, annealed_langevin), + ) else: def deltas(time, xz): @@ -859,3 +882,50 @@ def deltas(time, xz): x = state["xz"] return x + + +def annealed_langevin( + score_fn: Callable, + state: dict[str, ArrayLike], + steps: int, + seed: keras.random.SeedGenerator, + L: int = 5, + start_time: ArrayLike = None, + stop_time: ArrayLike = None, + eps: float = 0.01, +) -> dict[str, ArrayLike]: + """ + Annealed Langevin dynamics for diffusion sampling. + + for t = T-1,...,1: + for s = 1,...,L: + eta ~ N(0, I) + theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta + """ + ratio = keras.ops.convert_to_tensor( + (stop_time + eps) / start_time, dtype=keras.ops.dtype(next(iter(state.values()))) + ) + + T = steps + # main loops + for t_T in range(T - 1, 0, -1): + t = t_T / T + dt = keras.ops.convert_to_tensor(stop_time, dtype=keras.ops.dtype(next(iter(state.values())))) * ( + ratio ** (stop_time - t) + ) + + sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) + # inner L Langevin steps at level t + for _ in range(L): + # score + drift = score_fn(t, **filter_kwargs(state, score_fn)) + # noise + eta = { + k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed) + for k, v in state.items() + } + + # update + for k, d in drift.items(): + state[k] = state[k] + 0.5 * dt * d + sqrt_dt * eta[k] + return state From 0a87694f654be4039f17e02060e357fdc9e07c70 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 14:44:42 +0200 Subject: [PATCH 034/101] fix annealed_langevin --- .../diffusion_model/diffusion_model.py | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index c3a625e06..08118a220 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -859,6 +859,7 @@ def scores(time, xz): state = annealed_langevin( score_fn=scores, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **filter_kwargs(integrate_kwargs, annealed_langevin), @@ -886,13 +887,14 @@ def deltas(time, xz): def annealed_langevin( score_fn: Callable, + noise_schedule: Callable, state: dict[str, ArrayLike], steps: int, seed: keras.random.SeedGenerator, - L: int = 5, start_time: ArrayLike = None, stop_time: ArrayLike = None, - eps: float = 0.01, + langevin_corrector_steps: int = 5, + step_size_factor: float = 0.1, ) -> dict[str, ArrayLike]: """ Annealed Langevin dynamics for diffusion sampling. @@ -902,30 +904,25 @@ def annealed_langevin( eta ~ N(0, I) theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta """ - ratio = keras.ops.convert_to_tensor( - (stop_time + eps) / start_time, dtype=keras.ops.dtype(next(iter(state.values()))) - ) + log_snr_t = noise_schedule.get_log_snr(t=start_time, training=False) + _, max_sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - T = steps # main loops - for t_T in range(T - 1, 0, -1): - t = t_T / T - dt = keras.ops.convert_to_tensor(stop_time, dtype=keras.ops.dtype(next(iter(state.values())))) * ( - ratio ** (stop_time - t) - ) - - sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) - # inner L Langevin steps at level t - for _ in range(L): - # score + for step in range(steps - 1, 0, -1): + t = step / steps + log_snr_t = noise_schedule.get_log_snr(t=t, training=False) + _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + annealing_step_size = step_size_factor * keras.ops.square(sigma_t / max_sigma_t) + + sqrt_dt = keras.ops.sqrt(keras.ops.abs(annealing_step_size)) + for _ in range(langevin_corrector_steps): drift = score_fn(t, **filter_kwargs(state, score_fn)) - # noise - eta = { + noise = { k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed) for k, v in state.items() } # update for k, d in drift.items(): - state[k] = state[k] + 0.5 * dt * d + sqrt_dt * eta[k] + state[k] = state[k] + 0.5 * annealing_step_size * d + sqrt_dt * noise[k] return state From 64d43735dc9136b622605bb9b17ea372629cbee9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:00:39 +0200 Subject: [PATCH 035/101] add predictor corrector sampling --- .../diffusion_model/diffusion_model.py | 17 ++++++++ bayesflow/utils/integrate.py | 41 +++++++++++++++++-- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 08118a220..fd1415616 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -836,9 +836,26 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + scores = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def scores(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=scores, state=state, seed=self.seed_generator, **integrate_kwargs, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b197ea975..be269ebaa 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -401,11 +401,17 @@ def integrate_stochastic( steps: int, seed: keras.random.SeedGenerator, method: str = "euler_maruyama", + score_fn: Callable = None, + corrector_steps: int = 0, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ Integrates a stochastic differential equation from start_time to stop_time. + When score_fn is provided, performs predictor-corrector sampling where: + - Predictor: reverse diffusion SDE solver + - Corrector: annealed Langevin dynamics with step size e = sqrt(dim) + Args: drift_fn: Function that computes the drift term. diffusion_fn: Function that computes the diffusion term. @@ -415,11 +421,13 @@ def integrate_stochastic( steps: Number of integration steps. seed: Random seed for noise generation. method: Integration method to use, e.g., 'euler_maruyama'. + score_fn: Optional score function for predictor-corrector sampling. + Should take (time, **state) and return score dict. + corrector_steps: Number of corrector steps to take after each predictor step. **kwargs: Additional arguments to pass to the step function. Returns: - If return_noise is False, returns the final state dictionary. - If return_noise is True, returns a tuple of (final_state, noise_history). + Final state dictionary after integration. """ if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -438,17 +446,44 @@ def integrate_stochastic( step_size = (stop_time - start_time) / steps sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) - # Pre-generate noise history: shape = (steps, *state_shape) + # Pre-generate noise history for predictor: shape = (steps, *state_shape) noise_history = {} for key, val in state.items(): noise_history[key] = ( keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt ) + # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) + corrector_noise_history = {} + if score_fn is not None and corrector_steps > 0: + for key, val in state.items(): + corrector_noise_history[key] = keras.random.normal( + (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed + ) + def body(_loop_var, _loop_state): _current_state, _current_time = _loop_state _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} + + # Predictor step new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) + + # Corrector steps: annealed Langevin dynamics if score_fn is provided + if score_fn is not None: + first_key = next(iter(new_state.keys())) + dim = keras.ops.cast(keras.ops.shape(new_state[first_key])[-1], keras.ops.dtype(new_state[first_key])) + e = keras.ops.sqrt(dim) + sqrt_2e = keras.ops.sqrt(2.0 * e) + + for corrector_step in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + + # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector + for k in new_state.keys(): + if k in score: + new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] + return new_state, new_time final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) From 5b4236862a6c76682a1b03c54eebd278314549fe Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:13:32 +0200 Subject: [PATCH 036/101] add predictor corrector sampling --- .../diffusion_model/diffusion_model.py | 81 ++++++++++++++++--- 1 file changed, 68 insertions(+), 13 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index fd1415616..69dae59ac 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -246,6 +246,55 @@ def _apply_subnet( else: return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training) + def score( + self, + xz: Tensor, + time: float | Tensor = None, + log_snr_t: Tensor = None, + conditions: Tensor = None, + training: bool = False, + ) -> Tensor: + """ + Computes the score of the target or latent variable `xz`. + + Parameters + ---------- + xz : Tensor + The current state of the latent variable `z`, typically of shape (..., D), + where D is the dimensionality of the latent space. + time : float or Tensor + Scalar or tensor representing the time (or noise level) at which the velocity + should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. + log_snr_t : Tensor + The log signal-to-noise ratio at time `t`. If None, time must be provided. + conditions : Tensor, optional + Conditional inputs to the network, such as conditioning variables + or encoder outputs. Shape must be broadcastable with `xz`. Default is None. + training : bool, optional + Whether the model is in training mode. Affects behavior of dropout, batch norm, + or other stochastic layers. Default is False. + + Returns + ------- + Tensor + The velocity tensor of the same shape as `xz`, representing the right-hand + side of the SDE or ODE at the given `time`. + """ + if log_snr_t is None: + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + subnet_out = self._apply_subnet( + xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) + + x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) + + score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + return score + def velocity( self, xz: Tensor, @@ -282,19 +331,10 @@ def velocity( The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ - # calculate the current noise level and transform into correct shape log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - subnet_out = self._apply_subnet( - xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training - ) - pred = self.output_projector(subnet_out, training=training) - x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) - - score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -452,9 +492,24 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + score_fn = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=score_fn, state=state, seed=self.seed_generator, **integrate_kwargs, @@ -836,11 +891,11 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} - scores = None + score_fn = None if "corrector_steps" in integrate_kwargs: if integrate_kwargs["corrector_steps"] > 0: - def scores(time, xz): + def score_fn(time, xz): return { "xz": self.compositional_score( xz, @@ -855,7 +910,7 @@ def scores(time, xz): state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, - score_fn=scores, + score_fn=score_fn, state=state, seed=self.seed_generator, **integrate_kwargs, From 94029414a6c58ec56aeeb2c5d75741fde015e283 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:27:44 +0200 Subject: [PATCH 037/101] add predictor corrector sampling --- .../diffusion_model/diffusion_model.py | 2 ++ bayesflow/utils/integrate.py | 29 ++++++++++++++----- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 69dae59ac..e303e961d 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -510,6 +510,7 @@ def score_fn(time, xz): drift_fn=deltas, diffusion_fn=diffusion, score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, @@ -911,6 +912,7 @@ def score_fn(time, xz): drift_fn=deltas, diffusion_fn=diffusion, score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index be269ebaa..b3127a737 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -403,6 +403,7 @@ def integrate_stochastic( method: str = "euler_maruyama", score_fn: Callable = None, corrector_steps: int = 0, + noise_schedule=None, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ @@ -424,6 +425,7 @@ def integrate_stochastic( score_fn: Optional score function for predictor-corrector sampling. Should take (time, **state) and return score dict. corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. **kwargs: Additional arguments to pass to the step function. Returns: @@ -455,7 +457,10 @@ def integrate_stochastic( # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) corrector_noise_history = {} - if score_fn is not None and corrector_steps > 0: + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + for key, val in state.items(): corrector_noise_history[key] = keras.random.normal( (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed @@ -469,19 +474,29 @@ def body(_loop_var, _loop_state): new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) # Corrector steps: annealed Langevin dynamics if score_fn is provided - if score_fn is not None: - first_key = next(iter(new_state.keys())) - dim = keras.ops.cast(keras.ops.shape(new_state[first_key])[-1], keras.ops.dtype(new_state[first_key])) - e = keras.ops.sqrt(dim) - sqrt_2e = keras.ops.sqrt(2.0 * e) - + if corrector_steps > 0: for corrector_step in range(corrector_steps): score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + # Compute noise schedule components for corrector step size + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + lambda_t = keras.ops.exp(-log_snr_t) # lambda_t from noise schedule + # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector + # where e = 2*alpha_t * (lambda_t * ||z|| / ||score||)**2 for k in new_state.keys(): if k in score: + z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + + # Prevent division by zero + score_norm = keras.ops.maximum(score_norm, 1e-8) + + e = 2.0 * alpha_t * (lambda_t * z_norm / score_norm) ** 2 + sqrt_2e = keras.ops.sqrt(2.0 * e) + new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] return new_state, new_time From e0b3bd5dfdfc6320cb35daa9429cbb4b816e8f69 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:32:27 +0200 Subject: [PATCH 038/101] add predictor corrector sampling --- bayesflow/utils/integrate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b3127a737..a46d3e78a 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -404,6 +404,7 @@ def integrate_stochastic( score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, + r: float = 0.1, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ @@ -426,6 +427,7 @@ def integrate_stochastic( Should take (time, **state) and return score dict. corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. + r: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. Returns: @@ -482,10 +484,9 @@ def body(_loop_var, _loop_state): # Compute noise schedule components for corrector step size log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - lambda_t = keras.ops.exp(-log_snr_t) # lambda_t from noise schedule # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector - # where e = 2*alpha_t * (lambda_t * ||z|| / ||score||)**2 + # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 for k in new_state.keys(): if k in score: z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True) @@ -494,7 +495,7 @@ def body(_loop_var, _loop_state): # Prevent division by zero score_norm = keras.ops.maximum(score_norm, 1e-8) - e = 2.0 * alpha_t * (lambda_t * z_norm / score_norm) ** 2 + e = 2.0 * alpha_t * (r * z_norm / score_norm) ** 2 sqrt_2e = keras.ops.sqrt(2.0 * e) new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] From 89361f75282dfa367136b2658c3fd70453c6ea93 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:42:13 +0200 Subject: [PATCH 039/101] add predictor corrector sampling --- .../diffusion_model/diffusion_model.py | 67 ------------------- bayesflow/utils/integrate.py | 8 +-- 2 files changed, 4 insertions(+), 71 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index e303e961d..81c64bfbd 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -16,15 +16,12 @@ integrate_stochastic, logging, tensor_utils, - filter_kwargs, ) from bayesflow.utils.serialization import serialize, deserialize, serializable from .schedules.noise_schedule import NoiseSchedule from .dispatch import find_noise_schedule -ArrayLike = int | float | Tensor - # disable module check, use potential module after moving from experimental @serializable("bayesflow.networks", disable_module_check=True) @@ -917,27 +914,6 @@ def score_fn(time, xz): seed=self.seed_generator, **integrate_kwargs, ) - elif integrate_kwargs["method"] == "langevin": - - def scores(time, xz): - return { - "xz": self.compositional_score( - xz, - time=time, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - state = annealed_langevin( - score_fn=scores, - noise_schedule=self.noise_schedule, - state=state, - seed=self.seed_generator, - **filter_kwargs(integrate_kwargs, annealed_langevin), - ) else: def deltas(time, xz): @@ -957,46 +933,3 @@ def deltas(time, xz): x = state["xz"] return x - - -def annealed_langevin( - score_fn: Callable, - noise_schedule: Callable, - state: dict[str, ArrayLike], - steps: int, - seed: keras.random.SeedGenerator, - start_time: ArrayLike = None, - stop_time: ArrayLike = None, - langevin_corrector_steps: int = 5, - step_size_factor: float = 0.1, -) -> dict[str, ArrayLike]: - """ - Annealed Langevin dynamics for diffusion sampling. - - for t = T-1,...,1: - for s = 1,...,L: - eta ~ N(0, I) - theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta - """ - log_snr_t = noise_schedule.get_log_snr(t=start_time, training=False) - _, max_sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - # main loops - for step in range(steps - 1, 0, -1): - t = step / steps - log_snr_t = noise_schedule.get_log_snr(t=t, training=False) - _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - annealing_step_size = step_size_factor * keras.ops.square(sigma_t / max_sigma_t) - - sqrt_dt = keras.ops.sqrt(keras.ops.abs(annealing_step_size)) - for _ in range(langevin_corrector_steps): - drift = score_fn(t, **filter_kwargs(state, score_fn)) - noise = { - k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed) - for k, v in state.items() - } - - # update - for k, d in drift.items(): - state[k] = state[k] + 0.5 * annealing_step_size * d + sqrt_dt * noise[k] - return state diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a46d3e78a..961015b8f 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -404,7 +404,7 @@ def integrate_stochastic( score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, - r: float = 0.1, + step_size_factor: float = 0.1, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ @@ -427,7 +427,7 @@ def integrate_stochastic( Should take (time, **state) and return score dict. corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. - r: Scaling factor for corrector step size. + step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. Returns: @@ -489,13 +489,13 @@ def body(_loop_var, _loop_state): # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 for k in new_state.keys(): if k in score: - z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True) + z_norm = keras.ops.norm(_corrector_noise[k], axis=-1, keepdims=True) score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) # Prevent division by zero score_norm = keras.ops.maximum(score_norm, 1e-8) - e = 2.0 * alpha_t * (r * z_norm / score_norm) ** 2 + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 sqrt_2e = keras.ops.sqrt(2.0 * e) new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] From 5969bd380514acd883ed3820b32ecdf514003059 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:59:52 +0200 Subject: [PATCH 040/101] robust mean scores --- bayesflow/networks/diffusion_model/diffusion_model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 81c64bfbd..f56655c05 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -734,13 +734,11 @@ def compositional_score( individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component - prior_score = compute_prior_score(xz) + weighted_prior_score = (1.0 - time) * compute_prior_score(xz) # Combine scores using compositional formula, mean over individual scores and scale with n to get sum - summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) - - # Prior contribution - weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score + weighted_individual_scores = individual_scores - weighted_prior_score + summed_individual_scores = n_compositional * ops.mean(weighted_individual_scores, axis=1) # Combined score time_tensor = ops.cast(time, dtype=ops.dtype(xz)) From e983cf7a746b13eae0122bd1d6f57f7c20095cc1 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 17:10:49 +0200 Subject: [PATCH 041/101] add some tests --- .../diffusion_model/diffusion_model.py | 2 +- .../test_compositional_sampling.py | 178 ++++++++++++++++++ 2 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 tests/test_networks/test_diffusion_model/test_compositional_sampling.py diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index f56655c05..25a6b4c7c 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -737,7 +737,7 @@ def compositional_score( weighted_prior_score = (1.0 - time) * compute_prior_score(xz) # Combine scores using compositional formula, mean over individual scores and scale with n to get sum - weighted_individual_scores = individual_scores - weighted_prior_score + weighted_individual_scores = individual_scores - keras.ops.expand_dims(weighted_prior_score, axis=1) summed_individual_scores = n_compositional * ops.mean(weighted_individual_scores, axis=1) # Combined score diff --git a/tests/test_networks/test_diffusion_model/test_compositional_sampling.py b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py new file mode 100644 index 000000000..4fa0ebf59 --- /dev/null +++ b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py @@ -0,0 +1,178 @@ +import keras +import pytest + + +@pytest.fixture +def simple_diffusion_model(): + """Create a simple diffusion model for testing compositional sampling.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + return DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule="cosine", + prediction_type="noise", + loss_type="noise", + ) + + +@pytest.fixture +def compositional_conditions(): + """Create test conditions for compositional sampling.""" + batch_size = 2 + n_compositional = 3 + n_samples = 4 + condition_dim = 5 + + return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim)) + + +@pytest.fixture +def compositional_state(): + """Create test state for compositional sampling.""" + batch_size = 2 + n_samples = 4 + param_dim = 3 + + return keras.random.normal((batch_size, n_samples, param_dim)) + + +@pytest.fixture +def mock_prior_score(): + """Create a mock prior score function for testing.""" + + def prior_score_fn(theta): + # Simple quadratic prior: -0.5 * ||theta||^2 + return -theta + + return prior_score_fn + + +def test_compositional_score_shape( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test that compositional score returns correct shapes.""" + # Build the model + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + time = 0.5 + + score = simple_diffusion_model.compositional_score( + xz=compositional_state, + time=time, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + training=False, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(score) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +def test_compositional_score_no_conditions_raises_error(simple_diffusion_model, compositional_state, mock_prior_score): + """Test that compositional score raises error when conditions is None.""" + simple_diffusion_model.build(keras.ops.shape(compositional_state), None) + + with pytest.raises(ValueError, match="Conditions are required for compositional sampling"): + simple_diffusion_model.compositional_score( + xz=compositional_state, time=0.5, conditions=None, compute_prior_score=mock_prior_score, training=False + ) + + +def test_inverse_compositional_basic( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test basic compositional inverse sampling.""" + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + # Test inverse sampling with ODE method + result = simple_diffusion_model._inverse_compositional( + z=compositional_state, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + density=False, + training=False, + method="euler", + steps=5, + start_time=1.0, + stop_time=0.0, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(result) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +def test_inverse_compositional_euler_maruyama_with_corrector( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test compositional inverse sampling with Euler-Maruyama and corrector steps.""" + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + result = simple_diffusion_model._inverse_compositional( + z=compositional_state, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + density=False, + training=False, + method="euler_maruyama", + steps=5, + corrector_steps=2, + start_time=1.0, + stop_time=0.0, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(result) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +@pytest.mark.parametrize("noise_schedule_name", ["cosine", "edm"]) +def test_compositional_sampling_with_different_schedules( + noise_schedule_name, compositional_state, compositional_conditions, mock_prior_score +): + """Test compositional sampling with different noise schedules.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + diffusion_model = DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule=noise_schedule_name, + prediction_type="noise", + loss_type="noise", + ) + + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + diffusion_model.build(state_shape, conditions_shape) + + score = diffusion_model.compositional_score( + xz=compositional_state, + time=0.5, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + training=False, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(score) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) From eac9aaf562eda96326f48473097e324223abf0cd Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 17:30:57 +0200 Subject: [PATCH 042/101] minor fixes --- bayesflow/networks/diffusion_model/diffusion_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 25a6b4c7c..a6ff78510 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -734,13 +734,13 @@ def compositional_score( individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component - weighted_prior_score = (1.0 - time) * compute_prior_score(xz) + prior_score = compute_prior_score(xz) + weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score - # Combine scores using compositional formula, mean over individual scores and scale with n to get sum - weighted_individual_scores = individual_scores - keras.ops.expand_dims(weighted_prior_score, axis=1) - summed_individual_scores = n_compositional * ops.mean(weighted_individual_scores, axis=1) + # Sum individual scores across compositional dimensiont + summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) - # Combined score + # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) time_tensor = ops.cast(time, dtype=ops.dtype(xz)) compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) return compositional_score From 2a9b0e100c2f9fa24bf3ae53cfdcc6a3c0044024 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 17:41:25 +0200 Subject: [PATCH 043/101] minor fixes --- .../diffusion_model/diffusion_model.py | 41 ++++--------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index a6ff78510..7a75b0a9b 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -721,7 +721,6 @@ def compositional_score( # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores if mini_batch_size is not None and mini_batch_size < n_compositional: @@ -731,13 +730,13 @@ def compositional_score( conditions_batch = conditions[:, mini_batch_idx] else: conditions_batch = conditions - individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) + individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training) # Compute prior score component prior_score = compute_prior_score(xz) weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score - # Sum individual scores across compositional dimensiont + # Sum individual scores across compositional dimensions summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) @@ -749,8 +748,6 @@ def _compute_individual_scores( self, xz: Tensor, log_snr_t: Tensor, - alpha_t: Tensor, - sigma_t: Tensor, conditions: Tensor, training: bool, ) -> Tensor: @@ -762,9 +759,6 @@ def _compute_individual_scores( Tensor Individual scores with shape (n_datasets, n_compositional, ...) """ - # Apply subnet to each compositional condition separately - transformed_log_snr = self._transform_log_snr(log_snr_t) - # Get shapes xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) @@ -777,38 +771,21 @@ def _compute_individual_scores( xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) - # Expand noise schedule components to match compositional structure - log_snr_expanded = ops.expand_dims(transformed_log_snr, axis=1) + # Expand log_snr_t to match compositional structure + log_snr_expanded = ops.expand_dims(log_snr_t, axis=1) log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) - alpha_expanded = ops.expand_dims(alpha_t, axis=1) - alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples, 1)) - - sigma_expanded = ops.expand_dims(sigma_t, axis=1) - sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples, 1)) - - # Flatten for subnet application: (n_datasets * n_compositional, num_samples, ..., dims) + # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims) xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) - alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples, 1)) - sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples, 1)) conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) - # Apply subnet - subnet_out = self._apply_subnet(xz_flat, log_snr_flat, conditions=conditions_flat, training=training) - pred = self.output_projector(subnet_out, training=training) - - # Convert prediction to x - x_pred = self.convert_prediction_to_x( - pred=pred, z=xz_flat, alpha_t=alpha_flat, sigma_t=sigma_flat, log_snr_t=log_snr_flat - ) - - # Compute score: (α_t * x_pred - z) / σ_t² - score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat) + # Use standard score function + scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training) # Reshape back to compositional structure - score = ops.reshape(score, (n_datasets, n_compositional, num_samples) + dims) - return score + scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims) + return scores def _inverse_compositional( self, From 9a1ba32dc6e28b49b97cdb87ad0e41d8bbe518bd Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 20:17:00 +0200 Subject: [PATCH 044/101] add test for compute_prior_score_pre --- .../approximators/continuous_approximator.py | 13 +-- tests/test_approximators/conftest.py | 53 +++++++++ .../test_compositional_prior_score.py | 109 ++++++++++++++++++ .../test_diffusion_model/conftest.py | 47 ++++++++ .../test_compositional_sampling.py | 46 -------- 5 files changed, 214 insertions(+), 54 deletions(-) create mode 100644 tests/test_approximators/test_compositional_prior_score.py diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 7a0d757d9..075358c5f 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -699,7 +699,7 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: _samples, forward=False, log_det_jac=True ) else: - log_det_jac_standardize = 0 + log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) adapted_samples, log_det_jac = self.adapter( _samples, inverse=True, strict=False, log_det_jac=True, **kwargs @@ -708,15 +708,12 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: for key in adapted_samples: if isinstance(prior_score[key], np.ndarray): prior_score[key] = prior_score[key].astype("float32") - if len(log_det_jac) > 0: - prior_score[key] += log_det_jac[key] + if len(log_det_jac) > 0 and key in log_det_jac: + prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) - # make a tensor - out = keras.ops.concatenate( - list(prior_score.values()), axis=-1 - ) # todo: assumes same order, might be incorrect - return out + expand_right_as(log_det_jac_standardize, out) + out = keras.ops.concatenate(list(prior_score.values()), axis=-1) + return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) # Test prior score function, useful for debugging test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index a56802a3e..528e7969b 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -220,3 +220,56 @@ def approximator_with_summaries(request): ) case _: raise ValueError("Invalid param for approximator class.") + + +@pytest.fixture +def simple_log_simulator(): + """Create a simple simulator for testing.""" + import numpy as np + from bayesflow.simulators import Simulator + from bayesflow.utils.decorators import allow_batch_size + from bayesflow.types import Shape, Tensor + + class SimpleSimulator(Simulator): + """Simple simulator that generates mean and scale parameters.""" + + @allow_batch_size + def sample(self, batch_shape: Shape) -> dict[str, Tensor]: + # Generate parameters in original space + loc = np.random.normal(0.0, 1.0, size=batch_shape + (2,)) # location parameters + scale = np.random.lognormal(0.0, 0.5, size=batch_shape + (2,)) # scale parameters > 0 + + # Generate some dummy conditions + conditions = np.random.normal(0.0, 1.0, size=batch_shape + (3,)) + + return dict( + loc=loc.astype("float32"), scale=scale.astype("float32"), conditions=conditions.astype("float32") + ) + + return SimpleSimulator() + + +@pytest.fixture +def transforming_adapter(): + """Create an adapter that applies log transformation to scale parameters.""" + from bayesflow.adapters import Adapter + + adapter = Adapter() + adapter.to_array() + adapter.convert_dtype("float64", "float32") + + # Apply log transformation to scale parameters (to make them unbounded) + adapter.log(["scale"]) + + adapter.concatenate(["loc", "scale"], into="inference_variables") + adapter.concatenate(["conditions"], into="inference_conditions") + adapter.keep(["inference_variables", "inference_conditions"]) + return adapter + + +@pytest.fixture +def diffusion_network(): + """Create a diffusion network for compositional sampling.""" + from bayesflow.networks import DiffusionModel, MLP + + return DiffusionModel(subnet=MLP(widths=[32, 32])) diff --git a/tests/test_approximators/test_compositional_prior_score.py b/tests/test_approximators/test_compositional_prior_score.py new file mode 100644 index 000000000..cd4b81413 --- /dev/null +++ b/tests/test_approximators/test_compositional_prior_score.py @@ -0,0 +1,109 @@ +"""Tests for compositional sampling and prior score computation with adapters.""" + +import numpy as np +import keras + +from bayesflow import ContinuousApproximator +from bayesflow.utils import expand_right_as + + +def mock_prior_score_original_space(data_dict): + """Mock prior score function that expects data in original (loc, scale) space.""" + # The function receives data in the same format the compute_prior_score_pre creates + # after running the inverse adapter + loc = data_dict["loc"] + scale = data_dict["scale"] + + # Simple prior: N(0,1) for loc, LogNormal(0,0.5) for scale + loc_score = -loc + scale_score = -1.0 / scale - np.log(scale) / (0.25 * scale) + + return {"loc": loc_score, "scale": scale_score} + + +def test_prior_score_transforming_adapter(simple_log_simulator, transforming_adapter, diffusion_network): + """Test that prior scores work correctly with transforming adapter (log transformation).""" + + # Create approximator with transforming adapter + approximator = ContinuousApproximator( + adapter=transforming_adapter, + inference_network=diffusion_network, + ) + + # Generate test data and adapt it + data = simple_log_simulator.sample((2,)) + adapted_data = transforming_adapter(data) + + # Build approximator + approximator.build_from_data(adapted_data) + + # Test compositional sampling + n_datasets, n_compositional = 3, 5 + conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")} + + # This should work - the compute_prior_score_pre function should handle the inverse transformation + samples = approximator.compositional_sample( + num_samples=10, + conditions=conditions, + compute_prior_score=mock_prior_score_original_space, + ) + + assert "loc" in samples + assert "scale" in samples + assert samples["loc"].shape == (n_datasets, 10, 2) + assert samples["scale"].shape == (n_datasets, 10, 2) + + +def test_prior_score_jacobian_correction(simple_log_simulator, transforming_adapter, diffusion_network): + """Test that Jacobian correction is applied correctly in compute_prior_score_pre.""" + + # Create approximator with transforming adapter + approximator = ContinuousApproximator( + adapter=transforming_adapter, inference_network=diffusion_network, standardize=[] + ) + + # Build with dummy data + dummy_data_dict = simple_log_simulator.sample((1,)) + adapted_dummy_data = transforming_adapter(dummy_data_dict) + approximator.build_from_data(adapted_dummy_data) + + # Get the internal compute_prior_score_pre function + def get_compute_prior_score_pre(): + def compute_prior_score_pre(_samples): + if "inference_variables" in approximator.standardize: + _samples, log_det_jac_standardize = approximator.standardize_layers["inference_variables"]( + _samples, forward=False, log_det_jac=True + ) + else: + log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") + + _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) + adapted_samples, log_det_jac = approximator.adapter(_samples, inverse=True, strict=False, log_det_jac=True) + + prior_score = mock_prior_score_original_space(adapted_samples) + for key in adapted_samples: + if isinstance(prior_score[key], np.ndarray): + prior_score[key] = prior_score[key].astype("float32") + if len(log_det_jac) > 0 and key in log_det_jac: + prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) + + prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) + out = keras.ops.concatenate(list(prior_score.values()), axis=-1) + return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) + + return compute_prior_score_pre + + compute_prior_score_pre = get_compute_prior_score_pre() + + # Test with a known transformation + y_samples = adapted_dummy_data["inference_variables"] + scores = compute_prior_score_pre(y_samples) + scores_np = keras.ops.convert_to_numpy(scores)[0] # Remove batch dimension + + # With Jacobian correction: score_transformed = score_original - log|J| + old_scores = mock_prior_score_original_space(dummy_data_dict) + det_jac_scale = y_samples[0, 2:].sum() + expected_scores = np.array([old_scores["loc"][0], old_scores["scale"][0] - det_jac_scale]).flatten() + + # Check that scores are reasonably close + np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6) diff --git a/tests/test_networks/test_diffusion_model/conftest.py b/tests/test_networks/test_diffusion_model/conftest.py index b1ee915ae..581b4abde 100644 --- a/tests/test_networks/test_diffusion_model/conftest.py +++ b/tests/test_networks/test_diffusion_model/conftest.py @@ -1,4 +1,5 @@ import pytest +import keras @pytest.fixture() @@ -21,3 +22,49 @@ def edm_noise_schedule(): ) def noise_schedule(request): return request.getfixturevalue(request.param) + + +@pytest.fixture +def simple_diffusion_model(): + """Create a simple diffusion model for testing compositional sampling.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + return DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule="cosine", + prediction_type="noise", + loss_type="noise", + ) + + +@pytest.fixture +def compositional_conditions(): + """Create test conditions for compositional sampling.""" + batch_size = 2 + n_compositional = 3 + n_samples = 4 + condition_dim = 5 + + return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim)) + + +@pytest.fixture +def compositional_state(): + """Create test state for compositional sampling.""" + batch_size = 2 + n_samples = 4 + param_dim = 3 + + return keras.random.normal((batch_size, n_samples, param_dim)) + + +@pytest.fixture +def mock_prior_score(): + """Create a mock prior score function for testing.""" + + def prior_score_fn(theta): + # Simple quadratic prior: -0.5 * ||theta||^2 + return -theta + + return prior_score_fn diff --git a/tests/test_networks/test_diffusion_model/test_compositional_sampling.py b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py index 4fa0ebf59..2757bd28a 100644 --- a/tests/test_networks/test_diffusion_model/test_compositional_sampling.py +++ b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py @@ -2,52 +2,6 @@ import pytest -@pytest.fixture -def simple_diffusion_model(): - """Create a simple diffusion model for testing compositional sampling.""" - from bayesflow.networks.diffusion_model import DiffusionModel - from bayesflow.networks import MLP - - return DiffusionModel( - subnet=MLP(widths=[32, 32]), - noise_schedule="cosine", - prediction_type="noise", - loss_type="noise", - ) - - -@pytest.fixture -def compositional_conditions(): - """Create test conditions for compositional sampling.""" - batch_size = 2 - n_compositional = 3 - n_samples = 4 - condition_dim = 5 - - return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim)) - - -@pytest.fixture -def compositional_state(): - """Create test state for compositional sampling.""" - batch_size = 2 - n_samples = 4 - param_dim = 3 - - return keras.random.normal((batch_size, n_samples, param_dim)) - - -@pytest.fixture -def mock_prior_score(): - """Create a mock prior score function for testing.""" - - def prior_score_fn(theta): - # Simple quadratic prior: -0.5 * ||theta||^2 - return -theta - - return prior_score_fn - - def test_compositional_score_shape( simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score ): From 93b59ba0da9e3b556ff0c7b718219688139709d9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 20:44:56 +0200 Subject: [PATCH 045/101] fix order of prior scores --- bayesflow/approximators/continuous_approximator.py | 2 +- tests/test_approximators/conftest.py | 2 +- tests/test_approximators/test_compositional_prior_score.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 075358c5f..e4e4f09c2 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -712,7 +712,7 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) - out = keras.ops.concatenate(list(prior_score.values()), axis=-1) + out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) # Test prior score function, useful for debugging diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 528e7969b..5587901b5 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -261,7 +261,7 @@ def transforming_adapter(): # Apply log transformation to scale parameters (to make them unbounded) adapter.log(["scale"]) - adapter.concatenate(["loc", "scale"], into="inference_variables") + adapter.concatenate(["scale", "loc"], into="inference_variables") adapter.concatenate(["conditions"], into="inference_conditions") adapter.keep(["inference_variables", "inference_conditions"]) return adapter diff --git a/tests/test_approximators/test_compositional_prior_score.py b/tests/test_approximators/test_compositional_prior_score.py index cd4b81413..96ac7d29e 100644 --- a/tests/test_approximators/test_compositional_prior_score.py +++ b/tests/test_approximators/test_compositional_prior_score.py @@ -88,7 +88,7 @@ def compute_prior_score_pre(_samples): prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) - out = keras.ops.concatenate(list(prior_score.values()), axis=-1) + out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) return compute_prior_score_pre @@ -102,8 +102,9 @@ def compute_prior_score_pre(_samples): # With Jacobian correction: score_transformed = score_original - log|J| old_scores = mock_prior_score_original_space(dummy_data_dict) - det_jac_scale = y_samples[0, 2:].sum() - expected_scores = np.array([old_scores["loc"][0], old_scores["scale"][0] - det_jac_scale]).flatten() + # order of parameters is flipped due to concatenation in adapter + det_jac_scale = y_samples[0, :2].sum() + expected_scores = np.array([old_scores["scale"][0] - det_jac_scale, old_scores["loc"][0]]).flatten() # Check that scores are reasonably close np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6) From 922040d4cd36caf9f0a0baa2964ac9bc7cefe7b2 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 13:29:34 +0200 Subject: [PATCH 046/101] fix prior scores standardize --- .../approximators/continuous_approximator.py | 48 ++++++++--- tests/test_approximators/conftest.py | 15 ++++ .../test_compositional_prior_score.py | 79 ++----------------- 3 files changed, 58 insertions(+), 84 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index e4e4f09c2..24b02f145 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -14,7 +14,6 @@ squeeze_inner_estimates_dict, concatenate_valid, concatenate_valid_shapes, - expand_right_as, ) from bayesflow.utils.serialization import serialize, deserialize, serializable @@ -695,25 +694,52 @@ def compositional_sample( # Prepare prior scores to handle adapter def compute_prior_score_pre(_samples: Tensor) -> Tensor: if "inference_variables" in self.standardize: - _samples, log_det_jac_standardize = self.standardize_layers["inference_variables"]( - _samples, forward=False, log_det_jac=True - ) - else: - log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") + _samples = self.standardize_layers["inference_variables"](_samples, forward=False) _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) adapted_samples, log_det_jac = self.adapter( _samples, inverse=True, strict=False, log_det_jac=True, **kwargs ) + + if len(log_det_jac) > 0: + problematic_keys = [key for key in log_det_jac if log_det_jac[key] != 0.0] + raise NotImplementedError( + f"Cannot use compositional sampling with adapters " + f"that have non-zero log_det_jac. Problematic keys: {problematic_keys}" + ) + prior_score = compute_prior_score(adapted_samples) for key in adapted_samples: - if isinstance(prior_score[key], np.ndarray): - prior_score[key] = prior_score[key].astype("float32") - if len(log_det_jac) > 0 and key in log_det_jac: - prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) + prior_score[key] = prior_score[key].astype(np.float32) prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) - return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) + + if "inference_variables" in self.standardize: + # Apply jacobian correction from standardization + # For standardization T^{-1}(z) = z * std + mean, the jacobian is diagonal with std on diagonal + # The gradient of log|det(J)| w.r.t. z is 0 since log|det(J)| = sum(log(std)) is constant w.r.t. z + # But we need to transform the score: score_z = score_x * std where x = T^{-1}(z) + standardize_layer = self.standardize_layers["inference_variables"] + + # Compute the correct standard deviation for all components + std_components = [] + for idx in range(len(standardize_layer.moving_mean)): + std_val = standardize_layer.moving_std(idx) + std_components.append(std_val) + + # Concatenate std components to match the shape of out + if len(std_components) == 1: + std = std_components[0] + else: + std = keras.ops.concatenate(std_components, axis=-1) + + # Expand std to match batch dimension of out + std_expanded = keras.ops.expand_dims(std, (0, 1)) # Add batch, sample dimensions + std_expanded = keras.ops.tile(std_expanded, [n_datasets, num_samples, 1]) + + # Apply the jacobian: score_z = score_x * std + out = out * std_expanded + return out # Test prior score function, useful for debugging test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 5587901b5..befc0da06 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -249,6 +249,21 @@ def sample(self, batch_shape: Shape) -> dict[str, Tensor]: return SimpleSimulator() +@pytest.fixture +def identity_adapter(): + """Create an adapter that applies no transformation to the parameters.""" + from bayesflow.adapters import Adapter + + adapter = Adapter() + adapter.to_array() + adapter.convert_dtype("float64", "float32") + + adapter.concatenate(["loc"], into="inference_variables") + adapter.concatenate(["conditions"], into="inference_conditions") + adapter.keep(["inference_variables", "inference_conditions"]) + return adapter + + @pytest.fixture def transforming_adapter(): """Create an adapter that applies log transformation to scale parameters.""" diff --git a/tests/test_approximators/test_compositional_prior_score.py b/tests/test_approximators/test_compositional_prior_score.py index 96ac7d29e..02be46c00 100644 --- a/tests/test_approximators/test_compositional_prior_score.py +++ b/tests/test_approximators/test_compositional_prior_score.py @@ -1,38 +1,31 @@ """Tests for compositional sampling and prior score computation with adapters.""" import numpy as np -import keras from bayesflow import ContinuousApproximator -from bayesflow.utils import expand_right_as def mock_prior_score_original_space(data_dict): - """Mock prior score function that expects data in original (loc, scale) space.""" - # The function receives data in the same format the compute_prior_score_pre creates - # after running the inverse adapter + """Mock prior score function that expects data in original space.""" loc = data_dict["loc"] - scale = data_dict["scale"] - # Simple prior: N(0,1) for loc, LogNormal(0,0.5) for scale + # Simple prior: N(0,1) for loc loc_score = -loc - scale_score = -1.0 / scale - np.log(scale) / (0.25 * scale) + return {"loc": loc_score} - return {"loc": loc_score, "scale": scale_score} - -def test_prior_score_transforming_adapter(simple_log_simulator, transforming_adapter, diffusion_network): +def test_prior_score_identity_adapter(simple_log_simulator, identity_adapter, diffusion_network): """Test that prior scores work correctly with transforming adapter (log transformation).""" # Create approximator with transforming adapter approximator = ContinuousApproximator( - adapter=transforming_adapter, + adapter=identity_adapter, inference_network=diffusion_network, ) # Generate test data and adapt it data = simple_log_simulator.sample((2,)) - adapted_data = transforming_adapter(data) + adapted_data = identity_adapter(data) # Build approximator approximator.build_from_data(adapted_data) @@ -40,8 +33,6 @@ def test_prior_score_transforming_adapter(simple_log_simulator, transforming_ada # Test compositional sampling n_datasets, n_compositional = 3, 5 conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")} - - # This should work - the compute_prior_score_pre function should handle the inverse transformation samples = approximator.compositional_sample( num_samples=10, conditions=conditions, @@ -49,62 +40,4 @@ def test_prior_score_transforming_adapter(simple_log_simulator, transforming_ada ) assert "loc" in samples - assert "scale" in samples assert samples["loc"].shape == (n_datasets, 10, 2) - assert samples["scale"].shape == (n_datasets, 10, 2) - - -def test_prior_score_jacobian_correction(simple_log_simulator, transforming_adapter, diffusion_network): - """Test that Jacobian correction is applied correctly in compute_prior_score_pre.""" - - # Create approximator with transforming adapter - approximator = ContinuousApproximator( - adapter=transforming_adapter, inference_network=diffusion_network, standardize=[] - ) - - # Build with dummy data - dummy_data_dict = simple_log_simulator.sample((1,)) - adapted_dummy_data = transforming_adapter(dummy_data_dict) - approximator.build_from_data(adapted_dummy_data) - - # Get the internal compute_prior_score_pre function - def get_compute_prior_score_pre(): - def compute_prior_score_pre(_samples): - if "inference_variables" in approximator.standardize: - _samples, log_det_jac_standardize = approximator.standardize_layers["inference_variables"]( - _samples, forward=False, log_det_jac=True - ) - else: - log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") - - _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) - adapted_samples, log_det_jac = approximator.adapter(_samples, inverse=True, strict=False, log_det_jac=True) - - prior_score = mock_prior_score_original_space(adapted_samples) - for key in adapted_samples: - if isinstance(prior_score[key], np.ndarray): - prior_score[key] = prior_score[key].astype("float32") - if len(log_det_jac) > 0 and key in log_det_jac: - prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) - - prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) - out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) - return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) - - return compute_prior_score_pre - - compute_prior_score_pre = get_compute_prior_score_pre() - - # Test with a known transformation - y_samples = adapted_dummy_data["inference_variables"] - scores = compute_prior_score_pre(y_samples) - scores_np = keras.ops.convert_to_numpy(scores)[0] # Remove batch dimension - - # With Jacobian correction: score_transformed = score_original - log|J| - old_scores = mock_prior_score_original_space(dummy_data_dict) - # order of parameters is flipped due to concatenation in adapter - det_jac_scale = y_samples[0, :2].sum() - expected_scores = np.array([old_scores["scale"][0] - det_jac_scale, old_scores["loc"][0]]).flatten() - - # Check that scores are reasonably close - np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6) From b2991d177bb24deab200cf6419ad0c823781febb Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 13:51:37 +0200 Subject: [PATCH 047/101] better standard values for compositional --- .../networks/diffusion_model/diffusion_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 7a75b0a9b..b026d5ea9 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -799,21 +799,21 @@ def _inverse_compositional( """ Inverse pass for compositional diffusion sampling. """ - integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} + n_compositional = ops.shape(conditions)[1] + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) - - if mini_batch_size is not None: - # if backend is jax, mini batching does not work - if keras.backend.backend() == "jax": + if keras.backend.backend() == "jax": + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + if mini_batch_size is not None: raise ValueError( "Mini batching is not supported with JAX backend. Set mini_batch_size to None " "or use another backend." ) + else: + mini_batch_size = integrate_kwargs.get("mini_batch_size", int(n_compositional * 0.1)) # x is sampled from a normal distribution, must be scaled with var 1/n_compositional - n_compositional = ops.shape(conditions)[1] scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) From d2a36a8349bbe95acee2e31ba4f065958205a751 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 13:57:32 +0200 Subject: [PATCH 048/101] better compositional_bridge --- bayesflow/networks/diffusion_model/diffusion_model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index b026d5ea9..c9e2a2271 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -3,6 +3,7 @@ import keras from keras import ops +import numpy as np from ..inference_network import InferenceNetwork from bayesflow.types import Tensor, Shape @@ -600,10 +601,10 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss} - @staticmethod - def compositional_bridge(time: Tensor) -> Tensor: + def compositional_bridge(self, time: Tensor) -> Tensor: """ - Bridge function for compositional diffusion. In the simplest case, this is just 1. + Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1. + Otherwise, it can be used to scale the compositional score over time. Parameters ---------- @@ -616,7 +617,7 @@ def compositional_bridge(time: Tensor) -> Tensor: Bridge function value with same shape as time. """ - return ops.ones_like(time) + return ops.exp(-np.log(self.compositional_d0 / self.compositional_d1) * time) def compositional_velocity( self, @@ -812,6 +813,8 @@ def _inverse_compositional( ) else: mini_batch_size = integrate_kwargs.get("mini_batch_size", int(n_compositional * 0.1)) + self.compositional_d0 = float(integrate_kwargs.pop("compositional_d0", 1.0)) + self.compositional_d1 = float(integrate_kwargs.pop("compositional_d1", 1.0)) # x is sampled from a normal distribution, must be scaled with var 1/n_compositional scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) From 0ff960f8dfd80237830a368f3333e4f9e91196c7 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 13:59:01 +0200 Subject: [PATCH 049/101] fix integrate_kwargs --- bayesflow/networks/diffusion_model/diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index c9e2a2271..36063b577 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -812,7 +812,7 @@ def _inverse_compositional( "or use another backend." ) else: - mini_batch_size = integrate_kwargs.get("mini_batch_size", int(n_compositional * 0.1)) + mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) self.compositional_d0 = float(integrate_kwargs.pop("compositional_d0", 1.0)) self.compositional_d1 = float(integrate_kwargs.pop("compositional_d1", 1.0)) From b2ef75522268811558899b123f63cd8ec2ae38ff Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 14:06:07 +0200 Subject: [PATCH 050/101] fix integrate_kwargs --- bayesflow/networks/diffusion_model/diffusion_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 36063b577..99b712203 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -617,7 +617,7 @@ def compositional_bridge(self, time: Tensor) -> Tensor: Bridge function value with same shape as time. """ - return ops.exp(-np.log(self.compositional_d0 / self.compositional_d1) * time) + return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time) def compositional_velocity( self, @@ -813,8 +813,8 @@ def _inverse_compositional( ) else: mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) - self.compositional_d0 = float(integrate_kwargs.pop("compositional_d0", 1.0)) - self.compositional_d1 = float(integrate_kwargs.pop("compositional_d1", 1.0)) + self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) + self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) # x is sampled from a normal distribution, must be scaled with var 1/n_compositional scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) @@ -893,6 +893,7 @@ def score_fn(time, xz): **integrate_kwargs, ) else: + integrate_kwargs.pop("corrector_steps", None) def deltas(time, xz): return { From ca7f3bdaf9700d2fd6268c844c4abf6fb5963139 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 12:21:13 +0200 Subject: [PATCH 051/101] fix kwargs in sample --- bayesflow/networks/transformers/mab.py | 4 +- .../networks/transformers/set_transformer.py | 10 ++- tests/test_approximators/test_sample.py | 90 +++++++++++++++++++ 3 files changed, 99 insertions(+), 5 deletions(-) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 5bd7c9dff..eddb8cf09 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, filter_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs + query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index d0d748067..94690f3ef 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -1,7 +1,7 @@ import keras from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same +from bayesflow.utils import check_lengths_same, filter_kwargs from bayesflow.utils.serialization import serializable from ..summary_network import SummaryNetwork @@ -147,7 +147,11 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: out : Tensor Output of shape (batch_size, set_size, output_dim) """ - summary = self.attention_blocks(input_set, training=training, **kwargs) - summary = self.pooling_by_attention(summary, training=training, **kwargs) + summary = self.attention_blocks( + input_set, training=training, **filter_kwargs(kwargs, self.attention_blocks.call) + ) + summary = self.pooling_by_attention( + summary, training=training, **filter_kwargs(kwargs, self.pooling_by_attention.call) + ) summary = self.output_projector(summary) return summary diff --git a/tests/test_approximators/test_sample.py b/tests/test_approximators/test_sample.py index c62ffc581..e76b72a40 100644 --- a/tests/test_approximators/test_sample.py +++ b/tests/test_approximators/test_sample.py @@ -1,3 +1,4 @@ +import pytest import keras from tests.utils import check_combination_simulator_adapter @@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter): samples = approximator.sample(num_samples=2, conditions=data) assert isinstance(samples, dict) + + +@pytest.mark.parametrize("inference_network_type", ["flow_matching", "diffusion_model"]) +@pytest.mark.parametrize("summary_network_type", ["none", "deep_set", "set_transformer", "time_series"]) +@pytest.mark.parametrize("method", ["euler", "rk45", "euler_maruyama"]) +def test_approximator_sample_with_integration_methods( + inference_network_type, summary_network_type, method, simulator, adapter +): + """Test approximator sampling with different integration methods and summary networks. + + Tests flow matching and diffusion models with different ODE/SDE solvers: + - euler, rk45: Available for both flow matching and diffusion models + - euler_maruyama: Only for diffusion models (stochastic) + + Also tests with different summary network types. + """ + batch_size = 8 # Use smaller batch size for faster tests + check_combination_simulator_adapter(simulator, adapter) + + # Skip euler_maruyama for flow matching (deterministic model) + if inference_network_type == "flow_matching" and method == "euler_maruyama": + pytest.skip("euler_maruyama is only available for diffusion models") + + # Create inference network based on type + if inference_network_type == "flow_matching": + from bayesflow.networks import FlowMatching, MLP + + inference_network = FlowMatching( + subnet=MLP(widths=[32, 32]), + integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests + ) + elif inference_network_type == "diffusion_model": + from bayesflow.networks import DiffusionModel, MLP + + inference_network = DiffusionModel( + subnet=MLP(widths=[32, 32]), + integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests + ) + else: + pytest.skip(f"Unsupported inference network type: {inference_network_type}") + + # Create summary network based on type + summary_network = None + if summary_network_type != "none": + if summary_network_type == "deep_set": + from bayesflow.networks import DeepSet, MLP + + summary_network = DeepSet(subnet=MLP(widths=[16, 16])) + elif summary_network_type == "set_transformer": + from bayesflow.networks import SetTransformer + + summary_network = SetTransformer(embed_dims=[16, 16], mlp_widths=[16, 16]) + elif summary_network_type == "time_series": + from bayesflow.networks import TimeSeriesNetwork + + summary_network = TimeSeriesNetwork(subnet_kwargs={"widths": [16, 16]}, cell_type="lstm") + else: + pytest.skip(f"Unsupported summary network type: {summary_network_type}") + + # Update adapter to include summary variables if summary network is present + from bayesflow import ContinuousApproximator + + adapter = ContinuousApproximator.build_adapter( + inference_variables=["mean", "std"], + summary_variables=["x"], # Use x as summary variable for testing + ) + + # Create approximator + from bayesflow import ContinuousApproximator + + approximator = ContinuousApproximator( + adapter=adapter, inference_network=inference_network, summary_network=summary_network + ) + + # Generate test data + num_batches = 2 # Use fewer batches for faster tests + data = simulator.sample((num_batches * batch_size,)) + + # Build approximator + batch = adapter(data) + batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch) + batch_shapes = keras.tree.map_structure(keras.ops.shape, batch) + approximator.build(batch_shapes) + + # Test sampling with the specified method + samples = approximator.sample(num_samples=2, conditions=data, method=method) + + # Verify results + assert isinstance(samples, dict) From 2c161c6d85f675eac9347d036cbc93b76524b323 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:32:20 +0200 Subject: [PATCH 052/101] fix kwargs in set transformer --- bayesflow/networks/transformers/isab.py | 1 + bayesflow/networks/transformers/mab.py | 4 ++-- bayesflow/networks/transformers/set_transformer.py | 10 +++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py index 03f15a561..1b763c2b3 100644 --- a/bayesflow/networks/transformers/isab.py +++ b/bayesflow/networks/transformers/isab.py @@ -107,5 +107,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: batch_size = keras.ops.shape(input_set)[0] inducing_points_expanded = keras.ops.expand_dims(self.inducing_points, axis=0) inducing_points_tiled = keras.ops.tile(inducing_points_expanded, [batch_size, 1, 1]) + print(kwargs) h = self.mab0(inducing_points_tiled, input_set, training=training, **kwargs) return self.mab1(input_set, h, training=training, **kwargs) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index eddb8cf09..5bd7c9dff 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs, filter_kwargs +from bayesflow.utils import layer_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) + query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 94690f3ef..7e9da76ea 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -1,7 +1,7 @@ import keras from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same, filter_kwargs +from bayesflow.utils import check_lengths_same from bayesflow.utils.serialization import serializable from ..summary_network import SummaryNetwork @@ -147,11 +147,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: out : Tensor Output of shape (batch_size, set_size, output_dim) """ - summary = self.attention_blocks( - input_set, training=training, **filter_kwargs(kwargs, self.attention_blocks.call) - ) - summary = self.pooling_by_attention( - summary, training=training, **filter_kwargs(kwargs, self.pooling_by_attention.call) - ) + summary = self.attention_blocks(input_set, training=training) + summary = self.pooling_by_attention(summary, training=training) summary = self.output_projector(summary) return summary From 9d4c1a1c605c7e226ea72f97321e1a55c7e718ac Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:37:38 +0200 Subject: [PATCH 053/101] fix kwargs in set transformer --- bayesflow/networks/transformers/mab.py | 4 ++-- bayesflow/networks/transformers/set_transformer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 5bd7c9dff..eddb8cf09 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, filter_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs + query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 7e9da76ea..bd8290272 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -148,6 +148,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: Output of shape (batch_size, set_size, output_dim) """ summary = self.attention_blocks(input_set, training=training) - summary = self.pooling_by_attention(summary, training=training) + summary = self.pooling_by_attention(summary, training=training, **kwargs) summary = self.output_projector(summary) return summary From ea0659d14962e4b423a42e5bbf53dd79f1797eb9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:38:40 +0200 Subject: [PATCH 054/101] remove print --- bayesflow/networks/transformers/isab.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py index 1b763c2b3..03f15a561 100644 --- a/bayesflow/networks/transformers/isab.py +++ b/bayesflow/networks/transformers/isab.py @@ -107,6 +107,5 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: batch_size = keras.ops.shape(input_set)[0] inducing_points_expanded = keras.ops.expand_dims(self.inducing_points, axis=0) inducing_points_tiled = keras.ops.tile(inducing_points_expanded, [batch_size, 1, 1]) - print(kwargs) h = self.mab0(inducing_points_tiled, input_set, training=training, **kwargs) return self.mab1(input_set, h, training=training, **kwargs) From 9220816a1662dc42ddd2014ed4d7098b8230ac6d Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 22 Sep 2025 10:31:49 +0200 Subject: [PATCH 055/101] new class for compositional diffusion --- .../networks/diffusion_model/__init__.py | 1 + .../compositional_diffusion_model.py | 412 ++++++++++++++++++ .../diffusion_model/diffusion_model.py | 315 +------------ 3 files changed, 414 insertions(+), 314 deletions(-) create mode 100644 bayesflow/networks/diffusion_model/compositional_diffusion_model.py diff --git a/bayesflow/networks/diffusion_model/__init__.py b/bayesflow/networks/diffusion_model/__init__.py index 341c84c62..ca8aa19be 100644 --- a/bayesflow/networks/diffusion_model/__init__.py +++ b/bayesflow/networks/diffusion_model/__init__.py @@ -1,4 +1,5 @@ from .diffusion_model import DiffusionModel +from .compositional_diffusion_model import CompositionalDiffusionModel from .schedules import CosineNoiseSchedule from .schedules import EDMNoiseSchedule from .schedules import NoiseSchedule diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py new file mode 100644 index 000000000..cde1290ed --- /dev/null +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -0,0 +1,412 @@ +from typing import Literal, Callable + +import keras +import numpy as np +from keras import ops + +from bayesflow.types import Tensor +from bayesflow.utils import ( + expand_right_as, + integrate, + integrate_stochastic, +) +from bayesflow.utils.serialization import serializable +from diffusion_model import DiffusionModel +from .schedules.noise_schedule import NoiseSchedule + + +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) +class CompositionalDiffusionModel(DiffusionModel): + """Compositional Diffusion Model for Amortized Bayesian Inference. Allows to learn a single + diffusion model one single i.i.d simulations that can perform inference for multiple simulations by leveraging a + compositional score function as in [2]. + + [1] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021) + [2] Compositional Score Modeling for Simulation-Based Inference: Geffner et al. (2023) + [3] Compositional amortized inference for large-scale hierarchical Bayesian models: Arruda et al. (2025) + """ + + MLP_DEFAULT_CONFIG = { + "widths": (256, 256, 256, 256, 256), + "activation": "mish", + "kernel_initializer": "he_normal", + "residual": True, + "dropout": 0.0, + "spectral_normalization": False, + } + + INTEGRATE_DEFAULT_CONFIG = { + "method": "euler_maruyama", + "corrector_steps": 1, + "steps": 100, + } + + def __init__( + self, + *, + subnet: str | type | keras.Layer = "mlp", + noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm", + prediction_type: Literal["velocity", "noise", "F", "x"] = "F", + loss_type: Literal["velocity", "noise", "F"] = "noise", + subnet_kwargs: dict[str, any] = None, + schedule_kwargs: dict[str, any] = None, + integrate_kwargs: dict[str, any] = None, + **kwargs, + ): + """ + Initializes a diffusion model with configurable subnet architecture, noise schedule, + and prediction/loss types for amortized Bayesian inference. + + Note, that score-based diffusion is the most sluggish of all available samplers, + so expect slower inference times than flow matching and much slower than normalizing flows. + + Parameters + ---------- + subnet : str, type or keras.Layer, optional + Architecture for the transformation network. Can be "mlp", a custom network class, or + a Layer object, e.g., `bayesflow.networks.MLP(widths=[32, 32])`. Default is "mlp". + noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional + Noise schedule controlling the diffusion dynamics. Can be a string identifier, + a schedule class, or a pre-initialized schedule instance. Default is "edm". + prediction_type : {'velocity', 'noise', 'F', 'x'}, optional + Output format of the model's prediction. Default is "F". + loss_type : {'velocity', 'noise', 'F'}, optional + Loss function used to train the model. Default is "noise". + subnet_kwargs : dict[str, any], optional + Additional keyword arguments passed to the subnet constructor. Default is None. + schedule_kwargs : dict[str, any], optional + Additional keyword arguments passed to the noise schedule constructor. Default is None. + integrate_kwargs : dict[str, any], optional + Configuration dictionary for integration during training or inference. Default is None. + concatenate_subnet_input: bool, optional + Flag for advanced users to control whether all inputs to the subnet should be concatenated + into a single vector or passed as separate arguments. If set to False, the subnet + must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio), + and optional 'conditions'. Default is True. + + **kwargs + Additional keyword arguments passed to the base class and internal components. + """ + super().__init__( + subnet=subnet, + noise_schedule=noise_schedule, + prediction_type=prediction_type, + loss_type=loss_type, + subnet_kwargs=subnet_kwargs, + schedule_kwargs=schedule_kwargs, + integrate_kwargs=integrate_kwargs, + **kwargs, + ) + + def compositional_bridge(self, time: Tensor) -> Tensor: + """ + Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1. + Otherwise, it can be used to scale the compositional score over time. + + Parameters + ---------- + time: Tensor + Time step for the diffusion process. + + Returns + ------- + Tensor + Bridge function value with same shape as time. + + """ + return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time) + + def compositional_velocity( + self, + xz: Tensor, + time: float | Tensor, + stochastic_solver: bool, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional velocity for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + stochastic_solver : bool + Whether to use stochastic (SDE) or deterministic (ODE) formulation + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + compositional_score = self.compositional_score( + xz=xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + + # Compute velocity using standard drift-diffusion formulation + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) + + if stochastic_solver: + # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW + velocity = f - g_squared * compositional_score + else: + # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt + velocity = f - 0.5 * g_squared * compositional_score + + return velocity + + def compositional_score( + self, + xz: Tensor, + time: float | Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional score for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + if conditions is None: + raise ValueError("Conditions are required for compositional sampling") + + # Get shapes for compositional structure + n_compositional = ops.shape(conditions)[1] + + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + # Compute individual dataset scores + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + conditions_batch = conditions[:, mini_batch_idx] + else: + conditions_batch = conditions + individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training) + + # Compute prior score component + prior_score = compute_prior_score(xz) + weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score + + # Sum individual scores across compositional dimensions + summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) + + # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) + compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) + return compositional_score + + def _compute_individual_scores( + self, + xz: Tensor, + log_snr_t: Tensor, + conditions: Tensor, + training: bool, + ) -> Tensor: + """ + Compute individual dataset scores s_ψ(θ,t,yᵢ) for each compositional condition. + + Returns + ------- + Tensor + Individual scores with shape (n_datasets, n_compositional, ...) + """ + # Get shapes + xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) + conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) + n_datasets, n_compositional = conditions_shape[0], conditions_shape[1] + conditions_dims = tuple(conditions_shape[3:]) + num_samples = xz_shape[1] + dims = tuple(xz_shape[2:]) + + # Expand xz to match compositional structure + xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) + xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) + + # Expand log_snr_t to match compositional structure + log_snr_expanded = ops.expand_dims(log_snr_t, axis=1) + log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) + + # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims) + xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) + log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) + conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) + + # Use standard score function + scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training) + + # Reshape back to compositional structure + scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims) + return scores + + def _inverse_compositional( + self, + z: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Inverse pass for compositional diffusion sampling. + """ + n_compositional = ops.shape(conditions)[1] + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + if keras.backend.backend() == "jax": + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + if mini_batch_size is not None: + raise ValueError( + "Mini batching is not supported with JAX backend. Set mini_batch_size to None " + "or use another backend." + ) + else: + mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) + self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) + self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) + + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional + scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) + z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) + + if density: + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for density computation.") + + def deltas(time, xz): + v = self.compositional_velocity( + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) + return {"xz": v, "trace": trace} + + state = { + "xz": z, + "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), + } + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) - ops.squeeze(state["trace"], axis=-1) + return x, log_density + + state = {"xz": z} + + if integrate_kwargs["method"] == "euler_maruyama": + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, + time=time, + stochastic_solver=True, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + def diffusion(time, xz): + return {"xz": self.diffusion_term(xz, time=time, training=training)} + + score_fn = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def score_fn(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = integrate_stochastic( + drift_fn=deltas, + diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, + state=state, + seed=self.seed_generator, + **integrate_kwargs, + ) + else: + integrate_kwargs.pop("corrector_steps", None) + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + return x diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 99b712203..9955c4abc 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -1,9 +1,8 @@ from collections.abc import Sequence -from typing import Literal, Callable +from typing import Literal import keras from keras import ops -import numpy as np from ..inference_network import InferenceNetwork from bayesflow.types import Tensor, Shape @@ -600,315 +599,3 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss} - - def compositional_bridge(self, time: Tensor) -> Tensor: - """ - Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1. - Otherwise, it can be used to scale the compositional score over time. - - Parameters - ---------- - time: Tensor - Time step for the diffusion process. - - Returns - ------- - Tensor - Bridge function value with same shape as time. - - """ - return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time) - - def compositional_velocity( - self, - xz: Tensor, - time: float | Tensor, - stochastic_solver: bool, - conditions: Tensor, - compute_prior_score: Callable[[Tensor], Tensor], - mini_batch_size: int | None = None, - training: bool = False, - ) -> Tensor: - """ - Computes the compositional velocity for multiple datasets using the formula: - s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) - - Parameters - ---------- - xz : Tensor - The current state of the latent variable, shape (n_datasets, n_compositional, ...) - time : float or Tensor - Time step for the diffusion process - stochastic_solver : bool - Whether to use stochastic (SDE) or deterministic (ODE) formulation - conditions : Tensor - Conditional inputs with compositional structure (n_datasets, n_compositional, ...) - compute_prior_score: Callable - Function to compute the prior score ∇_θ log p(θ). - mini_batch_size : int or None - Mini batch size for computing individual scores. If None, use all conditions. - training : bool, optional - Whether in training mode - - Returns - ------- - Tensor - Compositional velocity of same shape as input xz - """ - # Calculate standard noise schedule components - log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) - log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - - compositional_score = self.compositional_score( - xz=xz, - time=time, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - - # Compute velocity using standard drift-diffusion formulation - f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) - - if stochastic_solver: - # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW - velocity = f - g_squared * compositional_score - else: - # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt - velocity = f - 0.5 * g_squared * compositional_score - - return velocity - - def compositional_score( - self, - xz: Tensor, - time: float | Tensor, - conditions: Tensor, - compute_prior_score: Callable[[Tensor], Tensor], - mini_batch_size: int | None = None, - training: bool = False, - ) -> Tensor: - """ - Computes the compositional score for multiple datasets using the formula: - s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) - - Parameters - ---------- - xz : Tensor - The current state of the latent variable, shape (n_datasets, n_compositional, ...) - time : float or Tensor - Time step for the diffusion process - conditions : Tensor - Conditional inputs with compositional structure (n_datasets, n_compositional, ...) - compute_prior_score: Callable - Function to compute the prior score ∇_θ log p(θ). - mini_batch_size : int or None - Mini batch size for computing individual scores. If None, use all conditions. - training : bool, optional - Whether in training mode - - Returns - ------- - Tensor - Compositional velocity of same shape as input xz - """ - if conditions is None: - raise ValueError("Conditions are required for compositional sampling") - - # Get shapes for compositional structure - n_compositional = ops.shape(conditions)[1] - - # Calculate standard noise schedule components - log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) - log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - - # Compute individual dataset scores - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - mini_batch_idx = mini_batch_idx[:mini_batch_size] - conditions_batch = conditions[:, mini_batch_idx] - else: - conditions_batch = conditions - individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training) - - # Compute prior score component - prior_score = compute_prior_score(xz) - weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score - - # Sum individual scores across compositional dimensions - summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) - - # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) - time_tensor = ops.cast(time, dtype=ops.dtype(xz)) - compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) - return compositional_score - - def _compute_individual_scores( - self, - xz: Tensor, - log_snr_t: Tensor, - conditions: Tensor, - training: bool, - ) -> Tensor: - """ - Compute individual dataset scores s_ψ(θ,t,yᵢ) for each compositional condition. - - Returns - ------- - Tensor - Individual scores with shape (n_datasets, n_compositional, ...) - """ - # Get shapes - xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) - conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) - n_datasets, n_compositional = conditions_shape[0], conditions_shape[1] - conditions_dims = tuple(conditions_shape[3:]) - num_samples = xz_shape[1] - dims = tuple(xz_shape[2:]) - - # Expand xz to match compositional structure - xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) - xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) - - # Expand log_snr_t to match compositional structure - log_snr_expanded = ops.expand_dims(log_snr_t, axis=1) - log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) - - # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims) - xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) - log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) - conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) - - # Use standard score function - scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training) - - # Reshape back to compositional structure - scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims) - return scores - - def _inverse_compositional( - self, - z: Tensor, - conditions: Tensor, - compute_prior_score: Callable[[Tensor], Tensor], - density: bool = False, - training: bool = False, - **kwargs, - ) -> Tensor | tuple[Tensor, Tensor]: - """ - Inverse pass for compositional diffusion sampling. - """ - n_compositional = ops.shape(conditions)[1] - integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1} - integrate_kwargs = integrate_kwargs | self.integrate_kwargs - integrate_kwargs = integrate_kwargs | kwargs - if keras.backend.backend() == "jax": - mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) - if mini_batch_size is not None: - raise ValueError( - "Mini batching is not supported with JAX backend. Set mini_batch_size to None " - "or use another backend." - ) - else: - mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) - self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) - self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) - - # x is sampled from a normal distribution, must be scaled with var 1/n_compositional - scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) - z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) - - if density: - if integrate_kwargs["method"] == "euler_maruyama": - raise ValueError("Stochastic methods are not supported for density computation.") - - def deltas(time, xz): - v = self.compositional_velocity( - xz, - time=time, - stochastic_solver=False, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) - return {"xz": v, "trace": trace} - - state = { - "xz": z, - "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), - } - state = integrate(deltas, state, **integrate_kwargs) - - x = state["xz"] - log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) - ops.squeeze(state["trace"], axis=-1) - return x, log_density - - state = {"xz": z} - - if integrate_kwargs["method"] == "euler_maruyama": - - def deltas(time, xz): - return { - "xz": self.compositional_velocity( - xz, - time=time, - stochastic_solver=True, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - def diffusion(time, xz): - return {"xz": self.diffusion_term(xz, time=time, training=training)} - - score_fn = None - if "corrector_steps" in integrate_kwargs: - if integrate_kwargs["corrector_steps"] > 0: - - def score_fn(time, xz): - return { - "xz": self.compositional_score( - xz, - time=time, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - state = integrate_stochastic( - drift_fn=deltas, - diffusion_fn=diffusion, - score_fn=score_fn, - noise_schedule=self.noise_schedule, - state=state, - seed=self.seed_generator, - **integrate_kwargs, - ) - else: - integrate_kwargs.pop("corrector_steps", None) - - def deltas(time, xz): - return { - "xz": self.compositional_velocity( - xz, - time=time, - stochastic_solver=False, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - state = integrate(deltas, state, **integrate_kwargs) - - x = state["xz"] - return x From ee1c3209a429f0c03e1fa698ddeb6f9bf5c5fe5a Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 22 Sep 2025 10:34:16 +0200 Subject: [PATCH 056/101] fix import --- .../networks/diffusion_model/compositional_diffusion_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py index cde1290ed..8095aba7d 100644 --- a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -11,7 +11,7 @@ integrate_stochastic, ) from bayesflow.utils.serialization import serializable -from diffusion_model import DiffusionModel +from .diffusion_model import DiffusionModel from .schedules.noise_schedule import NoiseSchedule @@ -299,7 +299,7 @@ def _inverse_compositional( Inverse pass for compositional diffusion sampling. """ n_compositional = ops.shape(conditions)[1] - integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1} + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if keras.backend.backend() == "jax": From e6513c1992fca355cdd0ebbaeecfda9db0710731 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 26 Sep 2025 15:41:40 +0200 Subject: [PATCH 057/101] add import --- bayesflow/networks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index f71d4b536..fb9819445 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -7,7 +7,7 @@ from .consistency_models import ConsistencyModel from .coupling_flow import CouplingFlow from .deep_set import DeepSet -from .diffusion_model import DiffusionModel +from .diffusion_model import DiffusionModel, CompositionalDiffusionModel from .flow_matching import FlowMatching from .inference_network import InferenceNetwork from .point_inference_network import PointInferenceNetwork From e87f9d153dac4b9a1fdfe944b7417a768967c66b Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 26 Sep 2025 18:13:29 +0200 Subject: [PATCH 058/101] fix mini_batch_size --- .../networks/diffusion_model/compositional_diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py index 8095aba7d..abd7d49a9 100644 --- a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -310,7 +310,7 @@ def _inverse_compositional( "or use another backend." ) else: - mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) + mini_batch_size = min(integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)), 1) self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) From 983cb8d399097dd9d9451b02b830c21a16febc8b Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 26 Sep 2025 18:16:02 +0200 Subject: [PATCH 059/101] fix mini_batch_size --- .../networks/diffusion_model/compositional_diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py index abd7d49a9..171184314 100644 --- a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -310,7 +310,7 @@ def _inverse_compositional( "or use another backend." ) else: - mini_batch_size = min(integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)), 1) + mini_batch_size = max(integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)), 1) self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) From 3b887dc8ff8e3a7c8970a8e7a5445db73bc49aa3 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 27 Oct 2025 11:20:26 +0100 Subject: [PATCH 060/101] fix scm --- .../stable_consistency_model/stable_consistency_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index 6ce27fdf4..dc092ab4e 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -307,7 +307,7 @@ def compute_metrics( r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here def f_teacher(x, t): - o = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + o = self._apply_subnet(x, self.time_emb(t), conditions, training=stage == "training") return self.subnet_projector(o) primals = (xt / self.sigma, t) @@ -321,7 +321,7 @@ def f_teacher(x, t): cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt) # calculate output of the network - subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + subnet_out = self._apply_subnet(xt / self.sigma, self.time_emb(t), conditions, training=stage == "training") student_out = self.subnet_projector(subnet_out) # calculate the tangent From 64516a464ba828a6dd45db7f03f5e4cf4c7cb79a Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 29 Oct 2025 17:31:44 +0100 Subject: [PATCH 061/101] fix saving --- .../stable_consistency_model/stable_consistency_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index dc092ab4e..0f787d44f 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -105,7 +105,6 @@ def __init__( ) embedding_kwargs = embedding_kwargs or {} - self.embedding_kwargs = embedding_kwargs self.time_emb = FourierEmbedding(**embedding_kwargs) self.time_emb_dim = self.time_emb.embed_dim @@ -123,13 +122,14 @@ def get_config(self): config = { "subnet": self.subnet, "sigma": self.sigma, - "embedding_kwargs": self.embedding_kwargs, + "time_emb": self.time_emb, "concatenate_subnet_input": self._concatenate_subnet_input, } return base_config | serialize(config) - def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs): + @staticmethod + def _discretize_time(num_steps: int, rho: float = 3.5): t = keras.ops.linspace(0.0, pi / 2, num_steps) times = keras.ops.exp((t - pi / 2) * rho) * pi / 2 times = keras.ops.concatenate([keras.ops.zeros((1,)), times[1:]], axis=0) From 034b2c4cf6b1fc69175c406440f6e5eb806d68ee Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 17:17:40 +0100 Subject: [PATCH 062/101] correct rk45 and add tsit5 --- bayesflow/utils/integrate.py | 253 ++++++++++++++++++++++++----- tests/test_utils/test_integrate.py | 28 ++++ 2 files changed, 238 insertions(+), 43 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b197ea975..a9bd6ea3f 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -29,6 +29,7 @@ def euler_step( k1 = fn(time, **filter_kwargs(state, fn)) if use_adaptive_step_size: + # Use Heun's method (RK2) as embedded pair for proper error estimation intermediate_state = state.copy() for key, delta in k1.items(): intermediate_state[key] = state[key] + step_size * delta @@ -39,18 +40,23 @@ def euler_step( if set(k1.keys()) != set(k2.keys()): raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - # compute next step size - intermediate_error = keras.ops.stack([keras.ops.norm(k2[key] - k1[key], ord=2, axis=-1) for key in k1]) - new_step_size = step_size * tolerance / (intermediate_error + 1e-9) + # Heun's (RK2) solution + heun_state = state.copy() + for key in k1.keys(): + heun_state[key] = state[key] + 0.5 * step_size * (k1[key] + k2[key]) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + # Error estimate: difference between Euler and Heun + intermediate_error = keras.ops.stack( + [keras.ops.norm(heun_state[key] - intermediate_state[key], ord=2, axis=-1) for key in k1] + ) - # consolidate step size - new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) + max_error = keras.ops.max(intermediate_error) + new_step_size = step_size * keras.ops.sqrt(tolerance / (max_error + 1e-9)) + + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) else: new_step_size = step_size - # apply updates new_state = state.copy() for key in k1.keys(): new_state[key] = state[key] + step_size * k1[key] @@ -60,6 +66,16 @@ def euler_step( return new_state, new_time, new_step_size +def add_scaled(state, ks, coeffs, h): + out = {} + for key, y in state.items(): + acc = keras.ops.zeros_like(y) + for c, k in zip(coeffs, ks): + acc = acc + c * k[key] + out[key] = y + h * acc + return out + + def rk45_step( fn: Callable, state: dict[str, ArrayLike], @@ -70,57 +86,151 @@ def rk45_step( max_step_size: ArrayLike = float("inf"), use_adaptive_step_size: bool = False, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + """ + Dormand-Prince 5(4) method with embedded error estimation. + """ step_size = last_step_size + h = step_size k1 = fn(time, **filter_kwargs(state, fn)) + k2 = fn(time + h * (1 / 5), **add_scaled(state, [k1], [1 / 5], h)) + k3 = fn(time + h * (3 / 10), **add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h)) + k4 = fn(time + h * (4 / 5), **add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h)) + k5 = fn( + time + h * (8 / 9), + **add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), + ) + k6 = fn( + time + h, + **add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), + ) - intermediate_state = state.copy() - for key, delta in k1.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + # check all keys are equal + if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5, k6]): + raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - k2 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + # 5th order solution + new_state = {} + for key in k1.keys(): + new_state[key] = state[key] + h * ( + 35 / 384 * k1[key] + 500 / 1113 * k3[key] + 125 / 192 * k4[key] - 2187 / 6784 * k5[key] + 11 / 84 * k6[key] + ) - intermediate_state = state.copy() - for key, delta in k2.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta + if use_adaptive_step_size: + k7 = fn(time + h, **filter_kwargs(new_state, fn)) + + # 4th order embedded solution + err_state = {} + for key in k1.keys(): + y4 = state[key] + h * ( + 5179 / 57600 * k1[key] + + 7571 / 16695 * k3[key] + + 393 / 640 * k4[key] + - 92097 / 339200 * k5[key] + + 187 / 2100 * k6[key] + + 1 / 40 * k7[key] + ) + err_state[key] = new_state[key] - y4 - k3 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) - intermediate_state = state.copy() - for key, delta in k3.items(): - intermediate_state[key] = state[key] + step_size * delta + new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + else: + new_step_size = step_size - k4 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) + new_time = time + h + return new_state, new_time, new_step_size - if use_adaptive_step_size: - intermediate_state = state.copy() - for key, delta in k4.items(): - intermediate_state[key] = state[key] + 0.5 * step_size * delta - k5 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn)) +def tsit5_step( + fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + last_step_size: ArrayLike, + tolerance: ArrayLike = 1e-6, + min_step_size: ArrayLike = -float("inf"), + max_step_size: ArrayLike = float("inf"), + use_adaptive_step_size: bool = False, +): + """ + Implements a single step of the Tsitouras 5/4 Runge-Kutta method. + """ + step_size = last_step_size + h = step_size - # check all keys are equal - if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5]): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") + # Butcher tableau coefficients + c2 = 0.161 + c3 = 0.327 + c4 = 0.9 + c5 = 0.9800255409045097 - # compute next step size - intermediate_error = keras.ops.stack([keras.ops.norm(k5[key] - k4[key], ord=2, axis=-1) for key in k5.keys()]) - new_step_size = step_size * tolerance / (intermediate_error + 1e-9) + k1 = fn(time, **filter_kwargs(state, fn)) + k2 = fn(time + h * c2, **add_scaled(state, [k1], [0.161], h)) + k3 = fn(time + h * c3, **add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h)) + k4 = fn( + time + h * c4, **add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h) + ) + k5 = fn( + time + h * c5, + **add_scaled( + state, [k1, k2, k3, k4], [4.325279681768730, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h + ), + ) + k6 = fn( + time + h, + **add_scaled( + state, + [k1, k2, k3, k4, k5], + [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838], + h, + ), + ) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + # 5th order solution: b coefficients + new_state = {} + for key in state.keys(): + new_state[key] = state[key] + h * ( + 0.09646076681806523 * k1[key] + + 0.01 * k2[key] + + 0.4798896504144996 * k3[key] + + 1.379008574103742 * k4[key] + - 3.290069515436081 * k5[key] + + 2.324710524099774 * k6[key] + ) - # consolidate step size - new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) - else: - new_step_size = step_size + if use_adaptive_step_size: + # 7th stage evaluation + k7 = fn(time + h, **filter_kwargs(new_state, fn)) + + # 4th order embedded solution: b_hat coefficients + y4 = {} + for key in state.keys(): + y4[key] = state[key] + h * ( + 0.001780011052226 * k1[key] + + 0.000816434459657 * k2[key] + - 0.007880878010262 * k3[key] + + 0.144711007173263 * k4[key] + - 0.582357165452555 * k5[key] + + 0.458082105929187 * k6[key] + + (1.0 / 66.0) * k7[key] + ) - # apply updates - new_state = state.copy() - for key in k1.keys(): - new_state[key] = state[key] + (step_size / 6.0) * (k1[key] + 2.0 * k2[key] + 2.0 * k3[key] + k4[key]) + # Error estimate + err_state = {} + for key in state.keys(): + err_state[key] = new_state[key] - y4[key] - new_time = time + step_size + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) + new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + else: + new_step_size = h + + new_time = time + h return new_state, new_time, new_step_size @@ -141,6 +251,8 @@ def integrate_fixed( step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -180,6 +292,8 @@ def integrate_adaptive( step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -249,6 +363,8 @@ def integrate_scheduled( step_fn = euler_step case "rk45": step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -401,11 +517,19 @@ def integrate_stochastic( steps: int, seed: keras.random.SeedGenerator, method: str = "euler_maruyama", + score_fn: Callable = None, + corrector_steps: int = 0, + noise_schedule=None, + step_size_factor: float = 0.1, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ Integrates a stochastic differential equation from start_time to stop_time. + When score_fn is provided, performs predictor-corrector sampling where: + - Predictor: reverse diffusion SDE solver + - Corrector: annealed Langevin dynamics with step size e = sqrt(dim) + Args: drift_fn: Function that computes the drift term. diffusion_fn: Function that computes the diffusion term. @@ -415,11 +539,15 @@ def integrate_stochastic( steps: Number of integration steps. seed: Random seed for noise generation. method: Integration method to use, e.g., 'euler_maruyama'. + score_fn: Optional score function for predictor-corrector sampling. + Should take (time, **state) and return score dict. + corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. + step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. Returns: - If return_noise is False, returns the final state dictionary. - If return_noise is True, returns a tuple of (final_state, noise_history). + Final state dictionary after integration. """ if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -438,17 +566,56 @@ def integrate_stochastic( step_size = (stop_time - start_time) / steps sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) - # Pre-generate noise history: shape = (steps, *state_shape) + # Pre-generate noise history for predictor: shape = (steps, *state_shape) noise_history = {} for key, val in state.items(): noise_history[key] = ( keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt ) + # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) + corrector_noise_history = {} + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + + for key, val in state.items(): + corrector_noise_history[key] = keras.random.normal( + (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed + ) + def body(_loop_var, _loop_state): _current_state, _current_time = _loop_state _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} + + # Predictor step new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) + + # Corrector steps: annealed Langevin dynamics if score_fn is provided + if corrector_steps > 0: + for corrector_step in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + + # Compute noise schedule components for corrector step size + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector + # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 + for k in new_state.keys(): + if k in score: + z_norm = keras.ops.norm(_corrector_noise[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + + # Prevent division by zero + score_norm = keras.ops.maximum(score_norm, 1e-8) + + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 + sqrt_2e = keras.ops.sqrt(2.0 * e) + + new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] + return new_state, new_time final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index db5c448d7..75e65ca1d 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -1,4 +1,11 @@ import numpy as np +import keras +import pytest +from bayesflow.utils import integrate + + +TOLERANCE_ADAPTIVE = 1e-6 # Adaptive solvers should be very accurate. +TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance def test_scheduled_integration(): @@ -34,3 +41,24 @@ def fn(t, x): scipy_kwargs={"atol": 1e-6, "rtol": 1e-6}, )["x"] np.testing.assert_allclose(exact_result, result, atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize( + "method, atol", [("euler", TOLERANCE_EULER), ("rk45", TOLERANCE_ADAPTIVE), ("tsit5", TOLERANCE_ADAPTIVE)] +) +def test_analytical_integration(method, atol): + def fn(t, x): + return {"x": keras.ops.convert_to_tensor([2.0 * t])} + + initial_state = {"x": keras.ops.convert_to_tensor([1.0])} + T_final = 2.0 + num_steps = 100 + analytical_result = 1.0 + T_final**2 + + result = integrate(fn, initial_state, start_time=0.0, stop_time=T_final, steps=num_steps, method=method)["x"] + result_adaptive = integrate( + fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 + )["x"] + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.01) From 3cae88379fe4d05a5ef741ea1fd591d19378e231 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 17:19:58 +0100 Subject: [PATCH 063/101] add predictor corrector --- .../diffusion_model/diffusion_model.py | 76 ++++++++++++++++--- 1 file changed, 66 insertions(+), 10 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index ca8a634e9..bc77a884d 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -243,6 +243,55 @@ def _apply_subnet( else: return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training) + def score( + self, + xz: Tensor, + time: float | Tensor = None, + log_snr_t: Tensor = None, + conditions: Tensor = None, + training: bool = False, + ) -> Tensor: + """ + Computes the score of the target or latent variable `xz`. + + Parameters + ---------- + xz : Tensor + The current state of the latent variable `z`, typically of shape (..., D), + where D is the dimensionality of the latent space. + time : float or Tensor + Scalar or tensor representing the time (or noise level) at which the velocity + should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. + log_snr_t : Tensor + The log signal-to-noise ratio at time `t`. If None, time must be provided. + conditions : Tensor, optional + Conditional inputs to the network, such as conditioning variables + or encoder outputs. Shape must be broadcastable with `xz`. Default is None. + training : bool, optional + Whether the model is in training mode. Affects behavior of dropout, batch norm, + or other stochastic layers. Default is False. + + Returns + ------- + Tensor + The velocity tensor of the same shape as `xz`, representing the right-hand + side of the SDE or ODE at the given `time`. + """ + if log_snr_t is None: + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + subnet_out = self._apply_subnet( + xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) + + x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) + + score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + return score + def velocity( self, xz: Tensor, @@ -279,19 +328,10 @@ def velocity( The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ - # calculate the current noise level and transform into correct shape log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - subnet_out = self._apply_subnet( - xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training - ) - pred = self.output_projector(subnet_out, training=training) - - x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) - score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -447,9 +487,25 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + score_fn = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, From 39682b162a5562c7dd98e1c8813b39c8f570b4a6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 18:09:47 +0100 Subject: [PATCH 064/101] add adaptive sampler SDE --- bayesflow/utils/integrate.py | 227 +++++++++++++++++++++++------ tests/test_utils/test_integrate.py | 108 +++++++++++++- 2 files changed, 288 insertions(+), 47 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a9bd6ea3f..86d94c1e1 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Sequence +from typing import Dict, Tuple, Optional from functools import partial import keras @@ -505,24 +506,115 @@ def euler_maruyama_step( base = base + diffusion[key] * noise[key] new_state[key] = base - return new_state, time + step_size + return new_state, time + step_size, step_size + + +def shark_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: Dict[str, ArrayLike], + time: ArrayLike, + step_size: ArrayLike, + noise: Dict[str, ArrayLike], + use_adaptive_step_size: bool = False, + tolerance: ArrayLike = 1e-3, + min_step_size: ArrayLike = 1e-6, + max_step_size: ArrayLike = float("inf"), + half_noises: Optional[Tuple[Dict[str, ArrayLike], Dict[str, ArrayLike]]] = None, + bridge_aux: Optional[Dict[str, ArrayLike]] = None, + validate_split: bool = True, +) -> Tuple[Dict[str, ArrayLike], ArrayLike] | Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: + """ + Shifted Additive-noise Runge-Kutta method for additive SDEs. + """ + h = step_size + t = time + + # full step: midpoint drift, diffusion at midpoint time + k1 = drift_fn(t, **filter_kwargs(state, drift_fn)) + mid_state = {k: state[k] + 0.5 * h * k1[k] for k in state} + k2 = drift_fn(t + 0.5 * h, **filter_kwargs(mid_state, drift_fn)) + g_mid = diffusion_fn(t + 0.5 * h, **filter_kwargs(state, diffusion_fn)) + + det_full = {k: state[k] + h * k2[k] for k in state} + sto_full = {k: g_mid[k] * noise[k] for k in g_mid} + y_full = {k: det_full[k] + sto_full.get(k, keras.ops.zeros_like(det_full[k])) for k in det_full} + + if not use_adaptive_step_size: + return y_full, t + h, h + + # prepare two half step noises without drawing randomness here + if half_noises is not None: + dW1, dW2 = half_noises + if set(dW1.keys()) != set(noise.keys()) or set(dW2.keys()) != set(noise.keys()): + raise ValueError("half_noises must have the same keys as noise") + if validate_split: + sum_diff = {k: dW1[k] + dW2[k] - noise[k] for k in noise} + parts = [] + for v in sum_diff.values(): + if not hasattr(v, "shape") or len(v.shape) == 0: + v = keras.ops.reshape(v, (1,)) + parts.append(keras.ops.norm(v, ord=2, axis=-1)) + if float(keras.ops.max(keras.ops.stack(parts))) > 1e-6: + raise ValueError("half_noises do not sum to provided noise") + else: + if bridge_aux is None: + raise ValueError("Provide either half_noises or bridge_aux when use_adaptive_step_size is True") + if set(bridge_aux.keys()) != set(noise.keys()): + raise ValueError("bridge_aux must have the same keys as noise") + sqrt_h = keras.ops.sqrt(h + 1e-12) + dW1 = {k: 0.5 * noise[k] + 0.5 * sqrt_h * bridge_aux[k] for k in noise} + dW2 = {k: noise[k] - dW1[k] for k in noise} + + half = 0.5 * h + + # first half step on [t, t + h 2] + k1h = drift_fn(t, **filter_kwargs(state, drift_fn)) + mid1 = {k: state[k] + 0.5 * half * k1h[k] for k in state} + k2h = drift_fn(t + 0.5 * half, **filter_kwargs(mid1, drift_fn)) + g_q1 = diffusion_fn(t + 0.5 * half, **filter_kwargs(state, diffusion_fn)) + y_half = {k: state[k] + half * k2h[k] + g_q1.get(k, 0) * dW1.get(k, 0) for k in state} + + # second half step on [t + h 2, t + h] + k1h2 = drift_fn(t + half, **filter_kwargs(y_half, drift_fn)) + mid2 = {k: y_half[k] + 0.5 * half * k1h2[k] for k in y_half} + k2h2 = drift_fn(t + 1.5 * half, **filter_kwargs(mid2, drift_fn)) + g_q2 = diffusion_fn(t + 1.5 * half, **filter_kwargs(state, diffusion_fn)) + y_twohalf = {k: y_half[k] + half * k2h2[k] + g_q2.get(k, 0) * dW2.get(k, 0) for k in y_half} + + # error estimate + parts = [] + for k in y_full: + v = y_full[k] - y_twohalf[k] + if not hasattr(v, "shape") or len(v.shape) == 0: + v = keras.ops.reshape(v, (1,)) + parts.append(keras.ops.norm(v, ord=2, axis=-1)) + err = keras.ops.max(keras.ops.stack(parts)) + + # controller for strong order one on additive noise local error ~ h^{3 2} + factor = 0.9 * (tolerance / (err + 1e-12)) ** (2.0 / 3.0) + h_new = keras.ops.clip(h * keras.ops.clip(factor, 0.2, 5.0), min_step_size, max_step_size) + + return y_full, t + h, h_new def integrate_stochastic( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: Dict[str, ArrayLike], start_time: ArrayLike, stop_time: ArrayLike, - steps: int, seed: keras.random.SeedGenerator, + min_steps: int = 10, + max_steps: int = 10_000, + steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, step_size_factor: float = 0.1, **kwargs, -) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: +) -> Union[Dict[str, ArrayLike], Tuple[Dict[str, ArrayLike], Dict[str, Sequence[ArrayLike]]]]: """ Integrates a stochastic differential equation from start_time to stop_time. @@ -535,88 +627,131 @@ def integrate_stochastic( diffusion_fn: Function that computes the diffusion term. state: Dictionary containing the initial state. start_time: Starting time for integration. - stop_time: Ending time for integration. - steps: Number of integration steps. + stop_time: Ending time for integration. steps: Number of integration steps. seed: Random seed for noise generation. - method: Integration method to use, e.g., 'euler_maruyama'. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. + steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. + method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. score_fn: Optional score function for predictor-corrector sampling. - Should take (time, **state) and return score dict. + Should take (time, **state) and return score dict. corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. - Returns: - Final state dictionary after integration. + Returns: Final state dictionary after integration. """ - if steps <= 0: - raise ValueError("Number of steps must be positive.") + use_adaptive = False + if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: + if start_time is None or stop_time is None: + raise ValueError( + "Please provide start_time and stop_time for the integration, was " + f"'start_time={start_time}', 'stop_time={stop_time}'." + ) + if min_steps <= 0 or max_steps <= 0: + raise ValueError("min_steps and max_steps must be positive.") + if max_steps < min_steps: + raise ValueError("max_steps must be greater or equal to min_steps.") + use_adaptive = True + loop_steps = max_steps + initial_step = (stop_time - start_time) / float(min_steps) + elif isinstance(steps, int): + if steps <= 0: + raise ValueError("Number of steps must be positive.") + use_adaptive = False + loop_steps = steps + initial_step = (stop_time - start_time) / float(steps) + else: + raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") - # Select step function based on method match method: case "euler_maruyama": step_fn = euler_maruyama_step + if use_adaptive: + raise ValueError("Adaptive step size is not supported for Euler Maruyama method.") + case "shark": + step_fn = shark_step case other: raise TypeError(f"Invalid integration method: {other!r}") - # Prepare step function with partial application step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs) - # Time step - step_size = (stop_time - start_time) / steps - sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) - - # Pre-generate noise history for predictor: shape = (steps, *state_shape) - noise_history = {} + # pre generate standard normals scale by sqrt(dt) inside the loop using the current dt + z_history = {} + bridge_history = {} for key, val in state.items(): - noise_history[key] = ( - keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt - ) + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + if method == "shark" and use_adaptive: + bridge_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) + # pre generate corrector noise if requested corrector_noise_history = {} if corrector_steps > 0: if score_fn is None or noise_schedule is None: raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") - for key, val in state.items(): + shape = keras.ops.shape(val) corrector_noise_history[key] = keras.random.normal( - (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed + (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed ) - def body(_loop_var, _loop_state): - _current_state, _current_time = _loop_state - _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} - - # Predictor step - new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) - - # Corrector steps: annealed Langevin dynamics if score_fn is provided + def body(_i, _loop_state): + _current_state, _current_time, _current_step = _loop_state + + # clamp last step to hit stop_time + remaining = stop_time - _current_time + dt = keras.ops.minimum(_current_step, remaining) + + sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) + _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + + if method == "shark" and use_adaptive: + _bridge = {k: bridge_history[k][_i] for k in _current_state.keys()} + out = step_fn( + state=_current_state, + time=_current_time, + step_size=dt, + noise=_noise_i, + bridge_aux=_bridge, + use_adaptive_step_size=True, + ) + new_state, new_time, new_step = out + else: + out = step_fn(state=_current_state, time=_current_time, step_size=dt, noise=_noise_i) + if isinstance(out, tuple) and len(out) == 2: + new_state, new_time = out + new_step = _current_step + else: + new_state, new_time, new_step = out + + # corrector if corrector_steps > 0: - for corrector_step in range(corrector_steps): + for j in range(corrector_steps): score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) - _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + _z_corr = {k: corrector_noise_history[k][_i, j] for k in new_state.keys()} - # Compute noise schedule components for corrector step size log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector - # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 for k in new_state.keys(): if k in score: - z_norm = keras.ops.norm(_corrector_noise[k], axis=-1, keepdims=True) + z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) - - # Prevent division by zero score_norm = keras.ops.maximum(score_norm, 1e-8) e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 - sqrt_2e = keras.ops.sqrt(2.0 * e) + new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] - new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] + return new_state, new_time, new_step - return new_state, new_time + final_state, final_time, last_step = keras.ops.fori_loop(0, loop_steps, body, (state, start_time, initial_step)) + + if use_adaptive and float(final_time) < float(stop_time): + logging.warning( + f"Reached max_steps={max_steps} before stop_time. " + f"final_time={float(final_time)} stop_time={float(stop_time)}" + ) - final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) return final_state diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 75e65ca1d..142a0bbb7 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -1,12 +1,17 @@ import numpy as np import keras import pytest -from bayesflow.utils import integrate +from bayesflow.utils import integrate, integrate_stochastic TOLERANCE_ADAPTIVE = 1e-6 # Adaptive solvers should be very accurate. TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance +# tolerances for SDE tests +TOL_MEAN = 3e-2 +TOL_VAR = 5e-2 +TOL_DET = 1e-3 + def test_scheduled_integration(): import keras @@ -62,3 +67,104 @@ def fn(t, x): np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.01) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("shark", False), + ("shark", True), + ], +) +def test_additive_OU_weak_means_and_vars(method, use_adapt): + """ + Ornstein Uhlenbeck with additive noise + dX = a X dt + sigma dW + Exact at time T: + E[X_T] = x0 * exp(a T) + Var[X_T] = sigma^2 * (exp(2 a T) - 1) / (2 a) + We verify weak accuracy by matching empirical mean and variance. + """ + # SDE parameters + a = -1.0 + sigma = 0.5 + x0 = 1.2 + T = 1.0 + + # batch of trajectories + N = 20000 # large enough to control sampling error + seed = keras.random.SeedGenerator(42) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # additive noise, independent of state + return {"x": keras.ops.convert_to_tensor([sigma])} + + initial_state = {"x": keras.ops.ones((N,)) * x0} + steps = 200 if not use_adapt else "adaptive" + + # expected mean and variance + exp_mean = x0 * np.exp(a * T) + exp_var = sigma**2 * (np.exp(2.0 * a * T) - 1.0) / (2.0 * a) + + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=0.0, + stop_time=T, + steps=steps, + seed=seed, + method=method, + ) + + xT = np.array(out["x"]) + emp_mean = float(xT.mean()) + emp_var = float(xT.var()) + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("shark", False), + ("shark", True), + ], +) +def test_zero_noise_reduces_to_deterministic(method, use_adapt): + """ + With zero diffusion the SDE reduces to the ODE + dX = a X dt + """ + a = 0.7 + x0 = 0.9 + T = 1.25 + steps = 200 if not use_adapt else "adaptive" + seed = keras.random.SeedGenerator(999) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # identically zero diffusion + return {"x": keras.ops.convert_to_tensor([0.0])} + + initial_state = {"x": keras.ops.ones((256,)) * x0} + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=0.0, + stop_time=T, + steps=steps, + seed=seed, + method=method, + )["x"] + + exact = x0 * np.exp(a * T) + np.testing.assert_allclose(np.array(out).mean(), exact, atol=TOL_DET, rtol=0.1) From 770abc72ea00af98226581382b2a13d6c3f34dda Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 18:24:53 +0100 Subject: [PATCH 065/101] add shark --- bayesflow/networks/diffusion_model/diffusion_model.py | 6 +++--- bayesflow/utils/integrate.py | 7 +------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index bc77a884d..659641f51 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -408,7 +408,7 @@ def _forward( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: raise ValueError("Stochastic methods are not supported for forward integration.") if density: @@ -458,7 +458,7 @@ def _inverse( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if density: - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): @@ -477,7 +477,7 @@ def deltas(time, xz): return x, log_density state = {"xz": z} - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: def deltas(time, xz): return { diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 86d94c1e1..d7935c1ee 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -717,14 +717,9 @@ def body(_i, _loop_state): bridge_aux=_bridge, use_adaptive_step_size=True, ) - new_state, new_time, new_step = out else: out = step_fn(state=_current_state, time=_current_time, step_size=dt, noise=_noise_i) - if isinstance(out, tuple) and len(out) == 2: - new_state, new_time = out - new_step = _current_step - else: - new_state, new_time, new_step = out + new_state, new_time, new_step = out # corrector if corrector_steps > 0: From eba68922c73bc699d1d1f1eecaf6215e3597e9a2 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 18:30:26 +0100 Subject: [PATCH 066/101] rm warn --- bayesflow/utils/integrate.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index d7935c1ee..ed35f9b71 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -742,11 +742,4 @@ def body(_i, _loop_state): return new_state, new_time, new_step final_state, final_time, last_step = keras.ops.fori_loop(0, loop_steps, body, (state, start_time, initial_step)) - - if use_adaptive and float(final_time) < float(stop_time): - logging.warning( - f"Reached max_steps={max_steps} before stop_time. " - f"final_time={float(final_time)} stop_time={float(stop_time)}" - ) - return final_state From e901b733c8610f53f59f4b69ad142b8ab9704eba Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 18:35:03 +0100 Subject: [PATCH 067/101] fix dt --- bayesflow/utils/integrate.py | 6 +++--- tests/test_utils/test_integrate.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index ed35f9b71..bd0a580ae 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -699,10 +699,10 @@ def integrate_stochastic( def body(_i, _loop_state): _current_state, _current_time, _current_step = _loop_state - - # clamp last step to hit stop_time remaining = stop_time - _current_time - dt = keras.ops.minimum(_current_step, remaining) + sign = keras.ops.sign(remaining) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + dt = sign * dt_mag sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 142a0bbb7..d328f9476 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -119,6 +119,7 @@ def diffusion_fn(t, x): steps=steps, seed=seed, method=method, + max_steps=1_000, ) xT = np.array(out["x"]) @@ -164,6 +165,7 @@ def diffusion_fn(t, x): steps=steps, seed=seed, method=method, + max_steps=1_000, )["x"] exact = x0 * np.exp(a * T) From e8be555ed14730786799e7a6d6a8f131d9cd6b36 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:06:16 +0100 Subject: [PATCH 068/101] fix adaptive step size --- bayesflow/utils/integrate.py | 44 ++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index bd0a580ae..9e61534c4 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -475,6 +475,8 @@ def euler_maruyama_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], + min_step_size: ArrayLike = None, + max_step_size: ArrayLike = None, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -486,6 +488,8 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + min_step_size: Minimum allowed step size (not used here). + max_step_size: Maximum allowed step size (not used here). Returns: new_state: Updated state after one Euler-Maruyama step. @@ -516,19 +520,22 @@ def shark_step( time: ArrayLike, step_size: ArrayLike, noise: Dict[str, ArrayLike], + min_step_size: ArrayLike, + max_step_size: ArrayLike, use_adaptive_step_size: bool = False, tolerance: ArrayLike = 1e-3, - min_step_size: ArrayLike = 1e-6, - max_step_size: ArrayLike = float("inf"), half_noises: Optional[Tuple[Dict[str, ArrayLike], Dict[str, ArrayLike]]] = None, bridge_aux: Optional[Dict[str, ArrayLike]] = None, validate_split: bool = True, -) -> Tuple[Dict[str, ArrayLike], ArrayLike] | Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: +) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: """ - Shifted Additive-noise Runge-Kutta method for additive SDEs. + Shifted Additive noise Runge Kutta for additive SDEs. """ + # direction aware handling h = step_size t = time + h_sign = keras.ops.sign(h) + h_mag = keras.ops.abs(h) # full step: midpoint drift, diffusion at midpoint time k1 = drift_fn(t, **filter_kwargs(state, drift_fn)) @@ -562,20 +569,20 @@ def shark_step( raise ValueError("Provide either half_noises or bridge_aux when use_adaptive_step_size is True") if set(bridge_aux.keys()) != set(noise.keys()): raise ValueError("bridge_aux must have the same keys as noise") - sqrt_h = keras.ops.sqrt(h + 1e-12) + sqrt_h = keras.ops.sqrt(h_mag + 1e-12) # use magnitude dW1 = {k: 0.5 * noise[k] + 0.5 * sqrt_h * bridge_aux[k] for k in noise} dW2 = {k: noise[k] - dW1[k] for k in noise} half = 0.5 * h - # first half step on [t, t + h 2] + # first half step k1h = drift_fn(t, **filter_kwargs(state, drift_fn)) mid1 = {k: state[k] + 0.5 * half * k1h[k] for k in state} k2h = drift_fn(t + 0.5 * half, **filter_kwargs(mid1, drift_fn)) g_q1 = diffusion_fn(t + 0.5 * half, **filter_kwargs(state, diffusion_fn)) y_half = {k: state[k] + half * k2h[k] + g_q1.get(k, 0) * dW1.get(k, 0) for k in state} - # second half step on [t + h 2, t + h] + # second half step k1h2 = drift_fn(t + half, **filter_kwargs(y_half, drift_fn)) mid2 = {k: y_half[k] + 0.5 * half * k1h2[k] for k in y_half} k2h2 = drift_fn(t + 1.5 * half, **filter_kwargs(mid2, drift_fn)) @@ -591,9 +598,14 @@ def shark_step( parts.append(keras.ops.norm(v, ord=2, axis=-1)) err = keras.ops.max(keras.ops.stack(parts)) - # controller for strong order one on additive noise local error ~ h^{3 2} + # controller for strong order one on additive noise factor = 0.9 * (tolerance / (err + 1e-12)) ** (2.0 / 3.0) - h_new = keras.ops.clip(h * keras.ops.clip(factor, 0.2, 5.0), min_step_size, max_step_size) + h_prop = h * keras.ops.clip(factor, 0.2, 5.0) + + # clip by magnitude bounds then restore original sign + mag = keras.ops.abs(h_prop) + mag_new = keras.ops.clip(mag, min_step_size, max_step_size) + h_new = h_sign * mag_new return y_full, t + h, h_new @@ -656,12 +668,17 @@ def integrate_stochastic( use_adaptive = True loop_steps = max_steps initial_step = (stop_time - start_time) / float(min_steps) + + span_mag = keras.ops.abs(stop_time - start_time) + min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) + max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) elif isinstance(steps, int): if steps <= 0: raise ValueError("Number of steps must be positive.") use_adaptive = False loop_steps = steps initial_step = (stop_time - start_time) / float(steps) + min_step_size, max_step_size = initial_step, initial_step else: raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") @@ -675,7 +692,14 @@ def integrate_stochastic( case other: raise TypeError(f"Invalid integration method: {other!r}") - step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs) + step_fn = partial( + step_fn, + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + min_step_size=min_step_size, + max_step_size=max_step_size, + **kwargs, + ) # pre generate standard normals scale by sqrt(dt) inside the loop using the current dt z_history = {} From 36a16b35c5d3ee04652252ea2d902585cfedc2d4 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:31:32 +0100 Subject: [PATCH 069/101] refactor stochastic integrator --- bayesflow/utils/integrate.py | 307 ++++++++++++++++++++++++++++++++++- 1 file changed, 302 insertions(+), 5 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 9e61534c4..428e4e3b5 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -15,6 +15,7 @@ from . import logging ArrayLike = int | float | Tensor +StateDict = Dict[str, ArrayLike] def euler_step( @@ -475,6 +476,7 @@ def euler_maruyama_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], + use_adaptive_step_size: bool = False, min_step_size: ArrayLike = None, max_step_size: ArrayLike = None, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): @@ -488,6 +490,7 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + use_adaptive_step_size: Whether to use adaptive step sizing (not used here). min_step_size: Minimum allowed step size (not used here). max_step_size: Maximum allowed step size (not used here). @@ -610,21 +613,21 @@ def shark_step( return y_full, t + h, h_new -def integrate_stochastic( +def integrate_stochastic_old( drift_fn: Callable, diffusion_fn: Callable, state: Dict[str, ArrayLike], start_time: ArrayLike, stop_time: ArrayLike, seed: keras.random.SeedGenerator, - min_steps: int = 10, - max_steps: int = 10_000, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, step_size_factor: float = 0.1, + min_steps: int = 10, + max_steps: int = 1_000, **kwargs, ) -> Union[Dict[str, ArrayLike], Tuple[Dict[str, ArrayLike], Dict[str, Sequence[ArrayLike]]]]: """ @@ -641,8 +644,6 @@ def integrate_stochastic( start_time: Starting time for integration. stop_time: Ending time for integration. steps: Number of integration steps. seed: Random seed for noise generation. - min_steps: Minimum number of steps for adaptive integration. - max_steps: Maximum number of steps for adaptive integration. steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. score_fn: Optional score function for predictor-corrector sampling. @@ -650,6 +651,8 @@ def integrate_stochastic( corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. step_size_factor: Scaling factor for corrector step size. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. **kwargs: Additional arguments to pass to the step function. Returns: Final state dictionary after integration. @@ -767,3 +770,297 @@ def body(_i, _loop_state): final_state, final_time, last_step = keras.ops.fori_loop(0, loop_steps, body, (state, start_time, initial_step)) return final_state + + +def _apply_corrector( + new_state: StateDict, + new_time: ArrayLike, + i: ArrayLike, + corrector_steps: int, + score_fn: Optional[Callable], + step_size_factor: float, + corrector_noise_history: Dict[str, ArrayLike], + noise_schedule=None, +) -> StateDict: + """Helper function to apply corrector steps.""" + if corrector_steps <= 0: + return new_state + + # Ensures score_fn and noise_schedule are present if needed, though checked in integrate_stochastic + if score_fn is None or noise_schedule is None: + return new_state # Should not happen if checks are passed + + for j in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + _z_corr = {k: corrector_noise_history[k][i, j] for k in new_state.keys()} + + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + for k in new_state.keys(): + if k in score: + # Calculate required norms for Langevin step + z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + score_norm = keras.ops.maximum(score_norm, 1e-8) + + # Compute step size 'e' for the Langevin update + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 + + # Annealed Langevin Dynamics update + new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] + return new_state + + +def integrate_stochastic_fixed( + step_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + steps: int, + z_history: Dict[str, ArrayLike], + corrector_steps: int, + score_fn: Optional[Callable], + step_size_factor: float, + corrector_noise_history: Dict[str, ArrayLike], + noise_schedule=None, +) -> StateDict: + """ + Performs fixed-step SDE integration. + """ + initial_step = (stop_time - start_time) / float(steps) + + def body_fixed(_i, _loop_state): + _current_state, _current_time, _current_step = _loop_state + + # Determine step size: either the constant size or the remainder to reach stop_time + remaining = stop_time - _current_time + sign = keras.ops.sign(remaining) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + dt = sign * dt_mag + + # Generate noise increment scaled by sqrt(dt) + sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) + _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + + new_state, new_time, new_step = step_fn( + state=_current_state, + time=_current_time, + step_size=dt, + noise=_noise_i, + use_adaptive_step_size=False, + ) + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + return new_state, new_time, initial_step + + # Execute the fixed loop + final_state, final_time, _ = keras.ops.fori_loop(0, steps, body_fixed, (state, start_time, initial_step)) + return final_state + + +def integrate_stochastic_adaptive( + step_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + max_steps: int, + initial_step: ArrayLike, + z_history: Dict[str, ArrayLike], + bridge_history: Dict[str, ArrayLike], + corrector_steps: int, + score_fn: Optional[Callable], + step_size_factor: float, + corrector_noise_history: Dict[str, ArrayLike], + noise_schedule=None, +) -> StateDict: + """ + Performs adaptive-step SDE integration. + """ + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step) + + def cond(i, current_state, current_time, current_step): + # We use a small epsilon check for floating point equality + time_reached = keras.ops.all(keras.ops.isclose(current_time, stop_time)) + return keras.ops.logical_and(keras.ops.less(i, max_steps), keras.ops.logical_not(time_reached)) + + def body_adaptive(_i, _current_state, _current_time, _current_step): + # Step Size Control + remaining = stop_time - _current_time + sign = keras.ops.sign(remaining) + # Ensure the next step does not overshoot the stop_time + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + dt = sign * dt_mag + + sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) + _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + _bridge = {k: bridge_history[k][_i] for k in _current_state.keys()} + + new_state, new_time, new_step = step_fn( + state=_current_state, + time=_current_time, + step_size=dt, + noise=_noise_i, + bridge_aux=_bridge, + use_adaptive_step_size=True, + ) + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + + return _i + 1, new_state, new_time, new_step + + # Execute the adaptive loop + _, final_state, _, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + return final_state + + +def integrate_stochastic( + drift_fn: Callable, + diffusion_fn: Callable, + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + seed: keras.random.SeedGenerator, + steps: int | Literal["adaptive"] = 100, + method: str = "euler_maruyama", + score_fn: Callable = None, + corrector_steps: int = 0, + noise_schedule=None, + step_size_factor: float = 0.1, + min_steps: int = 10, + max_steps: int = 10_000, + **kwargs, +) -> StateDict: + """ + Integrates a stochastic differential equation from start_time to stop_time. + + Dispatches to fixed-step or adaptive-step integration logic. + + Args: + drift_fn: Function that computes the drift term. + diffusion_fn: Function that computes the diffusion term. + state: Dictionary containing the initial state. + start_time: Starting time for integration. + stop_time: Ending time for integration. steps: Number of integration steps. + seed: Random seed for noise generation. + steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. + method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. + score_fn: Optional score function for predictor-corrector sampling. + corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing alpha_t in corrector. + step_size_factor: Scaling factor for corrector step size. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. + **kwargs: Additional arguments to pass to the step function. + + Returns: Final state dictionary after integration. + """ + is_adaptive = isinstance(steps, str) and steps in ["adaptive", "dynamic"] + if is_adaptive: + if start_time is None or stop_time is None: + raise ValueError("Please provide start_time and stop_time for adaptive integration.") + if min_steps <= 0 or max_steps <= 0 or max_steps < min_steps: + raise ValueError("min_steps and max_steps must be positive, and max_steps >= min_steps.") + if method != "shark": + raise ValueError("Adaptive step size is only supported for the 'shark' method.") + + loop_steps = max_steps + initial_step = (stop_time - start_time) / float(min_steps) + span_mag = keras.ops.abs(stop_time - start_time) + min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) + max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) + else: + if steps <= 0: + raise ValueError("Number of steps must be positive.") + loop_steps = int(steps) + initial_step = (stop_time - start_time) / float(loop_steps) + # For fixed step, min/max step size are just the fixed step size + min_step_size, max_step_size = initial_step, initial_step + + match method: + case "euler_maruyama": + step_fn_raw = euler_maruyama_step + case "shark": + step_fn_raw = shark_step + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + # Partial the step function with common arguments + step_fn = partial( + step_fn_raw, + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + min_step_size=min_step_size, + max_step_size=max_step_size, + **kwargs, + ) + + # Pre-generate standard normals for the predictor step (up to max_steps) + z_history = {} + bridge_history = {} + for key, val in state.items(): + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + if is_adaptive and method == "shark": + # Only required for SHARK adaptive step (Brownian Bridge aux noise) + bridge_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + + # Pre-generate corrector noise if requested + corrector_noise_history = {} + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + for key, val in state.items(): + shape = keras.ops.shape(val) + corrector_noise_history[key] = keras.random.normal( + (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed + ) + + if is_adaptive: + return integrate_stochastic_adaptive( + step_fn=step_fn, + state=state, + start_time=start_time, + stop_time=stop_time, + max_steps=max_steps, + initial_step=initial_step, + z_history=z_history, + bridge_history=bridge_history, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + else: + return integrate_stochastic_fixed( + step_fn=step_fn, + state=state, + start_time=start_time, + stop_time=stop_time, + steps=loop_steps, + z_history=z_history, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) From de57eaf100106ae525091e79ff5248e93fca3099 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:36:06 +0100 Subject: [PATCH 070/101] refactor stochastic integrator --- bayesflow/utils/integrate.py | 177 ++--------------------------------- 1 file changed, 10 insertions(+), 167 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 428e4e3b5..3f9cdee33 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -613,165 +613,6 @@ def shark_step( return y_full, t + h, h_new -def integrate_stochastic_old( - drift_fn: Callable, - diffusion_fn: Callable, - state: Dict[str, ArrayLike], - start_time: ArrayLike, - stop_time: ArrayLike, - seed: keras.random.SeedGenerator, - steps: int | Literal["adaptive"] = 100, - method: str = "euler_maruyama", - score_fn: Callable = None, - corrector_steps: int = 0, - noise_schedule=None, - step_size_factor: float = 0.1, - min_steps: int = 10, - max_steps: int = 1_000, - **kwargs, -) -> Union[Dict[str, ArrayLike], Tuple[Dict[str, ArrayLike], Dict[str, Sequence[ArrayLike]]]]: - """ - Integrates a stochastic differential equation from start_time to stop_time. - - When score_fn is provided, performs predictor-corrector sampling where: - - Predictor: reverse diffusion SDE solver - - Corrector: annealed Langevin dynamics with step size e = sqrt(dim) - - Args: - drift_fn: Function that computes the drift term. - diffusion_fn: Function that computes the diffusion term. - state: Dictionary containing the initial state. - start_time: Starting time for integration. - stop_time: Ending time for integration. steps: Number of integration steps. - seed: Random seed for noise generation. - steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. - method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. - score_fn: Optional score function for predictor-corrector sampling. - Should take (time, **state) and return score dict. - corrector_steps: Number of corrector steps to take after each predictor step. - noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. - step_size_factor: Scaling factor for corrector step size. - min_steps: Minimum number of steps for adaptive integration. - max_steps: Maximum number of steps for adaptive integration. - **kwargs: Additional arguments to pass to the step function. - - Returns: Final state dictionary after integration. - """ - use_adaptive = False - if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: - if start_time is None or stop_time is None: - raise ValueError( - "Please provide start_time and stop_time for the integration, was " - f"'start_time={start_time}', 'stop_time={stop_time}'." - ) - if min_steps <= 0 or max_steps <= 0: - raise ValueError("min_steps and max_steps must be positive.") - if max_steps < min_steps: - raise ValueError("max_steps must be greater or equal to min_steps.") - use_adaptive = True - loop_steps = max_steps - initial_step = (stop_time - start_time) / float(min_steps) - - span_mag = keras.ops.abs(stop_time - start_time) - min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) - max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) - elif isinstance(steps, int): - if steps <= 0: - raise ValueError("Number of steps must be positive.") - use_adaptive = False - loop_steps = steps - initial_step = (stop_time - start_time) / float(steps) - min_step_size, max_step_size = initial_step, initial_step - else: - raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") - - match method: - case "euler_maruyama": - step_fn = euler_maruyama_step - if use_adaptive: - raise ValueError("Adaptive step size is not supported for Euler Maruyama method.") - case "shark": - step_fn = shark_step - case other: - raise TypeError(f"Invalid integration method: {other!r}") - - step_fn = partial( - step_fn, - drift_fn=drift_fn, - diffusion_fn=diffusion_fn, - min_step_size=min_step_size, - max_step_size=max_step_size, - **kwargs, - ) - - # pre generate standard normals scale by sqrt(dt) inside the loop using the current dt - z_history = {} - bridge_history = {} - for key, val in state.items(): - shape = keras.ops.shape(val) - z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - if method == "shark" and use_adaptive: - bridge_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - - # pre generate corrector noise if requested - corrector_noise_history = {} - if corrector_steps > 0: - if score_fn is None or noise_schedule is None: - raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") - for key, val in state.items(): - shape = keras.ops.shape(val) - corrector_noise_history[key] = keras.random.normal( - (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed - ) - - def body(_i, _loop_state): - _current_state, _current_time, _current_step = _loop_state - remaining = stop_time - _current_time - sign = keras.ops.sign(remaining) - dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) - dt = sign * dt_mag - - sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) - _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} - - if method == "shark" and use_adaptive: - _bridge = {k: bridge_history[k][_i] for k in _current_state.keys()} - out = step_fn( - state=_current_state, - time=_current_time, - step_size=dt, - noise=_noise_i, - bridge_aux=_bridge, - use_adaptive_step_size=True, - ) - else: - out = step_fn(state=_current_state, time=_current_time, step_size=dt, noise=_noise_i) - new_state, new_time, new_step = out - - # corrector - if corrector_steps > 0: - for j in range(corrector_steps): - score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) - _z_corr = {k: corrector_noise_history[k][_i, j] for k in new_state.keys()} - - log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) - alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - for k in new_state.keys(): - if k in score: - z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) - score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) - score_norm = keras.ops.maximum(score_norm, 1e-8) - - e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 - new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] - - return new_state, new_time, new_step - - final_state, final_time, last_step = keras.ops.fori_loop(0, loop_steps, body, (state, start_time, initial_step)) - return final_state - - def _apply_corrector( new_state: StateDict, new_time: ArrayLike, @@ -886,20 +727,21 @@ def integrate_stochastic_adaptive( """ Performs adaptive-step SDE integration. """ - initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step) + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0) def cond(i, current_state, current_time, current_step): # We use a small epsilon check for floating point equality time_reached = keras.ops.all(keras.ops.isclose(current_time, stop_time)) return keras.ops.logical_and(keras.ops.less(i, max_steps), keras.ops.logical_not(time_reached)) - def body_adaptive(_i, _current_state, _current_time, _current_step): + def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): # Step Size Control remaining = stop_time - _current_time sign = keras.ops.sign(remaining) # Ensure the next step does not overshoot the stop_time dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) dt = sign * dt_mag + _counter += 1 sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} @@ -925,10 +767,11 @@ def body_adaptive(_i, _current_state, _current_time, _current_step): corrector_noise_history=corrector_noise_history, ) - return _i + 1, new_state, new_time, new_step + return _i + 1, new_state, new_time, new_step, _counter # Execute the adaptive loop - _, final_state, _, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + _, final_state, _, counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + logging.debug("Finished integration after {} steps.", counter) return final_state @@ -941,12 +784,12 @@ def integrate_stochastic( seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", + min_steps: int = 10, + max_steps: int = 10_000, score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, step_size_factor: float = 0.1, - min_steps: int = 10, - max_steps: int = 10_000, **kwargs, ) -> StateDict: """ @@ -963,12 +806,12 @@ def integrate_stochastic( seed: Random seed for noise generation. steps: Number of steps or 'adaptive' for adaptive step sizing. Only 'shark' method supports adaptive steps. method: Integration method to use, e.g., 'euler_maruyama' or 'shark'. + min_steps: Minimum number of steps for adaptive integration. + max_steps: Maximum number of steps for adaptive integration. score_fn: Optional score function for predictor-corrector sampling. corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing alpha_t in corrector. step_size_factor: Scaling factor for corrector step size. - min_steps: Minimum number of steps for adaptive integration. - max_steps: Maximum number of steps for adaptive integration. **kwargs: Additional arguments to pass to the step function. Returns: Final state dictionary after integration. From dde5451e5026636a0657530b0264ac72dee6dfe1 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:38:25 +0100 Subject: [PATCH 071/101] refactor stochastic integrator --- bayesflow/utils/integrate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 3f9cdee33..0249bf131 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -729,7 +729,7 @@ def integrate_stochastic_adaptive( """ initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0) - def cond(i, current_state, current_time, current_step): + def cond(i, current_state, current_time, current_step, counter): # We use a small epsilon check for floating point equality time_reached = keras.ops.all(keras.ops.isclose(current_time, stop_time)) return keras.ops.logical_and(keras.ops.less(i, max_steps), keras.ops.logical_not(time_reached)) @@ -770,8 +770,8 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): return _i + 1, new_state, new_time, new_step, _counter # Execute the adaptive loop - _, final_state, _, counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) - logging.debug("Finished integration after {} steps.", counter) + _, final_state, _, final_counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + logging.debug("Finished integration after {} steps.", final_counter) return final_state From 3d2c80ea125ee84fc9090295448835861ba279f2 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 24 Nov 2025 19:44:14 +0100 Subject: [PATCH 072/101] fix adaptive --- bayesflow/utils/integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 0249bf131..9d4683dc7 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -770,7 +770,7 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): return _i + 1, new_state, new_time, new_step, _counter # Execute the adaptive loop - _, final_state, _, final_counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + _, final_state, _, _, final_counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) logging.debug("Finished integration after {} steps.", final_counter) return final_state From ed5e89fd1706eee0e3b5636d5a362d3707dc46e0 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 25 Nov 2025 17:16:40 +0100 Subject: [PATCH 073/101] fix Tsit5 --- bayesflow/utils/integrate.py | 58 +++++++++++------------------- tests/test_utils/test_integrate.py | 2 +- 2 files changed, 21 insertions(+), 39 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 9d4683dc7..6ae7ed0b8 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -31,30 +31,19 @@ def euler_step( k1 = fn(time, **filter_kwargs(state, fn)) if use_adaptive_step_size: - # Use Heun's method (RK2) as embedded pair for proper error estimation - intermediate_state = state.copy() - for key, delta in k1.items(): - intermediate_state[key] = state[key] + step_size * delta + # Euler step + y_euler = {k: state[k] + step_size * k1[k] for k in state} - k2 = fn(time + step_size, **filter_kwargs(intermediate_state, fn)) + # Heun slope + k2 = fn(time + step_size, **filter_kwargs(y_euler, fn)) - # check all keys are equal - if set(k1.keys()) != set(k2.keys()): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") + # error = (h/2) (k2 - k1) + err_state = {k: 0.5 * step_size * (k2[k] - k1[k]) for k in state} - # Heun's (RK2) solution - heun_state = state.copy() - for key in k1.keys(): - heun_state[key] = state[key] + 0.5 * step_size * (k1[key] + k2[key]) - - # Error estimate: difference between Euler and Heun - intermediate_error = keras.ops.stack( - [keras.ops.norm(heun_state[key] - intermediate_state[key], ord=2, axis=-1) for key in k1] - ) - - max_error = keras.ops.max(intermediate_error) - new_step_size = step_size * keras.ops.sqrt(tolerance / (max_error + 1e-9)) + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) + new_step_size = step_size * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.5, 0.2, 5.0) new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) else: new_step_size = step_size @@ -177,7 +166,7 @@ def tsit5_step( k5 = fn( time + h * c5, **add_scaled( - state, [k1, k2, k3, k4], [4.325279681768730, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h + state, [k1, k2, k3, k4], [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h ), ) k6 = fn( @@ -203,26 +192,19 @@ def tsit5_step( ) if use_adaptive_step_size: - # 7th stage evaluation k7 = fn(time + h, **filter_kwargs(new_state, fn)) - # 4th order embedded solution: b_hat coefficients - y4 = {} - for key in state.keys(): - y4[key] = state[key] + h * ( - 0.001780011052226 * k1[key] - + 0.000816434459657 * k2[key] - - 0.007880878010262 * k3[key] - + 0.144711007173263 * k4[key] - - 0.582357165452555 * k5[key] - + 0.458082105929187 * k6[key] - + (1.0 / 66.0) * k7[key] - ) - - # Error estimate err_state = {} for key in state.keys(): - err_state[key] = new_state[key] - y4[key] + err_state[key] = h * ( + -0.00178001105222577714 * k1[key] + - 0.0008164344596567469 * k2[key] + + 0.007880878010261995 * k3[key] + - 0.1447110071732629 * k4[key] + + 0.5823571654525552 * k5[key] + - 0.45808210592918697 * k6[key] + + 0.015151515151515152 * k7[key] + ) err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) err = keras.ops.max(err_norm) @@ -230,7 +212,7 @@ def tsit5_step( new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) else: - new_step_size = h + new_step_size = step_size new_time = time + h return new_state, new_time, new_step_size diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index d328f9476..765032c43 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -64,8 +64,8 @@ def fn(t, x): result_adaptive = integrate( fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 )["x"] - np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.01) From c4b52a770c90d33c852911bb3387a1fd0b29e258 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 25 Nov 2025 20:18:34 +0100 Subject: [PATCH 074/101] fix sampler --- bayesflow/utils/integrate.py | 39 +++++++++--------------------- tests/test_utils/test_integrate.py | 11 ++++++--- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 6ae7ed0b8..a79c98b18 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -28,33 +28,17 @@ def euler_step( max_step_size: ArrayLike = float("inf"), use_adaptive_step_size: bool = False, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): - k1 = fn(time, **filter_kwargs(state, fn)) - if use_adaptive_step_size: - # Euler step - y_euler = {k: state[k] + step_size * k1[k] for k in state} - - # Heun slope - k2 = fn(time + step_size, **filter_kwargs(y_euler, fn)) - - # error = (h/2) (k2 - k1) - err_state = {k: 0.5 * step_size * (k2[k] - k1[k]) for k in state} - - err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) - err = keras.ops.max(err_norm) + raise ValueError("Adaptive step size not supported for Euler method.") - new_step_size = step_size * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.5, 0.2, 5.0) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - else: - new_step_size = step_size + k1 = fn(time, **filter_kwargs(state, fn)) new_state = state.copy() for key in k1.keys(): new_state[key] = state[key] + step_size * k1[key] - new_time = time + step_size - return new_state, new_time, new_step_size + return new_state, new_time, step_size def add_scaled(state, ks, coeffs, h): @@ -224,7 +208,7 @@ def integrate_fixed( start_time: ArrayLike, stop_time: ArrayLike, steps: int, - method: str = "rk45", + method: str, **kwargs, ) -> dict[str, ArrayLike]: if steps <= 0: @@ -263,17 +247,15 @@ def integrate_adaptive( state: dict[str, ArrayLike], start_time: ArrayLike, stop_time: ArrayLike, - min_steps: int = 10, - max_steps: int = 1000, - method: str = "rk45", + min_steps: int, + max_steps: int, + method: str, **kwargs, ) -> dict[str, ArrayLike]: if max_steps <= min_steps: raise ValueError("Maximum number of steps must be greater than minimum number of steps.") match method: - case "euler": - step_fn = euler_step case "rk45": step_fn = rk45_step case "tsit5": @@ -339,7 +321,7 @@ def integrate_scheduled( fn: Callable, state: dict[str, ArrayLike], steps: Tensor | np.ndarray, - method: str = "rk45", + method: str, **kwargs, ) -> dict[str, ArrayLike]: match method: @@ -422,7 +404,7 @@ def integrate( min_steps: int = 10, max_steps: int = 10_000, steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, - method: str = "euler", + method: str = "rk45", **kwargs, ) -> dict[str, ArrayLike]: if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: @@ -480,6 +462,9 @@ def euler_maruyama_step( new_state: Updated state after one Euler-Maruyama step. new_time: time + dt. """ + if use_adaptive_step_size: + raise ValueError("Adaptive step size not supported for Euler method.") + # Compute drift and diffusion drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 765032c43..4f76cc5da 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -61,12 +61,15 @@ def fn(t, x): analytical_result = 1.0 + T_final**2 result = integrate(fn, initial_state, start_time=0.0, stop_time=T_final, steps=num_steps, method=method)["x"] - result_adaptive = integrate( - fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 - )["x"] + if method == "euler": + result_adaptive = result + else: + result_adaptive = integrate( + fn, initial_state, start_time=0.0, stop_time=T_final, steps="adaptive", method=method, max_steps=1_000 + )["x"] np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) - np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.01) + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) @pytest.mark.parametrize( From ac22af55e96b6c46ec1a398bd9e865bbb92b9163 Mon Sep 17 00:00:00 2001 From: arrjon Date: Thu, 27 Nov 2025 15:04:45 +0100 Subject: [PATCH 075/101] updated stochastic solvers --- bayesflow/utils/integrate.py | 328 +++++++++++++++++++++-------- tests/test_utils/test_integrate.py | 6 +- 2 files changed, 246 insertions(+), 88 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a79c98b18..bc7c62ae3 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -440,6 +440,7 @@ def euler_maruyama_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], + noise_aux: dict[str, ArrayLike] = None, use_adaptive_step_size: bool = False, min_step_size: ArrayLike = None, max_step_size: ArrayLike = None, @@ -454,6 +455,7 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise (not used here). use_adaptive_step_size: Whether to use adaptive step sizing (not used here). min_step_size: Minimum allowed step size (not used here). max_step_size: Maximum allowed step size (not used here). @@ -483,101 +485,244 @@ def euler_maruyama_step( return new_state, time + step_size, step_size +def sea_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + step_size: ArrayLike, + noise: dict[str, ArrayLike], + noise_aux: dict[str, ArrayLike] = None, + use_adaptive_step_size: bool = False, + min_step_size: ArrayLike = None, + max_step_size: ArrayLike = None, +) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + """ + Performs a single shifted Euler step for SDEs with additive noise [1]. + + Compared to Euler-Maruyama, this evaluates the drift at a shifted state, + which improves the local error and the global error constant for additive noise. + + The scheme is + X_{n+1} = X_n + f(t_n, X_n + 0.5 * g(t_n) * ΔW_n) * h + g(t_n) * ΔW_n + + [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise (not used here). + use_adaptive_step_size: Whether to use adaptive step sizing (not used here). + min_step_size: Minimum allowed step size (not used here). + max_step_size: Maximum allowed step size (not used here). + + Returns: + new_state: Updated state after one SEA step. + new_time: time + dt. + """ + if use_adaptive_step_size: + raise ValueError("Adaptive step size not supported for Euler method.") + + # Compute diffusion (assumed additive or weakly state dependent) + diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) + + # Check noise keys + if set(diffusion.keys()) != set(noise.keys()): + raise ValueError("Keys of diffusion terms and noise do not match.") + + # Build shifted state: X_shift = X + 0.5 * g * ΔW + shifted_state = {} + for key, x in state.items(): + if key in diffusion: + shifted_state[key] = x + 0.5 * diffusion[key] * noise[key] + else: + shifted_state[key] = x + + # Drift evaluated at shifted state + drift_shifted = drift_fn(time, **filter_kwargs(shifted_state, drift_fn)) + + # Final update + new_state = {} + for key, d in drift_shifted.items(): + base = state[key] + step_size * d + if key in diffusion: + base = base + diffusion[key] * noise[key] + new_state[key] = base + + return new_state, time + step_size, step_size + + def shark_step( drift_fn: Callable, diffusion_fn: Callable, state: Dict[str, ArrayLike], time: ArrayLike, step_size: ArrayLike, - noise: Dict[str, ArrayLike], - min_step_size: ArrayLike, - max_step_size: ArrayLike, + noise: Dict[str, ArrayLike], # w_k = ΔW_k (already scaled by sqrt(|h|)) + noise_aux: Dict[str, ArrayLike], # Z_k ~ N(0,1), used to build H_k use_adaptive_step_size: bool = False, - tolerance: ArrayLike = 1e-3, - half_noises: Optional[Tuple[Dict[str, ArrayLike], Dict[str, ArrayLike]]] = None, - bridge_aux: Optional[Dict[str, ArrayLike]] = None, - validate_split: bool = True, + min_step_size: ArrayLike = -float("inf"), + max_step_size: ArrayLike = float("inf"), + tolerance: float = 1e-3, ) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: """ - Shifted Additive noise Runge Kutta for additive SDEs. + Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion + per step and has a strong order 1.5. + + SHARK method as specified: + + 1) ỹ_k = y_k + g(y_k) H_k + 2) ỹ_{k+5/6} = ỹ_k + (5/6)[ f(ỹ_k) h + g(ỹ_k) W_k ] + 3) y_{k+1} = y_k + + (2/5) f(ỹ_k) h + + (3/5) f(ỹ_{k+5/6}) h + + g(ỹ_k) ( 2/5 W_k + 6/5 H_k ) + + g(ỹ_{k+5/6}) ( 3/5 W_k - 6/5 H_k ) + + with + H_k = 0.5 * |h| * W_k + (|h| ** 1.5) / (2 * sqrt(3)) * Z_k + + [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise. + use_adaptive_step_size: Whether to use adaptive step sizing (not used here). + min_step_size: Minimum allowed step size (not used here). + max_step_size: Maximum allowed step size (not used here). + tolerance: Tolerance for adaptive step sizing. + + Returns: + new_state: Updated state after one SHARK step. + new_time: time + dt. """ - # direction aware handling h = step_size t = time - h_sign = keras.ops.sign(h) - h_mag = keras.ops.abs(h) - # full step: midpoint drift, diffusion at midpoint time - k1 = drift_fn(t, **filter_kwargs(state, drift_fn)) - mid_state = {k: state[k] + 0.5 * h * k1[k] for k in state} - k2 = drift_fn(t + 0.5 * h, **filter_kwargs(mid_state, drift_fn)) - g_mid = diffusion_fn(t + 0.5 * h, **filter_kwargs(state, diffusion_fn)) + # Magnitude of the time step for stochastic scaling + h_mag = keras.ops.abs(h) + h_sign = keras.ops.sign(h) + sqrt_h_mag = keras.ops.sqrt(h_mag) + inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(h_mag)) + + # g(y_k) + g0 = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) + + # Build H_k from w_k and Z_k + H = {} + for k in state.keys(): + if k in g0: + w_k = noise[k] # already scaled by sqrt(|h|) + z_k = noise_aux[k] # standard normal + term1 = 0.5 * h_mag * w_k + term2 = 0.5 * h_mag * sqrt_h_mag * inv_sqrt3 * z_k + H[k] = term1 + term2 + else: + H[k] = keras.ops.zeros_like(state[k]) + + # === 1) shifted initial state === + y_tilde_k = {} + for k in state.keys(): + if k in g0: + y_tilde_k[k] = state[k] + g0[k] * H[k] + else: + y_tilde_k[k] = state[k] + + # === evaluate drift and diffusion at ỹ_k === + f_tilde_k = drift_fn(t, **filter_kwargs(y_tilde_k, drift_fn)) + g_tilde_k = diffusion_fn(t, **filter_kwargs(y_tilde_k, diffusion_fn)) + + # === 2) internal stage at 5/6 === + y_tilde_mid = {} + for k in state.keys(): + drift_part = (5.0 / 6.0) * f_tilde_k[k] * h + if k in g_tilde_k: + sto_part = (5.0 / 6.0) * g_tilde_k[k] * noise[k] + else: + sto_part = keras.ops.zeros_like(state[k]) + y_tilde_mid[k] = y_tilde_k[k] + drift_part + sto_part + + # === evaluate drift and diffusion at ỹ_(k+5/6) === + f_tilde_mid = drift_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, drift_fn)) + g_tilde_mid = diffusion_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, diffusion_fn)) + + # === 3) final update === + new_state = {} + for k in state.keys(): + # deterministic weights + det = state[k] + (2.0 / 5.0) * f_tilde_k[k] * h + (3.0 / 5.0) * f_tilde_mid[k] * h + + # stochastic parts + sto1 = ( + g_tilde_k[k] * ((2.0 / 5.0) * noise[k] + (6.0 / 5.0) * H[k]) + if k in g_tilde_k + else keras.ops.zeros_like(det) + ) + sto2 = ( + g_tilde_mid[k] * ((3.0 / 5.0) * noise[k] - (6.0 / 5.0) * H[k]) + if k in g_tilde_mid + else keras.ops.zeros_like(det) + ) - det_full = {k: state[k] + h * k2[k] for k in state} - sto_full = {k: g_mid[k] * noise[k] for k in g_mid} - y_full = {k: det_full[k] + sto_full.get(k, keras.ops.zeros_like(det_full[k])) for k in det_full} + new_state[k] = det + sto1 + sto2 if not use_adaptive_step_size: - return y_full, t + h, h - - # prepare two half step noises without drawing randomness here - if half_noises is not None: - dW1, dW2 = half_noises - if set(dW1.keys()) != set(noise.keys()) or set(dW2.keys()) != set(noise.keys()): - raise ValueError("half_noises must have the same keys as noise") - if validate_split: - sum_diff = {k: dW1[k] + dW2[k] - noise[k] for k in noise} - parts = [] - for v in sum_diff.values(): - if not hasattr(v, "shape") or len(v.shape) == 0: - v = keras.ops.reshape(v, (1,)) - parts.append(keras.ops.norm(v, ord=2, axis=-1)) - if float(keras.ops.max(keras.ops.stack(parts))) > 1e-6: - raise ValueError("half_noises do not sum to provided noise") + return new_state, t + h, h + + # embedded lower order solution y_low + # here: one stage strong order one method using y_tilde_k + y_low = {} + for k in state.keys(): + det_low = state[k] + f_tilde_k[k] * h + if k in g0: + sto_low = g0[k] * noise[k] + else: + sto_low = keras.ops.zeros_like(det_low) + y_low[k] = det_low + sto_low + + # error estimate as max over components of RMS norm + err_list = [] + for k in state.keys(): + diff = new_state[k] - y_low[k] + sq = keras.ops.square(diff) + mean_sq = keras.ops.mean(sq) + err_k = keras.ops.sqrt(mean_sq) + err_list.append(err_k) + + if len(err_list) == 0: + err = keras.ops.zeros_like(h_mag) else: - if bridge_aux is None: - raise ValueError("Provide either half_noises or bridge_aux when use_adaptive_step_size is True") - if set(bridge_aux.keys()) != set(noise.keys()): - raise ValueError("bridge_aux must have the same keys as noise") - sqrt_h = keras.ops.sqrt(h_mag + 1e-12) # use magnitude - dW1 = {k: 0.5 * noise[k] + 0.5 * sqrt_h * bridge_aux[k] for k in noise} - dW2 = {k: noise[k] - dW1[k] for k in noise} - - half = 0.5 * h - - # first half step - k1h = drift_fn(t, **filter_kwargs(state, drift_fn)) - mid1 = {k: state[k] + 0.5 * half * k1h[k] for k in state} - k2h = drift_fn(t + 0.5 * half, **filter_kwargs(mid1, drift_fn)) - g_q1 = diffusion_fn(t + 0.5 * half, **filter_kwargs(state, diffusion_fn)) - y_half = {k: state[k] + half * k2h[k] + g_q1.get(k, 0) * dW1.get(k, 0) for k in state} - - # second half step - k1h2 = drift_fn(t + half, **filter_kwargs(y_half, drift_fn)) - mid2 = {k: y_half[k] + 0.5 * half * k1h2[k] for k in y_half} - k2h2 = drift_fn(t + 1.5 * half, **filter_kwargs(mid2, drift_fn)) - g_q2 = diffusion_fn(t + 1.5 * half, **filter_kwargs(state, diffusion_fn)) - y_twohalf = {k: y_half[k] + half * k2h2[k] + g_q2.get(k, 0) * dW2.get(k, 0) for k in y_half} - - # error estimate - parts = [] - for k in y_full: - v = y_full[k] - y_twohalf[k] - if not hasattr(v, "shape") or len(v.shape) == 0: - v = keras.ops.reshape(v, (1,)) - parts.append(keras.ops.norm(v, ord=2, axis=-1)) - err = keras.ops.max(keras.ops.stack(parts)) - - # controller for strong order one on additive noise - factor = 0.9 * (tolerance / (err + 1e-12)) ** (2.0 / 3.0) - h_prop = h * keras.ops.clip(factor, 0.2, 5.0) - - # clip by magnitude bounds then restore original sign - mag = keras.ops.abs(h_prop) - mag_new = keras.ops.clip(mag, min_step_size, max_step_size) - h_new = h_sign * mag_new - - return y_full, t + h, h_new + err = err_list[0] + for e_k in err_list[1:]: + err = keras.ops.maximum(err, e_k) + + tiny = keras.ops.cast(1e12, dtype=keras.ops.dtype(h_mag)) + safety = keras.ops.cast(0.9, dtype=keras.ops.dtype(h_mag)) + # effective order between one and one point five + exponent = keras.ops.cast(0.5, dtype=keras.ops.dtype(h_mag)) + + factor = safety * keras.ops.power(tolerance / (err + tiny), exponent) + + # clamp factor + factor_min = keras.ops.cast(0.2, dtype=keras.ops.dtype(h_mag)) + factor_max = keras.ops.cast(5.0, dtype=keras.ops.dtype(h_mag)) + factor = keras.ops.minimum(keras.ops.maximum(factor, factor_min), factor_max) + + new_h_mag = h_mag * factor + new_h_mag = keras.ops.maximum(new_h_mag, min_step_size) + new_h_mag = keras.ops.minimum(new_h_mag, max_step_size) + + new_h = h_sign * new_h_mag + + return new_state, t + h, new_h def _apply_corrector( @@ -627,6 +772,7 @@ def integrate_stochastic_fixed( stop_time: ArrayLike, steps: int, z_history: Dict[str, ArrayLike], + z_extra_history: Dict[str, ArrayLike], corrector_steps: int, score_fn: Optional[Callable], step_size_factor: float, @@ -650,12 +796,17 @@ def body_fixed(_i, _loop_state): # Generate noise increment scaled by sqrt(dt) sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + if len(z_extra_history) == 0: + _noise_extra_i = None + else: + _noise_extra_i = {k: z_extra_history[k][_i] for k in _current_state.keys()} new_state, new_time, new_step = step_fn( state=_current_state, time=_current_time, step_size=dt, noise=_noise_i, + noise_aux=_noise_extra_i, use_adaptive_step_size=False, ) @@ -684,7 +835,7 @@ def integrate_stochastic_adaptive( max_steps: int, initial_step: ArrayLike, z_history: Dict[str, ArrayLike], - bridge_history: Dict[str, ArrayLike], + z_extra_history: Dict[str, ArrayLike], corrector_steps: int, score_fn: Optional[Callable], step_size_factor: float, @@ -712,14 +863,17 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} - _bridge = {k: bridge_history[k][_i] for k in _current_state.keys()} + if len(z_extra_history) == 0: + _noise_extra_i = None + else: + _noise_extra_i = {k: z_extra_history[k][_i] for k in _current_state.keys()} new_state, new_time, new_step = step_fn( state=_current_state, time=_current_time, step_size=dt, noise=_noise_i, - bridge_aux=_bridge, + noise_aux=_noise_extra_i, use_adaptive_step_size=True, ) @@ -808,6 +962,8 @@ def integrate_stochastic( match method: case "euler_maruyama": step_fn_raw = euler_maruyama_step + case "sea": + step_fn_raw = sea_step case "shark": step_fn_raw = shark_step case other: @@ -825,13 +981,12 @@ def integrate_stochastic( # Pre-generate standard normals for the predictor step (up to max_steps) z_history = {} - bridge_history = {} + z_extra_history = {} for key, val in state.items(): shape = keras.ops.shape(val) z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - if is_adaptive and method == "shark": - # Only required for SHARK adaptive step (Brownian Bridge aux noise) - bridge_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + if method == "shark": + z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) # Pre-generate corrector noise if requested corrector_noise_history = {} @@ -853,7 +1008,7 @@ def integrate_stochastic( max_steps=max_steps, initial_step=initial_step, z_history=z_history, - bridge_history=bridge_history, + z_extra_history=z_extra_history, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, @@ -868,6 +1023,7 @@ def integrate_stochastic( stop_time=stop_time, steps=loop_steps, z_history=z_history, + z_extra_history=z_extra_history, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 4f76cc5da..160e7f228 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -76,6 +76,7 @@ def fn(t, x): "method,use_adapt", [ ("euler_maruyama", False), + ("sea", False), ("shark", False), ("shark", True), ], @@ -96,7 +97,7 @@ def test_additive_OU_weak_means_and_vars(method, use_adapt): T = 1.0 # batch of trajectories - N = 20000 # large enough to control sampling error + N = 10000 # large enough to control sampling error seed = keras.random.SeedGenerator(42) def drift_fn(t, x): @@ -136,6 +137,7 @@ def diffusion_fn(t, x): "method,use_adapt", [ ("euler_maruyama", False), + ("sea", False), ("shark", False), ("shark", True), ], @@ -149,7 +151,7 @@ def test_zero_noise_reduces_to_deterministic(method, use_adapt): x0 = 0.9 T = 1.25 steps = 200 if not use_adapt else "adaptive" - seed = keras.random.SeedGenerator(999) + seed = keras.random.SeedGenerator(0) def drift_fn(t, x): return {"x": a * x} From 44570cfe4dd313776ab01eabc0e66388d151bd72 Mon Sep 17 00:00:00 2001 From: arrjon Date: Thu, 27 Nov 2025 16:11:47 +0100 Subject: [PATCH 076/101] add Langevin --- .../diffusion_model/diffusion_model.py | 7 +- bayesflow/utils/__init__.py | 2 +- bayesflow/utils/integrate.py | 146 ++++++++++++++++-- tests/test_utils/test_integrate.py | 73 +++++++++ 4 files changed, 207 insertions(+), 21 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 659641f51..dfd1c28e4 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -16,6 +16,7 @@ integrate_stochastic, logging, tensor_utils, + STOCHASTIC_METHODS, ) from bayesflow.utils.serialization import serialize, deserialize, serializable @@ -408,7 +409,7 @@ def _forward( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for forward integration.") if density: @@ -458,7 +459,7 @@ def _inverse( integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if density: - if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): @@ -477,7 +478,7 @@ def deltas(time, xz): return x, log_density state = {"xz": z} - if integrate_kwargs["method"] in ["euler_maruyama", "shark"]: + if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index a8d28a50a..25b7dd920 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -47,7 +47,7 @@ ) from .hparam_utils import find_batch_size, find_memory_budget -from .integrate import integrate, integrate_stochastic +from .integrate import integrate, integrate_stochastic, DETERMINISTIC_METHODS, STOCHASTIC_METHODS from .io import ( pickle_load, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index bc7c62ae3..dddae5f9c 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -18,6 +18,10 @@ StateDict = Dict[str, ArrayLike] +DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"] +STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin"] + + def euler_step( fn: Callable, state: dict[str, ArrayLike], @@ -731,11 +735,14 @@ def _apply_corrector( i: ArrayLike, corrector_steps: int, score_fn: Optional[Callable], - step_size_factor: float, corrector_noise_history: Dict[str, ArrayLike], + step_size_factor: float = 0.01, noise_schedule=None, ) -> StateDict: - """Helper function to apply corrector steps.""" + """Helper function to apply corrector steps [1]. + + [1] Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations" (2020) + """ if corrector_steps <= 0: return new_state @@ -773,10 +780,10 @@ def integrate_stochastic_fixed( steps: int, z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], - corrector_steps: int, score_fn: Optional[Callable], step_size_factor: float, corrector_noise_history: Dict[str, ArrayLike], + corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: """ @@ -793,7 +800,7 @@ def body_fixed(_i, _loop_state): dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) dt = sign * dt_mag - # Generate noise increment scaled by sqrt(dt) + # Generate noise increment sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} if len(z_extra_history) == 0: @@ -836,10 +843,10 @@ def integrate_stochastic_adaptive( initial_step: ArrayLike, z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], - corrector_steps: int, score_fn: Optional[Callable], step_size_factor: float, corrector_noise_history: Dict[str, ArrayLike], + corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: """ @@ -896,6 +903,89 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): return final_state +def integrate_langevin( + state: StateDict, + start_time: ArrayLike, + stop_time: ArrayLike, + steps: int, + z_history: Dict[str, ArrayLike], + score_fn: Callable, + noise_schedule, + corrector_noise_history: Dict[str, ArrayLike], + step_size_factor: float = 0.01, + corrector_steps: int = 0, +) -> StateDict: + """ + Annealed Langevin dynamics using the given score_fn and noise_schedule [1]. + + At each step i with time t_i, performs for every state component k: + state_k <- state_k + e * score_k + sqrt(2 * e) * z + + Times are stepped linearly from start_time to stop_time. + + [1] Song et al., "Generative Modeling by Estimating Gradients of the Data Distribution" (2020) + """ + + if steps <= 0: + raise ValueError("Number of Langevin steps must be positive.") + if score_fn is None or noise_schedule is None: + raise ValueError("score_fn and noise_schedule must be provided.") + # basic shape check + for k, v in state.items(): + if k not in z_history: + raise ValueError(f"Missing noise for key {k!r} in z_history.") + if keras.ops.shape(z_history[k])[0] < steps: + raise ValueError(f"z_history[{k!r}] has fewer than {steps} steps.") + + # Linear time grid + dt = (stop_time - start_time) / float(steps) + effective_factor = step_size_factor * 100 / np.sqrt(steps) + + def body(_i, loop_state): + current_state, current_time = loop_state + t = current_time + + # score at current time + score = score_fn(t, **filter_kwargs(current_state, score_fn)) + + # noise schedule + log_snr_t = noise_schedule.get_log_snr(t=t, training=False) + _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + new_state: StateDict = {} + for k in current_state.keys(): + s_k = score.get(k, None) + if s_k is None: + new_state[k] = current_state[k] + continue + + e = effective_factor * sigma_t**2 + new_state[k] = current_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i] + + new_time = current_time + dt + + new_state = _apply_corrector( + new_state=new_state, + new_time=new_time, + i=_i, + corrector_steps=corrector_steps, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_noise_history=corrector_noise_history, + ) + + return new_state, new_time + + final_state, _ = keras.ops.fori_loop( + 0, + steps, + body, + (state, start_time), + ) + return final_state + + def integrate_stochastic( drift_fn: Callable, diffusion_fn: Callable, @@ -910,7 +1000,7 @@ def integrate_stochastic( score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, - step_size_factor: float = 0.1, + step_size_factor: float = 0.01, **kwargs, ) -> StateDict: """ @@ -938,6 +1028,7 @@ def integrate_stochastic( Returns: Final state dictionary after integration. """ is_adaptive = isinstance(steps, str) and steps in ["adaptive", "dynamic"] + if is_adaptive: if start_time is None or stop_time is None: raise ValueError("Please provide start_time and stop_time for adaptive integration.") @@ -959,6 +1050,17 @@ def integrate_stochastic( # For fixed step, min/max step size are just the fixed step size min_step_size, max_step_size = initial_step, initial_step + # Pre-generate corrector noise if requested + corrector_noise_history = {} + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + for key, val in state.items(): + shape = keras.ops.shape(val) + corrector_noise_history[key] = keras.random.normal( + (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed + ) + match method: case "euler_maruyama": step_fn_raw = euler_maruyama_step @@ -966,6 +1068,27 @@ def integrate_stochastic( step_fn_raw = sea_step case "shark": step_fn_raw = shark_step + case "langevin": + if is_adaptive: + raise ValueError("Langevin sampling does not support adaptive steps.") + + z_history = {} + for key, val in state.items(): + shape = keras.ops.shape(val) + z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) + + return integrate_langevin( + state=state, + start_time=start_time, + stop_time=stop_time, + steps=loop_steps, + z_history=z_history, + score_fn=score_fn, + noise_schedule=noise_schedule, + step_size_factor=step_size_factor, + corrector_steps=corrector_steps, + corrector_noise_history=corrector_noise_history, + ) case other: raise TypeError(f"Invalid integration method: {other!r}") @@ -988,17 +1111,6 @@ def integrate_stochastic( if method == "shark": z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - # Pre-generate corrector noise if requested - corrector_noise_history = {} - if corrector_steps > 0: - if score_fn is None or noise_schedule is None: - raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") - for key, val in state.items(): - shape = keras.ops.shape(val) - corrector_noise_history[key] = keras.random.normal( - (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed - ) - if is_adaptive: return integrate_stochastic_adaptive( step_fn=step_fn, diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 160e7f228..c846679cf 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -175,3 +175,76 @@ def diffusion_fn(t, x): exact = x0 * np.exp(a * T) np.testing.assert_allclose(np.array(out).mean(), exact, atol=TOL_DET, rtol=0.1) + + +@pytest.mark.parametrize("steps", [500]) +def test_langevin_gaussian_sampling(steps): + """ + Test annealed Langevin dynamics on a 1D Gaussian target. + + Target distribution: N(mu, sigma^2), with score + ∇_x log p(x) = -(x - mu) / sigma^2 + + We verify that the empirical mean and variance after Langevin sampling + match the target within a loose tolerance (to allow for Monte Carlo noise). + """ + # target parameters + mu = 0.3 + sigma = 0.7 + + # number of particles + N = 20000 + start_time = 0.0 + stop_time = 1.0 + + # tolerances for mean and variance + tol_mean = 5e-2 + tol_var = 5e-2 + + # initial state: broad Gaussian, independent of target + seed = keras.random.SeedGenerator(42) + x0 = keras.random.normal((N,), dtype="float32", seed=seed) + initial_state = {"x": x0} + + # simple dummy noise schedule: constant alpha + class DummyNoiseSchedule: + def get_log_snr(self, t, training=False): + return keras.ops.zeros_like(t) + + def get_alpha_sigma(self, log_snr_t): + alpha_t = keras.ops.ones_like(log_snr_t) + sigma_t = keras.ops.ones_like(log_snr_t) + return alpha_t, sigma_t + + noise_schedule = DummyNoiseSchedule() + + # score of the target Gaussian + def score_fn(t, x): + s = -(x - mu) / (sigma**2) + return {"x": s} + + # run Langevin + final_state = integrate_stochastic( + drift_fn=None, + diffusion_fn=None, + score_fn=score_fn, + noise_schedule=noise_schedule, + state=initial_state, + start_time=start_time, + stop_time=stop_time, + steps=steps, + seed=seed, + method="langevin", + max_steps=1_000, + corrector_steps=1, + ) + + xT = np.array(final_state["x"]) + emp_mean = float(xT.mean()) + emp_var = float(xT.var()) + + exp_mean = mu + exp_var = sigma**2 + + np.testing.assert_allclose(emp_mean, exp_mean, atol=tol_mean, rtol=0.0) + np.testing.assert_allclose(emp_var, exp_var, atol=tol_var, rtol=0.0) From 531c6109afc58b70f4be9559d03b2bd241b84a37 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 28 Nov 2025 08:55:32 +0100 Subject: [PATCH 077/101] add Langevin --- .../diffusion_model/diffusion_model.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index dfd1c28e4..0e38ea4f1 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -489,18 +489,17 @@ def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} score_fn = None - if "corrector_steps" in integrate_kwargs: - if integrate_kwargs["corrector_steps"] > 0: - - def score_fn(time, xz): - return { - "xz": self.score( - xz, - time=time, - conditions=conditions, - training=training, - ) - } + if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin": + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } state = integrate_stochastic( drift_fn=deltas, From fdeeb2f291c54c14b157bf4b20a7e99701c534a7 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 28 Nov 2025 09:32:21 +0100 Subject: [PATCH 078/101] add adaptive step size --- bayesflow/utils/integrate.py | 228 ++++++++++++++++++----------- tests/test_utils/test_integrate.py | 4 + 2 files changed, 146 insertions(+), 86 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index dddae5f9c..83fc6b74c 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -27,10 +27,8 @@ def euler_step( state: dict[str, ArrayLike], time: ArrayLike, step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), use_adaptive_step_size: bool = False, + **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): if use_adaptive_step_size: raise ValueError("Adaptive step size not supported for Euler method.") @@ -437,6 +435,38 @@ def integrate( raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") +def adaptive_step_size_controller(state, drift, adaptive_factor, min_step_size, max_step_size): + """ + Adaptive step size controller based on [1]. + + Adaptive step sizing uses: + h = max(1, ||x||**2) / max(1, ||f(x)||**2) * adaptive_factor + + + [1] Fang & Giles, Adaptive Euler-Maruyama Method for SDEs with Non-Globally Lipschitz Drift Coefficients (2020) + + Returns + ------- + New step size. + """ + state_norms = [] + drift_norms = [] + for key in state.keys(): + state_norms.append(keras.ops.norm(state[key], ord=2, axis=-1)) + drift_norms.append(keras.ops.norm(drift[key], ord=2, axis=-1)) + state_norm = keras.ops.stack(state_norms) + drift_norm = keras.ops.stack(drift_norms) + max_state_norm = keras.ops.maximum( + keras.ops.cast(1.0, dtype=keras.ops.dtype(state_norm)), keras.ops.max(state_norm) ** 2 + ) + max_drift_norm = keras.ops.maximum( + keras.ops.cast(1.0, dtype=keras.ops.dtype(drift_norm)), keras.ops.max(drift_norm) ** 2 + ) + new_step_size = max_state_norm / max_drift_norm * adaptive_factor + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + return new_step_size + + def euler_maruyama_step( drift_fn: Callable, diffusion_fn: Callable, @@ -444,10 +474,11 @@ def euler_maruyama_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], - noise_aux: dict[str, ArrayLike] = None, use_adaptive_step_size: bool = False, - min_step_size: ArrayLike = None, - max_step_size: ArrayLike = None, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), + adaptive_factor: float = 1.0, + **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -459,18 +490,15 @@ def euler_maruyama_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. - noise_aux: Mapping of variable names to auxiliary noise (not used here). - use_adaptive_step_size: Whether to use adaptive step sizing (not used here). - min_step_size: Minimum allowed step size (not used here). - max_step_size: Maximum allowed step size (not used here). + use_adaptive_step_size: Whether to use adaptive step sizing. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). Returns: new_state: Updated state after one Euler-Maruyama step. new_time: time + dt. """ - if use_adaptive_step_size: - raise ValueError("Adaptive step size not supported for Euler method.") - # Compute drift and diffusion drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) @@ -486,7 +514,17 @@ def euler_maruyama_step( base = base + diffusion[key] * noise[key] new_state[key] = base - return new_state, time + step_size, step_size + new_step_size = step_size + if use_adaptive_step_size: + new_step_size = adaptive_step_size_controller( + state=state, + drift=drift, + adaptive_factor=adaptive_factor, + min_step_size=min_step_size, + max_step_size=max_step_size, + ) + + return new_state, time + step_size, new_step_size def sea_step( @@ -496,10 +534,11 @@ def sea_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], - noise_aux: dict[str, ArrayLike] = None, use_adaptive_step_size: bool = False, - min_step_size: ArrayLike = None, - max_step_size: ArrayLike = None, + min_step_size: ArrayLike = -float("inf"), + max_step_size: ArrayLike = float("inf"), + adaptive_factor: ArrayLike = 1.0, + **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ Performs a single shifted Euler step for SDEs with additive noise [1]. @@ -518,18 +557,15 @@ def sea_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. - noise_aux: Mapping of variable names to auxiliary noise (not used here). - use_adaptive_step_size: Whether to use adaptive step sizing (not used here). - min_step_size: Minimum allowed step size (not used here). - max_step_size: Maximum allowed step size (not used here). + use_adaptive_step_size: Whether to use adaptive step sizing. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). Returns: new_state: Updated state after one SEA step. new_time: time + dt. """ - if use_adaptive_step_size: - raise ValueError("Adaptive step size not supported for Euler method.") - # Compute diffusion (assumed additive or weakly state dependent) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) @@ -556,7 +592,17 @@ def sea_step( base = base + diffusion[key] * noise[key] new_state[key] = base - return new_state, time + step_size, step_size + new_step_size = step_size + if use_adaptive_step_size: + new_step_size = adaptive_step_size_controller( + state=state, + drift=drift_shifted, + adaptive_factor=adaptive_factor, + min_step_size=min_step_size, + max_step_size=max_step_size, + ) + + return new_state, time + step_size, new_step_size def shark_step( @@ -570,7 +616,7 @@ def shark_step( use_adaptive_step_size: bool = False, min_step_size: ArrayLike = -float("inf"), max_step_size: ArrayLike = float("inf"), - tolerance: float = 1e-3, + adaptive_factor: ArrayLike = 1.0, ) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion @@ -599,10 +645,10 @@ def shark_step( step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. noise_aux: Mapping of variable names to auxiliary noise. - use_adaptive_step_size: Whether to use adaptive step sizing (not used here). - min_step_size: Minimum allowed step size (not used here). - max_step_size: Maximum allowed step size (not used here). - tolerance: Tolerance for adaptive step sizing. + use_adaptive_step_size: Whether to use adaptive step sizing. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). Returns: new_state: Updated state after one SHARK step. @@ -613,7 +659,7 @@ def shark_step( # Magnitude of the time step for stochastic scaling h_mag = keras.ops.abs(h) - h_sign = keras.ops.sign(h) + # h_sign = keras.ops.sign(h) sqrt_h_mag = keras.ops.sqrt(h_mag) inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(h_mag)) @@ -678,55 +724,67 @@ def shark_step( new_state[k] = det + sto1 + sto2 - if not use_adaptive_step_size: - return new_state, t + h, h + # if not use_adaptive_step_size: + # return new_state, t + h, h - # embedded lower order solution y_low - # here: one stage strong order one method using y_tilde_k - y_low = {} - for k in state.keys(): - det_low = state[k] + f_tilde_k[k] * h - if k in g0: - sto_low = g0[k] * noise[k] - else: - sto_low = keras.ops.zeros_like(det_low) - y_low[k] = det_low + sto_low - - # error estimate as max over components of RMS norm - err_list = [] - for k in state.keys(): - diff = new_state[k] - y_low[k] - sq = keras.ops.square(diff) - mean_sq = keras.ops.mean(sq) - err_k = keras.ops.sqrt(mean_sq) - err_list.append(err_k) - - if len(err_list) == 0: - err = keras.ops.zeros_like(h_mag) - else: - err = err_list[0] - for e_k in err_list[1:]: - err = keras.ops.maximum(err, e_k) - - tiny = keras.ops.cast(1e12, dtype=keras.ops.dtype(h_mag)) - safety = keras.ops.cast(0.9, dtype=keras.ops.dtype(h_mag)) - # effective order between one and one point five - exponent = keras.ops.cast(0.5, dtype=keras.ops.dtype(h_mag)) - - factor = safety * keras.ops.power(tolerance / (err + tiny), exponent) - - # clamp factor - factor_min = keras.ops.cast(0.2, dtype=keras.ops.dtype(h_mag)) - factor_max = keras.ops.cast(5.0, dtype=keras.ops.dtype(h_mag)) - factor = keras.ops.minimum(keras.ops.maximum(factor, factor_min), factor_max) - - new_h_mag = h_mag * factor - new_h_mag = keras.ops.maximum(new_h_mag, min_step_size) - new_h_mag = keras.ops.minimum(new_h_mag, max_step_size) - - new_h = h_sign * new_h_mag + new_step_size = h + if use_adaptive_step_size: + new_step_size = adaptive_step_size_controller( + state=state, + drift=f_tilde_k, + adaptive_factor=adaptive_factor, + min_step_size=min_step_size, + max_step_size=max_step_size, + ) - return new_state, t + h, new_h + return new_state, t + h, new_step_size + + # # embedded lower order solution y_low + # # here: one stage strong order one method using y_tilde_k + # y_low = {} + # for k in state.keys(): + # det_low = state[k] + f_tilde_k[k] * h + # if k in g0: + # sto_low = g0[k] * noise[k] + # else: + # sto_low = keras.ops.zeros_like(det_low) + # y_low[k] = det_low + sto_low + # + # # error estimate as max over components of RMS norm + # err_list = [] + # for k in state.keys(): + # diff = new_state[k] - y_low[k] + # sq = keras.ops.square(diff) + # mean_sq = keras.ops.mean(sq) + # err_k = keras.ops.sqrt(mean_sq) + # err_list.append(err_k) + # + # if len(err_list) == 0: + # err = keras.ops.zeros_like(h_mag) + # else: + # err = err_list[0] + # for e_k in err_list[1:]: + # err = keras.ops.maximum(err, e_k) + # + # tiny = keras.ops.cast(1e12, dtype=keras.ops.dtype(h_mag)) + # safety = keras.ops.cast(0.9, dtype=keras.ops.dtype(h_mag)) + # # effective order between one and one point five + # exponent = keras.ops.cast(0.5, dtype=keras.ops.dtype(h_mag)) + # + # factor = safety * keras.ops.power(tolerance / (err + tiny), exponent) + # + # # clamp factor + # factor_min = keras.ops.cast(0.2, dtype=keras.ops.dtype(h_mag)) + # factor_max = keras.ops.cast(5.0, dtype=keras.ops.dtype(h_mag)) + # factor = keras.ops.minimum(keras.ops.maximum(factor, factor_min), factor_max) + # + # new_h_mag = h_mag * factor + # new_h_mag = keras.ops.maximum(new_h_mag, min_step_size) + # new_h_mag = keras.ops.minimum(new_h_mag, max_step_size) + # + # new_h = h_sign * new_h_mag + # + # return new_state, t + h, new_h def _apply_corrector( @@ -736,7 +794,7 @@ def _apply_corrector( corrector_steps: int, score_fn: Optional[Callable], corrector_noise_history: Dict[str, ArrayLike], - step_size_factor: float = 0.01, + step_size_factor: ArrayLike = 0.01, noise_schedule=None, ) -> StateDict: """Helper function to apply corrector steps [1]. @@ -764,7 +822,7 @@ def _apply_corrector( score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) score_norm = keras.ops.maximum(score_norm, 1e-8) - # Compute step size 'e' for the Langevin update + # Compute step size for the Langevin update e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 # Annealed Langevin Dynamics update @@ -781,7 +839,7 @@ def integrate_stochastic_fixed( z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], score_fn: Optional[Callable], - step_size_factor: float, + step_size_factor: ArrayLike, corrector_noise_history: Dict[str, ArrayLike], corrector_steps: int = 0, noise_schedule=None, @@ -844,7 +902,7 @@ def integrate_stochastic_adaptive( z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], score_fn: Optional[Callable], - step_size_factor: float, + step_size_factor: ArrayLike, corrector_noise_history: Dict[str, ArrayLike], corrector_steps: int = 0, noise_schedule=None, @@ -912,7 +970,7 @@ def integrate_langevin( score_fn: Callable, noise_schedule, corrector_noise_history: Dict[str, ArrayLike], - step_size_factor: float = 0.01, + step_size_factor: ArrayLike = 0.01, corrector_steps: int = 0, ) -> StateDict: """ @@ -1000,7 +1058,7 @@ def integrate_stochastic( score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, - step_size_factor: float = 0.01, + step_size_factor: ArrayLike = 0.01, **kwargs, ) -> StateDict: """ @@ -1034,8 +1092,6 @@ def integrate_stochastic( raise ValueError("Please provide start_time and stop_time for adaptive integration.") if min_steps <= 0 or max_steps <= 0 or max_steps < min_steps: raise ValueError("min_steps and max_steps must be positive, and max_steps >= min_steps.") - if method != "shark": - raise ValueError("Adaptive step size is only supported for the 'shark' method.") loop_steps = max_steps initial_step = (stop_time - start_time) / float(min_steps) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index c846679cf..ba286d857 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -76,7 +76,9 @@ def fn(t, x): "method,use_adapt", [ ("euler_maruyama", False), + ("euler_maruyama", True), ("sea", False), + ("sea", True), ("shark", False), ("shark", True), ], @@ -137,7 +139,9 @@ def diffusion_fn(t, x): "method,use_adapt", [ ("euler_maruyama", False), + ("euler_maruyama", True), ("sea", False), + ("sea", True), ("shark", False), ("shark", True), ], From f45c2dcf69991b0e4c63cd33251e8df1e7f31e69 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 29 Nov 2025 11:15:06 +0100 Subject: [PATCH 079/101] tune adaptive step size --- bayesflow/utils/integrate.py | 70 +++++++----------------------------- 1 file changed, 13 insertions(+), 57 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 83fc6b74c..984fd3267 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -435,7 +435,13 @@ def integrate( raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") -def adaptive_step_size_controller(state, drift, adaptive_factor, min_step_size, max_step_size): +def adaptive_step_size_controller( + state, + drift, + adaptive_factor: ArrayLike = 1.0, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), +) -> ArrayLike: """ Adaptive step size controller based on [1]. @@ -477,7 +483,7 @@ def euler_maruyama_step( use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), - adaptive_factor: float = 1.0, + adaptive_factor: float = 0.1, **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ @@ -493,7 +499,7 @@ def euler_maruyama_step( use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). + adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one Euler-Maruyama step. @@ -537,7 +543,7 @@ def sea_step( use_adaptive_step_size: bool = False, min_step_size: ArrayLike = -float("inf"), max_step_size: ArrayLike = float("inf"), - adaptive_factor: ArrayLike = 1.0, + adaptive_factor: ArrayLike = 0.1, **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): """ @@ -560,7 +566,7 @@ def sea_step( use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). + adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one SEA step. @@ -616,7 +622,7 @@ def shark_step( use_adaptive_step_size: bool = False, min_step_size: ArrayLike = -float("inf"), max_step_size: ArrayLike = float("inf"), - adaptive_factor: ArrayLike = 1.0, + adaptive_factor: ArrayLike = 0.1, ) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion @@ -648,7 +654,7 @@ def shark_step( use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < step_size_factor < 1). + adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one SHARK step. @@ -724,9 +730,6 @@ def shark_step( new_state[k] = det + sto1 + sto2 - # if not use_adaptive_step_size: - # return new_state, t + h, h - new_step_size = h if use_adaptive_step_size: new_step_size = adaptive_step_size_controller( @@ -739,53 +742,6 @@ def shark_step( return new_state, t + h, new_step_size - # # embedded lower order solution y_low - # # here: one stage strong order one method using y_tilde_k - # y_low = {} - # for k in state.keys(): - # det_low = state[k] + f_tilde_k[k] * h - # if k in g0: - # sto_low = g0[k] * noise[k] - # else: - # sto_low = keras.ops.zeros_like(det_low) - # y_low[k] = det_low + sto_low - # - # # error estimate as max over components of RMS norm - # err_list = [] - # for k in state.keys(): - # diff = new_state[k] - y_low[k] - # sq = keras.ops.square(diff) - # mean_sq = keras.ops.mean(sq) - # err_k = keras.ops.sqrt(mean_sq) - # err_list.append(err_k) - # - # if len(err_list) == 0: - # err = keras.ops.zeros_like(h_mag) - # else: - # err = err_list[0] - # for e_k in err_list[1:]: - # err = keras.ops.maximum(err, e_k) - # - # tiny = keras.ops.cast(1e12, dtype=keras.ops.dtype(h_mag)) - # safety = keras.ops.cast(0.9, dtype=keras.ops.dtype(h_mag)) - # # effective order between one and one point five - # exponent = keras.ops.cast(0.5, dtype=keras.ops.dtype(h_mag)) - # - # factor = safety * keras.ops.power(tolerance / (err + tiny), exponent) - # - # # clamp factor - # factor_min = keras.ops.cast(0.2, dtype=keras.ops.dtype(h_mag)) - # factor_max = keras.ops.cast(5.0, dtype=keras.ops.dtype(h_mag)) - # factor = keras.ops.minimum(keras.ops.maximum(factor, factor_min), factor_max) - # - # new_h_mag = h_mag * factor - # new_h_mag = keras.ops.maximum(new_h_mag, min_step_size) - # new_h_mag = keras.ops.minimum(new_h_mag, max_step_size) - # - # new_h = h_sign * new_h_mag - # - # return new_state, t + h, new_h - def _apply_corrector( new_state: StateDict, From 9fd77074325a9f50a74e990ed19b8a8aa76dc08b Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 30 Nov 2025 17:18:40 +0100 Subject: [PATCH 080/101] add Gotta Go Fast SDE sampler --- bayesflow/utils/integrate.py | 327 ++++++++++++++++++++--------- tests/test_utils/test_integrate.py | 109 ++++++++-- 2 files changed, 315 insertions(+), 121 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 984fd3267..c9f951982 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -19,7 +19,7 @@ DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"] -STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin"] +STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin", "fast_adaptive"] def euler_step( @@ -27,12 +27,8 @@ def euler_step( state: dict[str, ArrayLike], time: ArrayLike, step_size: ArrayLike, - use_adaptive_step_size: bool = False, **kwargs, ) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): - if use_adaptive_step_size: - raise ValueError("Adaptive step size not supported for Euler method.") - k1 = fn(time, **filter_kwargs(state, fn)) new_state = state.copy() @@ -82,10 +78,6 @@ def rk45_step( **add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), ) - # check all keys are equal - if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5, k6]): - raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.") - # 5th order solution new_state = {} for key in k1.keys(): @@ -262,6 +254,8 @@ def integrate_adaptive( step_fn = rk45_step case "tsit5": step_fn = tsit5_step + case "euler": + raise ValueError("Adaptive step sizing is not supported for the 'euler' method.") case str() as name: raise ValueError(f"Unknown integration method name: {name!r}") case other: @@ -438,12 +432,12 @@ def integrate( def adaptive_step_size_controller( state, drift, - adaptive_factor: ArrayLike = 1.0, + adaptive_factor: ArrayLike, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), ) -> ArrayLike: """ - Adaptive step size controller based on [1]. + Adaptive step size controller based on [1]. Similar to a tamed explicit Euler method when used in Euler-Maruyama. Adaptive step sizing uses: h = max(1, ||x||**2) / max(1, ||f(x)||**2) * adaptive_factor @@ -483,9 +477,12 @@ def euler_maruyama_step( use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), - adaptive_factor: float = 0.1, + adaptive_factor: float = 0.01, **kwargs, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): +) -> Union[ + Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike], + Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike, Dict[str, ArrayLike]], +]: """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -509,19 +506,9 @@ def euler_maruyama_step( drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) - # Check noise keys - if set(diffusion.keys()) != set(noise.keys()): - raise ValueError("Keys of diffusion terms and noise do not match.") - - new_state = {} - for key, d in drift.items(): - base = state[key] + step_size * d - if key in diffusion: # stochastic update - base = base + diffusion[key] * noise[key] - new_state[key] = base - new_step_size = step_size if use_adaptive_step_size: + sign_step = keras.ops.sign(step_size) new_step_size = adaptive_step_size_controller( state=state, drift=drift, @@ -529,8 +516,168 @@ def euler_maruyama_step( min_step_size=min_step_size, max_step_size=max_step_size, ) + new_step_size = sign_step * keras.ops.abs(new_step_size) + + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(new_step_size)) + + new_state = {} + for key, d in drift.items(): + base = state[key] + new_step_size * d + if key in diffusion: + base = base + diffusion[key] * sqrt_step_size * noise[key] + new_state[key] = base + + if use_adaptive_step_size: + return new_state, time + new_step_size, new_step_size, state + return new_state, time + new_step_size, new_step_size + + +def fast_adaptive_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + step_size: ArrayLike, + noise: dict[str, ArrayLike], + last_state: dict[str, ArrayLike] = None, + use_adaptive_step_size: bool = True, + min_step_size: float = -float("inf"), + max_step_size: float = float("inf"), + e_abs: float = 0.01, + e_rel: float = 0.01, + r: float = 0.9, + adapt_safety: float = 0.9, + **kwargs, +) -> Union[ + Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike], + Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike, Dict[str, ArrayLike]], +]: + """ + Performs a single adaptive step for stochastic differential equations based on [1]. + + Based on + + This method uses a predictor-corrector approach with error estimation: + 1. Take an Euler-Maruyama step (predictor) + 2. Take another Euler-Maruyama step from the predicted state + 3. Average the two predictions (corrector) + 4. Estimate error and adapt step size + + When step_size reaches min_step_size, steps are always accepted regardless of + error to ensure progress and termination within max_steps. + + [1] Jolicoeur-Martineau et al. (2021) "Gotta Go Fast When Generating Data with Score-Based Models" + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors (pre-scaled by sqrt(dt)). + last_state: Previous state for error estimation. + use_adaptive_step_size: Whether to adapt step size. + min_step_size: Minimum allowed step size. + max_step_size: Maximum allowed step size. + e_abs: Absolute error tolerance. + e_rel: Relative error tolerance. + r: Order of the method for step size adaptation. + adapt_safety: Safety factor for step size adaptation. + **kwargs: Additional arguments passed to drift_fn and diffusion_fn. + + Returns: + new_state: Updated state after one adaptive step. + new_time: time + dt (or time if step rejected). + new_step_size: Adapted step size for next iteration. + """ + state_euler, time_mid, _ = euler_maruyama_step( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=state, + time=time, + step_size=step_size, + min_step_size=min_step_size, + max_step_size=max_step_size, + noise=noise, + use_adaptive_step_size=False, + ) + + # Compute drift and diffusion at new state, but update from old state + drift_mid = drift_fn(time_mid, **filter_kwargs(state_euler, drift_fn)) + diffusion_mid = diffusion_fn(time_mid, **filter_kwargs(state_euler, diffusion_fn)) + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) + + state_euler_mid = {} + for key, d in drift_mid.items(): + base = state[key] + step_size * d + if key in diffusion_mid: + base = base + diffusion_mid[key] * sqrt_step_size * noise[key] + state_euler_mid[key] = base + + # average the two predictions + state_heun = {} + for key in state.keys(): + state_heun[key] = 0.5 * (state_euler[key] + state_euler_mid[key]) + + # Error estimation + if use_adaptive_step_size: + # Check if we're at minimum step size - if so, force acceptance + at_min_step = keras.ops.less_equal(step_size, min_step_size) + + # Compute error tolerance for each component + e_abs_tensor = keras.ops.cast(e_abs, dtype=keras.ops.dtype(list(state.values())[0])) + e_rel_tensor = keras.ops.cast(e_rel, dtype=keras.ops.dtype(list(state.values())[0])) + + max_error = keras.ops.cast(0.0, dtype=keras.ops.dtype(list(state.values())[0])) + + for key in state.keys(): + # Local error estimate: difference between Heun and first Euler step + error_estimate = keras.ops.abs(state_heun[key] - state_euler[key]) - return new_state, time + step_size, new_step_size + # Tolerance threshold + delta = keras.ops.maximum( + e_abs_tensor, + e_rel_tensor * keras.ops.maximum(keras.ops.abs(state_euler[key]), keras.ops.abs(last_state[key])), + ) + + # Normalized error + normalized_error = error_estimate / (delta + 1e-10) + + # Maximum error across all components and batch dimensions + component_max_error = keras.ops.max(normalized_error) + max_error = keras.ops.maximum(max_error, component_max_error) + + error_scale = 1 # 1/sqrt(n_params) + E2 = error_scale * max_error + + # Accept step if error is acceptable OR if at minimum step size + error_acceptable = keras.ops.less_equal(E2, keras.ops.cast(1.0, dtype=keras.ops.dtype(E2))) + accepted = keras.ops.logical_or(error_acceptable, at_min_step) + + # Adapt step size for next iteration (only if not at minimum) + # Ensure E2 is not zero to avoid division issues + E2_safe = keras.ops.maximum(E2, 1e-10) + + # New step size based on error estimate + adapt_factor = adapt_safety * keras.ops.power(E2_safe, -r) + new_step_candidate = step_size * adapt_factor + + # Clamp to valid range + sign_step = keras.ops.sign(step_size) + new_step_size = keras.ops.minimum(keras.ops.maximum(new_step_candidate, min_step_size), max_step_size) + new_step_size = sign_step * keras.ops.abs(new_step_size) + + # Return appropriate state based on acceptance + new_state = keras.ops.cond(accepted, lambda: state_heun, lambda: state) + + new_time = keras.ops.cond(accepted, lambda: time_mid, lambda: time) + + prev_state = keras.ops.cond(accepted, lambda: state_euler, lambda: state) + + return new_state, new_time, new_step_size, prev_state + + else: + return state_heun, time_mid, step_size def sea_step( @@ -540,12 +687,8 @@ def sea_step( time: ArrayLike, step_size: ArrayLike, noise: dict[str, ArrayLike], - use_adaptive_step_size: bool = False, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - adaptive_factor: ArrayLike = 0.1, **kwargs, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): +) -> Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: """ Performs a single shifted Euler step for SDEs with additive noise [1]. @@ -563,10 +706,6 @@ def sea_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. - use_adaptive_step_size: Whether to use adaptive step sizing. - min_step_size: Minimum allowed step size. - max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one SEA step. @@ -574,16 +713,13 @@ def sea_step( """ # Compute diffusion (assumed additive or weakly state dependent) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) - - # Check noise keys - if set(diffusion.keys()) != set(noise.keys()): - raise ValueError("Keys of diffusion terms and noise do not match.") + sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) # Build shifted state: X_shift = X + 0.5 * g * ΔW shifted_state = {} for key, x in state.items(): if key in diffusion: - shifted_state[key] = x + 0.5 * diffusion[key] * noise[key] + shifted_state[key] = x + 0.5 * diffusion[key] * sqrt_step_size * noise[key] else: shifted_state[key] = x @@ -595,20 +731,10 @@ def sea_step( for key, d in drift_shifted.items(): base = state[key] + step_size * d if key in diffusion: - base = base + diffusion[key] * noise[key] + base = base + diffusion[key] * sqrt_step_size * noise[key] new_state[key] = base - new_step_size = step_size - if use_adaptive_step_size: - new_step_size = adaptive_step_size_controller( - state=state, - drift=drift_shifted, - adaptive_factor=adaptive_factor, - min_step_size=min_step_size, - max_step_size=max_step_size, - ) - - return new_state, time + step_size, new_step_size + return new_state, time + step_size, step_size def shark_step( @@ -617,13 +743,10 @@ def shark_step( state: Dict[str, ArrayLike], time: ArrayLike, step_size: ArrayLike, - noise: Dict[str, ArrayLike], # w_k = ΔW_k (already scaled by sqrt(|h|)) - noise_aux: Dict[str, ArrayLike], # Z_k ~ N(0,1), used to build H_k - use_adaptive_step_size: bool = False, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - adaptive_factor: ArrayLike = 0.1, -) -> Union[Tuple[Dict[str, ArrayLike], ArrayLike], Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]]: + noise: Dict[str, ArrayLike], + noise_aux: Dict[str, ArrayLike], + **kwargs, +) -> Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion per step and has a strong order 1.5. @@ -651,10 +774,6 @@ def shark_step( step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. noise_aux: Mapping of variable names to auxiliary noise. - use_adaptive_step_size: Whether to use adaptive step sizing. - min_step_size: Minimum allowed step size. - max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one SHARK step. @@ -676,7 +795,7 @@ def shark_step( H = {} for k in state.keys(): if k in g0: - w_k = noise[k] # already scaled by sqrt(|h|) + w_k = sqrt_h_mag * noise[k] z_k = noise_aux[k] # standard normal term1 = 0.5 * h_mag * w_k term2 = 0.5 * h_mag * sqrt_h_mag * inv_sqrt3 * z_k @@ -701,7 +820,7 @@ def shark_step( for k in state.keys(): drift_part = (5.0 / 6.0) * f_tilde_k[k] * h if k in g_tilde_k: - sto_part = (5.0 / 6.0) * g_tilde_k[k] * noise[k] + sto_part = (5.0 / 6.0) * g_tilde_k[k] * sqrt_h_mag * noise[k] else: sto_part = keras.ops.zeros_like(state[k]) y_tilde_mid[k] = y_tilde_k[k] + drift_part + sto_part @@ -718,29 +837,19 @@ def shark_step( # stochastic parts sto1 = ( - g_tilde_k[k] * ((2.0 / 5.0) * noise[k] + (6.0 / 5.0) * H[k]) + g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * H[k]) if k in g_tilde_k else keras.ops.zeros_like(det) ) sto2 = ( - g_tilde_mid[k] * ((3.0 / 5.0) * noise[k] - (6.0 / 5.0) * H[k]) + g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * H[k]) if k in g_tilde_mid else keras.ops.zeros_like(det) ) new_state[k] = det + sto1 + sto2 - new_step_size = h - if use_adaptive_step_size: - new_step_size = adaptive_step_size_controller( - state=state, - drift=f_tilde_k, - adaptive_factor=adaptive_factor, - min_step_size=min_step_size, - max_step_size=max_step_size, - ) - - return new_state, t + h, new_step_size + return new_state, t + h, h def _apply_corrector( @@ -792,6 +901,8 @@ def integrate_stochastic_fixed( start_time: ArrayLike, stop_time: ArrayLike, steps: int, + min_step_size: ArrayLike, + max_step_size: ArrayLike, z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], score_fn: Optional[Callable], @@ -809,14 +920,13 @@ def body_fixed(_i, _loop_state): _current_state, _current_time, _current_step = _loop_state # Determine step size: either the constant size or the remainder to reach stop_time - remaining = stop_time - _current_time - sign = keras.ops.sign(remaining) - dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + remaining = keras.ops.abs(stop_time - _current_time) + sign = keras.ops.sign(_current_step) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) dt = sign * dt_mag # Generate noise increment - sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) - _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} + _noise_i = {k: z_history[k][_i] for k in _current_state.keys()} if len(z_extra_history) == 0: _noise_extra_i = None else: @@ -826,6 +936,8 @@ def body_fixed(_i, _loop_state): state=_current_state, time=_current_time, step_size=dt, + min_step_size=min_step_size, + max_step_size=keras.ops.minimum(max_step_size, remaining), noise=_noise_i, noise_aux=_noise_extra_i, use_adaptive_step_size=False, @@ -854,6 +966,8 @@ def integrate_stochastic_adaptive( start_time: ArrayLike, stop_time: ArrayLike, max_steps: int, + min_step_size: ArrayLike, + max_step_size: ArrayLike, initial_step: ArrayLike, z_history: Dict[str, ArrayLike], z_extra_history: Dict[str, ArrayLike], @@ -866,33 +980,34 @@ def integrate_stochastic_adaptive( """ Performs adaptive-step SDE integration. """ - initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0) + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0, state) - def cond(i, current_state, current_time, current_step, counter): - # We use a small epsilon check for floating point equality - time_reached = keras.ops.all(keras.ops.isclose(current_time, stop_time)) - return keras.ops.logical_and(keras.ops.less(i, max_steps), keras.ops.logical_not(time_reached)) + def cond(i, current_state, current_time, current_step, counter, last_state): + # time remaining after the next step + time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) + return keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) - def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): + def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _last_state): # Step Size Control - remaining = stop_time - _current_time - sign = keras.ops.sign(remaining) + remaining = keras.ops.abs(stop_time - _current_time) + sign = keras.ops.sign(_current_step) # Ensure the next step does not overshoot the stop_time - dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), keras.ops.abs(remaining)) + dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) dt = sign * dt_mag _counter += 1 - sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) - _noise_i = {k: z_history[k][_i] * sqrt_dt for k in _current_state.keys()} - if len(z_extra_history) == 0: - _noise_extra_i = None - else: + _noise_i = {k: z_history[k][_i] for k in _current_state.keys()} + _noise_extra_i = None + if len(z_extra_history) > 0: _noise_extra_i = {k: z_extra_history[k][_i] for k in _current_state.keys()} - new_state, new_time, new_step = step_fn( + new_state, new_time, new_step, _new_current_state = step_fn( state=_current_state, + last_state=_last_state, time=_current_time, step_size=dt, + min_step_size=min_step_size, + max_step_size=keras.ops.minimum(max_step_size, remaining), noise=_noise_i, noise_aux=_noise_extra_i, use_adaptive_step_size=True, @@ -909,10 +1024,10 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter): corrector_noise_history=corrector_noise_history, ) - return _i + 1, new_state, new_time, new_step, _counter + return _i + 1, new_state, new_time, new_step, _counter, _new_current_state # Execute the adaptive loop - _, final_state, _, _, final_counter = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + _, final_state, _, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) logging.debug("Finished integration after {} steps.", final_counter) return final_state @@ -1009,7 +1124,7 @@ def integrate_stochastic( seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", - min_steps: int = 10, + min_steps: int = 20, max_steps: int = 10_000, score_fn: Callable = None, corrector_steps: int = 0, @@ -1078,8 +1193,14 @@ def integrate_stochastic( step_fn_raw = euler_maruyama_step case "sea": step_fn_raw = sea_step + if is_adaptive: + raise ValueError("SEA SDE solver does not support adaptive steps.") case "shark": step_fn_raw = shark_step + if is_adaptive: + raise ValueError("SHARK SDE solver does not support adaptive steps.") + case "fast_adaptive": + step_fn_raw = fast_adaptive_step case "langevin": if is_adaptive: raise ValueError("Langevin sampling does not support adaptive steps.") @@ -1109,8 +1230,6 @@ def integrate_stochastic( step_fn_raw, drift_fn=drift_fn, diffusion_fn=diffusion_fn, - min_step_size=min_step_size, - max_step_size=max_step_size, **kwargs, ) @@ -1130,6 +1249,8 @@ def integrate_stochastic( start_time=start_time, stop_time=stop_time, max_steps=max_steps, + min_step_size=min_step_size, + max_step_size=max_step_size, initial_step=initial_step, z_history=z_history, z_extra_history=z_extra_history, @@ -1145,6 +1266,8 @@ def integrate_stochastic( state=state, start_time=start_time, stop_time=stop_time, + min_step_size=min_step_size, + max_step_size=max_step_size, steps=loop_steps, z_history=z_history, z_extra_history=z_extra_history, diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index ba286d857..8a6de502c 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -78,42 +78,44 @@ def fn(t, x): ("euler_maruyama", False), ("euler_maruyama", True), ("sea", False), - ("sea", True), ("shark", False), - ("shark", True), + ("fast_adaptive", False), + ("fast_adaptive", True), ], ) -def test_additive_OU_weak_means_and_vars(method, use_adapt): +def test_forward_additive_ou_weak_means_and_vars(method, use_adapt): """ - Ornstein Uhlenbeck with additive noise + Ornstein-Uhlenbeck with additive noise, integrated FORWARD in time. + This serves as a sanity check that forward integration still works correctly. + + Forward SDE: dX = a X dt + sigma dW - Exact at time T: - E[X_T] = x0 * exp(a T) - Var[X_T] = sigma^2 * (exp(2 a T) - 1) / (2 a) - We verify weak accuracy by matching empirical mean and variance. + + Exact at time T starting from X(0) = x_0: + E[X(T)] = x_0 * exp(a T) + Var[X(T)] = sigma^2 * (exp(2 a T) - 1) / (2 a) """ # SDE parameters a = -1.0 sigma = 0.5 - x0 = 1.2 + x_0 = 1.2 # initial condition at time 0 T = 1.0 # batch of trajectories - N = 10000 # large enough to control sampling error + N = 10000 seed = keras.random.SeedGenerator(42) def drift_fn(t, x): return {"x": a * x} def diffusion_fn(t, x): - # additive noise, independent of state return {"x": keras.ops.convert_to_tensor([sigma])} - initial_state = {"x": keras.ops.ones((N,)) * x0} + initial_state = {"x": keras.ops.ones((N,)) * x_0} steps = 200 if not use_adapt else "adaptive" - # expected mean and variance - exp_mean = x0 * np.exp(a * T) + # Expected mean and variance at t=T + exp_mean = x_0 * np.exp(a * T) exp_var = sigma**2 * (np.exp(2.0 * a * T) - 1.0) / (2.0 * a) out = integrate_stochastic( @@ -128,9 +130,78 @@ def diffusion_fn(t, x): max_steps=1_000, ) - xT = np.array(out["x"]) - emp_mean = float(xT.mean()) - emp_var = float(xT.var()) + x_T = np.array(out["x"]) + emp_mean = float(x_T.mean()) + emp_var = float(x_T.var()) + + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) + + +@pytest.mark.parametrize( + "method,use_adapt", + [ + ("euler_maruyama", False), + ("euler_maruyama", True), + ("sea", False), + ("shark", False), + ("fast_adaptive", False), + ("fast_adaptive", True), + ], +) +def test_backward_additive_ou_weak_means_and_vars(method, use_adapt): + """ + Ornstein-Uhlenbeck with additive noise, integrated BACKWARD in time. + + When integrating from t=T back to t=0 with initial condition X(T) = x_T, + we get X(0) which should satisfy: + E[X(0)] = x_T * exp(-a T) (-a because we go backward) + Var[X(0)] = sigma^2 * (exp(-2 a T) - 1) / (-2 a) + + We verify weak accuracy by matching empirical mean and variance. + """ + # SDE parameters + a = -1.0 + sigma = 0.5 + x_T = 1.2 # initial condition at time T + T = 1.0 + + # batch of trajectories + N = 10000 # large enough to control sampling error + seed = keras.random.SeedGenerator(42) + + def drift_fn(t, x): + return {"x": a * x} + + def diffusion_fn(t, x): + # additive noise, independent of state + return {"x": keras.ops.convert_to_tensor([sigma])} + + # Start at time T with value x_T + initial_state = {"x": keras.ops.ones((N,)) * x_T} + steps = 200 if not use_adapt else "adaptive" + + # Expected mean and variance at t=0 after integrating backward from t=T + # For backward integration, the effective drift coefficient changes sign + exp_mean = x_T * np.exp(-a * T) + exp_var = sigma**2 * (np.exp(-2.0 * a * T) - 1.0) / (-2.0 * a) + + out = integrate_stochastic( + drift_fn=drift_fn, + diffusion_fn=diffusion_fn, + state=initial_state, + start_time=T, + stop_time=0.0, + steps=steps, + seed=seed, + method=method, + max_steps=1_000, + ) + + x_0 = np.array(out["x"]) + emp_mean = float(x_0.mean()) + emp_var = float(x_0.var()) + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) @@ -141,9 +212,9 @@ def diffusion_fn(t, x): ("euler_maruyama", False), ("euler_maruyama", True), ("sea", False), - ("sea", True), ("shark", False), - ("shark", True), + ("fast_adaptive", False), + ("fast_adaptive", True), ], ) def test_zero_noise_reduces_to_deterministic(method, use_adapt): From 5c5abd361efc34267f05805838b8c1522ffaa45d Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 30 Nov 2025 18:33:45 +0100 Subject: [PATCH 081/101] improve adaptive ODE samplers --- bayesflow/utils/integrate.py | 358 ++++++++++++++++------------- tests/test_utils/test_integrate.py | 27 ++- 2 files changed, 220 insertions(+), 165 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index c9f951982..f0e049094 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -24,11 +24,11 @@ def euler_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, **kwargs, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): +) -> Tuple[StateDict, ArrayLike, None, ArrayLike]: k1 = fn(time, **filter_kwargs(state, fn)) new_state = state.copy() @@ -36,7 +36,7 @@ def euler_step( new_state[key] = state[key] + step_size * k1[key] new_time = time + step_size - return new_state, new_time, step_size + return new_state, new_time, None, 0.0 def add_scaled(state, ks, coeffs, h): @@ -51,21 +51,19 @@ def add_scaled(state, ks, coeffs, h): def rk45_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, - last_step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - use_adaptive_step_size: bool = False, -) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + step_size: ArrayLike, + k1: StateDict = None, + use_adaptive_step_size: bool = True, +) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: """ Dormand-Prince 5(4) method with embedded error estimation. """ - step_size = last_step_size h = step_size - k1 = fn(time, **filter_kwargs(state, fn)) + if k1 is None: # reuse k1 if available + k1 = fn(time, **filter_kwargs(state, fn)) k2 = fn(time + h * (1 / 5), **add_scaled(state, [k1], [1 / 5], h)) k3 = fn(time + h * (3 / 10), **add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h)) k4 = fn(time + h * (4 / 5), **add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h)) @@ -85,48 +83,42 @@ def rk45_step( 35 / 384 * k1[key] + 500 / 1113 * k3[key] + 125 / 192 * k4[key] - 2187 / 6784 * k5[key] + 11 / 84 * k6[key] ) - if use_adaptive_step_size: - k7 = fn(time + h, **filter_kwargs(new_state, fn)) - - # 4th order embedded solution - err_state = {} - for key in k1.keys(): - y4 = state[key] + h * ( - 5179 / 57600 * k1[key] - + 7571 / 16695 * k3[key] - + 393 / 640 * k4[key] - - 92097 / 339200 * k5[key] - + 187 / 2100 * k6[key] - + 1 / 40 * k7[key] - ) - err_state[key] = new_state[key] - y4 + new_time = time + h + if not use_adaptive_step_size: + return new_state, new_time, None, 0.0 - err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) - err = keras.ops.max(err_norm) + k7 = fn(time + h, **filter_kwargs(new_state, fn)) - new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - else: - new_step_size = step_size + # 4th order embedded solution + err_state = {} + for key in k1.keys(): + y4 = state[key] + h * ( + 5179 / 57600 * k1[key] + + 7571 / 16695 * k3[key] + + 393 / 640 * k4[key] + - 92097 / 339200 * k5[key] + + 187 / 2100 * k6[key] + + 1 / 40 * k7[key] + ) + err_state[key] = new_state[key] - y4 - new_time = time + h - return new_state, new_time, new_step_size + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) + + return new_state, new_time, k7, err def tsit5_step( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, - last_step_size: ArrayLike, - tolerance: ArrayLike = 1e-6, - min_step_size: ArrayLike = -float("inf"), - max_step_size: ArrayLike = float("inf"), - use_adaptive_step_size: bool = False, -): + step_size: ArrayLike, + k1: StateDict = None, + use_adaptive_step_size: bool = True, +) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: """ Implements a single step of the Tsitouras 5/4 Runge-Kutta method. """ - step_size = last_step_size h = step_size # Butcher tableau coefficients @@ -135,7 +127,8 @@ def tsit5_step( c4 = 0.9 c5 = 0.9800255409045097 - k1 = fn(time, **filter_kwargs(state, fn)) + if k1 is None: # reuse k1 if available + k1 = fn(time, **filter_kwargs(state, fn)) k2 = fn(time + h * c2, **add_scaled(state, [k1], [0.161], h)) k3 = fn(time + h * c3, **add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h)) k4 = fn( @@ -169,42 +162,39 @@ def tsit5_step( + 2.324710524099774 * k6[key] ) - if use_adaptive_step_size: - k7 = fn(time + h, **filter_kwargs(new_state, fn)) + new_time = time + h + if not use_adaptive_step_size: + return new_state, new_time, None, 0.0 - err_state = {} - for key in state.keys(): - err_state[key] = h * ( - -0.00178001105222577714 * k1[key] - - 0.0008164344596567469 * k2[key] - + 0.007880878010261995 * k3[key] - - 0.1447110071732629 * k4[key] - + 0.5823571654525552 * k5[key] - - 0.45808210592918697 * k6[key] - + 0.015151515151515152 * k7[key] - ) + k7 = fn(time + h, **filter_kwargs(new_state, fn)) - err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) - err = keras.ops.max(err_norm) + err_state = {} + for key in state.keys(): + err_state[key] = h * ( + -0.00178001105222577714 * k1[key] + - 0.0008164344596567469 * k2[key] + + 0.007880878010261995 * k3[key] + - 0.1447110071732629 * k4[key] + + 0.5823571654525552 * k5[key] + - 0.45808210592918697 * k6[key] + + 0.015151515151515152 * k7[key] + ) - new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - else: - new_step_size = step_size + err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()]) + err = keras.ops.max(err_norm) - new_time = time + h - return new_state, new_time, new_step_size + return new_state, new_time, k7, err def integrate_fixed( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, steps: int, method: str, **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -227,7 +217,7 @@ def integrate_fixed( def body(_loop_var, _loop_state): _state, _time = _loop_state - _state, _time, _ = step_fn(_state, _time, step_size) + _state, _time, _, _ = step_fn(_state, _time, step_size) return _state, _time @@ -236,6 +226,37 @@ def body(_loop_var, _loop_state): return state +def integrate_scheduled( + fn: Callable, + state: StateDict, + steps: Tensor | np.ndarray, + method: str, + **kwargs, +) -> StateDict: + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case "tsit5": + step_fn = tsit5_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + + def body(_loop_var, _loop_state): + _time = steps[_loop_var] + step_size = steps[_loop_var + 1] - steps[_loop_var] + _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_state + + state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + return state + + def integrate_adaptive( fn: Callable, state: dict[str, ArrayLike], @@ -261,98 +282,106 @@ def integrate_adaptive( case other: raise TypeError(f"Invalid integration method: {other!r}") + tolerance = keras.ops.convert_to_tensor(kwargs.get("tolerance", 1e-6), dtype="float32") step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True) - def cond(_state, _time, _step_size, _step): - # while step < min_steps or time_remaining > 0 and step < max_steps + # Initial (conservative) step size guess + total_time = stop_time - start_time + step_size0 = keras.ops.convert_to_tensor(total_time / max_steps, dtype="float32") - # time remaining after the next step - time_remaining = keras.ops.abs(stop_time - (_time + _step_size)) + # Track step count as scalar tensor + step0 = keras.ops.convert_to_tensor(0.0, dtype="float32") + count_not_accepted = 0 + + # "First Same As Last" (FSAL) property + k1_0 = fn(start_time, **filter_kwargs(state, fn)) + + def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted): + time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (_time + _step_size)) + step_lt_min = keras.ops.less(_step, float(min_steps)) + step_lt_max = keras.ops.less(_step, float(max_steps)) return keras.ops.logical_or( - keras.ops.all(_step < min_steps), - keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.all(_step < max_steps)), + step_lt_min, + keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max), ) - def body(_state, _time, _step_size, _step): - _step = _step + 1 - - # time remaining after the next step - time_remaining = stop_time - (_time + _step_size) + def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): + # Time remaining from current point + time_remaining = stop_time - _time + # Per-step min/max step sizes (like original code) min_step_size = time_remaining / (max_steps - _step) max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0) - # reorder - min_step_size, max_step_size = ( - keras.ops.minimum(min_step_size, max_step_size), - keras.ops.maximum(min_step_size, max_step_size), - ) - - _state, _time, _step_size = step_fn( - _state, _time, _step_size, min_step_size=min_step_size, max_step_size=max_step_size + # Ensure ordering: min_step_size <= max_step_size + lower = keras.ops.minimum(min_step_size, max_step_size) + upper = keras.ops.maximum(min_step_size, max_step_size) + min_step_size = lower + max_step_size = upper + h = keras.ops.clip(_step_size, min_step_size, max_step_size) + + # Take one trial step + new_state, new_time, new_k1, err = step_fn( + state=_state, + time=_time, + step_size=h, + k1=_k1, ) - return _state, _time, _step_size, _step - - # select initial step size conservatively - step_size = (stop_time - start_time) / max_steps - - step = 0 - time = start_time - - state, time, step_size, step = keras.ops.while_loop(cond, body, [state, time, step_size, step]) - - # do the last step - step_size = stop_time - time - state, _, _ = step_fn(state, time, step_size) - step = step + 1 - - logging.debug("Finished integration after {} steps.", step) + new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) - return state + # Error control: reject if err > tolerance + too_big = keras.ops.greater(err, tolerance) + at_min = keras.ops.less_equal( + keras.ops.abs(h), + keras.ops.abs(min_step_size), + ) + accepted = keras.ops.logical_or(keras.ops.logical_not(too_big), at_min) + updated_state = keras.ops.cond(accepted, lambda: new_state, lambda: _state) + updated_time = keras.ops.cond(accepted, lambda: new_time, lambda: _time) + updated_k1 = keras.ops.cond(accepted, lambda: new_k1, lambda: _k1) -def integrate_scheduled( - fn: Callable, - state: dict[str, ArrayLike], - steps: Tensor | np.ndarray, - method: str, - **kwargs, -) -> dict[str, ArrayLike]: - match method: - case "euler": - step_fn = euler_step - case "rk45": - step_fn = rk45_step - case "tsit5": - step_fn = tsit5_step - case str() as name: - raise ValueError(f"Unknown integration method name: {name!r}") - case other: - raise TypeError(f"Invalid integration method: {other!r}") + # Step counter: increment only on accepted steps + updated_step = _step + keras.ops.where(accepted, 1.0, 0.0) + _count_not_accepted = _count_not_accepted + 1 if not accepted else _count_not_accepted - step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + # For the next iteration, always use the new suggested step size + return updated_state, updated_time, new_step_size, updated_step, updated_k1, _count_not_accepted - def body(_loop_var, _loop_state): - _time = steps[_loop_var] - step_size = steps[_loop_var + 1] - steps[_loop_var] + # Run the adaptive loop + state, time, step_size, step, k1, count_not_accepted = keras.ops.while_loop( + cond, + body, + [state, start_time, step_size0, step0, k1_0, count_not_accepted], + ) - _loop_state, _, _ = step_fn(_loop_state, _time, step_size) - return _loop_state + # Final step to hit stop_time exactly + time_diff = stop_time - time + time_remaining = keras.ops.sign(stop_time - start_time) * time_diff + if keras.ops.all(time_remaining > 0): + state, time, _, _ = step_fn( + state=state, + time=time, + step_size=time_diff, + k1=k1, + ) + step = step + 1.0 - state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + logging.debug(f"Finished integration after {step} steps with {count_not_accepted} rejected steps.") return state def integrate_scipy( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, scipy_kwargs: dict | None = None, **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: import scipy.integrate scipy_kwargs = scipy_kwargs or {} @@ -394,7 +423,7 @@ def scipy_wrapper_fn(time, x): def integrate( fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, start_time: ArrayLike | None = None, stop_time: ArrayLike | None = None, min_steps: int = 10, @@ -402,7 +431,7 @@ def integrate( steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, method: str = "rk45", **kwargs, -) -> dict[str, ArrayLike]: +) -> StateDict: if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: if start_time is None or stop_time is None: raise ValueError( @@ -429,7 +458,10 @@ def integrate( raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") -def adaptive_step_size_controller( +############ SDE Solvers ############# + + +def stochastic_adaptive_step_size_controller( state, drift, adaptive_factor: ArrayLike, @@ -470,19 +502,16 @@ def adaptive_step_size_controller( def euler_maruyama_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: dict[str, ArrayLike], + noise: StateDict, use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), adaptive_factor: float = 0.01, **kwargs, -) -> Union[ - Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike], - Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike, Dict[str, ArrayLike]], -]: +) -> Union[Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict]]: """ Performs a single Euler-Maruyama step for stochastic differential equations. @@ -509,7 +538,7 @@ def euler_maruyama_step( new_step_size = step_size if use_adaptive_step_size: sign_step = keras.ops.sign(step_size) - new_step_size = adaptive_step_size_controller( + new_step_size = stochastic_adaptive_step_size_controller( state=state, drift=drift, adaptive_factor=adaptive_factor, @@ -535,11 +564,11 @@ def euler_maruyama_step( def fast_adaptive_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: dict[str, ArrayLike], - last_state: dict[str, ArrayLike] = None, + noise: StateDict, + last_state: StateDict = None, use_adaptive_step_size: bool = True, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), @@ -549,8 +578,8 @@ def fast_adaptive_step( adapt_safety: float = 0.9, **kwargs, ) -> Union[ - Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike], - Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike, Dict[str, ArrayLike]], + Tuple[StateDict, ArrayLike, ArrayLike], + Tuple[StateDict, ArrayLike, ArrayLike, StateDict], ]: """ Performs a single adaptive step for stochastic differential equations based on [1]. @@ -683,12 +712,12 @@ def fast_adaptive_step( def sea_step( drift_fn: Callable, diffusion_fn: Callable, - state: dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: dict[str, ArrayLike], + noise: StateDict, **kwargs, -) -> Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: +) -> Tuple[StateDict, ArrayLike, ArrayLike]: """ Performs a single shifted Euler step for SDEs with additive noise [1]. @@ -740,13 +769,13 @@ def sea_step( def shark_step( drift_fn: Callable, diffusion_fn: Callable, - state: Dict[str, ArrayLike], + state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: Dict[str, ArrayLike], - noise_aux: Dict[str, ArrayLike], + noise: StateDict, + noise_aux: StateDict, **kwargs, -) -> Tuple[Dict[str, ArrayLike], ArrayLike, ArrayLike]: +) -> Tuple[StateDict, ArrayLike, ArrayLike]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion per step and has a strong order 1.5. @@ -858,7 +887,7 @@ def _apply_corrector( i: ArrayLike, corrector_steps: int, score_fn: Optional[Callable], - corrector_noise_history: Dict[str, ArrayLike], + corrector_noise_history: StateDict, step_size_factor: ArrayLike = 0.01, noise_schedule=None, ) -> StateDict: @@ -903,11 +932,11 @@ def integrate_stochastic_fixed( steps: int, min_step_size: ArrayLike, max_step_size: ArrayLike, - z_history: Dict[str, ArrayLike], - z_extra_history: Dict[str, ArrayLike], + z_history: StateDict, + z_extra_history: StateDict, score_fn: Optional[Callable], step_size_factor: ArrayLike, - corrector_noise_history: Dict[str, ArrayLike], + corrector_noise_history: StateDict, corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: @@ -969,11 +998,11 @@ def integrate_stochastic_adaptive( min_step_size: ArrayLike, max_step_size: ArrayLike, initial_step: ArrayLike, - z_history: Dict[str, ArrayLike], - z_extra_history: Dict[str, ArrayLike], + z_history: StateDict, + z_extra_history: StateDict, score_fn: Optional[Callable], step_size_factor: ArrayLike, - corrector_noise_history: Dict[str, ArrayLike], + corrector_noise_history: StateDict, corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: @@ -1027,8 +1056,9 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l return _i + 1, new_state, new_time, new_step, _counter, _new_current_state # Execute the adaptive loop - _, final_state, _, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) - logging.debug("Finished integration after {} steps.", final_counter) + _, final_state, final_time, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + + logging.debug(f"Finished integration after {final_counter} steps at {final_time}.") return final_state @@ -1037,10 +1067,10 @@ def integrate_langevin( start_time: ArrayLike, stop_time: ArrayLike, steps: int, - z_history: Dict[str, ArrayLike], + z_history: StateDict, score_fn: Callable, noise_schedule, - corrector_noise_history: Dict[str, ArrayLike], + corrector_noise_history: StateDict, step_size_factor: ArrayLike = 0.01, corrector_steps: int = 0, ) -> StateDict: diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 8a6de502c..3f83a3a2d 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -56,7 +56,7 @@ def fn(t, x): return {"x": keras.ops.convert_to_tensor([2.0 * t])} initial_state = {"x": keras.ops.convert_to_tensor([1.0])} - T_final = 2.0 + T_final = 1.0 num_steps = 100 analytical_result = 1.0 + T_final**2 @@ -72,6 +72,31 @@ def fn(t, x): np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) +@pytest.mark.parametrize( + "method, atol", [("euler", TOLERANCE_EULER), ("rk45", TOLERANCE_ADAPTIVE), ("tsit5", TOLERANCE_ADAPTIVE)] +) +def test_analytical_backward_integration(method, atol): + T_final = 1.0 + + def fn(t, x): + return {"x": keras.ops.convert_to_tensor([2.0 * t])} + + num_steps = 100 + analytical_result = 1.0 + initial_state = {"x": keras.ops.convert_to_tensor([1.0 + T_final**2])} + + result = integrate(fn, initial_state, start_time=T_final, stop_time=0.0, steps=num_steps, method=method)["x"] + if method == "euler": + result_adaptive = result + else: + result_adaptive = integrate( + fn, initial_state, start_time=T_final, stop_time=0.0, steps="adaptive", method=method, max_steps=1_000 + )["x"] + + np.testing.assert_allclose(result, analytical_result, atol=atol, rtol=0.1) + np.testing.assert_allclose(result_adaptive, analytical_result, atol=atol, rtol=0.1) + + @pytest.mark.parametrize( "method,use_adapt", [ From dd021bb782cf97ae359bd91de08d244495fec07e Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 30 Nov 2025 18:57:53 +0100 Subject: [PATCH 082/101] fix schedule test --- tests/test_utils/test_integrate.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 3f83a3a2d..8d4a06d7f 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -13,23 +13,20 @@ TOL_DET = 1e-3 -def test_scheduled_integration(): - import keras - from bayesflow.utils import integrate - +@pytest.mark.parametrize("method", ["euler", "rk45", "tsit5"]) +def test_scheduled_integration(method): def fn(t, x): return {"x": t**2} - steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0]) - approximate_result = 0.0 + 0.5**2 * 0.5 - result = integrate(fn, {"x": 0.0}, steps=steps)["x"] - assert result == approximate_result + def analytical_result(t): + return (t**3) / 3.0 + steps = keras.ops.arange(0.0, 1.0 + 1e-6, 0.01) + result = integrate(fn, {"x": 0.0}, steps=steps, method=method)["x"] + np.testing.assert_allclose(result, analytical_result(steps[-1]), atol=1e-1, rtol=1e-1) -def test_scipy_integration(): - import keras - from bayesflow.utils import integrate +def test_scipy_integration(): def fn(t, x): return {"x": keras.ops.exp(t)} From 1fe2c60212b90686cc0997ce1b87c4bab17b6fe3 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 1 Dec 2025 10:39:08 +0100 Subject: [PATCH 083/101] improved defaults --- bayesflow/utils/integrate.py | 84 ++++++++++++++++-------------- tests/test_utils/test_integrate.py | 13 ++--- 2 files changed, 53 insertions(+), 44 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index f0e049094..e299fef7b 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -19,7 +19,7 @@ DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"] -STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin", "fast_adaptive"] +STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] def euler_step( @@ -509,7 +509,6 @@ def euler_maruyama_step( use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), - adaptive_factor: float = 0.01, **kwargs, ) -> Union[Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict]]: """ @@ -525,7 +524,6 @@ def euler_maruyama_step( use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1). Returns: new_state: Updated state after one Euler-Maruyama step. @@ -541,7 +539,7 @@ def euler_maruyama_step( new_step_size = stochastic_adaptive_step_size_controller( state=state, drift=drift, - adaptive_factor=adaptive_factor, + adaptive_factor=max_step_size, min_step_size=min_step_size, max_step_size=max_step_size, ) @@ -561,7 +559,7 @@ def euler_maruyama_step( return new_state, time + new_step_size, new_step_size -def fast_adaptive_step( +def two_step_adaptive_step( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, @@ -572,8 +570,8 @@ def fast_adaptive_step( use_adaptive_step_size: bool = True, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), - e_abs: float = 0.01, - e_rel: float = 0.01, + e_rel: float = 0.1, + e_abs: float = None, r: float = 0.9, adapt_safety: float = 0.9, **kwargs, @@ -608,8 +606,8 @@ def fast_adaptive_step( use_adaptive_step_size: Whether to adapt step size. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. - e_abs: Absolute error tolerance. e_rel: Relative error tolerance. + e_abs: Absolute error tolerance. Default assumes standardized targets. r: Order of the method for step size adaptation. adapt_safety: Safety factor for step size adaptation. **kwargs: Additional arguments passed to drift_fn and diffusion_fn. @@ -650,6 +648,8 @@ def fast_adaptive_step( # Error estimation if use_adaptive_step_size: + if e_abs is None: + e_abs = 0.02576 # 1% of 99% CI of standardized unit variance # Check if we're at minimum step size - if so, force acceptance at_min_step = keras.ops.less_equal(step_size, min_step_size) @@ -709,13 +709,33 @@ def fast_adaptive_step( return state_heun, time_mid, step_size +def compute_levy_area( + state: StateDict, diffusion: StateDict, noise: StateDict, noise_aux: StateDict, step_size: ArrayLike +) -> StateDict: + step_size_abs = keras.ops.abs(step_size) + sqrt_step_size = keras.ops.sqrt(step_size_abs) + inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(step_size_abs)) + + # Build Lévy area H_k from w_k and Z_k + H = {} + for k in state.keys(): + if k in diffusion: + term1 = 0.5 * step_size_abs * noise[k] + term2 = 0.5 * step_size_abs * sqrt_step_size * inv_sqrt3 * noise_aux[k] + H[k] = term1 + term2 + else: + H[k] = keras.ops.zeros_like(state[k]) + return H + + def sea_step( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, time: ArrayLike, step_size: ArrayLike, - noise: StateDict, + noise: StateDict, # standard normals + noise_aux: StateDict, # standard normals **kwargs, ) -> Tuple[StateDict, ArrayLike, ArrayLike]: """ @@ -725,7 +745,7 @@ def sea_step( which improves the local error and the global error constant for additive noise. The scheme is - X_{n+1} = X_n + f(t_n, X_n + 0.5 * g(t_n) * ΔW_n) * h + g(t_n) * ΔW_n + X_{n+1} = X_n + f(t_n, X_n + g(t_n) * (0.5 * ΔW_n + ΔH_n) * h + g(t_n) * ΔW_n [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) Args: @@ -735,20 +755,23 @@ def sea_step( time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. + noise_aux: Mapping of variable names to auxiliary noise. Returns: new_state: Updated state after one SEA step. new_time: time + dt. """ - # Compute diffusion (assumed additive or weakly state dependent) + # Compute diffusion diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) - # Build shifted state: X_shift = X + 0.5 * g * ΔW + la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) + + # Build shifted state: X_shift = X + g * (0.5 * ΔW + ΔH) shifted_state = {} for key, x in state.items(): if key in diffusion: - shifted_state[key] = x + 0.5 * diffusion[key] * sqrt_step_size * noise[key] + shifted_state[key] = x + diffusion[key] * (0.5 * sqrt_step_size * noise[key] + la[key]) else: shifted_state[key] = x @@ -810,33 +833,18 @@ def shark_step( """ h = step_size t = time - - # Magnitude of the time step for stochastic scaling h_mag = keras.ops.abs(h) - # h_sign = keras.ops.sign(h) sqrt_h_mag = keras.ops.sqrt(h_mag) - inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(h_mag)) - # g(y_k) - g0 = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) + diffusion = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) - # Build H_k from w_k and Z_k - H = {} - for k in state.keys(): - if k in g0: - w_k = sqrt_h_mag * noise[k] - z_k = noise_aux[k] # standard normal - term1 = 0.5 * h_mag * w_k - term2 = 0.5 * h_mag * sqrt_h_mag * inv_sqrt3 * z_k - H[k] = term1 + term2 - else: - H[k] = keras.ops.zeros_like(state[k]) + la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) # === 1) shifted initial state === y_tilde_k = {} for k in state.keys(): - if k in g0: - y_tilde_k[k] = state[k] + g0[k] * H[k] + if k in diffusion: + y_tilde_k[k] = state[k] + diffusion[k] * la[k] else: y_tilde_k[k] = state[k] @@ -866,12 +874,12 @@ def shark_step( # stochastic parts sto1 = ( - g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * H[k]) + g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * la[k]) if k in g_tilde_k else keras.ops.zeros_like(det) ) sto2 = ( - g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * H[k]) + g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * la[k]) if k in g_tilde_mid else keras.ops.zeros_like(det) ) @@ -1154,7 +1162,7 @@ def integrate_stochastic( seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", - min_steps: int = 20, + min_steps: int = 10, max_steps: int = 10_000, score_fn: Callable = None, corrector_steps: int = 0, @@ -1229,8 +1237,8 @@ def integrate_stochastic( step_fn_raw = shark_step if is_adaptive: raise ValueError("SHARK SDE solver does not support adaptive steps.") - case "fast_adaptive": - step_fn_raw = fast_adaptive_step + case "two_step_adaptive": + step_fn_raw = two_step_adaptive_step case "langevin": if is_adaptive: raise ValueError("Langevin sampling does not support adaptive steps.") @@ -1269,7 +1277,7 @@ def integrate_stochastic( for key, val in state.items(): shape = keras.ops.shape(val) z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) - if method == "shark": + if method in ["sea", "shark"]: z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) if is_adaptive: diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 8d4a06d7f..ceaa5851c 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -101,8 +101,8 @@ def fn(t, x): ("euler_maruyama", True), ("sea", False), ("shark", False), - ("fast_adaptive", False), - ("fast_adaptive", True), + ("two_step_adaptive", False), + ("two_step_adaptive", True), ], ) def test_forward_additive_ou_weak_means_and_vars(method, use_adapt): @@ -167,8 +167,8 @@ def diffusion_fn(t, x): ("euler_maruyama", True), ("sea", False), ("shark", False), - ("fast_adaptive", False), - ("fast_adaptive", True), + ("two_step_adaptive", False), + ("two_step_adaptive", True), ], ) def test_backward_additive_ou_weak_means_and_vars(method, use_adapt): @@ -218,6 +218,7 @@ def diffusion_fn(t, x): seed=seed, method=method, max_steps=1_000, + min_steps=100, ) x_0 = np.array(out["x"]) @@ -235,8 +236,8 @@ def diffusion_fn(t, x): ("euler_maruyama", True), ("sea", False), ("shark", False), - ("fast_adaptive", False), - ("fast_adaptive", True), + ("two_step_adaptive", False), + ("two_step_adaptive", True), ], ) def test_zero_noise_reduces_to_deterministic(method, use_adapt): From a771e3230b76d079ebc3d257b752cc16c94c0b92 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 1 Dec 2025 10:43:43 +0100 Subject: [PATCH 084/101] improved defaults --- bayesflow/utils/integrate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index e299fef7b..34217908f 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -906,10 +906,6 @@ def _apply_corrector( if corrector_steps <= 0: return new_state - # Ensures score_fn and noise_schedule are present if needed, though checked in integrate_stochastic - if score_fn is None or noise_schedule is None: - return new_state # Should not happen if checks are passed - for j in range(corrector_steps): score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) _z_corr = {k: corrector_noise_history[k][i, j] for k in new_state.keys()} From a7adea2bcb9cb086139153630ff4acb3aff9e7e6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 1 Dec 2025 21:26:24 +0100 Subject: [PATCH 085/101] improved initial step size --- bayesflow/utils/integrate.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 34217908f..b8ed29afb 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -284,12 +284,7 @@ def integrate_adaptive( tolerance = keras.ops.convert_to_tensor(kwargs.get("tolerance", 1e-6), dtype="float32") step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True) - - # Initial (conservative) step size guess - total_time = stop_time - start_time - step_size0 = keras.ops.convert_to_tensor(total_time / max_steps, dtype="float32") - - # Track step count as scalar tensor + initial_step = (stop_time - start_time) / float(min_steps) step0 = keras.ops.convert_to_tensor(0.0, dtype="float32") count_not_accepted = 0 @@ -308,18 +303,10 @@ def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted): def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): # Time remaining from current point - time_remaining = stop_time - _time - - # Per-step min/max step sizes (like original code) + time_remaining = keras.ops.abs(stop_time - _time) min_step_size = time_remaining / (max_steps - _step) max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0) - - # Ensure ordering: min_step_size <= max_step_size - lower = keras.ops.minimum(min_step_size, max_step_size) - upper = keras.ops.maximum(min_step_size, max_step_size) - min_step_size = lower - max_step_size = upper - h = keras.ops.clip(_step_size, min_step_size, max_step_size) + h = keras.ops.sign(_step_size) * keras.ops.clip(keras.ops.abs(_step_size), min_step_size, max_step_size) # Take one trial step new_state, new_time, new_k1, err = step_fn( @@ -330,7 +317,9 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): ) new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0) - new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) + new_step_size = keras.ops.sign(new_step_size) * keras.ops.clip( + keras.ops.abs(new_step_size), min_step_size, max_step_size + ) # Error control: reject if err > tolerance too_big = keras.ops.greater(err, tolerance) @@ -355,7 +344,7 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): state, time, step_size, step, k1, count_not_accepted = keras.ops.while_loop( cond, body, - [state, start_time, step_size0, step0, k1_0, count_not_accepted], + [state, start_time, initial_step, step0, k1_0, count_not_accepted], ) # Final step to hit stop_time exactly From 08853fbe73fbd55c39b118d69f92a76706ded506 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 1 Dec 2025 21:59:58 +0100 Subject: [PATCH 086/101] improved initial step size --- bayesflow/utils/integrate.py | 9 ++++----- tests/test_utils/test_integrate.py | 1 - 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b8ed29afb..d947b54c0 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -640,7 +640,7 @@ def two_step_adaptive_step( if e_abs is None: e_abs = 0.02576 # 1% of 99% CI of standardized unit variance # Check if we're at minimum step size - if so, force acceptance - at_min_step = keras.ops.less_equal(step_size, min_step_size) + at_min_step = keras.ops.less_equal(keras.ops.abs(step_size), min_step_size) # Compute error tolerance for each component e_abs_tensor = keras.ops.cast(e_abs, dtype=keras.ops.dtype(list(state.values())[0])) @@ -681,9 +681,8 @@ def two_step_adaptive_step( new_step_candidate = step_size * adapt_factor # Clamp to valid range - sign_step = keras.ops.sign(step_size) - new_step_size = keras.ops.minimum(keras.ops.maximum(new_step_candidate, min_step_size), max_step_size) - new_step_size = sign_step * keras.ops.abs(new_step_size) + new_step_size = keras.ops.clip(keras.ops.abs(new_step_candidate), min_step_size, max_step_size) + new_step_size = keras.ops.sign(step_size) * new_step_size # Return appropriate state based on acceptance new_state = keras.ops.cond(accepted, lambda: state_heun, lambda: state) @@ -1147,7 +1146,7 @@ def integrate_stochastic( seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", - min_steps: int = 10, + min_steps: int = 50, max_steps: int = 10_000, score_fn: Callable = None, corrector_steps: int = 0, diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index ceaa5851c..78adae35f 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -218,7 +218,6 @@ def diffusion_fn(t, x): seed=seed, method=method, max_steps=1_000, - min_steps=100, ) x_0 = np.array(out["x"]) From 23a69eac0be53ed05bcb2e78849056af88ef8849 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 2 Dec 2025 12:57:48 +0100 Subject: [PATCH 087/101] check nan in integrate --- bayesflow/utils/integrate.py | 70 +++++++++++++++++++++++++----- tests/test_utils/test_integrate.py | 1 - 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index d947b54c0..2c027b3cc 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -22,6 +22,13 @@ STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] +def _check_all_nans(state: StateDict): + all_nans_flags = [] + for v in state.values(): + all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) + return keras.ops.all(keras.ops.stack(all_nans_flags)) + + def euler_step( fn: Callable, state: StateDict, @@ -218,7 +225,8 @@ def integrate_fixed( def body(_loop_var, _loop_state): _state, _time = _loop_state _state, _time, _, _ = step_fn(_state, _time, step_size) - + if _check_all_nans(_state): + raise RuntimeError(f"All values are NaNs in state during integration at {_time}.") return _state, _time state, time = keras.ops.fori_loop(0, steps, body, (state, time)) @@ -251,6 +259,9 @@ def body(_loop_var, _loop_state): _time = steps[_loop_var] step_size = steps[_loop_var + 1] - steps[_loop_var] _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) + + if _check_all_nans(_loop_state): + raise RuntimeError(f"All values are NaNs in state during integration at {_time}.") return _loop_state state = keras.ops.fori_loop(0, len(steps) - 1, body, state) @@ -296,10 +307,12 @@ def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted): step_lt_min = keras.ops.less(_step, float(min_steps)) step_lt_max = keras.ops.less(_step, float(max_steps)) - return keras.ops.logical_or( - step_lt_min, - keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max), + all_nans = _check_all_nans(_state) + + end_now = keras.ops.logical_or( + step_lt_min, keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max) ) + return keras.ops.logical_and(~all_nans, end_now) def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): # Time remaining from current point @@ -347,6 +360,9 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): [state, start_time, initial_step, step0, k1_0, count_not_accepted], ) + if _check_all_nans(state): + raise RuntimeError(f"All values are NaNs in state during integration at {time}.") + # Final step to hit stop_time exactly time_diff = stop_time - time time_remaining = keras.ops.sign(stop_time - start_time) * time_diff @@ -974,6 +990,9 @@ def body_fixed(_i, _loop_state): step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, ) + all_nans = _check_all_nans(new_state) + if all_nans: + raise RuntimeError(f"All values are NaNs in state during integration at {_current_time}.") return new_state, new_time, initial_step # Execute the fixed loop @@ -1004,9 +1023,10 @@ def integrate_stochastic_adaptive( initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0, state) def cond(i, current_state, current_time, current_step, counter, last_state): - # time remaining after the next step time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) - return keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) + all_nans = _check_all_nans(current_state) + end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) + return keras.ops.logical_and(~all_nans, end_now) def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _last_state): # Step Size Control @@ -1048,9 +1068,36 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l return _i + 1, new_state, new_time, new_step, _counter, _new_current_state # Execute the adaptive loop - _, final_state, final_time, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) + _, final_state, final_time, _, final_counter, final_k1 = keras.ops.while_loop( + cond, body_adaptive, initial_loop_state + ) + + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") + + # Final step to hit stop_time exactly + time_diff = stop_time - final_time + time_remaining = keras.ops.sign(stop_time - start_time) * time_diff + if keras.ops.all(time_remaining > 0): + noise_final = {k: z_history[k][-1] for k in final_state.keys()} + noise_extra_final = None + if len(z_extra_history) > 0: + noise_extra_final = {k: z_extra_history[k][-1] for k in final_state.keys()} + + final_state, _, _ = step_fn( + state=final_state, + time=final_time, + step_size=time_diff, + last_state=final_k1, + min_step_size=min_step_size, + max_step_size=time_remaining, + noise=noise_final, + noise_aux=noise_extra_final, + use_adaptive_step_size=False, + ) + final_counter = final_counter + 1 - logging.debug(f"Finished integration after {final_counter} steps at {final_time}.") + logging.debug(f"Finished integration after {final_counter}.") return final_state @@ -1094,13 +1141,12 @@ def integrate_langevin( def body(_i, loop_state): current_state, current_time = loop_state - t = current_time # score at current time - score = score_fn(t, **filter_kwargs(current_state, score_fn)) + score = score_fn(current_time, **filter_kwargs(current_state, score_fn)) # noise schedule - log_snr_t = noise_schedule.get_log_snr(t=t, training=False) + log_snr_t = noise_schedule.get_log_snr(t=current_time, training=False) _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) new_state: StateDict = {} @@ -1125,6 +1171,8 @@ def body(_i, loop_state): step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, ) + if _check_all_nans(new_state): + raise RuntimeError(f"All values are NaNs in state during integration at {current_time}.") return new_state, new_time diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 78adae35f..44a6fc60f 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -202,7 +202,6 @@ def diffusion_fn(t, x): # Start at time T with value x_T initial_state = {"x": keras.ops.ones((N,)) * x_T} steps = 200 if not use_adapt else "adaptive" - # Expected mean and variance at t=0 after integrating backward from t=T # For backward integration, the effective drift coefficient changes sign exp_mean = x_T * np.exp(-a * T) From b9e8c964cd31bd1b821bb039c3658ce63c75fbd7 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 2 Dec 2025 16:10:44 +0100 Subject: [PATCH 088/101] set default --- bayesflow/utils/integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 2c027b3cc..f0a36b06c 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -631,7 +631,7 @@ def two_step_adaptive_step( min_step_size=min_step_size, max_step_size=max_step_size, noise=noise, - use_adaptive_step_size=False, + use_adaptive_step_size=True, ) # Compute drift and diffusion at new state, but update from old state From be78470602ad64d91c43feb9110becd115136897 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 16:43:13 +0100 Subject: [PATCH 089/101] update model defaults --- bayesflow/networks/diffusion_model/diffusion_model.py | 6 +++--- bayesflow/networks/flow_matching/flow_matching.py | 4 ++-- bayesflow/utils/integrate.py | 8 ++++++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 0e38ea4f1..8cbce1e87 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -40,13 +40,13 @@ class DiffusionModel(InferenceNetwork): "activation": "mish", "kernel_initializer": "he_normal", "residual": True, - "dropout": 0.0, + "dropout": 0.05, "spectral_normalization": False, } INTEGRATE_DEFAULT_CONFIG = { - "method": "rk45", - "steps": 100, + "method": "two_step_adaptive", + "steps": "adaptive", } def __init__( diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index fa74089a4..485cbbd9e 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -53,8 +53,8 @@ class FlowMatching(InferenceNetwork): } INTEGRATE_DEFAULT_CONFIG = { - "method": "rk45", - "steps": 100, + "method": "tsit5", + "steps": "adaptive", } def __init__( diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index f0a36b06c..16bca18a2 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -65,7 +65,9 @@ def rk45_step( use_adaptive_step_size: bool = True, ) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: """ - Dormand-Prince 5(4) method with embedded error estimation. + Dormand-Prince 5(4) method with embedded error estimation [1]. + + Dormand (1996), Numerical Methods for Differential Equations: A Computational Approach """ h = step_size @@ -124,7 +126,9 @@ def tsit5_step( use_adaptive_step_size: bool = True, ) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]: """ - Implements a single step of the Tsitouras 5/4 Runge-Kutta method. + Implements a single step of the Tsitouras 5/4 Runge-Kutta method [1]. + + [1] Tsitouras (2011), Runge--Kutta pairs of order 5(4) satisfying only the first column simplifying assumption """ h = step_size From ad276063e5bcddf0c6bd03c85355b708e4b478e5 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 17:46:29 +0100 Subject: [PATCH 090/101] make loop jax compatible --- bayesflow/utils/integrate.py | 109 ++++++++++++++++++++--------------- 1 file changed, 64 insertions(+), 45 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 16bca18a2..db0ebc813 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -224,17 +224,22 @@ def integrate_fixed( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) step_size = (stop_time - start_time) / steps - time = start_time - - def body(_loop_var, _loop_state): - _state, _time = _loop_state - _state, _time, _, _ = step_fn(_state, _time, step_size) - if _check_all_nans(_state): - raise RuntimeError(f"All values are NaNs in state during integration at {_time}.") - return _state, _time + def cond(_loop_var, _loop_state, _loop_time): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) - state, time = keras.ops.fori_loop(0, steps, body, (state, time)) + def body(_loop_var, _loop_state, _loop_time): + _loop_state, _loop_time, _, _ = step_fn(_loop_state, _loop_time, step_size) + return _loop_var + 1, _loop_state, _loop_time + _, state, _ = keras.ops.while_loop( + cond, + body, + [0, state, start_time], + ) + if _check_all_nans(state): + raise RuntimeError("All values are NaNs in state during integration.") return state @@ -259,16 +264,25 @@ def integrate_scheduled( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + def cond(_loop_var, _loop_state): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, len(steps) - 1) + return keras.ops.logical_and(~all_nans, end_now) + def body(_loop_var, _loop_state): _time = steps[_loop_var] step_size = steps[_loop_var + 1] - steps[_loop_var] _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_var + 1, _loop_state - if _check_all_nans(_loop_state): - raise RuntimeError(f"All values are NaNs in state during integration at {_time}.") - return _loop_state + _, state = keras.ops.while_loop( + cond, + body, + [0, state], + ) - state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + if _check_all_nans(state): + raise RuntimeError("All values are NaNs in state during integration.") return state @@ -635,7 +649,7 @@ def two_step_adaptive_step( min_step_size=min_step_size, max_step_size=max_step_size, noise=noise, - use_adaptive_step_size=True, + use_adaptive_step_size=False, ) # Compute drift and diffusion at new state, but update from old state @@ -957,9 +971,12 @@ def integrate_stochastic_fixed( """ initial_step = (stop_time - start_time) / float(steps) - def body_fixed(_i, _loop_state): - _current_state, _current_time, _current_step = _loop_state + def cond(_loop_var, _loop_state, _loop_time, _loop_step): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) + def body(_i, _current_state, _current_time, _current_step): # Determine step size: either the constant size or the remainder to reach stop_time remaining = keras.ops.abs(stop_time - _current_time) sign = keras.ops.sign(_current_step) @@ -994,13 +1011,16 @@ def body_fixed(_i, _loop_state): step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, ) - all_nans = _check_all_nans(new_state) - if all_nans: - raise RuntimeError(f"All values are NaNs in state during integration at {_current_time}.") - return new_state, new_time, initial_step + return _i + 1, new_state, new_time, initial_step + + _, final_state, final_time, _ = keras.ops.while_loop( + cond, + body, + [0, state, start_time, initial_step], + ) + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") - # Execute the fixed loop - final_state, final_time, _ = keras.ops.fori_loop(0, steps, body_fixed, (state, start_time, initial_step)) return final_state @@ -1024,22 +1044,21 @@ def integrate_stochastic_adaptive( """ Performs adaptive-step SDE integration. """ - initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0, state) + initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, state) - def cond(i, current_state, current_time, current_step, counter, last_state): + def cond(i, current_state, current_time, current_step, last_state): time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) all_nans = _check_all_nans(current_state) end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) return keras.ops.logical_and(~all_nans, end_now) - def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _last_state): + def body_adaptive(_i, _current_state, _current_time, _current_step, _last_state): # Step Size Control remaining = keras.ops.abs(stop_time - _current_time) sign = keras.ops.sign(_current_step) # Ensure the next step does not overshoot the stop_time dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) dt = sign * dt_mag - _counter += 1 _noise_i = {k: z_history[k][_i] for k in _current_state.keys()} _noise_extra_i = None @@ -1069,12 +1088,10 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l corrector_noise_history=corrector_noise_history, ) - return _i + 1, new_state, new_time, new_step, _counter, _new_current_state + return _i + 1, new_state, new_time, new_step, _new_current_state # Execute the adaptive loop - _, final_state, final_time, _, final_counter, final_k1 = keras.ops.while_loop( - cond, body_adaptive, initial_loop_state - ) + final_counter, final_state, final_time, _, final_k1 = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) if _check_all_nans(final_state): raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") @@ -1143,27 +1160,30 @@ def integrate_langevin( dt = (stop_time - start_time) / float(steps) effective_factor = step_size_factor * 100 / np.sqrt(steps) - def body(_i, loop_state): - current_state, current_time = loop_state + def cond(_loop_var, _loop_state, _loop_time): + all_nans = _check_all_nans(_loop_state) + end_now = keras.ops.less(_loop_var, steps) + return keras.ops.logical_and(~all_nans, end_now) + def body(_i, _loop_state, _loop_time): # score at current time - score = score_fn(current_time, **filter_kwargs(current_state, score_fn)) + score = score_fn(_loop_time, **filter_kwargs(_loop_state, score_fn)) # noise schedule - log_snr_t = noise_schedule.get_log_snr(t=current_time, training=False) + log_snr_t = noise_schedule.get_log_snr(t=_loop_time, training=False) _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) new_state: StateDict = {} - for k in current_state.keys(): + for k in _loop_state.keys(): s_k = score.get(k, None) if s_k is None: - new_state[k] = current_state[k] + new_state[k] = _loop_state[k] continue e = effective_factor * sigma_t**2 - new_state[k] = current_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i] + new_state[k] = _loop_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i] - new_time = current_time + dt + new_time = _loop_time + dt new_state = _apply_corrector( new_state=new_state, @@ -1175,17 +1195,16 @@ def body(_i, loop_state): step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, ) - if _check_all_nans(new_state): - raise RuntimeError(f"All values are NaNs in state during integration at {current_time}.") - return new_state, new_time + return _i + 1, new_state, new_time - final_state, _ = keras.ops.fori_loop( - 0, - steps, + _, final_state, final_time = keras.ops.while_loop( + cond, body, - (state, start_time), + (0, state, start_time), ) + if _check_all_nans(final_state): + raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") return final_state From 5a1a3fa4a90395e21657b93cf69b138270c43ea1 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 18:05:48 +0100 Subject: [PATCH 091/101] filter kwargs --- bayesflow/utils/integrate.py | 49 +++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index db0ebc813..41af3ecc1 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -73,16 +73,21 @@ def rk45_step( if k1 is None: # reuse k1 if available k1 = fn(time, **filter_kwargs(state, fn)) - k2 = fn(time + h * (1 / 5), **add_scaled(state, [k1], [1 / 5], h)) - k3 = fn(time + h * (3 / 10), **add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h)) - k4 = fn(time + h * (4 / 5), **add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h)) + k2 = fn(time + h * (1 / 5), **filter_kwargs(add_scaled(state, [k1], [1 / 5], h), fn)) + k3 = fn(time + h * (3 / 10), **filter_kwargs(add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h), fn)) + k4 = fn(time + h * (4 / 5), **filter_kwargs(add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h), fn)) k5 = fn( time + h * (8 / 9), - **add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), + **filter_kwargs( + add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), fn + ), ) k6 = fn( time + h, - **add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), + **filter_kwargs( + add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h), + fn, + ), ) # 5th order solution @@ -140,24 +145,38 @@ def tsit5_step( if k1 is None: # reuse k1 if available k1 = fn(time, **filter_kwargs(state, fn)) - k2 = fn(time + h * c2, **add_scaled(state, [k1], [0.161], h)) - k3 = fn(time + h * c3, **add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h)) + k2 = fn(time + h * c2, **filter_kwargs(add_scaled(state, [k1], [0.161], h), fn)) + k3 = fn( + time + h * c3, **filter_kwargs(add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h), fn) + ) k4 = fn( - time + h * c4, **add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h) + time + h * c4, + **filter_kwargs( + add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h), fn + ), ) k5 = fn( time + h * c5, - **add_scaled( - state, [k1, k2, k3, k4], [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], h + **filter_kwargs( + add_scaled( + state, + [k1, k2, k3, k4], + [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525], + h, + ), + fn, ), ) k6 = fn( time + h, - **add_scaled( - state, - [k1, k2, k3, k4, k5], - [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838], - h, + **filter_kwargs( + add_scaled( + state, + [k1, k2, k3, k4, k5], + [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838], + h, + ), + fn, ), ) From e5857083a3470313ccb9d2e8edcaa6ab80693e88 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 18:35:13 +0100 Subject: [PATCH 092/101] fix density computation --- .../diffusion_model/diffusion_model.py | 12 +++++ .../networks/flow_matching/flow_matching.py | 23 +++++++-- bayesflow/utils/integrate.py | 50 +++++++------------ 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 8cbce1e87..c2f5b5fde 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -413,6 +413,12 @@ def _forward( raise ValueError("Stochastic methods are not supported for forward integration.") if density: + if integrate_kwargs["steps"] == "adaptive": + logging.warning( + "Using adaptive integration for density estimation can lead to " + "problems with autodiff. Switching to 200 fixed steps instead." + ) + integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) @@ -461,6 +467,12 @@ def _inverse( if density: if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") + if integrate_kwargs["steps"] == "adaptive": + logging.warning( + "Using adaptive integration for density estimation can lead to " + "problems with autodiff. Switching to 200 fixed steps instead." + ) + integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 485cbbd9e..ea581a7c5 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Sequence import keras @@ -236,14 +237,21 @@ def f(x): def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = self.integrate_kwargs | kwargs if density: + if integrate_kwargs["steps"] == "adaptive": + logging.warning( + "Using adaptive integration for density estimation can lead to " + "problems with autodiff. Switching to 200 fixed steps instead." + ) + integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))} - state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs) z = state["xz"] log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1) @@ -254,7 +262,7 @@ def deltas(time, xz): return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)} state = {"xz": x} - state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs) z = state["xz"] @@ -263,14 +271,21 @@ def deltas(time, xz): def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = self.integrate_kwargs | kwargs if density: + if integrate_kwargs["steps"] == "adaptive": + logging.warning( + "Using adaptive integration for density estimation can lead to " + "problems with autodiff. Switching to 200 fixed steps instead." + ) + integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))} - state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs) x = state["xz"] log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1) @@ -281,7 +296,7 @@ def deltas(time, xz): return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)} state = {"xz": z} - state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs)) + state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs) x = state["xz"] diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 41af3ecc1..b00a45325 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -22,13 +22,6 @@ STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] -def _check_all_nans(state: StateDict): - all_nans_flags = [] - for v in state.values(): - all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) - return keras.ops.all(keras.ops.stack(all_nans_flags)) - - def euler_step( fn: Callable, state: StateDict, @@ -243,22 +236,17 @@ def integrate_fixed( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) step_size = (stop_time - start_time) / steps - def cond(_loop_var, _loop_state, _loop_time): - all_nans = _check_all_nans(_loop_state) - end_now = keras.ops.less(_loop_var, steps) - return keras.ops.logical_and(~all_nans, end_now) - - def body(_loop_var, _loop_state, _loop_time): - _loop_state, _loop_time, _, _ = step_fn(_loop_state, _loop_time, step_size) - return _loop_var + 1, _loop_state, _loop_time + def body(_loop_var, _loop_state): + _state, _time = _loop_state + _state, _time, _, _ = step_fn(_state, _time, step_size) + return _state, _time - _, state, _ = keras.ops.while_loop( - cond, + state, _ = keras.ops.fori_loop( + 0, + steps, body, - [0, state, start_time], + (state, start_time), ) - if _check_all_nans(state): - raise RuntimeError("All values are NaNs in state during integration.") return state @@ -283,25 +271,18 @@ def integrate_scheduled( step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) - def cond(_loop_var, _loop_state): - all_nans = _check_all_nans(_loop_state) - end_now = keras.ops.less(_loop_var, len(steps) - 1) - return keras.ops.logical_and(~all_nans, end_now) - def body(_loop_var, _loop_state): _time = steps[_loop_var] step_size = steps[_loop_var + 1] - steps[_loop_var] _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size) - return _loop_var + 1, _loop_state + return _loop_state - _, state = keras.ops.while_loop( - cond, + state = keras.ops.fori_loop( + 0, + keras.ops.shape(steps)[0] - 1, body, - [0, state], + state, ) - - if _check_all_nans(state): - raise RuntimeError("All values are NaNs in state during integration.") return state @@ -501,6 +482,11 @@ def integrate( ############ SDE Solvers ############# +def _check_all_nans(state: StateDict): + all_nans_flags = [] + for v in state.values(): + all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) + return keras.ops.all(keras.ops.stack(all_nans_flags)) def stochastic_adaptive_step_size_controller( From ac07af288a4162cd7a79340167c173bcf8f1a875 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 19:00:27 +0100 Subject: [PATCH 093/101] fix jax all nans --- bayesflow/utils/integrate.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b00a45325..a208cb8e1 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -11,6 +11,7 @@ from bayesflow.types import Tensor from bayesflow.utils import filter_kwargs from bayesflow.utils.logging import warning +from keras import backend as K from . import logging @@ -22,6 +23,15 @@ STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"] +def _check_all_nans(state: StateDict): + if K.backend() == "jax": + return False # JAX backend does not support checks of the state variables + all_nans_flags = [] + for v in state.values(): + all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) + return keras.ops.all(keras.ops.stack(all_nans_flags)) + + def euler_step( fn: Callable, state: StateDict, @@ -482,11 +492,6 @@ def integrate( ############ SDE Solvers ############# -def _check_all_nans(state: StateDict): - all_nans_flags = [] - for v in state.values(): - all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) - return keras.ops.all(keras.ops.stack(all_nans_flags)) def stochastic_adaptive_step_size_controller( From 4c9d44b6540ffc373046989d5cc716e7790a1cfb Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 21:16:42 +0100 Subject: [PATCH 094/101] fix jax all nans --- bayesflow/utils/integrate.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a208cb8e1..3f5581894 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -11,7 +11,6 @@ from bayesflow.types import Tensor from bayesflow.utils import filter_kwargs from bayesflow.utils.logging import warning -from keras import backend as K from . import logging @@ -24,8 +23,6 @@ def _check_all_nans(state: StateDict): - if K.backend() == "jax": - return False # JAX backend does not support checks of the state variables all_nans_flags = [] for v in state.values(): all_nans_flags.append(keras.ops.all(keras.ops.isnan(v))) @@ -376,7 +373,7 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted): # Step counter: increment only on accepted steps updated_step = _step + keras.ops.where(accepted, 1.0, 0.0) - _count_not_accepted = _count_not_accepted + 1 if not accepted else _count_not_accepted + _count_not_accepted = _count_not_accepted + keras.ops.where(accepted, 1.0, 0.0) # For the next iteration, always use the new suggested step size return updated_state, updated_time, new_step_size, updated_step, updated_k1, _count_not_accepted From 87297455b8be90dbe9bad92cfd710e9dde549600 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 21:17:58 +0100 Subject: [PATCH 095/101] fix jax all nans --- bayesflow/utils/integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 3f5581894..03c3aff8a 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -12,7 +12,7 @@ from bayesflow.utils import filter_kwargs from bayesflow.utils.logging import warning -from . import logging +import logging ArrayLike = int | float | Tensor StateDict = Dict[str, ArrayLike] From adefa7b5dcece15c61cdd714f87c67b515fcf212 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 3 Dec 2025 22:34:35 +0100 Subject: [PATCH 096/101] relax tols in tests --- tests/test_utils/test_integrate.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 44a6fc60f..b3214c229 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -8,9 +8,8 @@ TOLERANCE_EULER = 1e-3 # Euler with fixed steps requires a larger tolerance # tolerances for SDE tests -TOL_MEAN = 3e-2 +TOL_MEAN = 5e-2 TOL_VAR = 5e-2 -TOL_DET = 1e-3 @pytest.mark.parametrize("method", ["euler", "rk45", "tsit5"]) @@ -123,7 +122,6 @@ def test_forward_additive_ou_weak_means_and_vars(method, use_adapt): x_0 = 1.2 # initial condition at time 0 T = 1.0 - # batch of trajectories N = 10000 seed = keras.random.SeedGenerator(42) @@ -149,15 +147,14 @@ def diffusion_fn(t, x): steps=steps, seed=seed, method=method, - max_steps=1_000, ) x_T = np.array(out["x"]) emp_mean = float(x_T.mean()) emp_var = float(x_T.var()) - np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) - np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR) @pytest.mark.parametrize( @@ -188,8 +185,7 @@ def test_backward_additive_ou_weak_means_and_vars(method, use_adapt): x_T = 1.2 # initial condition at time T T = 1.0 - # batch of trajectories - N = 10000 # large enough to control sampling error + N = 10000 seed = keras.random.SeedGenerator(42) def drift_fn(t, x): @@ -216,15 +212,14 @@ def diffusion_fn(t, x): steps=steps, seed=seed, method=method, - max_steps=1_000, ) x_0 = np.array(out["x"]) emp_mean = float(x_0.mean()) emp_var = float(x_0.var()) - np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN, rtol=0.0) - np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR, rtol=0.0) + np.testing.assert_allclose(emp_mean, exp_mean, atol=TOL_MEAN) + np.testing.assert_allclose(emp_var, exp_var, atol=TOL_VAR) @pytest.mark.parametrize( @@ -270,7 +265,7 @@ def diffusion_fn(t, x): )["x"] exact = x0 * np.exp(a * T) - np.testing.assert_allclose(np.array(out).mean(), exact, atol=TOL_DET, rtol=0.1) + np.testing.assert_allclose(np.array(out).mean(), exact, atol=1e-3, rtol=0.1) @pytest.mark.parametrize("steps", [500]) From f9823f8c71307cb3d69d09a6e8772b16bcff3d59 Mon Sep 17 00:00:00 2001 From: arrjon Date: Thu, 4 Dec 2025 16:09:23 +0100 Subject: [PATCH 097/101] enable density computation with adaptive step size solvers --- .../networks/diffusion_model/diffusion_model.py | 12 ------------ bayesflow/networks/flow_matching/flow_matching.py | 13 ------------- 2 files changed, 25 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index c2f5b5fde..8cbce1e87 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -413,12 +413,6 @@ def _forward( raise ValueError("Stochastic methods are not supported for forward integration.") if density: - if integrate_kwargs["steps"] == "adaptive": - logging.warning( - "Using adaptive integration for density estimation can lead to " - "problems with autodiff. Switching to 200 fixed steps instead." - ) - integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) @@ -467,12 +461,6 @@ def _inverse( if density: if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") - if integrate_kwargs["steps"] == "adaptive": - logging.warning( - "Using adaptive integration for density estimation can lead to " - "problems with autodiff. Switching to 200 fixed steps instead." - ) - integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index ea581a7c5..808cee681 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -1,4 +1,3 @@ -import logging from collections.abc import Sequence import keras @@ -239,12 +238,6 @@ def _forward( ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = self.integrate_kwargs | kwargs if density: - if integrate_kwargs["steps"] == "adaptive": - logging.warning( - "Using adaptive integration for density estimation can lead to " - "problems with autodiff. Switching to 200 fixed steps instead." - ) - integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) @@ -273,12 +266,6 @@ def _inverse( ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = self.integrate_kwargs | kwargs if density: - if integrate_kwargs["steps"] == "adaptive": - logging.warning( - "Using adaptive integration for density estimation can lead to " - "problems with autodiff. Switching to 200 fixed steps instead." - ) - integrate_kwargs["steps"] = 200 def deltas(time, xz): v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) From 1129b1c4c5c528d91de5aa9e9b83596248167dd2 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 10 Dec 2025 21:38:38 +0100 Subject: [PATCH 098/101] merge new samplers --- .../diffusion_model/compositional_diffusion_model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py index 171184314..2ba1197f7 100644 --- a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -32,14 +32,13 @@ class CompositionalDiffusionModel(DiffusionModel): "activation": "mish", "kernel_initializer": "he_normal", "residual": True, - "dropout": 0.0, + "dropout": 0.05, "spectral_normalization": False, } INTEGRATE_DEFAULT_CONFIG = { - "method": "euler_maruyama", - "corrector_steps": 1, - "steps": 100, + "method": "two_step_adaptive", + "steps": "adaptive", } def __init__( From 77e950c23bf3832c1b518d3d3646305f48d66d7a Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 10 Dec 2025 22:09:54 +0100 Subject: [PATCH 099/101] merge new samplers --- .../compositional_diffusion_model.py | 38 ++++++++----------- bayesflow/workflows/basic_workflow.py | 2 +- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py index 2ba1197f7..3d26639ab 100644 --- a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -5,11 +5,7 @@ from keras import ops from bayesflow.types import Tensor -from bayesflow.utils import ( - expand_right_as, - integrate, - integrate_stochastic, -) +from bayesflow.utils import expand_right_as, integrate, integrate_stochastic, STOCHASTIC_METHODS from bayesflow.utils.serialization import serializable from .diffusion_model import DiffusionModel from .schedules.noise_schedule import NoiseSchedule @@ -318,7 +314,7 @@ def _inverse_compositional( z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) if density: - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in STOCHASTIC_METHODS: raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): @@ -346,7 +342,7 @@ def deltas(time, xz): state = {"xz": z} - if integrate_kwargs["method"] == "euler_maruyama": + if integrate_kwargs["method"] in STOCHASTIC_METHODS: def deltas(time, xz): return { @@ -365,20 +361,19 @@ def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} score_fn = None - if "corrector_steps" in integrate_kwargs: - if integrate_kwargs["corrector_steps"] > 0: - - def score_fn(time, xz): - return { - "xz": self.compositional_score( - xz, - time=time, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } + if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin": + + def score_fn(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } state = integrate_stochastic( drift_fn=deltas, @@ -390,7 +385,6 @@ def score_fn(time, xz): **integrate_kwargs, ) else: - integrate_kwargs.pop("corrector_steps", None) def deltas(time, xz): return { diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 75763ecec..a4216066a 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -291,7 +291,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], - compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], **kwargs, ) -> dict[str, np.ndarray]: """ From 322fb4b81643e273970bfb1ecaf4943062bea3c4 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 10 Dec 2025 22:12:25 +0100 Subject: [PATCH 100/101] type hint --- bayesflow/workflows/basic_workflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index a4216066a..3c8953804 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -291,7 +291,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], - compute_prior_score: Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], + compute_prior_score: Callable[Mapping[str, np.ndarray], Mapping[str, np.ndarray]], **kwargs, ) -> dict[str, np.ndarray]: """ @@ -307,7 +307,7 @@ def compositional_sample( NumPy arrays containing the adapted simulated variables. Keys used as summary or inference conditions during training should be present. Should have shape (n_datasets, n_compositional_conditions, ...). - compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + compute_prior_score : Callable[dict[str, np.ndarray], dict[str, np.ndarray]] A function that computes the log probability of samples under the prior distribution. **kwargs : dict, optional Additional keyword arguments passed to the approximator's sampling function. From 4753236d3e233a1a39057cd6a0c5b9593f7c40e8 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 10 Dec 2025 22:13:14 +0100 Subject: [PATCH 101/101] type hint --- bayesflow/workflows/basic_workflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 3c8953804..6500125a7 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -291,7 +291,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], - compute_prior_score: Callable[Mapping[str, np.ndarray], Mapping[str, np.ndarray]], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], **kwargs, ) -> dict[str, np.ndarray]: """ @@ -307,7 +307,7 @@ def compositional_sample( NumPy arrays containing the adapted simulated variables. Keys used as summary or inference conditions during training should be present. Should have shape (n_datasets, n_compositional_conditions, ...). - compute_prior_score : Callable[dict[str, np.ndarray], dict[str, np.ndarray]] + compute_prior_score : Callable[[dict[str, np.ndarray]], dict[str, np.ndarray]] A function that computes the log probability of samples under the prior distribution. **kwargs : dict, optional Additional keyword arguments passed to the approximator's sampling function.