From 038b7c91402b75fb3b988730db1844a2ce18a625 Mon Sep 17 00:00:00 2001 From: Miguel de la Varga Date: Thu, 20 Nov 2025 15:34:27 +0100 Subject: [PATCH 1/4] [ENH] Improve tensor handling and device management in backend_tensor.py - Updated tensor creation to use pinned memory and non-blocking transfer for improved GPU performance. - Introduced `_zeros`, `_ones`, and `_eye` wrapper functions for better consistency in tensor initialization on the specified device. - Refined the `_wrap_pytorch_functions` method to streamline tensor operations and ensure compatibility with the device settings. - Enabled stricter CUDA checks by updating conditions for GPU availability. --- gempy_engine/core/backend_tensor.py | 26 ++++++++++++++++++++------ gempy_engine/core/utils.py | 12 ++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/gempy_engine/core/backend_tensor.py b/gempy_engine/core/backend_tensor.py index c79263c..9e4d310 100644 --- a/gempy_engine/core/backend_tensor.py +++ b/gempy_engine/core/backend_tensor.py @@ -128,13 +128,13 @@ def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool = # Check if CUDA is available if not pytorch_copy.cuda.is_available(): raise RuntimeError("GPU requested but CUDA is not available in PyTorch") - if False: # * (Miguel) this slows down the code a lot + if True: # * (Miguel) this slows down the code a lot # Check if CUDA device is available if not pytorch_copy.cuda.device_count(): raise RuntimeError("GPU requested but no CUDA device is available in PyTorch") # Set default device to CUDA cls.device = pytorch_copy.device("cuda") - pytorch_copy.set_default_device("cuda") + # pytorch_copy.set_default_device("cuda") print(f"GPU enabled. Using device: {cls.device}") print(f"GPU device count: {pytorch_copy.cuda.device_count()}") print(f"Current GPU device: {pytorch_copy.cuda.current_device()}") @@ -166,7 +166,7 @@ def describe_conf(cls): @classmethod def _wrap_pytorch_functions(cls): - from torch import sum, repeat_interleave, isclose + from torch import sum, repeat_interleave, isclose, zeros as torch_zeros, eye as torch_eye, ones as torch_ones import torch def _sum(tensor, axis=None, dtype=None, keepdims=False): @@ -192,7 +192,9 @@ def _array(array_like, dtype=None): if not array_like.flags.c_contiguous: array_like = numpy.ascontiguousarray(array_like) - return torch.tensor(array_like, dtype=dtype) + + # return torch.tensor(array_like, dtype=dtype) + return torch.tensor(array_like, dtype=dtype).pin_memory().to(cls.device, non_blocking=True) def _concatenate(tensors, axis=0, dtype=None): # Switch if tensor is numpy array or a torch tensor @@ -225,7 +227,7 @@ def _packbits(tensor, axis=None, bitorder="big"): # Pad with zeros if we don't have multiples of 8 rows if n_rows % 8 != 0: padding_rows = 8 - (n_rows % 8) - padding = torch.zeros(padding_rows, n_cols, dtype=torch.uint8, device=tensor.device) + padding = torch_zeros((padding_rows, n_cols), dtype=torch.uint8, device=tensor.device) tensor = torch.cat([tensor, padding], dim=0) # Reshape to group every 8 rows together: (n_output_rows, 8, n_cols) @@ -254,7 +256,7 @@ def _packbits(tensor, axis=None, bitorder="big"): # Pad with zeros if needed if n_cols % 8 != 0: padding_cols = 8 - (n_cols % 8) - padding = torch.zeros(n_rows, padding_cols, dtype=torch.uint8, device=tensor.device) + padding = torch_zeros((n_rows, padding_cols), dtype=torch.uint8, device=tensor.device) tensor = torch.cat([tensor, padding], dim=1) # Reshape: (n_rows, n_output_cols, 8) @@ -294,6 +296,15 @@ def _fill_diagonal(tensor, value): diagonal_indices = torch.arange(min(tensor.size(0), tensor.size(1))) tensor[diagonal_indices, diagonal_indices] = value return tensor + + def _zeros(shape, dtype=None, device=None): + return torch_zeros(shape, dtype=dtype, device=cls.device) + + def _ones(shape, dtype=None, device=None): + return torch_ones(shape, dtype=dtype, device=cls.device) + + def _eye(n, dtype=None, device=None): + return torch_eye(n, dtype=dtype, device=cls.device) cls.tfnp.sum = _sum cls.tfnp.repeat = _repeat @@ -324,6 +335,9 @@ def _fill_diagonal(tensor, value): atol=atol, equal_nan=equal_nan ) + cls.tfnp.zeros = _zeros + cls.tfnp.eye = _eye + cls.tfnp.ones = _ones @classmethod def _wrap_pykeops_functions(cls): diff --git a/gempy_engine/core/utils.py b/gempy_engine/core/utils.py index a7e4398..7d14c24 100644 --- a/gempy_engine/core/utils.py +++ b/gempy_engine/core/utils.py @@ -4,8 +4,8 @@ from ..core.backend_tensor import BackendTensor -def cast_type_inplace(data_instance: Any, requires_grad:bool = False): - """Converts all numpy arrays to the global dtype""" +def cast_type_inplace(data_instance: Any, requires_grad: bool = False): + """Converts all numpy arrays to the global dtype""" for key, val in data_instance.__dict__.items(): if type(val) != np.ndarray: continue match BackendTensor.engine_backend: @@ -15,11 +15,11 @@ def cast_type_inplace(data_instance: Any, requires_grad:bool = False): # tensor = BackendTensor.t.from_numpy(val.astype(BackendTensor.dtype)) # if (BackendTensor.use_gpu): # tensor = tensor.cuda() - - tensor = BackendTensor.tfnp.array(val, dtype=BackendTensor.dtype_obj) - tensor.requires_grad = requires_grad + import torch + if isinstance(val, torch.Tensor): + continue + tensor = torch.tensor(val, dtype=BackendTensor.dtype_obj, requires_grad=requires_grad).pin_memory().to(BackendTensor.device, non_blocking=True) data_instance.__dict__[key] = tensor - def gempy_profiler_decorator(func): From 6908ef523b559a0ab3e3bf5761ea5742defdfcc5 Mon Sep 17 00:00:00 2001 From: Miguel de la Varga Date: Thu, 20 Nov 2025 16:43:00 +0100 Subject: [PATCH 2/4] [ENH] Add `keops_enabled` parameter to improve kernel constructor modularity and enhance batch processing support - Introduced the `keops_enabled` parameter across various modules to enable conditional usage of PyKeOps for optimized computations. - Added `_interpolate_stack_batched.py` for GPU-accelerated batched interpolation with CUDA streams, minimizing memory overhead and improving throughput. - Updated tensor creation logic in `backend_tensor.py` to include `pykeops_eval_enabled` for enhanced flexibility in method selection. - Refactored multiple constructor methods to propagate `keops_enabled`, ensuring consistent conditional logic for tensor handling and backend compatibility. - Improved fault data initialization and dependency handling in interpolation pipelines for better parallel computation. --- .../_interpolate_stack_batched.py | 367 ++++++++++++++++++ .../_multi_scalar_field_manager.py | 30 +- gempy_engine/core/backend_tensor.py | 3 +- gempy_engine/core/utils.py | 5 +- .../_covariance_assembler.py | 5 +- .../kernel_constructor/_kernels_assembler.py | 5 +- .../modules/kernel_constructor/_structs.py | 34 +- .../_vectors_preparation.py | 45 ++- 8 files changed, 445 insertions(+), 49 deletions(-) create mode 100644 gempy_engine/API/interp_single/_interpolate_stack_batched.py diff --git a/gempy_engine/API/interp_single/_interpolate_stack_batched.py b/gempy_engine/API/interp_single/_interpolate_stack_batched.py new file mode 100644 index 0000000..61788cc --- /dev/null +++ b/gempy_engine/API/interp_single/_interpolate_stack_batched.py @@ -0,0 +1,367 @@ +from typing import List + +import numpy as np +import torch + +from ._interp_single_feature import input_preprocess, _interpolate_external_function +from ...core.backend_tensor import BackendTensor +from ...core.data import TensorsStructure +from ...core.data.input_data_descriptor import InputDataDescriptor +from ...core.data.interpolation_input import InterpolationInput +from ...core.data.kernel_classes.faults import FaultsData +from ...core.data.options import InterpolationOptions +from ...core.data.scalar_field_output import ScalarFieldOutput +from ...core.data.stack_relation_type import StackRelationType +# ... existing code ... +from ...modules.evaluator.generic_evaluator import generic_evaluator +from ...modules.evaluator.symbolic_evaluator import symbolic_evaluator +from ...modules.kernel_constructor import kernel_constructor_interface as kernel_constructor + + +def _interpolate_stack_batched(root_data_descriptor: InputDataDescriptor, root_interpolation_input: InterpolationInput, + options: InterpolationOptions) -> List[ScalarFieldOutput]: + """ + Optimized interpolation using CUDA streams. + Solves each stack one-by-one in its own stream to maximize GPU throughput + without memory overhead of padding/stacking matrices. + """ + stack_structure = root_data_descriptor.stack_structure + n_stacks = stack_structure.n_stacks + + # Result holder + all_scalar_fields_outputs: List[ScalarFieldOutput | None] = [None] * n_stacks + + # Shared memory for fault interactions (pre-allocated on GPU) + xyz_to_interpolate_size: int = root_interpolation_input.grid.len_all_grids + root_interpolation_input.surface_points.n_points + all_stack_values_block: torch.Tensor = BackendTensor.t.zeros( + (n_stacks, xyz_to_interpolate_size), + dtype=BackendTensor.dtype_obj, + device=BackendTensor.device + ) + + # Create a stream for each stack to allow concurrent execution + streams = [torch.cuda.Stream() for _ in range(n_stacks)] + # Events to signal when a stack is fully computed (for dependencies) + stack_done_events = [torch.cuda.Event() for _ in range(n_stacks)] + BackendTensor.pykeops_enable = False + + for i in range(n_stacks): + stream = streams[i] + with torch.cuda.stream(stream): + # === 1. Python Setup (Runs on CPU) === + stack_structure.stack_number = i + tensor_struct_i = TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i) + interpolation_input_i = InterpolationInput.from_interpolation_input_subset( + all_interpolation_input=root_interpolation_input, + stack_structure=stack_structure + ) + + # === 2. Dependency Handling (GPU Synchronization) === + # If this stack depends on faults, we must wait for those specific stacks to finish. + active_faults = stack_structure.active_faults_relations + if active_faults is not None: + # active_faults is typically a boolean mask or list of indices + if hasattr(active_faults, 'dtype') and active_faults.dtype == bool: + dep_indices = np.where(active_faults)[0] + elif isinstance(active_faults, (list, tuple, np.ndarray)): + dep_indices = active_faults + else: + dep_indices = [] + + # Make current stream wait for dependency events + for dep_idx in dep_indices: + if dep_idx < i: # Sanity check + stream.wait_event(stack_done_events[dep_idx]) + + # Now it is safe to read from all_stack_values_block + fault_data = interpolation_input_i.fault_values or FaultsData() + if interpolation_input_i.fault_values: + fault_data.fault_values_everywhere = all_stack_values_block[stack_structure.active_faults_relations] + # Slice data for Surface Points (SP) + fv_on_all_sp = fault_data.fault_values_everywhere[:, interpolation_input_i.grid.len_all_grids:] + fault_data.fault_values_on_sp = fv_on_all_sp[:, interpolation_input_i.slice_feature] + interpolation_input_i.fault_values = fault_data + + # === 3. Execution Pipeline (Queued on GPU) === + solver_input = input_preprocess(tensor_struct_i, interpolation_input_i) + + if stack_structure.interp_function is None: + # --- A. Kriging Solve --- + # Prepare Matrices (Kernel Construction) + A_mat = kernel_constructor.yield_covariance(solver_input, options.kernel_options) + b_vec = kernel_constructor.yield_b_vector(solver_input.ori_internal, A_mat.shape[0]) + + # Solve System (Async GPU call) + # No padding needed, we solve exact size + weights = torch.linalg.solve(A_mat, b_vec) + + # Evaluate Field + if BackendTensor.pykeops_eval_enabled: + exported_fields = symbolic_evaluator(solver_input, weights, options) + else: + exported_fields = generic_evaluator(solver_input, weights, options) + + # Metadata + exported_fields.set_structure_values( + reference_sp_position=tensor_struct_i.reference_sp_position, + slice_feature=interpolation_input_i.slice_feature, + grid_size=interpolation_input_i.grid.len_all_grids + ) + exported_fields.debug = solver_input.debug + + else: + # --- B. External Function --- + weights = None + xyz = interpolation_input_i.grid.values + exported_fields = _interpolate_external_function( + stack_structure.interp_function, xyz + ) + exported_fields.set_structure_values( + reference_sp_position=None, + slice_feature=None, + grid_size=xyz.shape[0] + ) + + # --- Post-Processing --- + if stack_structure.segmentation_function is not None: + sigmoid_slope = stack_structure.segmentation_function(solver_input.xyz_to_interpolate) + else: + sigmoid_slope = options.sigmoid_slope + + from ...modules.activator import activator_interface + values_block = activator_interface.activate_formation_block(exported_fields, interpolation_input_i.unit_values, sigmoid_slope=sigmoid_slope) + + output = ScalarFieldOutput( + weights=weights, + grid=interpolation_input_i.grid, + exported_fields=exported_fields, + values_block=values_block, + stack_relation=interpolation_input_i.stack_relation + ) + all_scalar_fields_outputs[i] = output + + # Update Shared Block (in-place GPU write) + if interpolation_input_i.stack_relation is StackRelationType.FAULT: + val_min = BackendTensor.t.min(output.values_on_all_xyz, axis=1).reshape(-1, 1) + shifted_vals = (output.values_on_all_xyz - val_min) + + if fault_data.finite_faults_defined: + finite_fault_scalar = fault_data.finite_fault_data.apply(points=solver_input.xyz_to_interpolate) + fault_scalar_field = shifted_vals * finite_fault_scalar + else: + fault_scalar_field = shifted_vals + + all_stack_values_block[i, :] = fault_scalar_field + else: + all_stack_values_block[i, :] = output.values_on_all_xyz + + # Record that this stack is finished + stack_done_events[i].record(stream) + + # Wait for all streams to finish before returning results to Python + torch.cuda.synchronize() + + return all_scalar_fields_outputs + + + +def _interpolate_stack_batched_(root_data_descriptor: InputDataDescriptor, root_interpolation_input: InterpolationInput, + options: InterpolationOptions) -> List[ScalarFieldOutput]: + """Optimized batched interpolation for PyTorch backend.""" + stack_structure = root_data_descriptor.stack_structure + n_stacks = stack_structure.n_stacks + + all_scalar_fields_outputs: List[ScalarFieldOutput | None] = [None] * n_stacks + + xyz_to_interpolate_size: int = root_interpolation_input.grid.len_all_grids + root_interpolation_input.surface_points.n_points + # Pre-allocate final values block on GPU + all_stack_values_block: torch.Tensor = BackendTensor.t.zeros( + (n_stacks, xyz_to_interpolate_size), + dtype=BackendTensor.dtype_obj, + device=BackendTensor.device + ) + + # 1. Prepare Data and Matrices for all stacks + solvable_stacks_indices = [] + solver_inputs = [] + A_matrices = [] + b_vectors = [] + interp_inputs_i = [] + + for i in range(n_stacks): + stack_structure.stack_number = i + tensor_struct_i = TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i) + interpolation_input_i = InterpolationInput.from_interpolation_input_subset( + all_interpolation_input=root_interpolation_input, + stack_structure=stack_structure + ) + + # Handle Fault Data + fault_data = interpolation_input_i.fault_values or FaultsData() + # Note: all_stack_values_block is updated in-place later, but for faults we need previous results + # Since faults are sequential dependencies (usually), we might need synchronization if faults depend on previous stacks + # However, here we grab the slice. In batched mode, if stack J depends on stack I, and we compute parallel, it's an issue. + # But usually fault dependency is strictly structural. + # For safety, if there is a FAULT relation, we might need to be careful. + # Assuming standard GemPy stack logic where faults are processed but their mask effect is in 'combine'. + # The 'fault_values_everywhere' comes from previous iterations in the original code. + # If we batch solve, we must ensure dependencies are met. + # Actually, 'fault_values_everywhere' reads from 'all_stack_values_block'. + # If stack I is a fault, its values are needed for stack J? + # Only if stack J is interpolated using stack I as a drift/fault drift. + # Current GemPy v3 usually treats faults via 'combine' mostly, but 'fault_drift' exists. + # If fault drift is active, we cannot fully parallelize without dependency. + # We will proceed assuming independent kriging systems or that fault values are updated iteratively. + + # For batching A/b construction, we don't need fault values yet (unless they affect drifts). + # If they affect drifts, they are needed in 'input_preprocess'. + + if interpolation_input_i.fault_values: + # In batched mode, we might be reading zeros if previous stacks haven't finished. + # If strict dependency exists, we must fallback or synchronize. + # For now, we proceed with the data prep. + fault_data.fault_values_everywhere = all_stack_values_block[stack_structure.active_faults_relations] + fv_on_all_sp = fault_data.fault_values_everywhere[:, interpolation_input_i.grid.len_all_grids:] + fault_data.fault_values_on_sp = fv_on_all_sp[:, interpolation_input_i.slice_feature] + interpolation_input_i.fault_values = fault_data + + solver_input = input_preprocess(tensor_struct_i, interpolation_input_i) + + # Store inputs + interp_inputs_i.append(interpolation_input_i) + solver_inputs.append(solver_input) + + # If external function, skip solver prep + if stack_structure.interp_function is None: + solvable_stacks_indices.append(i) + # Compute Covariance and b vector (Kernel Construction) + # This is done per stack as they have different sizes/configs + A_mat = kernel_constructor.yield_covariance(solver_input, options.kernel_options) + b_vec = kernel_constructor.yield_b_vector(solver_input.ori_internal, A_mat.shape[0]) + A_matrices.append(A_mat) + b_vectors.append(b_vec) + + # 2. Batch Solve + weights_map = {} + if len(solvable_stacks_indices) > 0: + # Pad and stack + max_size = max(m.shape[0] for m in A_matrices) + padded_A = [] + padded_b = [] + + for A, b in zip(A_matrices, b_vectors): + s = A.shape[0] + pad = max_size - s + if pad > 0: + # Pad A with Identity logic + # A_padded = | A 0 | + # | 0 I | + # F.pad: (left, right, top, bottom) + A_p = torch.nn.functional.pad(A, (0, pad, 0, pad), value=0.0) + # Add Identity to the diagonal of the padded area + if pad > 0: + indices = torch.arange(s, max_size, device=A.device) + A_p[indices, indices] = 1.0 + + b_p = torch.nn.functional.pad(b, (0, pad), value=0.0) + else: + A_p = A + b_p = b + + padded_A.append(A_p) + padded_b.append(b_p) + + big_A = torch.stack(padded_A) + big_b = torch.stack(padded_b) + + # Solve all at once + # options.kernel_options.optimizing_condition_number logic is skipped here for speed + all_weights_padded = torch.linalg.solve(big_A, big_b) + + # Unpack + for idx, real_idx in enumerate(solvable_stacks_indices): + original_size = A_matrices[idx].shape[0] + weights_map[real_idx] = all_weights_padded[idx, :original_size] + + # 3. Evaluate and Store (Streamed) + streams = [torch.cuda.Stream() for _ in range(n_stacks)] + + for i in range(n_stacks): + with torch.cuda.stream(streams[i]): + current_solver_input = solver_inputs[i] + current_interp_input = interp_inputs_i[i] + current_stack_struct = TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i).stack_structure # Re-get to be safe + + if i in weights_map: + # Solved Kriging + weights = weights_map[i] + BackendTensor.pykeops_enabled = BackendTensor.use_pykeops + if BackendTensor.pykeops_enabled: + exported_fields = symbolic_evaluator(current_solver_input, weights, options) + else: + exported_fields = generic_evaluator(current_solver_input, weights, options) + + # Set structure values + exported_fields.set_structure_values( + reference_sp_position=TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i).reference_sp_position, + slice_feature=current_interp_input.slice_feature, + grid_size=current_interp_input.grid.len_all_grids + ) + exported_fields.debug = current_solver_input.debug + else: + # External Function + weights = None + xyz = current_interp_input.grid.values + exported_fields = _interpolate_external_function( + root_data_descriptor.stack_structure.interp_function, xyz + ) + exported_fields.set_structure_values( + reference_sp_position=None, + slice_feature=None, + grid_size=xyz.shape[0] + ) + + # Segmentation + if root_data_descriptor.stack_structure.segmentation_function is not None: + sigmoid_slope = root_data_descriptor.stack_structure.segmentation_function(current_solver_input.xyz_to_interpolate) + else: + sigmoid_slope = options.sigmoid_slope + + # Activate block + # Note: We are inside a stream. + from ...modules.activator import activator_interface + values_block = activator_interface.activate_formation_block(exported_fields, current_interp_input.unit_values, sigmoid_slope=sigmoid_slope) + + output = ScalarFieldOutput( + weights=weights, + grid=current_interp_input.grid, + exported_fields=exported_fields, + values_block=values_block, + stack_relation=current_interp_input.stack_relation + ) + all_scalar_fields_outputs[i] = output + + # Update all_stack_values_block + # Note: This might need synchronization if future stacks read this. + # Since we solved all weights already, the only dependency is drift. + # If drift depends on previous scalar fields, this design assumes data was ready at step 1. + + if current_interp_input.stack_relation is StackRelationType.FAULT: + fault_input = current_interp_input.fault_values + val_min = BackendTensor.t.min(output.values_on_all_xyz, axis=1).reshape(-1, 1) + shifted_vals = (output.values_on_all_xyz - val_min) + + if fault_input.finite_faults_defined: + finite_fault_scalar = fault_input.finite_fault_data.apply(points=current_solver_input.xyz_to_interpolate) + fault_scalar_field = shifted_vals * finite_fault_scalar + else: + fault_scalar_field = shifted_vals + + all_stack_values_block[i, :] = fault_scalar_field + else: + all_stack_values_block[i, :] = output.values_on_all_xyz + + # Synchronize all streams + torch.cuda.synchronize() + + return all_scalar_fields_outputs diff --git a/gempy_engine/API/interp_single/_multi_scalar_field_manager.py b/gempy_engine/API/interp_single/_multi_scalar_field_manager.py index f3083d9..4e20822 100644 --- a/gempy_engine/API/interp_single/_multi_scalar_field_manager.py +++ b/gempy_engine/API/interp_single/_multi_scalar_field_manager.py @@ -1,23 +1,22 @@ -import warnings -from typing import List, Iterable, Optional +from typing import List, Optional import numpy as np from numpy import ndarray -from ...core.data.internal_structs import SolverInput +from ._interp_single_feature import interpolate_feature, input_preprocess +from ._interpolate_stack_batched import _interpolate_stack_batched from ...core.backend_tensor import BackendTensor -from ...core.data.kernel_classes.faults import FaultsData -from ...core.data.exported_structs import CombinedScalarFieldsOutput -from ...core.data.interp_output import InterpOutput -from ...core.data.scalar_field_output import ScalarFieldOutput +from ...core.data import TensorsStructure from ...core.data.exported_fields import ExportedFields +from ...core.data.exported_structs import CombinedScalarFieldsOutput from ...core.data.input_data_descriptor import InputDataDescriptor -from ...core.data.stack_relation_type import StackRelationType -from ...core.data import TensorsStructure +from ...core.data.internal_structs import SolverInput +from ...core.data.interp_output import InterpOutput from ...core.data.interpolation_input import InterpolationInput +from ...core.data.kernel_classes.faults import FaultsData from ...core.data.options import InterpolationOptions - -from ._interp_single_feature import interpolate_feature, input_preprocess +from ...core.data.scalar_field_output import ScalarFieldOutput +from ...core.data.stack_relation_type import StackRelationType # @off @@ -26,7 +25,14 @@ def interpolate_all_fields(interpolation_input: InterpolationInput, options: Int data_descriptor: InputDataDescriptor) -> List[InterpOutput]: """Interpolate all scalar fields given a xyz array of points""" - all_scalar_fields_outputs: List[ScalarFieldOutput] = _interpolate_stack(data_descriptor, interpolation_input, options) + # Check if we can use the optimized batched path + can_batch = (BackendTensor.engine_backend == BackendTensor.engine_backend.PYTORCH and + options.cache_mode == InterpolationOptions.CacheMode.NO_CACHE) + + if can_batch: + all_scalar_fields_outputs: List[ScalarFieldOutput] = _interpolate_stack_batched(data_descriptor, interpolation_input, options) + else: + all_scalar_fields_outputs: List[ScalarFieldOutput] = _interpolate_stack(data_descriptor, interpolation_input, options) combined_scalar_output: List[CombinedScalarFieldsOutput] = _combine_scalar_fields( all_scalar_fields_outputs = all_scalar_fields_outputs, diff --git a/gempy_engine/core/backend_tensor.py b/gempy_engine/core/backend_tensor.py index 9e4d310..f9d9376 100644 --- a/gempy_engine/core/backend_tensor.py +++ b/gempy_engine/core/backend_tensor.py @@ -22,6 +22,7 @@ class BackendTensor: engine_backend: AvailableBackends pykeops_enabled: bool = False + pykeops_eval_enabled: bool = True use_pykeops: bool = False use_gpu: bool = True dtype: str = DEFAULT_TENSOR_DTYPE @@ -37,7 +38,7 @@ class BackendTensor: @classmethod def get_backend_string(cls) -> str: - match (cls.use_gpu, cls.pykeops_enabled): + match (cls.use_gpu, cls.pykeops_eval_enabled): case (True, True): return "GPU" case (False, True): diff --git a/gempy_engine/core/utils.py b/gempy_engine/core/utils.py index 7d14c24..aeb5549 100644 --- a/gempy_engine/core/utils.py +++ b/gempy_engine/core/utils.py @@ -18,7 +18,10 @@ def cast_type_inplace(data_instance: Any, requires_grad: bool = False): import torch if isinstance(val, torch.Tensor): continue - tensor = torch.tensor(val, dtype=BackendTensor.dtype_obj, requires_grad=requires_grad).pin_memory().to(BackendTensor.device, non_blocking=True) + # tensor = torch.tensor(val, dtype=BackendTensor.dtype_obj, requires_grad=requires_grad).pin_memory().to(BackendTensor.device, non_blocking=True) + tensor = BackendTensor.tfnp.array(val, dtype=BackendTensor.dtype_obj) + # tensor = torch.from_numpy(val).pin_memory().to(device, non_blocking=True) + tensor.requires_grad = requires_grad data_instance.__dict__[key] = tensor diff --git a/gempy_engine/modules/kernel_constructor/_covariance_assembler.py b/gempy_engine/modules/kernel_constructor/_covariance_assembler.py index c58f8db..adf1619 100644 --- a/gempy_engine/modules/kernel_constructor/_covariance_assembler.py +++ b/gempy_engine/modules/kernel_constructor/_covariance_assembler.py @@ -48,7 +48,7 @@ def get_covariance(c_o, dm, k_a, k_p_ref, k_p_rest, k_ref_ref, k_ref_rest, k_res def _get_cov_grad(dm, k_a, k_p_ref, nugget): cov_grad = dm.hu * dm.hv / (dm.r_ref_ref ** 2 + 1e-5) * (- k_p_ref + k_a) - k_p_ref * dm.perp_matrix # C grad_nugget = nugget[0, 0] - if BackendTensor.pykeops_enabled is False: + if not BackendTensor.pykeops_enabled: eye = BackendTensor.t.array(np.eye(cov_grad.shape[0], dtype=BackendTensor.dtype)) nugget_selector = eye * dm.perp_matrix nugget_matrix = nugget_selector * grad_nugget @@ -145,7 +145,8 @@ def _get_faults_terms(ki: KernelInput) -> np.ndarray: y_size=cov_size, n_drift_eq=fault_n, drift_start_post_x=cov_size - fault_n, - drift_start_post_y=cov_size - fault_n + drift_start_post_y=cov_size - fault_n, + keops_enabled=BackendTensor.pykeops_enabled ) selector = (selector_components.sel_ui * (selector_components.sel_vj + 1)).sum(axis=-1) diff --git a/gempy_engine/modules/kernel_constructor/_kernels_assembler.py b/gempy_engine/modules/kernel_constructor/_kernels_assembler.py index 0b4229b..595fac5 100644 --- a/gempy_engine/modules/kernel_constructor/_kernels_assembler.py +++ b/gempy_engine/modules/kernel_constructor/_kernels_assembler.py @@ -2,7 +2,7 @@ from ._covariance_assembler import get_covariance from ._internalDistancesMatrices import InternalDistancesMatrices from ._structs import KernelInput, CartesianSelector, OrientationSurfacePointsCoords -from ...core.backend_tensor import BackendTensor as bt +from ...core.backend_tensor import BackendTensor as bt, BackendTensor from ...core.data.kernel_classes.kernel_functions import KernelFunction from ...core.data.options import KernelOptions @@ -86,7 +86,8 @@ def create_scalar_kernel(ki: KernelInput, options: KernelOptions) -> tensor_type y_size=j_size, n_drift_eq=fault_n, drift_start_post_x=cov_size - fault_n, - drift_start_post_y=j_size + drift_start_post_y=j_size, + keops_enabled=BackendTensor.pykeops_eval_enabled ) selector = bt.t.sum(selector_components.sel_ui * (selector_components.sel_vj + 1), axis=-1) diff --git a/gempy_engine/modules/kernel_constructor/_structs.py b/gempy_engine/modules/kernel_constructor/_structs.py index a1afbc5..7d31272 100644 --- a/gempy_engine/modules/kernel_constructor/_structs.py +++ b/gempy_engine/modules/kernel_constructor/_structs.py @@ -24,12 +24,12 @@ def _upgrade_kernel_input_to_keops_tensor_pytorch(struct_data_instance): if key == "n_faults_i": continue if (val.is_contiguous() is False): raise ValueError("Input tensors are not contiguous") - + struct_data_instance.__dict__[key] = LazyTensor(val.type(BackendTensor.dtype_obj)) -def _cast_tensors(data_class_instance): - match (BackendTensor.engine_backend, BackendTensor.pykeops_enabled): +def _cast_tensors(data_class_instance, pykeops_enabled): + match (BackendTensor.engine_backend, pykeops_enabled): case (AvailableBackends.numpy, True): _upgrade_kernel_input_to_keops_tensor_numpy(data_class_instance) case (AvailableBackends.PYTORCH, False): @@ -48,7 +48,7 @@ class OrientationSurfacePointsCoords: diprest_i: tensor_types = field(default_factory=lambda: np.empty((0, 1, 3))) diprest_j: tensor_types = field(default_factory=lambda: np.empty((1, 0, 3))) - def __init__(self, x_ref: np.ndarray, y_ref: np.ndarray, x_rest: np.ndarray, y_rest: np.ndarray): + def __init__(self, x_ref: np.ndarray, y_ref: np.ndarray, x_rest: np.ndarray, y_rest: np.ndarray, keops_enabled: bool): def _assembly(x, y) -> Tuple[np.ndarray, np.ndarray]: dips_points0 = x[:, None, :] # i dips_points1 = y[None, :, :] # j @@ -57,7 +57,7 @@ def _assembly(x, y) -> Tuple[np.ndarray, np.ndarray]: self.dip_ref_i, self.dip_ref_j = _assembly(x_ref, y_ref) self.diprest_i, self.diprest_j = _assembly(x_rest, y_rest) - _cast_tensors(self) + _cast_tensors(self, keops_enabled) @dataclass @@ -75,7 +75,7 @@ def __init__(self, x_degree_1: np.ndarray, y_degree_1: np.ndarray, x_degree_2: np.ndarray, y_degree_2: np.ndarray, x_degree_2b: np.ndarray, y_degree_2b: np.ndarray, - selector_degree_2: np.ndarray): + selector_degree_2: np.ndarray, keops_enabled: bool): self.dips_ug_ai = x_degree_1[:, None, :] self.dips_ug_aj = y_degree_1[None, :, :] self.dips_ug_bi = x_degree_2[:, None, :] @@ -85,7 +85,7 @@ def __init__(self, self.selector_ci = selector_degree_2[:, None, :] self.selector_cj = selector_degree_2[None, :, :] - _cast_tensors(self) + _cast_tensors(self, keops_enabled) @dataclass @@ -98,7 +98,7 @@ class PointsDrift: dipsPoints_ui_bj2: tensor_types = field(default_factory=lambda: np.empty((1, 0, 3))) def __init__(self, x_degree_1: np.ndarray, y_degree_1: np.ndarray, x_degree_2a: np.ndarray, - y_degree_2a: np.ndarray, x_degree_2b: np.ndarray, y_degree_2b: np.ndarray): + y_degree_2a: np.ndarray, x_degree_2b: np.ndarray, y_degree_2b: np.ndarray, keops_enabled): self.dipsPoints_ui_ai = x_degree_1[:, None, :] self.dipsPoints_ui_aj = y_degree_1[None, :, :] self.dipsPoints_ui_bi1 = x_degree_2a[:, None, :] @@ -106,7 +106,7 @@ def __init__(self, x_degree_1: np.ndarray, y_degree_1: np.ndarray, x_degree_2a: self.dipsPoints_ui_bi2 = x_degree_2b[:, None, :] self.dipsPoints_ui_bj2 = y_degree_2b[None, :, :] - _cast_tensors(self) + _cast_tensors(self, keops_enabled) @dataclass @@ -116,13 +116,13 @@ class FaultDrift: n_faults_i: int = 0 - def __init__(self, x_degree_1: np.ndarray, y_degree_1: np.ndarray, ): + def __init__(self, x_degree_1: np.ndarray, y_degree_1: np.ndarray, keops_enabled: bool): self.faults_i = x_degree_1[:, None, :] self.faults_j = y_degree_1[None, :, :] self.n_faults_i = x_degree_1.shape[1] - _cast_tensors(self) + _cast_tensors(self, keops_enabled) @dataclass @@ -137,14 +137,13 @@ class CartesianSelector: h_sel_rest_i: tensor_types = field(default_factory=lambda: np.empty((0, 1, 3))) h_sel_rest_j: tensor_types = field(default_factory=lambda: np.empty((1, 0, 3))) - # is_gradient: bool = False (June) This seems to be unused def __init__(self, x_sel_hu, y_sel_hu, x_sel_hv, y_sel_hv, x_sel_h_ref, y_sel_h_ref, x_sel_h_rest, y_sel_h_rest, - is_gradient=False): + keops_enabled: bool): self.hu_sel_i = x_sel_hu[:, None, :] self.hu_sel_j = y_sel_hu[None, :, :] @@ -157,15 +156,16 @@ def __init__(self, self.h_sel_rest_i = x_sel_h_rest[:, None, :] self.h_sel_rest_j = y_sel_h_rest[None, :, :] - _cast_tensors(self) + _cast_tensors(self, keops_enabled) @dataclass class DriftMatrixSelector: sel_ui: tensor_types = field(default_factory=lambda: np.empty((0, 1, 3))) sel_vj: tensor_types = field(default_factory=lambda: np.empty((1, 0, 3))) - - def __init__(self, x_size: int, y_size: int, n_drift_eq: int, drift_start_post_x: int, drift_start_post_y: int): + + def __init__(self, x_size: int, y_size: int, n_drift_eq: int, drift_start_post_x: int, drift_start_post_y: int, + keops_enabled: bool): sel_i = np.zeros((x_size, 2), dtype=BackendTensor.dtype) sel_j = np.zeros((y_size, 2), dtype=BackendTensor.dtype) @@ -185,7 +185,7 @@ def __init__(self, x_size: int, y_size: int, n_drift_eq: int, drift_start_post_x self.sel_ui = sel_i[:, None, :] self.sel_vj = sel_j[None, :, :] - _cast_tensors(self) + _cast_tensors(self, keops_enabled) @dataclass diff --git a/gempy_engine/modules/kernel_constructor/_vectors_preparation.py b/gempy_engine/modules/kernel_constructor/_vectors_preparation.py index 8b10c10..373fd87 100644 --- a/gempy_engine/modules/kernel_constructor/_vectors_preparation.py +++ b/gempy_engine/modules/kernel_constructor/_vectors_preparation.py @@ -38,7 +38,9 @@ def cov_vectors_preparation(interp_input: SolverInput, kernel_options: KernelOpt y_size=matrices_sizes.cov_size, drift_start_post_x=matrices_sizes.ori_size + matrices_sizes.sp_size, drift_start_post_y=matrices_sizes.ori_size + matrices_sizes.sp_size, - n_drift_eq=matrices_sizes.uni_drift_size) + n_drift_eq=matrices_sizes.uni_drift_size, + keops_enabled=BackendTensor.pykeops_enabled + ) if matrices_sizes.faults_size > 0: fault_vector_ref, fault_vector_rest = _assembly_fault_tensors(options, faults_val, matrices_sizes.ori_size) @@ -108,7 +110,9 @@ def evaluation_vectors_preparations(interp_input: SolverInput, kernel_options: K y_size=matrices_sizes.grid_size, drift_start_post_x=matrices_sizes.ori_size + matrices_sizes.sp_size, drift_start_post_y=matrices_sizes.grid_size, - n_drift_eq=matrices_sizes.uni_drift_size) + n_drift_eq=matrices_sizes.uni_drift_size, + keops_enabled = True + ) return KernelInput( ori_sp_matrices=orientations_sp_matrices, @@ -131,7 +135,7 @@ def _assembly_dips_points_tensors(matrices_size: MatricesSizes, ori_, sp_) -> Or dips_rest_coord = assembly_dips_points_tensor(ori_.dip_positions_tiled, sp_.rest_surface_points, matrices_size) orientations_sp_matrices = OrientationSurfacePointsCoords(dips_ref_coord, dips_ref_coord, dips_rest_coord, - dips_rest_coord) # When we create que core covariance these are the repeated since the distance are with themselves + dips_rest_coord, keops_enabled=BackendTensor.pykeops_enabled) # When we create que core covariance these are the repeated since the distance are with themselves return orientations_sp_matrices @@ -141,7 +145,8 @@ def _assembly_dips_points_grid_tensors(grid, matrices_size: MatricesSizes, ori_, dips_ref_coord = assembly_dips_points_tensor(ori_.dip_positions_tiled, sp_.ref_surface_points, matrices_size) dips_rest_coord = assembly_dips_points_tensor(ori_.dip_positions_tiled, sp_.rest_surface_points, matrices_size) - orientations_sp_matrices = OrientationSurfacePointsCoords(dips_ref_coord, grid, dips_rest_coord, grid) # When we create que core covariance this are the repeated since the distance are with themselves + orientations_sp_matrices = OrientationSurfacePointsCoords(dips_ref_coord, grid, dips_rest_coord, grid, + keops_enabled=BackendTensor.pykeops_eval_enabled) # When we create que core covariance this are the repeated since the distance are with themselves return orientations_sp_matrices @@ -153,7 +158,8 @@ def _assembly_cartesian_selector_tensors(matrices_sizes: MatricesSizes): x_sel_hu=sel_hu_input, y_sel_hu=sel_hv_input, x_sel_hv=sel_hv_input, y_sel_hv=sel_hu_input, x_sel_h_ref=sel_hu_points_input, y_sel_h_ref=sel_hu_points_input, - x_sel_h_rest=sel_hu_points_input, y_sel_h_rest=sel_hu_points_input + x_sel_h_rest=sel_hu_points_input, y_sel_h_rest=sel_hu_points_input, + keops_enabled=BackendTensor.pykeops_enabled ) return cartesian_selector @@ -167,7 +173,9 @@ def _assembly_cartesian_selector_grid(matrices_sizes, axis=None): x_sel_hu=sel_hu_input, y_sel_hu=sel_hu_grid, x_sel_hv=sel_hv_input, y_sel_hv=sel_hv_grid, x_sel_h_ref=sel_hu_points_input, y_sel_h_ref=sel_hu_points_grid, - x_sel_h_rest=sel_hu_points_input, y_sel_h_rest=sel_hu_points_grid) + x_sel_h_rest=sel_hu_points_input, y_sel_h_rest=sel_hu_points_grid, + keops_enabled=BackendTensor.pykeops_eval_enabled + ) return cartesian_selector @@ -179,14 +187,17 @@ def _assembly_drift_tensors(options: KernelOptions, ori_: OrientationsInternals, dips_ug_d1, dips_ug_d1, dips_ug_d2a, dips_ug_d2a, dips_ug_d2b, dips_ug_d2b, - second_degree_selector + second_degree_selector, + keops_enabled=BackendTensor.pykeops_enabled ) dips_ref_d1, dips_ref_d2a, dips_ref_d2b = assembly_dips_points_coords(sp_.ref_surface_points, matrices_sizes, options) dips_rest_d1, dips_rest_d2a, dips_rest_d2b = assembly_dips_points_coords(sp_.rest_surface_points, matrices_sizes, options) - dips_ref_ui = PointsDrift(dips_ref_d1, dips_ref_d1, dips_ref_d2a, dips_ref_d2a, dips_ref_d2b, dips_ref_d2b) - dips_rest_ui = PointsDrift(dips_rest_d1, dips_rest_d1, dips_rest_d2a, dips_rest_d2a, dips_rest_d2b, dips_rest_d2b) + dips_ref_ui = PointsDrift(dips_ref_d1, dips_ref_d1, dips_ref_d2a, dips_ref_d2a, dips_ref_d2b, dips_ref_d2b, + keops_enabled=BackendTensor.pykeops_enabled) + dips_rest_ui = PointsDrift(dips_rest_d1, dips_rest_d1, dips_rest_d2a, dips_rest_d2a, dips_rest_d2b, dips_rest_d2b, + keops_enabled=BackendTensor.pykeops_enabled) return dips_ref_ui, dips_rest_ui, dips_ug @@ -206,15 +217,17 @@ def _assembly_drift_grid_tensors(grid: np.ndarray, options: KernelOptions, matri x_degree_1=dips_ug_d1, y_degree_1=grid_1, x_degree_2=dips_ug_d2a, y_degree_2=grid * grid_1, x_degree_2b=dips_ug_d2b * sel, y_degree_2b=grid, - selector_degree_2=second_degree_selector) + selector_degree_2=second_degree_selector, + keops_enabled=BackendTensor.pykeops_eval_enabled + ) # endregion # region UI dips_ref_d1, dips_ref_d2a, dips_ref_d2b = assembly_dips_points_coords(sp_.ref_surface_points, matrices_size, options) dips_rest_d1, dips_rest_d2a, dips_rest_d2b = assembly_dips_points_coords(sp_.rest_surface_points, matrices_size, options) - dips_ref_ui = PointsDrift(dips_ref_d1, grid, dips_ref_d2a, grid, dips_ref_d2b, grid) - dips_rest_ui = PointsDrift(dips_rest_d1, grid, dips_rest_d2a, grid, dips_rest_d2b, grid) + dips_ref_ui = PointsDrift(dips_ref_d1, grid, dips_ref_d2a, grid, dips_ref_d2b, grid, keops_enabled=BackendTensor.pykeops_eval_enabled) + dips_rest_ui = PointsDrift(dips_rest_d1, grid, dips_rest_d2a, grid, dips_rest_d2b, grid, keops_enabled=BackendTensor.pykeops_eval_enabled) # endregion return dips_ref_ui, dips_rest_ui, dips_ug @@ -224,14 +237,18 @@ def _assembly_fault_grid_tensors(fault_values_on_grid, options: KernelOptions, f fault_vector_ref, fault_vector_rest = _assembly_fault_internals(faults_val, options, ori_size) fault_drift = FaultDrift( x_degree_1=fault_vector_ref, - y_degree_1=BackendTensor.t.ascontiguousarray(fault_values_on_grid.T) + y_degree_1=BackendTensor.t.ascontiguousarray(fault_values_on_grid.T), + keops_enabled=BackendTensor.pykeops_eval_enabled ) return fault_drift def _assembly_fault_tensors(options, faults_val: FaultsData, ori_size: int) -> Tuple[FaultDrift, FaultDrift]: fault_vector_ref, fault_vector_rest = _assembly_fault_internals(faults_val, options, ori_size) - return FaultDrift(fault_vector_ref, fault_vector_ref), FaultDrift(fault_vector_rest, fault_vector_rest) + return ( + FaultDrift(fault_vector_ref, fault_vector_ref, keops_enabled=BackendTensor.pykeops_enabled), + FaultDrift(fault_vector_rest, fault_vector_rest, keops_enabled=BackendTensor.pykeops_enabled) + ) def _assembly_fault_internals(faults_val, options, ori_size): From d7e8b82c03c48eb68c58831de96003a94b658a6c Mon Sep 17 00:00:00 2001 From: Miguel de la Varga Date: Thu, 20 Nov 2025 17:09:22 +0100 Subject: [PATCH 3/4] [WIP] Towards batching --- .../_interpolate_stack_batched.py | 346 +++++------------- .../_multi_scalar_field_manager.py | 2 +- 2 files changed, 86 insertions(+), 262 deletions(-) diff --git a/gempy_engine/API/interp_single/_interpolate_stack_batched.py b/gempy_engine/API/interp_single/_interpolate_stack_batched.py index 61788cc..7b8e94b 100644 --- a/gempy_engine/API/interp_single/_interpolate_stack_batched.py +++ b/gempy_engine/API/interp_single/_interpolate_stack_batched.py @@ -1,4 +1,6 @@ -from typing import List +# ... existing code ... + +from typing import List, Tuple import numpy as np import torch @@ -8,59 +10,92 @@ from ...core.data import TensorsStructure from ...core.data.input_data_descriptor import InputDataDescriptor from ...core.data.interpolation_input import InterpolationInput -from ...core.data.kernel_classes.faults import FaultsData from ...core.data.options import InterpolationOptions from ...core.data.scalar_field_output import ScalarFieldOutput from ...core.data.stack_relation_type import StackRelationType -# ... existing code ... from ...modules.evaluator.generic_evaluator import generic_evaluator from ...modules.evaluator.symbolic_evaluator import symbolic_evaluator from ...modules.kernel_constructor import kernel_constructor_interface as kernel_constructor +# TODO: [ ] Batch only pykeops evaluations. +# TODO: [ ] To speed up the interpolation, we should try pykeops solver with fall back + def _interpolate_stack_batched(root_data_descriptor: InputDataDescriptor, root_interpolation_input: InterpolationInput, options: InterpolationOptions) -> List[ScalarFieldOutput]: """ - Optimized interpolation using CUDA streams. - Solves each stack one-by-one in its own stream to maximize GPU throughput - without memory overhead of padding/stacking matrices. + Optimized batched interpolation using Split-Loop Pipelining and CUDA Streams. + + Strategy: + 1. CPU Phase: Pre-process all stacks (Python overhead, CPU prep, Data transfer initiation). + 2. GPU Phase: Launch Kernel Assembly, Solve, and Evaluation into parallel streams. + + This avoids the O(N^3) cost of padding matrices and prevents CPU prep from stalling the GPU. """ + BackendTensor.pykeops_enabled = False stack_structure = root_data_descriptor.stack_structure n_stacks = stack_structure.n_stacks - # Result holder + # Result container all_scalar_fields_outputs: List[ScalarFieldOutput | None] = [None] * n_stacks - # Shared memory for fault interactions (pre-allocated on GPU) + # Shared memory for results (Fault interactions need this) xyz_to_interpolate_size: int = root_interpolation_input.grid.len_all_grids + root_interpolation_input.surface_points.n_points + + # Allocate on GPU once all_stack_values_block: torch.Tensor = BackendTensor.t.zeros( (n_stacks, xyz_to_interpolate_size), dtype=BackendTensor.dtype_obj, device=BackendTensor.device ) - # Create a stream for each stack to allow concurrent execution + # === Phase 1: CPU Preparation Loop === + # We prepare all python objects and initiate data transfers here. + # This ensures that when we start launching GPU kernels, we don't stop for CPU work. + prepared_stacks: List[Tuple[int, InterpolationInput, object]] = [] + for i in range(n_stacks): + stack_structure.stack_number = i + tensor_struct_i = TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i) + interpolation_input_i = InterpolationInput.from_interpolation_input_subset( + all_interpolation_input=root_interpolation_input, + stack_structure=stack_structure + ) + + # Handle Faults Dependencies + # Note: In a fully pipelined approach, we can't read the *results* of previous stacks yet. + # However, the 'fault_values_everywhere' usually comes from the Shared Tensor 'all_stack_values_block'. + # As long as we enforce stream dependencies later, we can set up the views/pointers here. + + if interpolation_input_i.fault_values: + fault_data = interpolation_input_i.fault_values + # Create views into the shared block + fault_data.fault_values_everywhere = all_stack_values_block[stack_structure.active_faults_relations] + + # We need to be careful with slicing here. + # The slicing operation itself is fast/metadata only on Tensors. + fv_on_all_sp = fault_data.fault_values_everywhere[:, interpolation_input_i.grid.len_all_grids:] + fault_data.fault_values_on_sp = fv_on_all_sp[:, interpolation_input_i.slice_feature] + interpolation_input_i.fault_values = fault_data + + # Heavy CPU work: Prepare SolverInput (converts numpy to tensors, etc) + solver_input = input_preprocess(tensor_struct_i, interpolation_input_i) + + prepared_stacks.append((i, interpolation_input_i, solver_input)) + + # === Phase 2: GPU Execution Loop === + # Create streams and events streams = [torch.cuda.Stream() for _ in range(n_stacks)] - # Events to signal when a stack is fully computed (for dependencies) - stack_done_events = [torch.cuda.Event() for _ in range(n_stacks)] - BackendTensor.pykeops_enable = False + events = [torch.cuda.Event() for _ in range(n_stacks)] - for i in range(n_stacks): + for i, interpolation_input_i, solver_input in prepared_stacks: stream = streams[i] - with torch.cuda.stream(stream): - # === 1. Python Setup (Runs on CPU) === - stack_structure.stack_number = i - tensor_struct_i = TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i) - interpolation_input_i = InterpolationInput.from_interpolation_input_subset( - all_interpolation_input=root_interpolation_input, - stack_structure=stack_structure - ) - # === 2. Dependency Handling (GPU Synchronization) === - # If this stack depends on faults, we must wait for those specific stacks to finish. - active_faults = stack_structure.active_faults_relations + with torch.cuda.stream(stream): + # 1. Synchronization: Wait for dependencies (Faults) + active_faults = root_data_descriptor.stack_structure.active_faults_relations if active_faults is not None: - # active_faults is typically a boolean mask or list of indices + # Find which stacks we depend on + # active_faults is likely a boolean mask for previous stacks if hasattr(active_faults, 'dtype') and active_faults.dtype == bool: dep_indices = np.where(active_faults)[0] elif isinstance(active_faults, (list, tuple, np.ndarray)): @@ -68,42 +103,31 @@ def _interpolate_stack_batched(root_data_descriptor: InputDataDescriptor, root_i else: dep_indices = [] - # Make current stream wait for dependency events for dep_idx in dep_indices: - if dep_idx < i: # Sanity check - stream.wait_event(stack_done_events[dep_idx]) - - # Now it is safe to read from all_stack_values_block - fault_data = interpolation_input_i.fault_values or FaultsData() - if interpolation_input_i.fault_values: - fault_data.fault_values_everywhere = all_stack_values_block[stack_structure.active_faults_relations] - # Slice data for Surface Points (SP) - fv_on_all_sp = fault_data.fault_values_everywhere[:, interpolation_input_i.grid.len_all_grids:] - fault_data.fault_values_on_sp = fv_on_all_sp[:, interpolation_input_i.slice_feature] - interpolation_input_i.fault_values = fault_data - - # === 3. Execution Pipeline (Queued on GPU) === - solver_input = input_preprocess(tensor_struct_i, interpolation_input_i) - - if stack_structure.interp_function is None: - # --- A. Kriging Solve --- - # Prepare Matrices (Kernel Construction) + if dep_idx < i: # Can only wait on previous stacks + stream.wait_event(events[dep_idx]) + + # 2. Compute or Evaluate + if root_data_descriptor.stack_structure.interp_function is None: + # --- A. Kriging Solve (GPU) --- + + # Construct Covariance Matrix (O(N^2)) + # This is now inside the stream, so it runs in parallel with other stacks A_mat = kernel_constructor.yield_covariance(solver_input, options.kernel_options) b_vec = kernel_constructor.yield_b_vector(solver_input.ori_internal, A_mat.shape[0]) - # Solve System (Async GPU call) - # No padding needed, we solve exact size + # Solve System (O(N^3)) weights = torch.linalg.solve(A_mat, b_vec) - # Evaluate Field + # Evaluate Field (O(M*N)) if BackendTensor.pykeops_eval_enabled: exported_fields = symbolic_evaluator(solver_input, weights, options) else: exported_fields = generic_evaluator(solver_input, weights, options) - # Metadata + # Post-process results exported_fields.set_structure_values( - reference_sp_position=tensor_struct_i.reference_sp_position, + reference_sp_position=TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i).reference_sp_position, slice_feature=interpolation_input_i.slice_feature, grid_size=interpolation_input_i.grid.len_all_grids ) @@ -114,7 +138,7 @@ def _interpolate_stack_batched(root_data_descriptor: InputDataDescriptor, root_i weights = None xyz = interpolation_input_i.grid.values exported_fields = _interpolate_external_function( - stack_structure.interp_function, xyz + root_data_descriptor.stack_structure.interp_function, xyz ) exported_fields.set_structure_values( reference_sp_position=None, @@ -122,14 +146,16 @@ def _interpolate_stack_batched(root_data_descriptor: InputDataDescriptor, root_i grid_size=xyz.shape[0] ) - # --- Post-Processing --- - if stack_structure.segmentation_function is not None: - sigmoid_slope = stack_structure.segmentation_function(solver_input.xyz_to_interpolate) + # 3. Segmentation & Activation + if root_data_descriptor.stack_structure.segmentation_function is not None: + sigmoid_slope = root_data_descriptor.stack_structure.segmentation_function(solver_input.xyz_to_interpolate) else: sigmoid_slope = options.sigmoid_slope from ...modules.activator import activator_interface - values_block = activator_interface.activate_formation_block(exported_fields, interpolation_input_i.unit_values, sigmoid_slope=sigmoid_slope) + values_block = activator_interface.activate_formation_block( + exported_fields, interpolation_input_i.unit_values, sigmoid_slope=sigmoid_slope + ) output = ScalarFieldOutput( weights=weights, @@ -140,8 +166,9 @@ def _interpolate_stack_batched(root_data_descriptor: InputDataDescriptor, root_i ) all_scalar_fields_outputs[i] = output - # Update Shared Block (in-place GPU write) + # 4. Update Shared Block (In-place GPU write) if interpolation_input_i.stack_relation is StackRelationType.FAULT: + fault_data = interpolation_input_i.fault_values val_min = BackendTensor.t.min(output.values_on_all_xyz, axis=1).reshape(-1, 1) shifted_vals = (output.values_on_all_xyz - val_min) @@ -155,213 +182,10 @@ def _interpolate_stack_batched(root_data_descriptor: InputDataDescriptor, root_i else: all_stack_values_block[i, :] = output.values_on_all_xyz - # Record that this stack is finished - stack_done_events[i].record(stream) - - # Wait for all streams to finish before returning results to Python - torch.cuda.synchronize() - - return all_scalar_fields_outputs - - - -def _interpolate_stack_batched_(root_data_descriptor: InputDataDescriptor, root_interpolation_input: InterpolationInput, - options: InterpolationOptions) -> List[ScalarFieldOutput]: - """Optimized batched interpolation for PyTorch backend.""" - stack_structure = root_data_descriptor.stack_structure - n_stacks = stack_structure.n_stacks - - all_scalar_fields_outputs: List[ScalarFieldOutput | None] = [None] * n_stacks - - xyz_to_interpolate_size: int = root_interpolation_input.grid.len_all_grids + root_interpolation_input.surface_points.n_points - # Pre-allocate final values block on GPU - all_stack_values_block: torch.Tensor = BackendTensor.t.zeros( - (n_stacks, xyz_to_interpolate_size), - dtype=BackendTensor.dtype_obj, - device=BackendTensor.device - ) - - # 1. Prepare Data and Matrices for all stacks - solvable_stacks_indices = [] - solver_inputs = [] - A_matrices = [] - b_vectors = [] - interp_inputs_i = [] - - for i in range(n_stacks): - stack_structure.stack_number = i - tensor_struct_i = TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i) - interpolation_input_i = InterpolationInput.from_interpolation_input_subset( - all_interpolation_input=root_interpolation_input, - stack_structure=stack_structure - ) - - # Handle Fault Data - fault_data = interpolation_input_i.fault_values or FaultsData() - # Note: all_stack_values_block is updated in-place later, but for faults we need previous results - # Since faults are sequential dependencies (usually), we might need synchronization if faults depend on previous stacks - # However, here we grab the slice. In batched mode, if stack J depends on stack I, and we compute parallel, it's an issue. - # But usually fault dependency is strictly structural. - # For safety, if there is a FAULT relation, we might need to be careful. - # Assuming standard GemPy stack logic where faults are processed but their mask effect is in 'combine'. - # The 'fault_values_everywhere' comes from previous iterations in the original code. - # If we batch solve, we must ensure dependencies are met. - # Actually, 'fault_values_everywhere' reads from 'all_stack_values_block'. - # If stack I is a fault, its values are needed for stack J? - # Only if stack J is interpolated using stack I as a drift/fault drift. - # Current GemPy v3 usually treats faults via 'combine' mostly, but 'fault_drift' exists. - # If fault drift is active, we cannot fully parallelize without dependency. - # We will proceed assuming independent kriging systems or that fault values are updated iteratively. - - # For batching A/b construction, we don't need fault values yet (unless they affect drifts). - # If they affect drifts, they are needed in 'input_preprocess'. - - if interpolation_input_i.fault_values: - # In batched mode, we might be reading zeros if previous stacks haven't finished. - # If strict dependency exists, we must fallback or synchronize. - # For now, we proceed with the data prep. - fault_data.fault_values_everywhere = all_stack_values_block[stack_structure.active_faults_relations] - fv_on_all_sp = fault_data.fault_values_everywhere[:, interpolation_input_i.grid.len_all_grids:] - fault_data.fault_values_on_sp = fv_on_all_sp[:, interpolation_input_i.slice_feature] - interpolation_input_i.fault_values = fault_data - - solver_input = input_preprocess(tensor_struct_i, interpolation_input_i) - - # Store inputs - interp_inputs_i.append(interpolation_input_i) - solver_inputs.append(solver_input) - - # If external function, skip solver prep - if stack_structure.interp_function is None: - solvable_stacks_indices.append(i) - # Compute Covariance and b vector (Kernel Construction) - # This is done per stack as they have different sizes/configs - A_mat = kernel_constructor.yield_covariance(solver_input, options.kernel_options) - b_vec = kernel_constructor.yield_b_vector(solver_input.ori_internal, A_mat.shape[0]) - A_matrices.append(A_mat) - b_vectors.append(b_vec) - - # 2. Batch Solve - weights_map = {} - if len(solvable_stacks_indices) > 0: - # Pad and stack - max_size = max(m.shape[0] for m in A_matrices) - padded_A = [] - padded_b = [] - - for A, b in zip(A_matrices, b_vectors): - s = A.shape[0] - pad = max_size - s - if pad > 0: - # Pad A with Identity logic - # A_padded = | A 0 | - # | 0 I | - # F.pad: (left, right, top, bottom) - A_p = torch.nn.functional.pad(A, (0, pad, 0, pad), value=0.0) - # Add Identity to the diagonal of the padded area - if pad > 0: - indices = torch.arange(s, max_size, device=A.device) - A_p[indices, indices] = 1.0 - - b_p = torch.nn.functional.pad(b, (0, pad), value=0.0) - else: - A_p = A - b_p = b - - padded_A.append(A_p) - padded_b.append(b_p) - - big_A = torch.stack(padded_A) - big_b = torch.stack(padded_b) - - # Solve all at once - # options.kernel_options.optimizing_condition_number logic is skipped here for speed - all_weights_padded = torch.linalg.solve(big_A, big_b) - - # Unpack - for idx, real_idx in enumerate(solvable_stacks_indices): - original_size = A_matrices[idx].shape[0] - weights_map[real_idx] = all_weights_padded[idx, :original_size] - - # 3. Evaluate and Store (Streamed) - streams = [torch.cuda.Stream() for _ in range(n_stacks)] - - for i in range(n_stacks): - with torch.cuda.stream(streams[i]): - current_solver_input = solver_inputs[i] - current_interp_input = interp_inputs_i[i] - current_stack_struct = TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i).stack_structure # Re-get to be safe - - if i in weights_map: - # Solved Kriging - weights = weights_map[i] - BackendTensor.pykeops_enabled = BackendTensor.use_pykeops - if BackendTensor.pykeops_enabled: - exported_fields = symbolic_evaluator(current_solver_input, weights, options) - else: - exported_fields = generic_evaluator(current_solver_input, weights, options) - - # Set structure values - exported_fields.set_structure_values( - reference_sp_position=TensorsStructure.from_tensor_structure_subset(root_data_descriptor, i).reference_sp_position, - slice_feature=current_interp_input.slice_feature, - grid_size=current_interp_input.grid.len_all_grids - ) - exported_fields.debug = current_solver_input.debug - else: - # External Function - weights = None - xyz = current_interp_input.grid.values - exported_fields = _interpolate_external_function( - root_data_descriptor.stack_structure.interp_function, xyz - ) - exported_fields.set_structure_values( - reference_sp_position=None, - slice_feature=None, - grid_size=xyz.shape[0] - ) - - # Segmentation - if root_data_descriptor.stack_structure.segmentation_function is not None: - sigmoid_slope = root_data_descriptor.stack_structure.segmentation_function(current_solver_input.xyz_to_interpolate) - else: - sigmoid_slope = options.sigmoid_slope - - # Activate block - # Note: We are inside a stream. - from ...modules.activator import activator_interface - values_block = activator_interface.activate_formation_block(exported_fields, current_interp_input.unit_values, sigmoid_slope=sigmoid_slope) - - output = ScalarFieldOutput( - weights=weights, - grid=current_interp_input.grid, - exported_fields=exported_fields, - values_block=values_block, - stack_relation=current_interp_input.stack_relation - ) - all_scalar_fields_outputs[i] = output - - # Update all_stack_values_block - # Note: This might need synchronization if future stacks read this. - # Since we solved all weights already, the only dependency is drift. - # If drift depends on previous scalar fields, this design assumes data was ready at step 1. - - if current_interp_input.stack_relation is StackRelationType.FAULT: - fault_input = current_interp_input.fault_values - val_min = BackendTensor.t.min(output.values_on_all_xyz, axis=1).reshape(-1, 1) - shifted_vals = (output.values_on_all_xyz - val_min) - - if fault_input.finite_faults_defined: - finite_fault_scalar = fault_input.finite_fault_data.apply(points=current_solver_input.xyz_to_interpolate) - fault_scalar_field = shifted_vals * finite_fault_scalar - else: - fault_scalar_field = shifted_vals - - all_stack_values_block[i, :] = fault_scalar_field - else: - all_stack_values_block[i, :] = output.values_on_all_xyz + # 5. Record Event (Stack Finished) + events[i].record(stream) - # Synchronize all streams + # Wait for everything to finish torch.cuda.synchronize() return all_scalar_fields_outputs diff --git a/gempy_engine/API/interp_single/_multi_scalar_field_manager.py b/gempy_engine/API/interp_single/_multi_scalar_field_manager.py index 4e20822..30cbc06 100644 --- a/gempy_engine/API/interp_single/_multi_scalar_field_manager.py +++ b/gempy_engine/API/interp_single/_multi_scalar_field_manager.py @@ -29,7 +29,7 @@ def interpolate_all_fields(interpolation_input: InterpolationInput, options: Int can_batch = (BackendTensor.engine_backend == BackendTensor.engine_backend.PYTORCH and options.cache_mode == InterpolationOptions.CacheMode.NO_CACHE) - if can_batch: + if can_batch and True: all_scalar_fields_outputs: List[ScalarFieldOutput] = _interpolate_stack_batched(data_descriptor, interpolation_input, options) else: all_scalar_fields_outputs: List[ScalarFieldOutput] = _interpolate_stack(data_descriptor, interpolation_input, options) From bbf0394b79b8c87e0111eee1d6f9688d1bbf0980 Mon Sep 17 00:00:00 2001 From: Miguel de la Varga Date: Fri, 21 Nov 2025 11:37:44 +0100 Subject: [PATCH 4/4] [ENH] A few extra optimizations --- .../_multi_scalar_field_manager.py | 2 +- gempy_engine/core/backend_tensor.py | 7 +- .../data/kernel_classes/kernel_functions.py | 148 +++++++++++++----- .../_vectors_preparation.py | 2 +- 4 files changed, 118 insertions(+), 41 deletions(-) diff --git a/gempy_engine/API/interp_single/_multi_scalar_field_manager.py b/gempy_engine/API/interp_single/_multi_scalar_field_manager.py index 30cbc06..d2b962b 100644 --- a/gempy_engine/API/interp_single/_multi_scalar_field_manager.py +++ b/gempy_engine/API/interp_single/_multi_scalar_field_manager.py @@ -29,7 +29,7 @@ def interpolate_all_fields(interpolation_input: InterpolationInput, options: Int can_batch = (BackendTensor.engine_backend == BackendTensor.engine_backend.PYTORCH and options.cache_mode == InterpolationOptions.CacheMode.NO_CACHE) - if can_batch and True: + if can_batch and False: all_scalar_fields_outputs: List[ScalarFieldOutput] = _interpolate_stack_batched(data_descriptor, interpolation_input, options) else: all_scalar_fields_outputs: List[ScalarFieldOutput] = _interpolate_stack(data_descriptor, interpolation_input, options) diff --git a/gempy_engine/core/backend_tensor.py b/gempy_engine/core/backend_tensor.py index f9d9376..17f4a93 100644 --- a/gempy_engine/core/backend_tensor.py +++ b/gempy_engine/core/backend_tensor.py @@ -193,9 +193,10 @@ def _array(array_like, dtype=None): if not array_like.flags.c_contiguous: array_like = numpy.ascontiguousarray(array_like) - - # return torch.tensor(array_like, dtype=dtype) - return torch.tensor(array_like, dtype=dtype).pin_memory().to(cls.device, non_blocking=True) + if cls.use_gpu: + return torch.tensor(array_like, dtype=dtype).pin_memory().to(cls.device, non_blocking=True) + else: + return torch.tensor(array_like, dtype=dtype) def _concatenate(tensors, axis=0, dtype=None): # Switch if tensor is numpy array or a torch tensor diff --git a/gempy_engine/core/data/kernel_classes/kernel_functions.py b/gempy_engine/core/data/kernel_classes/kernel_functions.py index f9ad0d1..0214e1f 100644 --- a/gempy_engine/core/data/kernel_classes/kernel_functions.py +++ b/gempy_engine/core/data/kernel_classes/kernel_functions.py @@ -9,63 +9,134 @@ dtype = BackendTensor.dtype +from dataclasses import dataclass +from enum import Enum +from typing import Callable +import torch -def cubic_function(r, a): - a = float(a) - return 1 - 7 * (r / a) ** 2 + 35 * r ** 3 / (4 * a ** 3) - 7 * r ** 5 / (2 * a ** 5) + 3 * r ** 7 / (4 * a ** 7) +from gempy_engine.core.backend_tensor import BackendTensor +# We define JIT-compiled versions for GPU/PyTorch performance. +# These fuse all element-wise operations into a single kernel execution. -def cubic_function_p_div_r(r, a): - a = float(a) - return (-14 / a ** 2) + 105 * r / (4 * a ** 3) - 35 * r ** 3 / (2 * a ** 5) + 21 * r ** 5 / (4 * a ** 7) +@torch.jit.script +def cubic_function(r: torch.Tensor, a: float) -> torch.Tensor: + # Horner's method for stability and fewer ops: + # 1 - 7x^2 + 35/4 x^3 - 7/2 x^5 + 3/4 x^7 + # where x = r/a + # Pre-calculate constants + c2 = -7.0 + c3 = 8.75 # 35/4 + c5 = -3.5 # 7/2 + c7 = 0.75 # 3/4 -def cubic_function_a(r, a): - a = float(a) - return 7 * (9 * r ** 5 - 20 * a ** 2 * r ** 3 + 15 * a ** 4 * r - 4 * a ** 5) / (2 * a ** 7) + x = r / a + x2 = x * x + # Factor out x^2 to reduce powers: 1 + x^2 * (-7 + x * (8.75 + x^2 * (-3.5 + 0.75 * x^2))) + # But standard Horner on the polynomial in x is likely best or just explicit fused math + # 1 + x^2 * (-7 + x * (35/4 + x^2 * (-7/2 + x^2 * 3/4))) + return 1.0 + x2 * (c2 + x * (c3 + x2 * (c5 + x2 * c7))) -def exp_function(sq_r, a): - a = float(a) - return BackendTensor.tfnp.exp(-(sq_r / (2 * a ** 2))) +@torch.jit.script +def cubic_function_p_div_r(r: torch.Tensor, a: float) -> torch.Tensor: + # (-14 / a^2) + 105 r / (4 a^3) - 35 r^3 / (2 a^5) + 21 r^5 / (4 a^7) + a_inv = 1.0 / a + a2_inv = a_inv * a_inv + x = r * a_inv + x2 = x * x -def exp_function_p_div_r(sq_r, a): - a = float(a) - return -(1 / (a ** 2) * BackendTensor.tfnp.exp(-(sq_r / (2 * a ** 2)))) + t0 = -14.0 * a2_inv + t1 = 26.25 * a2_inv * a_inv # 105/4 / a^3 -> 26.25 * a^-3 = 26.25 * (r/a) / r / a^2 ... logic check + # Let's stick to the structure: + # term1 = -14/a^2 + # term2 = 26.25 * r / a^3 + # term3 = -17.5 * r^3 / a^5 + # term4 = 5.25 * r^5 / a^7 + # Optimized: + # a^-2 * ( -14 + x * (26.25 + x^2 * (-17.5 + 5.25 * x^2))) + return a2_inv * (-14.0 + x * (26.25 + x2 * (-17.5 + 5.25 * x2))) -def exp_function_a(sq_r, a): - a = float(a) - first_term = BackendTensor.tfnp.divide(sq_r, (a ** 4)) # ! This term is almost always zero. I thnk we can just remove it - second_term = 1 / (a ** 2) - third_term = BackendTensor.tfnp.exp(-(sq_r / (2 * a ** 2))) - return (first_term - second_term) * third_term +@torch.jit.script +def cubic_function_a(r: torch.Tensor, a: float) -> torch.Tensor: + # This one is complex, simpler to let JIT fuse the raw expression than optimize manually and risk bugs + # 7 * (9 * r^5 - 20 * a^2 * r^3 + 15 * a^4 * r - 4 * a^5) / (2 * a^7) + + # However, ensuring float literals helps JIT + return 7.0 * (9.0 * r ** 5 - 20.0 * (a ** 2) * (r ** 3) + 15.0 * (a ** 4) * r - 4.0 * (a ** 5)) / (2.0 * (a ** 7)) + + +@torch.jit.script +def exp_function(sq_r: torch.Tensor, a: float) -> torch.Tensor: + # exp(-(r^2 / (2 a^2))) + return torch.exp(-(sq_r / (2.0 * a * a))) + + +@torch.jit.script +def exp_function_p_div_r(sq_r: torch.Tensor, a: float) -> torch.Tensor: + # -(1 / a^2) * exp(...) + val = torch.exp(-(sq_r / (2.0 * a * a))) + return -(1.0 / (a * a)) * val -square_root_3 = 1.73205080757 +@torch.jit.script +def exp_function_a(sq_r: torch.Tensor, a: float) -> torch.Tensor: + # (sq_r / a^4 - 1/a^2) * exp(...) + a2 = a * a + a4 = a2 * a2 + term1 = sq_r / a4 + term2 = 1.0 / a2 + term3 = torch.exp(-(sq_r / (2.0 * a2))) + return (term1 - term2) * term3 + + +square_root_3 = 1.73205080757 sqrt5 = 2.2360679775 +@torch.jit.script +def matern_function_5_2(r: torch.Tensor, a: float) -> torch.Tensor: + # (1 + sqrt5 * r/a + 5/3 * r^2/a^2) * exp(-sqrt5 * r/a) + # a is float. + # Precompute constants + s5 = 2.2360679775 + + # Common term x = r/a + x = r / a + s5_x = s5 * x + + # Polynomial part: 1 + s5_x + (5/3) * x^2 + poly = 1.0 + s5_x + (1.6666666667 * x * x) -def matern_function_5_2(r, a, nu=5 / 2): - # Using nu=5/2 for the Matern kernel + return poly * torch.exp(-s5_x) - a = float(a) - sqrt5_r_over_ell = sqrt5 * r / a - return (1 + sqrt5_r_over_ell + (5 * r ** 2) / (3 * a ** 2)) * BackendTensor.tfnp.exp(-sqrt5_r_over_ell) +@torch.jit.script +def matern_function_5_2_p_div_r(r: torch.Tensor, a: float) -> torch.Tensor: + # -(5 * exp(...) * (a + sqrt5 * r)) / (3 * a^3) + s5 = 2.2360679775 + x = r / a -def matern_function_5_2_p_div_r(r, a, nu=5 / 2): - a = float(a) - sqrt5_r_over_ell = sqrt5 * r / a - return -(5 * BackendTensor.tfnp.exp(-sqrt5_r_over_ell) * (a + sqrt5 * r)) / (3 * a ** 3) + term_exp = torch.exp(-s5 * x) + numerator = -5.0 * term_exp * (a + s5 * r) + denominator = 3.0 * (a * a * a) + return numerator / denominator -def matern_function_5_2_a(r, a, nu=5 / 2): - a = float(a) - sqrt5_r_over_ell = sqrt5 * r / a - return -5 * BackendTensor.tfnp.exp(-sqrt5_r_over_ell) * (a ** 2 + sqrt5 * a * r - 5 * r ** 2) / (3 * a ** 4) + +@torch.jit.script +def matern_function_5_2_a(r: torch.Tensor, a: float) -> torch.Tensor: + s5 = 2.2360679775 + x = r / a + term_exp = torch.exp(-s5 * x) + + # (a^2 + sqrt5 * a * r - 5 * r^2) + poly = (a * a) + (s5 * a * r) - (5.0 * r * r) + + return -5.0 * term_exp * poly / (3.0 * (a * a * a * a)) @dataclass @@ -73,10 +144,15 @@ class KernelFunction: base_function: Callable derivative_div_r: Callable second_derivative: Callable - consume_sq_distance: bool # * Some kernels can be expressed as a function of the squared distance + consume_sq_distance: bool class AvailableKernelFunctions(Enum): + # We plug in the JIT functions here. + # Note: For NumPy compatibility, GemPy usually handles backend switching elsewhere. + # Since we are optimizing for GPU/PyTorch, providing JIT functions here is safe + # provided the inputs 'r' are Tensors. + cubic = KernelFunction(cubic_function, cubic_function_p_div_r, cubic_function_a, consume_sq_distance=False) exponential = KernelFunction(exp_function, exp_function_p_div_r, exp_function_a, consume_sq_distance=True) matern_5_2 = KernelFunction(matern_function_5_2, matern_function_5_2_p_div_r, matern_function_5_2_a, consume_sq_distance=False) diff --git a/gempy_engine/modules/kernel_constructor/_vectors_preparation.py b/gempy_engine/modules/kernel_constructor/_vectors_preparation.py index 373fd87..1dc6cce 100644 --- a/gempy_engine/modules/kernel_constructor/_vectors_preparation.py +++ b/gempy_engine/modules/kernel_constructor/_vectors_preparation.py @@ -210,7 +210,7 @@ def _assembly_drift_grid_tensors(grid: np.ndarray, options: KernelOptions, matri grid_1 = BackendTensor.t.zeros_like(grid) grid_1[:, axis] = 1 - sel = np.ones(options.number_dimensions) + sel = BackendTensor.t.ones(options.number_dimensions) sel[axis] = 0 dips_ug = OrientationsDrift(