diff --git a/topostats/grains.py b/topostats/grains.py index 781fddf9c46..4ee03fd1797 100644 --- a/topostats/grains.py +++ b/topostats/grains.py @@ -377,21 +377,31 @@ def area_thresholding(self, image: npt.NDArray, area_thresholds: tuple) -> npt.N upper_size_limit = image.size * self.pixel_to_nm_scaling**2 if lower_size_limit is None: lower_size_limit = 0 - # Get array of grain numbers (discounting zero) - uniq = np.delete(np.unique(image), 0) - grain_count = 0 LOGGER.debug( f"[{self.filename}] : Area thresholding grains | Thresholds: L: {(lower_size_limit / self.pixel_to_nm_scaling**2):.2f}," f"U: {(upper_size_limit / self.pixel_to_nm_scaling**2):.2f} px^2, L: {lower_size_limit:.2f}, U: {upper_size_limit:.2f} nm^2." ) - for grain_no in uniq: # Calculate grian area in nm^2 - grain_area = np.sum(image_cp == grain_no) * (self.pixel_to_nm_scaling**2) - # Compare area in nm^2 to area thresholds - if grain_area > upper_size_limit or grain_area < lower_size_limit: - image_cp[image_cp == grain_no] = 0 - else: - grain_count += 1 - image_cp[image_cp == grain_no] = grain_count + + grain_counts = np.bincount(image_cp.ravel()) + grain_counts = grain_counts[1:] + # Calculate areas in nm^2 + grain_areas = grain_counts * (self.pixel_to_nm_scaling**2) + + # Create a mask for valid grains + valid_grains = (grain_areas >= lower_size_limit) & (grain_areas <= upper_size_limit) + + # Create a new mapping for valid grain numbers + new_indices = np.arange(1, valid_grains.sum() + 1) # New indices for valid grains + valid_grain_numbers = np.where(valid_grains)[0] + 1 # Original grain numbers that are valid + + # Step 1: Create a boolean mask for valid grains + valid_mask = np.isin(image_cp, valid_grain_numbers) + # Step 2: Set invalid values to 0 + image_cp[~valid_mask] = 0 # Invert the mask to find invalid values + # Map old grain numbers to new ones + for new_idx, old_idx in enumerate(valid_grain_numbers): + image_cp[image_cp == old_idx] = new_indices[new_idx] + return image_cp def colour_regions(self, image: npt.NDArray, **kwargs) -> npt.NDArray: