@@ -238,7 +238,7 @@ def _init_degrees(
238238 mmax : int = None ,
239239 used_l : np .ndarray = None ,
240240 used_m : np .ndarray = None ,
241- ) -> tuple [np .ndarray , np .ndarray , int , int , int , int ]:
241+ ) -> tuple [xr . Dataset , np .ndarray , np .ndarray , int , int , int , int ]:
242242 """
243243 Initialize spherical harmonic degrees and orders to be used.
244244
@@ -257,6 +257,8 @@ def _init_degrees(
257257
258258 Returns
259259 -------
260+ sub_data : xr.Dataset
261+ Reduced dataset with selected degrees and orders.
260262 used_l: np.ndarray
261263 Degrees to use.
262264 used_m: np.ndarray
@@ -276,7 +278,8 @@ def _init_degrees(
276278 mmin = int (data .m .min ()) if mmin is None else mmin
277279 used_l = np .arange (lmin , lmax + 1 ) if used_l is None else used_l
278280 used_m = np .arange (mmin , mmax + 1 ) if used_m is None else used_m
279- return used_l , used_m , lmin , lmax , mmin , mmax
281+ sub_data = data .sel (l = used_l , m = used_m )
282+ return sub_data , used_l , used_m , lmin , lmax , mmin , mmax
280283
281284
282285def sh_to_grid (
@@ -304,6 +307,8 @@ def sh_to_grid(
304307 include_elastic : bool = True ,
305308 plm : xr .DataArray = None ,
306309 normalization_plm : Literal ["4pi" , "ortho" , "schmidt" ] = "4pi" ,
310+ use_dask : bool = False ,
311+ chunks_plm : dict | None = None ,
307312 ** kwargs ,
308313) -> xr .DataArray :
309314 """
@@ -374,6 +379,11 @@ def sh_to_grid(
374379 Either '4pi', 'ortho', or 'schmidt' for 4pi normalized, orthonormalized, or Schmidt semi-normalized SH
375380 functions, respectively. Default is '4pi'.
376381
382+ use_dask : bool, optional
383+ If True, use dask to chunk plm for memory optimization. Default is False.
384+ chunks_plm : dict, optional
385+ Define the chunking of plm when use_dask is True. Default is None, which set the chunking to {'latitude': 1}.
386+
377387 **kwargs :
378388 Supplementary parameters used by the function l_factor_conv to modify defaults constants used in the computation
379389 for the unit conversion. These parameters include (see :func:`l_factor_conv` documentation for more details) :
@@ -386,7 +396,7 @@ def sh_to_grid(
386396 """
387397 # add mask in output variable
388398
389- used_l , used_m , lmin , lmax , mmin , mmax = _init_degrees (
399+ sub_data , used_l , used_m , lmin , lmax , mmin , mmax = _init_degrees (
390400 data , lmin , lmax , mmin , mmax , used_l , used_m
391401 )
392402 used_l , use_czero_coef , force_mass_conservation = _handle_mass_conservation (
@@ -415,22 +425,22 @@ def sh_to_grid(
415425 plm = compute_plm (
416426 lmax ,
417427 np .cos (geocentric_colat ),
428+ latitude = latitude ,
418429 mmax = mmax ,
419430 normalization = normalization_plm ,
431+ use_dask = use_dask ,
432+ chunks = chunks_plm ,
420433 )
421434 else :
422435 plm = compute_plm (
423- lmax , sin_latitude , mmax = mmax , normalization = normalization_plm
436+ lmax ,
437+ sin_latitude ,
438+ latitude = latitude ,
439+ mmax = mmax ,
440+ normalization = normalization_plm ,
441+ use_dask = use_dask ,
442+ chunks = chunks_plm ,
424443 )
425- plm = xr .DataArray (
426- plm ,
427- dims = ["l" , "m" , "latitude" ],
428- coords = {
429- "l" : np .arange (lmax + 1 ),
430- "m" : np .arange (mmax + 1 ),
431- "latitude" : latitude ,
432- },
433- )
434444
435445 else :
436446 # Verify plm integrity
@@ -463,7 +473,7 @@ def sh_to_grid(
463473 include_elastic = include_elastic ,
464474 ellipsoidal_earth = ellipsoidal_earth ,
465475 geocentric_colat = geocentric_colat ,
466- attrs = data .attrs ,
476+ attrs = sub_data .attrs ,
467477 ** kwargs ,
468478 )
469479
@@ -484,14 +494,14 @@ def sh_to_grid(
484494
485495 # summation over all spherical harmonic degrees
486496 if not errors :
487- d_clm = (plm_lfactor * data . sel ( l = used_l , m = used_m ) .clm ).sum (dim = "l" )
488- d_slm = (plm_lfactor * data . sel ( l = used_l , m = used_m ) .slm ).sum (dim = "l" )
497+ d_clm = (plm_lfactor * sub_data .clm ).sum (dim = "l" )
498+ d_slm = (plm_lfactor * sub_data .slm ).sum (dim = "l" )
489499
490500 # Final calcul on the grid
491501 xgrid = c_cos .dot (d_clm ) + s_sin .dot (d_slm )
492502 else :
493- d_clm = (plm_lfactor ** 2 * data . sel ( l = used_l , m = used_m ) .clm ** 2 ).sum (dim = "l" )
494- d_slm = (plm_lfactor ** 2 * data . sel ( l = used_l , m = used_m ) .slm ** 2 ).sum (dim = "l" )
503+ d_clm = (plm_lfactor ** 2 * sub_data .clm ** 2 ).sum (dim = "l" )
504+ d_slm = (plm_lfactor ** 2 * sub_data .slm ** 2 ).sum (dim = "l" )
495505
496506 # Final calcul of sigma on the grid
497507 xgrid = np .sqrt ((c_cos ** 2 ).dot (d_clm ) + (s_sin ** 2 ).dot (d_slm ))
@@ -518,17 +528,17 @@ def sh_to_grid(
518528 # restore C0 mass
519529 if use_czero_coef :
520530 lfactor_zero = l_factor_conv (
521- np .array ([0 ]), unit = unit , attrs = data .attrs , ** kwargs
531+ np .array ([0 ]), unit = unit , attrs = sub_data .attrs , ** kwargs
522532 )[0 ]
523- xgrid = xgrid + (lfactor_zero * data .clm .sel (l = 0 , m = 0 )).values
533+ xgrid = xgrid + (lfactor_zero * sub_data .clm .sel (l = 0 , m = 0 )).values
524534
525535 xgrid = xgrid .transpose ("latitude" , "longitude" , ...)
526536
527537 xgrid .attrs = {"units" : unit , "max_degree" : int (lmax )}
528- if "radius" in data .attrs :
529- xgrid .attrs ["radius" ] = data .attrs ["radius" ]
530- if "earth_gravity_constant" in data .attrs :
531- xgrid .attrs ["earth_gravity_constant" ] = data .attrs ["earth_gravity_constant" ]
538+ if "radius" in sub_data .attrs :
539+ xgrid .attrs ["radius" ] = sub_data .attrs ["radius" ]
540+ if "earth_gravity_constant" in sub_data .attrs :
541+ xgrid .attrs ["earth_gravity_constant" ] = sub_data .attrs ["earth_gravity_constant" ]
532542
533543 return xgrid
534544
@@ -546,6 +556,8 @@ def grid_to_sh(
546556 include_elastic : bool = True ,
547557 plm : xr .DataArray | None = None ,
548558 normalization_plm : Literal ["4pi" , "ortho" , "schmidt" ] = "4pi" ,
559+ use_dask : bool = False ,
560+ chunks_plm : dict | None = None ,
549561 ** kwargs ,
550562) -> xr .Dataset :
551563 """
@@ -591,6 +603,11 @@ def grid_to_sh(
591603 4pi normalized, orthonormalized, or Schmidt semi-normalized SH functions, respectively. Default is '4pi'.
592604 Output SH coefficient will be normalized according to this parameter.
593605
606+ use_dask : bool, optional
607+ If True, use dask to chunk plm for memory optimization. Default is False.
608+ chunks_plm : dict, optional
609+ Define the chunking of plm when use_dask is True. Default is None, which set the chunking to {'latitude': 1}.
610+
594611 **kwargs :
595612 Supplementary parameters used by the function l_factor_conv to modify defaults constants used in the computation
596613 for the unit conversion. These parameters include (see :func:`l_factor_conv` documentation for more details) :
@@ -651,22 +668,21 @@ def grid_to_sh(
651668 plm = compute_plm (
652669 lmax ,
653670 np .cos (geocentric_colat ),
671+ latitude = grid .cf ["latitude" ],
654672 mmax = mmax ,
655673 normalization = normalization_plm ,
674+ use_dask = use_dask ,
675+ chunks = chunks_plm ,
656676 )
657677 else :
658678 plm = compute_plm (
659- lmax , sin_latitude , mmax = mmax , normalization = normalization_plm
679+ lmax , sin_latitude ,
680+ latitude = grid .cf ["latitude" ],
681+ mmax = mmax ,
682+ normalization = normalization_plm ,
683+ use_dask = use_dask ,
684+ chunks = chunks_plm ,
660685 )
661- plm = xr .DataArray (
662- plm ,
663- dims = ["l" , "m" , "latitude" ],
664- coords = {
665- "l" : np .arange (lmax + 1 ),
666- "m" : np .arange (lmax + 1 ),
667- "latitude" : grid .cf ["latitude" ],
668- },
669- )
670686
671687 else :
672688 # Verify plm integrity
@@ -736,10 +752,13 @@ def grid_to_sh(
736752def compute_plm (
737753 lmax : int ,
738754 z : np .ndarray ,
755+ latitude : np .ndarray = None ,
739756 mmax : int = None ,
740757 normalization : Literal ["4pi" , "ortho" , "schmidt" ] = "4pi" ,
741758 dtype : complex | float | type [complex ] | type [float ] = np .float128 ,
742- ) -> np .ndarray :
759+ use_dask : bool = False ,
760+ chunks : dict | None = None ,
761+ ) -> xr .DataArray :
743762 """
744763 Compute all the associated Legendre functions up to a maximum degree and
745764 order using the recursion relation from [Holmes2002]_
@@ -751,6 +770,8 @@ def compute_plm(
751770 Maximum degree of legrendre functions.
752771 z : np.ndarray
753772 Argument of the associated Legendre functions.
773+ latitude : np.ndarray, optional
774+ Latitude values in degrees. Default is None and latitude is made from z.
754775 mmax : int or NoneType, optional
755776 Maximum order of associated legrendre functions.
756777 normalization : {'4pi', 'ortho', 'schmidt'}, optional
@@ -759,10 +780,15 @@ def compute_plm(
759780 dtype : dtype, optional
760781 Data type of the output array. Default is np.float128.
761782
783+ use_dask : bool, optional
784+ If True, use dask to chunk plm for memory optimization. Default is False.
785+ chunks : dict, optional
786+ Define the chunking of plm when use_dask is True. Default is None, which set the chunking to {'latitude': 1}.
787+
762788 Returns
763789 -------
764- plm : np.ndarray
765- Fully-normalized Legendre functions as a 3D array with "l", "m" and z dimensions.
790+ plm : xr.DataArray
791+ Fully-normalized Legendre functions as a DataArray with "l", "m" and "latitude" dimensions.
766792
767793 References
768794 ----------
@@ -789,6 +815,9 @@ def compute_plm(
789815 # if default mmax, set mmax to be maximal degree
790816 mmax = lmax if mmax is None else mmax
791817
818+ # if default latitude, set it from z
819+ latitude = z if latitude is None else latitude
820+
792821 f1 , f2 , norm_p10 , norm_4pi = _compute_factors (lmax , normalization )
793822
794823 # scale factor based on Holmes2002
@@ -851,8 +880,24 @@ def compute_plm(
851880 ind = np .tril_indices (lmax + 1 )
852881 plm [ind ] = p
853882
883+ plm = xr .DataArray (
884+ plm [:, : mmax + 1 , :],
885+ dims = ["l" , "m" , "latitude" ],
886+ coords = {
887+ "l" : np .arange (lmax + 1 ),
888+ "m" : np .arange (mmax + 1 ),
889+ "latitude" : latitude , #grid.cf["latitude"]
890+ },
891+ )
892+
893+ # Chunking plm for dask usage and memory optimization
894+ if use_dask :
895+ if chunks is None :
896+ chunks = {"latitude" : 1 }
897+ plm = plm .chunk (chunks )
898+
854899 # return the legendre polynomials and truncating orders to mmax
855- return plm [:, : mmax + 1 , :]
900+ return plm
856901
857902
858903def mid_month_grace_estimate (
0 commit comments