@@ -2322,6 +2322,72 @@ def _apply_mode_reorder(self, sort_inds_2d):
23222322
23232323 return self .updated_copy (** modify_data )
23242324
2325+ def _apply_mode_subset (self , subset_inds_2d : np .ndarray ) -> ModeSolverData :
2326+ """Return copy of self containing only the selected modes.
2327+
2328+ Parameters
2329+ ----------
2330+ subset_inds_2d : np.ndarray
2331+ Array of shape ``(num_freqs, num_modes_keep)`` containing the indices of the original
2332+ modes to retain at each frequency.
2333+
2334+ Returns
2335+ -------
2336+ :class:`.ModeSolverData`
2337+ Copy of self with only the retained modes.
2338+ """
2339+
2340+ subset_inds_2d = np .asarray (subset_inds_2d , dtype = int )
2341+ if subset_inds_2d .ndim != 2 :
2342+ raise DataError (
2343+ "subset_inds_2d must be a 2D array of shape (num_freqs, num_modes_keep)."
2344+ )
2345+
2346+ num_freqs , num_keep = subset_inds_2d .shape
2347+ if num_keep == 0 :
2348+ raise DataError ("Cannot create a mode subset with zero modes." )
2349+
2350+ num_modes_full = self .n_eff ["mode_index" ].size
2351+
2352+ modify_data = {}
2353+ new_mode_index_coord = np .arange (num_keep )
2354+
2355+ for key , data in self .data_arrs .items ():
2356+ if "mode_index" not in data .dims or "f" not in data .dims :
2357+ continue
2358+
2359+ dims_orig = tuple (data .dims )
2360+ coords_out = {
2361+ k : (v .values if hasattr (v , "values" ) else np .asarray (v ))
2362+ for k , v in data .coords .items ()
2363+ }
2364+
2365+ f_axis = data .get_axis_num ("f" )
2366+ m_axis = data .get_axis_num ("mode_index" )
2367+ src_order = (
2368+ [f_axis ] + [ax for ax in range (data .ndim ) if ax not in (f_axis , m_axis )] + [m_axis ]
2369+ )
2370+
2371+ arr = np .moveaxis (data .data , src_order , range (data .ndim ))
2372+ nf , nm = arr .shape [0 ], arr .shape [- 1 ]
2373+ if nf != num_freqs or nm != num_modes_full :
2374+ raise DataError (
2375+ "subset_inds_2d shape does not match array shape in _apply_mode_subset."
2376+ )
2377+
2378+ arr2 = arr .reshape (nf , - 1 , nm )
2379+ inds = subset_inds_2d [:, None , :]
2380+ arr2_subset = np .take_along_axis (arr2 , inds , axis = 2 )
2381+ arr_subset = arr2_subset .reshape (arr .shape [:- 1 ] + (num_keep ,))
2382+ arr_subset = np .moveaxis (arr_subset , range (data .ndim ), src_order )
2383+
2384+ coords_out ["mode_index" ] = new_mode_index_coord
2385+ coords_out ["f" ] = data .coords ["f" ].values
2386+
2387+ modify_data [key ] = DataArray (arr_subset , coords = coords_out , dims = dims_orig )
2388+
2389+ return self .updated_copy (** modify_data )
2390+
23252391 def sort_modes (
23262392 self , sort_spec : Optional [ModeSortSpec ] = None , track_freq : Optional [TrackFreq ] = None
23272393 ) -> ModeSolverData :
@@ -2356,8 +2422,9 @@ def sort_modes(
23562422 num_freqs = self .n_eff ["f" ].size
23572423 num_modes = self .n_eff ["mode_index" ].size
23582424 all_inds = np .arange (num_modes )
2359- identity = np .arange (num_modes )
2360- sort_inds_2d = np .tile (identity , (num_freqs , 1 ))
2425+ drop_modes = getattr (sort_spec , "drop_modes" , False )
2426+ if drop_modes and sort_spec .filter_key is None :
2427+ raise ValidationError ("ModeSortSpec.drop_modes requires 'filter_key' to be set." )
23612428
23622429 # Helper to compute ordered indices within a subset
23632430 def _order_indices (indices , vals_all ):
@@ -2376,12 +2443,13 @@ def _order_indices(indices, vals_all):
23762443 filter_metric = getattr (self , sort_spec .filter_key )
23772444 if sort_spec .sort_key is not None :
23782445 sort_metric = getattr (self , sort_spec .sort_key )
2446+ identity = np .arange (num_modes )
2447+ sort_inds_2d = np .tile (identity , (num_freqs , 1 ))
23792448
23802449 for ifreq in range (num_freqs ):
23812450 # Build groups according to filter if requested
23822451 if filter_metric is not None :
23832452 vals_filt = filter_metric .isel (f = ifreq ).values
2384- # Boolean mask for modes in the first group
23852453 if sort_spec .filter_order == "over" :
23862454 mask_first = vals_filt >= sort_spec .filter_reference
23872455 else :
@@ -2406,11 +2474,10 @@ def _order_indices(indices, vals_all):
24062474
24072475 sort_inds_2d [ifreq , : len (sort_inds )] = sort_inds
24082476
2409- # If all rows are identity, skip
24102477 if np .all (sort_inds_2d == np .tile (identity , (num_freqs , 1 ))):
24112478 data_sorted = self
24122479 else :
2413- data_sorted = self ._apply_mode_reorder (sort_inds_2d ) # this creates a copy
2480+ data_sorted = self ._apply_mode_reorder (sort_inds_2d )
24142481 data_sorted = data_sorted .updated_copy (
24152482 path = "monitor/mode_spec" , sort_spec = sort_spec , deep = False , validate = False
24162483 )
@@ -2422,6 +2489,38 @@ def _order_indices(indices, vals_all):
24222489 if track_freq and num_freqs > 1 :
24232490 data_sorted = data_sorted .overlap_sort (track_freq )
24242491
2492+ if drop_modes :
2493+ # Re-evaluate the filter after sorting/tracking so modes are dropped consistently.
2494+ filter_metric_sorted = getattr (data_sorted , sort_spec .filter_key )
2495+ masks_after = []
2496+ for ifreq in range (num_freqs ):
2497+ vals = filter_metric_sorted .isel (f = ifreq ).values
2498+ if sort_spec .filter_order == "over" :
2499+ mask = vals >= sort_spec .filter_reference
2500+ else :
2501+ mask = vals <= sort_spec .filter_reference
2502+ masks_after .append (mask )
2503+
2504+ keep_mask = np .all (np .stack (masks_after , axis = 0 ), axis = 0 )
2505+ if not np .any (keep_mask ):
2506+ raise ValidationError (
2507+ "Filtering removes all modes; relax the filter threshold or disable drop_modes."
2508+ )
2509+
2510+ num_modes_sorted = filter_metric_sorted .sizes ["mode_index" ]
2511+ if keep_mask .sum () < num_modes_sorted :
2512+ keep_inds = np .where (keep_mask )[0 ]
2513+ subset_inds_2d = np .tile (keep_inds , (num_freqs , 1 ))
2514+ data_subset = data_sorted ._apply_mode_subset (subset_inds_2d )
2515+ mspec = data_subset .monitor .mode_spec
2516+ mspec_updated = mspec .updated_copy (num_modes = keep_inds .size , validate = False )
2517+ monitor_updated = data_subset .monitor .updated_copy (
2518+ mode_spec = mspec_updated , validate = False
2519+ )
2520+ data_sorted = data_subset .updated_copy (
2521+ monitor = monitor_updated , deep = False , validate = False
2522+ )
2523+
24252524 return data_sorted
24262525
24272526
0 commit comments