From ea27f61992c85b1ae26d03f4593137075d458656 Mon Sep 17 00:00:00 2001 From: Oliver Hennigh Date: Thu, 21 May 2026 16:01:25 -0700 Subject: [PATCH] Fix Warp mesh Poisson residual conflicts --- .../_warp_impl/_kernels.py | 35 +++++++++++++++++++ .../mesh_poisson_disk_sample/_warp_impl/op.py | 32 ++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/physicsnemo/nn/functional/geometry/mesh_poisson_disk_sample/_warp_impl/_kernels.py b/physicsnemo/nn/functional/geometry/mesh_poisson_disk_sample/_warp_impl/_kernels.py index 3aec265e39..970219a57f 100644 --- a/physicsnemo/nn/functional/geometry/mesh_poisson_disk_sample/_warp_impl/_kernels.py +++ b/physicsnemo/nn/functional/geometry/mesh_poisson_disk_sample/_warp_impl/_kernels.py @@ -208,6 +208,41 @@ def _commit_accepted_candidates( accepted_radii[accepted_idx] = candidate_radii[candidate_idx] +# Deterministically prune any conflicts left by the parallel candidate pass. +@wp.kernel +def _mark_accepted_conflicts( + hashgrid_id: wp.uint64, + accepted_positions: wp.array(dtype=wp.vec3f), + accepted_radii: wp.array(dtype=wp.float32), + accepted_alive: wp.array(dtype=wp.int32), + search_radius: wp.float32, +): + sample_idx = wp.tid() + if accepted_alive[sample_idx] == 0: + return + + sample_position = accepted_positions[sample_idx] + sample_radius = accepted_radii[sample_idx] + + neighbor_idx = int(0) + query = wp.hash_grid_query(hashgrid_id, sample_position, search_radius) + while wp.hash_grid_query_next(query, neighbor_idx): + if neighbor_idx >= sample_idx: + continue + if neighbor_idx >= accepted_positions.shape[0]: + continue + + neighbor_radius = accepted_radii[neighbor_idx] + min_radius = wp.min(sample_radius, neighbor_radius) + if _points_too_close( + sample_position, + accepted_positions[neighbor_idx], + min_radius, + ): + accepted_alive[sample_idx] = 0 + return + + # Compute Yuksel sample-elimination contribution for one pairwise distance. @wp.func def _wse_pair_weight( diff --git a/physicsnemo/nn/functional/geometry/mesh_poisson_disk_sample/_warp_impl/op.py b/physicsnemo/nn/functional/geometry/mesh_poisson_disk_sample/_warp_impl/op.py index ae0b280ab3..5e47ef0f0b 100644 --- a/physicsnemo/nn/functional/geometry/mesh_poisson_disk_sample/_warp_impl/op.py +++ b/physicsnemo/nn/functional/geometry/mesh_poisson_disk_sample/_warp_impl/op.py @@ -31,6 +31,7 @@ _count_wse_neighbors, _generate_surface_candidates, _initialize_wse_weights_from_csr, + _mark_accepted_conflicts, _mark_wse_deleted_batch, _reject_candidates_vs_accepted, _resolve_candidate_conflicts, @@ -851,7 +852,36 @@ def _run_dart_throwing_pass( pass_seed=random_seed, pass_limit=max_points, ) - return accepted_positions[:final_count].contiguous() + if final_count <= 1: + return accepted_positions[:final_count].contiguous() + + final_positions = accepted_positions[:final_count] + final_radii = accepted_radii[:final_count] + final_alive = torch.ones( + (final_count,), + device=mesh_vertices.device, + dtype=torch.int32, + ) + final_search_radius = max(min_distance, adaptive_max_radius) + accepted_grid.build( + points=wp.from_torch(final_positions, dtype=wp.vec3f), + radius=final_search_radius, + ) + wp.launch( + kernel=_mark_accepted_conflicts, + dim=final_count, + inputs=[ + accepted_grid.id, + wp.from_torch(final_positions, dtype=wp.vec3f, return_ctype=True), + wp.from_torch(final_radii, dtype=wp.float32, return_ctype=True), + wp.from_torch(final_alive, dtype=wp.int32, return_ctype=True), + float(final_search_radius), + ], + device=wp_launch_device, + stream=wp_launch_stream, + ) + kept_indices = torch.nonzero(final_alive != 0, as_tuple=False).squeeze(1) + return final_positions.index_select(0, kept_indices).contiguous() # Public alias used by the FunctionSpec wrapper.