diff --git a/warp/_src/fem/adaptivity.py b/warp/_src/fem/adaptivity.py index 388ffd2f17..78c43252f3 100644 --- a/warp/_src/fem/adaptivity.py +++ b/warp/_src/fem/adaptivity.py @@ -65,6 +65,7 @@ def adaptive_nanogrid_from_hierarchy( device=device, inputs=[l, voxel_offsets[l], grid_voxels, merged_ijks], ) + grid_voxels.release() # Allocate merged grid grid_info = grids[0].get_grid_info() @@ -74,6 +75,7 @@ def adaptive_nanogrid_from_hierarchy( translation=grid_info.translation, device=device, ) + merged_ijks.release() # Get unique voxel and corresponding level cell_count = cell_grid.get_voxel_count() @@ -89,6 +91,7 @@ def adaptive_nanogrid_from_hierarchy( dim=cell_count, inputs=[level_count, cell_grid_ids, cell_ijk, cell_level], ) + cell_ijk.release() cell_grid, cell_level = enforce_nanogrid_grading( cell_grid, cell_level, level_count=level_count, grading=grading, temporary_store=temporary_store @@ -170,10 +173,17 @@ def adaptive_nanogrid_from_field( fine_level, ], ) + cell_refinement.release() + + prev_cell_ijk = cell_ijk + prev_cell_level = cell_level # Fine is the new coarse cell_ijk = fine_ijk cell_level = fine_level + prev_cell_ijk.release() + prev_cell_level.release() + fine_count.release() wp.launch(_adjust_refined_ijk, dim=fine_shape, device=device, inputs=[cell_ijk, cell_level]) @@ -195,6 +205,8 @@ def adaptive_nanogrid_from_field( device=device, inputs=[fine_grid.id, cell_ijk, cell_level, fine_level], ) + cell_ijk.release() + cell_level.release() fine_grid, fine_level = enforce_nanogrid_grading( fine_grid, fine_level, level_count=level_count, grading=grading, temporary_store=temporary_store @@ -262,6 +274,8 @@ def enforce_nanogrid_grading( # Add new coordinates fine_shape = int(fine_count.numpy()[0]) if fine_shape == cell_count: + cell_ijk.release() + refinement.release() break fine_ijk = cache.borrow_temporary(temporary_store, shape=fine_shape, dtype=wp.vec3i, device=device) @@ -280,6 +294,8 @@ def enforce_nanogrid_grading( fine_level, ], ) + cell_ijk.release() + refinement.release() # Rebuild grid and levels cell_grid = wp.Volume.allocate_by_voxels( @@ -292,7 +308,10 @@ def enforce_nanogrid_grading( device=device, inputs=[cell_grid.id, fine_ijk, fine_level, cell_level], ) + fine_ijk.release() + fine_level.release() + fine_count.release() return cell_grid, cell_level diff --git a/warp/_src/fem/domain.py b/warp/_src/fem/domain.py index 1a81fd0a10..8ce5a5db77 100644 --- a/warp/_src/fem/domain.py +++ b/warp/_src/fem/domain.py @@ -455,13 +455,21 @@ def __init__( if element_indices is None: if element_mask is None: raise ValueError("Either 'element_mask' or 'element_indices' should be provided") - element_indices, _ = utils.masked_indices(mask=element_mask, temporary_store=temporary_store) + element_indices, element_global_to_local = utils.masked_indices( + mask=element_mask, temporary_store=temporary_store + ) element_indices = element_indices.detach() + element_global_to_local.release() + owns_element_indices = True elif element_mask is not None: raise ValueError("Only one of 'element_mask' and 'element_indices' should be provided") + else: + # If Temporary are passed, then they are not owned by the class, hence cannot be released by the class + owns_element_indices = False self._domain = domain self._element_indices = element_indices + self._owns_element_indices = owns_element_indices self.ElementIndexArg = self._make_element_index_arg() self.element_index = self._make_element_index() @@ -479,6 +487,13 @@ def __init__( self.element_partition_lookup = self._domain.element_partition_lookup self.element_normal = self._domain.element_normal + def __del__(self): + if getattr(self, "_owns_element_indices", False): + element_indices = getattr(self, "_element_indices", None) + if element_indices is not None and hasattr(element_indices, "release"): + element_indices.release() + self._element_indices = None + @property def name(self) -> str: return f"{self._domain.name}_Subdomain" diff --git a/warp/_src/fem/geometry/adaptive_nanogrid.py b/warp/_src/fem/geometry/adaptive_nanogrid.py index 075c2f6b9d..830c9c7822 100644 --- a/warp/_src/fem/geometry/adaptive_nanogrid.py +++ b/warp/_src/fem/geometry/adaptive_nanogrid.py @@ -417,8 +417,10 @@ def _build_face_grid(self, temporary_store: Optional[cache.TemporaryStore] = Non boundary_face_mask, ], ) - boundary_face_indices, _ = utils.masked_indices(boundary_face_mask) - self._boundary_face_indices = boundary_face_indices.detach() + boundary_face_indices, boundary_face_global_to_local = utils.masked_indices(boundary_face_mask) + self._replace_boundary_face_indices(boundary_face_indices.detach()) + boundary_face_global_to_local.release() + boundary_face_mask.release() def _ensure_stacked_edge_grid(self): if self._stacked_edge_grid is None: @@ -565,6 +567,7 @@ def _build_node_grid(cell_ijk, cell_level, cell_grid: wp.Volume, temporary_store node_grid = wp.Volume.allocate_by_voxels( cell_nodes.flatten(), voxel_size=cell_grid.get_voxel_size()[0], device=cell_ijk.device ) + cell_nodes.release() return node_grid @@ -577,6 +580,7 @@ def _build_cell_face_grid(cell_ijk, cell_level, grid: wp.Volume, temporary_store face_grid = wp.Volume.allocate_by_voxels( cell_faces.flatten(), voxel_size=grid.get_voxel_size()[0], device=cell_ijk.device ) + cell_faces.release() return face_grid @@ -634,6 +638,9 @@ def _build_completed_face_grid( face_grid = wp.Volume.allocate_by_voxels( cat_face_ijk.flatten(), voxel_size=cell_face_grid.get_voxel_size(), device=device ) + cat_face_ijk.release() + cell_face_ijk.release() + additional_face_count.release() return face_grid @@ -651,6 +658,7 @@ def _build_stacked_face_grid(cell_ijk, cell_level, grid: wp.Volume, temporary_st face_grid = wp.Volume.allocate_by_voxels( cell_faces.flatten(), voxel_size=grid.get_voxel_size()[0], device=cell_ijk.device ) + cell_faces.release() return face_grid @@ -668,6 +676,7 @@ def _build_stacked_edge_grid(cell_ijk, cell_level, grid: wp.Volume, temporary_st edge_grid = wp.Volume.allocate_by_voxels( cell_edges.flatten(), voxel_size=grid.get_voxel_size()[0], device=cell_ijk.device ) + cell_edges.release() return edge_grid diff --git a/warp/_src/fem/geometry/hexmesh.py b/warp/_src/fem/geometry/hexmesh.py index ea0d215eed..19e57e4561 100644 --- a/warp/_src/fem/geometry/hexmesh.py +++ b/warp/_src/fem/geometry/hexmesh.py @@ -157,8 +157,9 @@ def __init__( self._face_vertex_indices: wp.array = None self._face_hex_indices: wp.array = None self._face_hex_face_orientation: wp.array = None - self._vertex_hex_offsets: wp.array = None - self._vertex_hex_indices: wp.array = None + self._vertex_hex_offsets: wp.array = None # owned temporary reused between rebuilds + self._vertex_hex_indices: wp.array = None # owned temporary reused between rebuilds + self._boundary_face_indices: wp.array = None # owned temporary reused between rebuilds self._hex_edge_indices: wp.array = None self._edge_count = 0 self._build_topology(temporary_store) @@ -183,6 +184,9 @@ def __init__( if build_bvh: self.build_bvh(self.positions.device) + def __del__(self): + self._release_owned_temporaries() + def cell_count(self): return self.hex_vertex_indices.shape[0] @@ -447,8 +451,8 @@ def _build_topology(self, temporary_store: TemporaryStore): vertex_hex_offsets, vertex_hex_indices = compress_node_indices( self.vertex_count(), self.hex_vertex_indices, temporary_store=temporary_store ) - self._vertex_hex_offsets = vertex_hex_offsets.detach() - self._vertex_hex_indices = vertex_hex_indices.detach() + self._replace_owned_array("_vertex_hex_offsets", vertex_hex_offsets.detach()) + self._replace_owned_array("_vertex_hex_indices", vertex_hex_indices.detach()) vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count()) vertex_start_face_count.zero_() @@ -546,8 +550,24 @@ def _build_topology(self, temporary_store: TemporaryStore): ], ) - boundary_face_indices, _ = masked_indices(boundary_mask) - self._boundary_face_indices = boundary_face_indices.detach() + boundary_face_indices, boundary_face_global_to_local = masked_indices(boundary_mask) + self._replace_owned_array("_boundary_face_indices", boundary_face_indices.detach()) + boundary_face_global_to_local.release() + boundary_mask.release() + + def _replace_owned_array(self, attr_name: str, new_value): + if hasattr(self, attr_name): + old_value = getattr(self, attr_name) + if old_value is not None and old_value is not new_value and hasattr(old_value, "release"): + old_value.release() + setattr(self, attr_name, new_value) + + def _release_owned_temporaries(self): + for attr in ("_vertex_hex_offsets", "_vertex_hex_indices", "_boundary_face_indices"): + value = getattr(self, attr, None) + if value is not None and hasattr(value, "release"): + value.release() + setattr(self, attr, None) def _compute_hex_edges(self, temporary_store: Optional[TemporaryStore] = None): from warp._src.fem.utils import host_read_at_index diff --git a/warp/_src/fem/geometry/nanogrid.py b/warp/_src/fem/geometry/nanogrid.py index 7f98e83a14..cbc0b5228a 100644 --- a/warp/_src/fem/geometry/nanogrid.py +++ b/warp/_src/fem/geometry/nanogrid.py @@ -59,6 +59,9 @@ def __init__( self._cell_grid_info = cell_grid.get_grid_info() self._init_transform() + def __del__(self): + self._release_owned_temporaries() + def reference_cell(self) -> Element: return Element.CUBE @@ -126,6 +129,17 @@ def fill_side_index_arg(self, arg: SideIndexArg, device): def boundary_side_index(args: SideIndexArg, boundary_side_index: int): return args.boundary_face_indices[boundary_side_index] + def _replace_boundary_face_indices(self, new_value): + if self._boundary_face_indices is not None and self._boundary_face_indices is not new_value: + if hasattr(self._boundary_face_indices, "release"): + self._boundary_face_indices.release() + self._boundary_face_indices = new_value + + def _release_owned_temporaries(self): + if self._boundary_face_indices is not None and hasattr(self._boundary_face_indices, "release"): + self._boundary_face_indices.release() + self._boundary_face_indices = None + def make_filtered_cell_lookup(grid_geo, filter_func: wp.Function = None): suffix = f"{grid_geo.name}{filter_func.key if filter_func is not None else ''}" @@ -539,8 +553,10 @@ def _build_face_grid(self, temporary_store: Optional[cache.TemporaryStore] = Non device=device, inputs=[self._cell_grid.id, self._face_ijk, self._face_flags, boundary_face_mask], ) - boundary_face_indices, _ = utils.masked_indices(boundary_face_mask) - self._boundary_face_indices = boundary_face_indices.detach() + boundary_face_indices, boundary_face_global_to_local = utils.masked_indices(boundary_face_mask) + self._replace_boundary_face_indices(boundary_face_indices.detach()) + boundary_face_global_to_local.release() + boundary_face_mask.release() def _build_edge_grid(self, temporary_store: Optional[cache.TemporaryStore] = None): self._edge_grid = _build_edge_grid(self._cell_ijk, self._cell_grid, temporary_store) @@ -608,6 +624,7 @@ def _build_node_grid(cell_ijk, grid: wp.Volume, temporary_store: cache.Temporary node_grid = wp.Volume.allocate_by_voxels( cell_nodes.flatten(), voxel_size=grid.get_voxel_size(), device=cell_ijk.device ) + cell_nodes.release() return node_grid @@ -620,6 +637,7 @@ def _build_face_grid(cell_ijk, grid: wp.Volume, temporary_store: cache.Temporary face_grid = wp.Volume.allocate_by_voxels( cell_faces.flatten(), voxel_size=grid.get_voxel_size(), device=cell_ijk.device ) + cell_faces.release() return face_grid @@ -632,6 +650,7 @@ def _build_edge_grid(cell_ijk, grid: wp.Volume, temporary_store: cache.Temporary edge_grid = wp.Volume.allocate_by_voxels( cell_edges.flatten(), voxel_size=grid.get_voxel_size(), device=cell_ijk.device ) + cell_edges.release() return edge_grid diff --git a/warp/_src/fem/geometry/partition.py b/warp/_src/fem/geometry/partition.py index 7532aa240c..2300892e6b 100644 --- a/warp/_src/fem/geometry/partition.py +++ b/warp/_src/fem/geometry/partition.py @@ -177,9 +177,12 @@ def __init__( ): super().__init__(geometry) - self._partition_side_indices: wp.array = None - self._boundary_side_indices: wp.array = None - self._frontier_side_indices: wp.array = None + self._partition_side_indices: wp.array = None # owned temporary reused between rebuilds + self._boundary_side_indices: wp.array = None # owned temporary reused between rebuilds + self._frontier_side_indices: wp.array = None # owned temporary reused between rebuilds + + def __del__(self): + self._release_owned_temporaries() @cached_property def SideArg(self): @@ -236,9 +239,8 @@ def compute_side_indices_from_cells( self.side_arg_value.invalidate(self) if max_side_count == 0: - self._partition_side_indices = cache.borrow_temporary(temporary_store, dtype=int, shape=(0,), device=device) - self._boundary_side_indices = self._partition_side_indices - self._frontier_side_indices = self._partition_side_indices + empty = cache.borrow_temporary(temporary_store, dtype=int, shape=(0,), device=device) + self._set_partition_side_arrays(empty, empty, empty) return cell_arg_type = next(iter(cell_inclusion_test_func.input_types.values())) @@ -307,29 +309,54 @@ def count_sides( ) # Convert counts to indices - self._partition_side_indices, _ = masked_indices( + new_partition_side_indices, partition_side_global_to_local = masked_indices( partition_side_mask, max_index_count=max_side_count, local_to_global=self._partition_side_indices, temporary_store=temporary_store, ) - self._boundary_side_indices, _ = masked_indices( + self._replace_owned_array("_partition_side_indices", new_partition_side_indices) + partition_side_global_to_local.release() + new_boundary_side_indices, boundary_side_global_to_local = masked_indices( boundary_side_mask, max_index_count=max_side_count, local_to_global=self._boundary_side_indices, temporary_store=temporary_store, ) - self._frontier_side_indices, _ = masked_indices( + self._replace_owned_array("_boundary_side_indices", new_boundary_side_indices) + boundary_side_global_to_local.release() + new_frontier_side_indices, frontier_side_global_to_local = masked_indices( frontier_side_mask, max_index_count=max_side_count, local_to_global=self._frontier_side_indices, temporary_store=temporary_store, ) + self._replace_owned_array("_frontier_side_indices", new_frontier_side_indices) + frontier_side_global_to_local.release() partition_side_mask.release() boundary_side_mask.release() frontier_side_mask.release() + def _set_partition_side_arrays(self, partition, boundary, frontier): + self._replace_owned_array("_partition_side_indices", partition) + self._replace_owned_array("_boundary_side_indices", boundary) + self._replace_owned_array("_frontier_side_indices", frontier) + + def _replace_owned_array(self, attr_name: str, new_value): + if hasattr(self, attr_name): + old_value = getattr(self, attr_name) + if old_value is not None and old_value is not new_value and hasattr(old_value, "release"): + old_value.release() + setattr(self, attr_name, new_value) + + def _release_owned_temporaries(self): + for attr in ("_partition_side_indices", "_boundary_side_indices", "_frontier_side_indices"): + value = getattr(self, attr, None) + if value is not None and hasattr(value, "release"): + value.release() + setattr(self, attr, None) + @wp.func def side_to_cell_arg(side_arg: Any): return side_arg.cell_arg @@ -426,6 +453,11 @@ def __init__( self.rebuild(cell_mask, temporary_store) + def __del__(self): + self._replace_owned_array("_cells", None) + self._replace_owned_array("_partition_cells", None) + super().__del__() + def rebuild( self, cell_mask: "wp.array(dtype=int)", @@ -442,13 +474,15 @@ def rebuild( """ self.cell_arg_value.invalidate(self) - self._cells, self._partition_cells = masked_indices( + new_cells, new_partition_cells = masked_indices( cell_mask, local_to_global=self._cells, global_to_local=self._partition_cells, max_index_count=self._max_cell_count, temporary_store=temporary_store, ) + self._replace_owned_array("_cells", new_cells) + self._replace_owned_array("_partition_cells", new_partition_cells) super().compute_side_indices_from_cells( self.cell_arg_value(cell_mask.device), diff --git a/warp/_src/fem/geometry/quadmesh.py b/warp/_src/fem/geometry/quadmesh.py index c1f1391eec..061647e6e8 100644 --- a/warp/_src/fem/geometry/quadmesh.py +++ b/warp/_src/fem/geometry/quadmesh.py @@ -68,8 +68,9 @@ def __init__( self._edge_vertex_indices: wp.array = None self._edge_quad_indices: wp.array = None - self._vertex_quad_offsets: wp.array = None - self._vertex_quad_indices: wp.array = None + self._vertex_quad_offsets: wp.array = None # owned temporary reused between rebuilds + self._vertex_quad_indices: wp.array = None # owned temporary reused between rebuilds + self._boundary_edge_indices: wp.array = None # owned temporary reused between rebuilds self._build_topology(temporary_store) # Flip edges so that normals point away from inner cell @@ -88,6 +89,9 @@ def __init__( if build_bvh: self.build_bvh(self.positions.device) + def __del__(self): + self._release_owned_temporaries() + def cell_count(self): return self.quad_vertex_indices.shape[0] @@ -204,8 +208,8 @@ def _build_topology(self, temporary_store: TemporaryStore): vertex_quad_offsets, vertex_quad_indices = compress_node_indices( self.vertex_count(), self.quad_vertex_indices, temporary_store=temporary_store ) - self._vertex_quad_offsets = vertex_quad_offsets.detach() - self._vertex_quad_indices = vertex_quad_indices.detach() + self._replace_owned_array("_vertex_quad_offsets", vertex_quad_offsets.detach()) + self._replace_owned_array("_vertex_quad_indices", vertex_quad_indices.detach()) vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count()) vertex_start_edge_count.zero_() @@ -279,11 +283,28 @@ def _build_topology(self, temporary_store: TemporaryStore): vertex_edge_ends.release() vertex_edge_quads.release() - boundary_edge_indices, _ = masked_indices(boundary_mask, temporary_store=temporary_store) - self._boundary_edge_indices = boundary_edge_indices.detach() + boundary_edge_indices, boundary_edge_global_to_local = masked_indices( + boundary_mask, temporary_store=temporary_store + ) + self._replace_owned_array("_boundary_edge_indices", boundary_edge_indices.detach()) + boundary_edge_global_to_local.release() boundary_mask.release() + def _replace_owned_array(self, attr_name: str, new_value): + if hasattr(self, attr_name): + old_value = getattr(self, attr_name) + if old_value is not None and old_value is not new_value and hasattr(old_value, "release"): + old_value.release() + setattr(self, attr_name, new_value) + + def _release_owned_temporaries(self): + for attr in ("_vertex_quad_offsets", "_vertex_quad_indices", "_boundary_edge_indices"): + value = getattr(self, attr, None) + if value is not None and hasattr(value, "release"): + value.release() + setattr(self, attr, None) + @wp.kernel def _count_starting_edges_kernel( quad_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int) diff --git a/warp/_src/fem/geometry/tetmesh.py b/warp/_src/fem/geometry/tetmesh.py index f5a405bc65..c94fcd7608 100644 --- a/warp/_src/fem/geometry/tetmesh.py +++ b/warp/_src/fem/geometry/tetmesh.py @@ -78,8 +78,9 @@ def __init__( self._face_vertex_indices: wp.array = None self._face_tet_indices: wp.array = None - self._vertex_tet_offsets: wp.array = None - self._vertex_tet_indices: wp.array = None + self._vertex_tet_offsets: wp.array = None # owned temporary reused between rebuilds + self._vertex_tet_indices: wp.array = None # owned temporary reused between rebuilds + self._boundary_face_indices: wp.array = None # owned temporary reused between rebuilds self._tet_edge_indices: wp.array = None self._edge_count = 0 self._build_topology(temporary_store) @@ -92,6 +93,9 @@ def __init__( if build_bvh: self.build_bvh(self.positions.device) + def __del__(self): + self._release_owned_temporaries() + def cell_count(self): return self.tet_vertex_indices.shape[0] @@ -318,8 +322,8 @@ def _build_topology(self, temporary_store: TemporaryStore): vertex_tet_offsets, vertex_tet_indices = compress_node_indices( self.vertex_count(), self.tet_vertex_indices, temporary_store=temporary_store ) - self._vertex_tet_offsets = vertex_tet_offsets.detach() - self._vertex_tet_indices = vertex_tet_indices.detach() + self._replace_owned_array("_vertex_tet_offsets", vertex_tet_offsets.detach()) + self._replace_owned_array("_vertex_tet_indices", vertex_tet_indices.detach()) vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count()) vertex_start_face_count.zero_() @@ -401,8 +405,24 @@ def _build_topology(self, temporary_store: TemporaryStore): inputs=[self._face_vertex_indices, self._face_tet_indices, self.tet_vertex_indices, self.positions], ) - boundary_face_indices, _ = masked_indices(boundary_mask) - self._boundary_face_indices = boundary_face_indices.detach() + boundary_face_indices, boundary_face_global_to_local = masked_indices(boundary_mask) + self._replace_owned_array("_boundary_face_indices", boundary_face_indices.detach()) + boundary_face_global_to_local.release() + boundary_mask.release() + + def _replace_owned_array(self, attr_name: str, new_value): + if hasattr(self, attr_name): + old_value = getattr(self, attr_name) + if old_value is not None and old_value is not new_value and hasattr(old_value, "release"): + old_value.release() + setattr(self, attr_name, new_value) + + def _release_owned_temporaries(self): + for attr in ("_vertex_tet_offsets", "_vertex_tet_indices", "_boundary_face_indices"): + value = getattr(self, attr, None) + if value is not None and hasattr(value, "release"): + value.release() + setattr(self, attr, None) def _compute_tet_edges(self, temporary_store: Optional[TemporaryStore] = None): from warp._src.fem.utils import host_read_at_index diff --git a/warp/_src/fem/geometry/trimesh.py b/warp/_src/fem/geometry/trimesh.py index a65604f1d6..04c3d0a108 100644 --- a/warp/_src/fem/geometry/trimesh.py +++ b/warp/_src/fem/geometry/trimesh.py @@ -75,6 +75,9 @@ def __init__( self._edge_vertex_indices: wp.array = None self._edge_tri_indices: wp.array = None + self._vertex_tri_offsets: wp.array = None # owned temporary reused across rebuilds for vertex adjacency + self._vertex_tri_indices: wp.array = None # owned temporary reused across rebuilds for vertex adjacency + self._boundary_edge_indices: wp.array = None # owned temporary reused to expose boundary sides self._build_topology(temporary_store) # Flip edges so that normals point away from inner cell @@ -92,6 +95,9 @@ def __init__( if build_bvh: self.build_bvh(self.positions.device) + def __del__(self): + self._release_owned_temporaries() + def cell_count(self): return self.tri_vertex_indices.shape[0] @@ -209,8 +215,8 @@ def _build_topology(self, temporary_store: TemporaryStore): vertex_tri_offsets, vertex_tri_indices = compress_node_indices( self.vertex_count(), self.tri_vertex_indices, temporary_store=temporary_store ) - self._vertex_tri_offsets = vertex_tri_offsets.detach() - self._vertex_tri_indices = vertex_tri_indices.detach() + self._replace_owned_array("_vertex_tri_offsets", vertex_tri_offsets.detach()) + self._replace_owned_array("_vertex_tri_indices", vertex_tri_indices.detach()) vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count()) vertex_start_edge_count.zero_() @@ -282,11 +288,27 @@ def _build_topology(self, temporary_store: TemporaryStore): vertex_edge_ends.release() vertex_edge_tris.release() - boundary_edge_indices, _ = masked_indices(boundary_mask, temporary_store=temporary_store) - self._boundary_edge_indices = boundary_edge_indices.detach() + boundary_edge_indices, boundary_edge_global_to_local = masked_indices( + boundary_mask, temporary_store=temporary_store + ) + self._replace_owned_array("_boundary_edge_indices", boundary_edge_indices.detach()) + boundary_edge_global_to_local.release() boundary_mask.release() + def _replace_owned_array(self, attr_name: str, new_value): + old_value = getattr(self, attr_name, None) + if old_value is not None and old_value is not new_value and hasattr(old_value, "release"): + old_value.release() + setattr(self, attr_name, new_value) + + def _release_owned_temporaries(self): + for attr in ("_vertex_tri_offsets", "_vertex_tri_indices", "_boundary_edge_indices"): + value = getattr(self, attr, None) + if value is not None and hasattr(value, "release"): + value.release() + setattr(self, attr, None) + @wp.kernel def _count_starting_edges_kernel( tri_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int) diff --git a/warp/_src/fem/integrate.py b/warp/_src/fem/integrate.py index 9695eb3121..00605598b1 100644 --- a/warp/_src/fem/integrate.py +++ b/warp/_src/fem/integrate.py @@ -1324,13 +1324,16 @@ def _launch_integrate_kernel( if output == accumulate_array: return output if output is None: - return accumulate_array.numpy()[0] + result = accumulate_array.numpy()[0] + _release_temporary(accumulate_array) + return result if add_to_output: # accumulate dtype is distinct from output dtype array_axpy(x=accumulate_array, y=output) else: array_cast(in_array=accumulate_array, out_array=output) + _release_temporary(accumulate_array) return output test_arg = test.space_restriction.node_arg_value(device=device) @@ -1350,6 +1353,7 @@ def _launch_integrate_kernel( f"Incompatible output type {type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}" ) + # Result is handed back to the caller; they must release it once finished. output = cache.borrow_temporary( temporary_store=temporary_store, shape=output_shape, @@ -1496,9 +1500,9 @@ def as_2d_array(array): dtype=output_dtype, device=device, ) - triplet_cols = triplet_cols_temp.array - triplet_rows = triplet_rows_temp.array - triplet_values = triplet_values_temp.array + triplet_cols = triplet_cols_temp + triplet_rows = triplet_rows_temp + triplet_values = triplet_values_temp if nodal: wp.launch( @@ -2367,9 +2371,9 @@ def _launch_interpolate_kernel( shape=(nnz, *dest.block_shape), device=device, ) - triplet_cols = triplet_cols_temp.array - triplet_rows = triplet_rows_temp.array - triplet_values = triplet_values_temp.array + triplet_cols = triplet_cols_temp + triplet_rows = triplet_rows_temp + triplet_values = triplet_values_temp triplet_rows.fill_(-1) trial_partition_arg = trial.space_partition.partition_arg_value(device) @@ -2396,6 +2400,10 @@ def _launch_interpolate_kernel( bsr_set_from_triplets(dest, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {})) + triplet_values_temp.release() + triplet_rows_temp.release() + triplet_cols_temp.release() + @integrand def _identity_field(field: Field, s: Sample): @@ -2512,3 +2520,6 @@ def interpolate( bsr_options=bsr_options, device=device, ) +def _release_temporary(array): + if array is not None and hasattr(array, "release"): + array.release() diff --git a/warp/_src/fem/quadrature/pic_quadrature.py b/warp/_src/fem/quadrature/pic_quadrature.py index 1cfe03c560..41f292fda4 100644 --- a/warp/_src/fem/quadrature/pic_quadrature.py +++ b/warp/_src/fem/quadrature/pic_quadrature.py @@ -61,6 +61,13 @@ def __init__( super().__init__(domain) self._requires_grad = requires_grad + self._cell_particle_offsets: wp.array = None # owned temporary reused for per-cell particle ranges + self._cell_particle_indices: wp.array = None # owned temporary reused for element-to-particle lookup + self._cell_count: wp.array = None # owned temporary reused to cache active-cell count on device + self._cell_index_temp: wp.array = None # owned temporary reused when binning positions + self._particle_coords_temp: wp.array = None # owned temporary reused when binning positions + self._particle_fraction_temp: wp.array = None # owned temporary reused when computing fractions + self._bin_particles(positions, measures, max_dist=max_dist, temporary_store=temporary_store) self._max_particles_per_cell: int = None @@ -68,6 +75,9 @@ def __init__( def name(self): return self.__class__.__name__ + def __del__(self): + self._release_owned_temporaries() + @Quadrature.domain.setter def domain(self, domain: GeometryDomain): # Allow changing the quadrature domain as long as underlying geometry and element kind are the same @@ -213,13 +223,15 @@ def bin_particles( else: cell_coords[p] = cell_coordinates(cell_arg_value, sample.element_index, positions[p]) - self._cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device) - self.cell_indices = self._cell_index_temp.array + cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device) + self._replace_owned_array("_cell_index_temp", cell_index_temp) + self.cell_indices = self._cell_index_temp - self._particle_coords_temp = borrow_temporary( + particle_coords_temp = borrow_temporary( temporary_store, shape=positions.shape, dtype=Coords, device=device, requires_grad=self._requires_grad ) - self.particle_coords = self._particle_coords_temp.array + self._replace_owned_array("_particle_coords_temp", particle_coords_temp) + self.particle_coords = self._particle_coords_temp wp.launch( dim=positions.shape[0], @@ -242,25 +254,35 @@ def bin_particles( if self.cell_indices.shape != self.particle_coords.shape: raise ValueError("Cell index and coordinates arrays must have the same shape") - self._cell_index_temp = None - self._particle_coords_temp = None + self._replace_owned_array("_cell_index_temp", None) + self._replace_owned_array("_particle_coords_temp", None) - self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices( + ( + cell_particle_offsets, + cell_particle_indices, + cell_count, + unused_unique_nodes, + ) = compress_node_indices( self.domain.geometry_element_count(), self.cell_indices, return_unique_nodes=True, temporary_store=temporary_store, ) + self._replace_owned_array("_cell_particle_offsets", cell_particle_offsets.detach()) + self._replace_owned_array("_cell_particle_indices", cell_particle_indices.detach()) + self._replace_owned_array("_cell_count", cell_count.detach()) + unused_unique_nodes.release() self._compute_fraction(self.cell_indices, measures, temporary_store) def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStore): device = cell_index.device - self._particle_fraction_temp = borrow_temporary( + particle_fraction_temp = borrow_temporary( temporary_store, shape=cell_index.shape, dtype=float, device=device, requires_grad=self._requires_grad ) - self._particle_fraction = self._particle_fraction_temp.array + self._replace_owned_array("_particle_fraction_temp", particle_fraction_temp) + self._particle_fraction = self._particle_fraction_temp if measures is None: # Split fraction uniformly over all particles in cell @@ -312,6 +334,28 @@ def compute_fraction( device=device, ) + def _replace_owned_array(self, attr_name: str, new_value): + if hasattr(self, attr_name): + old_value = getattr(self, attr_name) + if old_value is not None and old_value is not new_value and hasattr(old_value, "release"): + old_value.release() + setattr(self, attr_name, new_value) + + def _release_owned_temporaries(self): + for attr in ( + "_cell_particle_offsets", + "_cell_particle_indices", + "_cell_count", + "_cell_index_temp", + "_particle_coords_temp", + "_particle_fraction_temp", + ): + if hasattr(self, attr): + value = getattr(self, attr) + if value is not None and hasattr(value, "release"): + value.release() + setattr(self, attr, None) + @wp.kernel def _max_particles_per_cell_kernel(offsets: wp.array(dtype=int), max_count: wp.array(dtype=int)): cell = wp.tid() diff --git a/warp/_src/fem/space/partition.py b/warp/_src/fem/space/partition.py index 044233a114..9c1ef712bb 100644 --- a/warp/_src/fem/space/partition.py +++ b/warp/_src/fem/space/partition.py @@ -73,6 +73,13 @@ def __str__(self) -> str: def name(self) -> str: return f"{self.__class__.__name__}" + def __del__(self): + """Return any cached temporaries we own back to the TemporaryStore.""" + self._release_owned_temporaries() + + def _release_owned_temporaries(self): + pass + class WholeSpacePartition(SpacePartition): @wp.struct @@ -123,6 +130,11 @@ def name(self) -> str: def _iota_kernel(indices: wp.array(dtype=int)): indices[wp.tid()] = wp.tid() + def _release_owned_temporaries(self): + if self._node_indices is not None and hasattr(self._node_indices, "release"): + self._node_indices.release() + self._node_indices = None + class NodeCategory: OWNED_INTERIOR = wp.constant(0) @@ -162,11 +174,11 @@ def __init__( self._with_halo = with_halo self._category_offsets: wp.array = None - """Offsets for each node category""" + """Offsets for each node category (owned temporary reused between rebuilds)""" self._node_indices: wp.array = None - """Mapping from local partition node indices to global space node indices""" + """Mapping from local partition node indices to global space node indices (owned temporary reused between rebuilds)""" self._space_to_partition: wp.array = None - """Mapping from global space node indices to local partition node indices""" + """Mapping from global space node indices to local partition node indices (owned temporary reused between rebuilds)""" self.rebuild(device, temporary_store) @@ -343,6 +355,8 @@ def _finalize_node_indices( # Compute global to local indices if self._space_to_partition is None or self._space_to_partition.shape != node_indices.shape: + if self._space_to_partition is not None: + self._space_to_partition.release() self._space_to_partition = cache.borrow_temporary_like(node_indices, temporary_store) wp.launch( @@ -357,12 +371,27 @@ def _finalize_node_indices( # Copy to shrunk-to-fit array if self._node_indices is None or self._node_indices.shape[0] != self.node_count(): + if self._node_indices is not None: + self._node_indices.release() self._node_indices = cache.borrow_temporary( temporary_store, shape=(self.node_count(),), dtype=int, device=device ) wp.copy(dest=self._node_indices, src=node_indices, count=self.node_count()) node_indices.release() + category_offsets.release() + + def _release_owned_temporaries(self): + super()._release_owned_temporaries() + if self._category_offsets is not None and hasattr(self._category_offsets, "release"): + self._category_offsets.release() + self._category_offsets = None + if self._node_indices is not None and hasattr(self._node_indices, "release"): + self._node_indices.release() + self._node_indices = None + if self._space_to_partition is not None and hasattr(self._space_to_partition, "release"): + self._space_to_partition.release() + self._space_to_partition = None @wp.kernel def _scatter_partition_indices( diff --git a/warp/_src/fem/space/restriction.py b/warp/_src/fem/space/restriction.py index aed7f2688c..f9bc66c893 100644 --- a/warp/_src/fem/space/restriction.py +++ b/warp/_src/fem/space/restriction.py @@ -51,19 +51,25 @@ def __init__( self.domain = domain self._node_count_dev: wp.array = None - """Number of unique partition node indices""" + """Number of unique partition node indices (owned temporary borrowed from cache; released once synchronized or on destruction)""" self._dof_partition_indices: wp.array = None - """Array of unique partition node indices""" + """Array of unique partition node indices (owned temporary borrowed from cache; released on resize or destruction)""" self._dof_partition_element_offsets: wp.array = None - """Mapping from partition node to offset in the per-node element indices array""" + """Mapping from partition node to offset in the per-node element indices array (owned temporary borrowed from cache; released on resize/destruction)""" self._dof_element_indices: wp.array = None - """Concatenation of neighboring elements indices for each partition node""" + """Concatenation of neighboring elements indices for each partition node (owned temporary borrowed from cache; reused across rebuilds)""" self._dof_indices_in_element: wp.array = None - """Concatenation of node index in element for each partition node""" + """Concatenation of node index in element for each partition node (owned temporary borrowed from cache; reused across rebuilds)""" self.rebuild(device=device, temporary_store=temporary_store) + def __del__(self): + # SpaceRestriction instances are not expected to participate in reference cycles, so explicitly + # releasing owned temporaries here is safe and prevents holding onto cache buffers until GC runs. + if hasattr(self, "_release_owned_temporaries"): + self._release_owned_temporaries() + def rebuild(self, device: Optional = None, temporary_store: Optional[cache.TemporaryStore] = None): max_nodes_per_element = self.space_topology.MAX_NODES_PER_ELEMENT @@ -115,10 +121,10 @@ def fill_element_node_indices( # Build compressed map from node to element indices flattened_node_indices = element_node_indices.flatten() ( - self._dof_partition_element_offsets, + new_partition_element_offsets, node_array_indices, - self._node_count_dev, - self._dof_partition_indices, + new_node_count_dev, + new_partition_indices, ) = compress_node_indices( self.space_partition.node_count(), flattened_node_indices, @@ -129,10 +135,17 @@ def fill_element_node_indices( temporary_store=temporary_store, ) + self._replace_owned_temporary("_dof_partition_element_offsets", new_partition_element_offsets) + self._replace_owned_temporary("_node_count_dev", new_node_count_dev) + self._replace_owned_temporary("_dof_partition_indices", new_partition_indices) + # Extract element index and index in element if self._dof_element_indices is None or self._dof_element_indices.shape != flattened_node_indices.shape: - self._dof_element_indices = cache.borrow_temporary_like(flattened_node_indices, temporary_store) - self._dof_indices_in_element = cache.borrow_temporary_like(flattened_node_indices, temporary_store) + new_dof_element_indices = cache.borrow_temporary_like(flattened_node_indices, temporary_store) + new_dof_indices_in_element = cache.borrow_temporary_like(flattened_node_indices, temporary_store) + + self._replace_owned_temporary("_dof_element_indices", new_dof_element_indices) + self._replace_owned_temporary("_dof_indices_in_element", new_dof_indices_in_element) wp.launch( kernel=SpaceRestriction._split_vertex_element_index, @@ -147,6 +160,7 @@ def fill_element_node_indices( ) node_array_indices.release() + element_node_indices.release() # Upper bound on node count, use `node_count_sync` to get the actual value self._node_count = min(self.space_partition.node_count(), self._dof_partition_indices.shape[0]) @@ -155,6 +169,8 @@ def node_count_sync(self) -> int: """Ensures that the node count is synchronized with the device and returns it""" if self._node_count_dev is not None: self._node_count = int(host_read_at_index(self._node_count_dev, index=0)) + if hasattr(self._node_count_dev, "release"): + self._node_count_dev.release() self._node_count_dev = None return self.node_count() @@ -219,3 +235,22 @@ def _split_vertex_element_index( element_index = idx // vertex_per_element vertex_element_index[wp.tid()] = element_index vertex_index_in_element[wp.tid()] = idx - vertex_per_element * element_index + + def _replace_owned_temporary(self, attr_name: str, new_value): + """Return previously owned temporaries to the cache before overwriting them.""" + if hasattr(self, attr_name): + old_value = getattr(self, attr_name) + if old_value is not None and old_value is not new_value and hasattr(old_value, "release"): + old_value.release() + setattr(self, attr_name, new_value) + + def _release_owned_temporaries(self): + for attr in ( + "_dof_partition_element_offsets", + "_dof_partition_indices", + "_dof_element_indices", + "_dof_indices_in_element", + "_node_count_dev", + ): + if hasattr(self, attr): + self._replace_owned_temporary(attr, None) diff --git a/warp/_src/fem/utils.py b/warp/_src/fem/utils.py index a430022e63..0a2c5c3c88 100644 --- a/warp/_src/fem/utils.py +++ b/warp/_src/fem/utils.py @@ -124,10 +124,12 @@ def compress_node_indices( # Build prefix sum of number of elements per node node_element_counts = cache.borrow_temporary(temporary_store, shape=index_count, dtype=int) - if unique_node_indices is None or unique_node_indices.shape != node_element_counts.shape: + owns_unique_node_indices = unique_node_indices is None or unique_node_indices.shape != node_element_counts.shape + if owns_unique_node_indices: unique_node_indices = cache.borrow_temporary_like(node_element_counts, temporary_store) - if unique_node_count is None or unique_node_count.shape != (1,): + owns_unique_node_count = unique_node_count is None or unique_node_count.shape != (1,) + if owns_unique_node_count: unique_node_count = cache.borrow_temporary(temporary_store, shape=(1,), dtype=int) runlength_encode( @@ -156,6 +158,10 @@ def compress_node_indices( node_element_counts.release() if not return_unique_nodes: + if owns_unique_node_indices: + unique_node_indices.release() + if owns_unique_node_count: + unique_node_count.release() return node_offsets, sorted_array_indices return node_offsets, sorted_array_indices, unique_node_count, unique_node_indices