@@ -115,60 +115,6 @@ def generic_aggregate(
115115 return result
116116
117117
118- def _normalize_dtype (dtype : DTypeLike , array_dtype : np .dtype , fill_value = None ) -> np .dtype :
119- if dtype is None :
120- dtype = array_dtype
121- if dtype is np .floating :
122- # mean, std, var always result in floating
123- # but we preserve the array's dtype if it is floating
124- if array_dtype .kind in "fcmM" :
125- dtype = array_dtype
126- else :
127- dtype = np .dtype ("float64" )
128- elif not isinstance (dtype , np .dtype ):
129- dtype = np .dtype (dtype )
130- if fill_value not in [None , dtypes .INF , dtypes .NINF , dtypes .NA ]:
131- dtype = np .result_type (dtype , fill_value )
132- return dtype
133-
134-
135- def _maybe_promote_int (dtype ) -> np .dtype :
136- # https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137- # The dtype of a is used by default unless a has an integer dtype of less precision
138- # than the default platform integer.
139- if not isinstance (dtype , np .dtype ):
140- dtype = np .dtype (dtype )
141- if dtype .kind == "i" :
142- dtype = np .result_type (dtype , np .intp )
143- elif dtype .kind == "u" :
144- dtype = np .result_type (dtype , np .uintp )
145- return dtype
146-
147-
148- def _get_fill_value (dtype , fill_value ):
149- """Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150- if fill_value in [None , dtypes .NA ] and dtype .kind in "US" :
151- return ""
152- if fill_value == dtypes .INF or fill_value is None :
153- return dtypes .get_pos_infinity (dtype , max_for_int = True )
154- if fill_value == dtypes .NINF :
155- return dtypes .get_neg_infinity (dtype , min_for_int = True )
156- if fill_value == dtypes .NA :
157- if np .issubdtype (dtype , np .floating ) or np .issubdtype (dtype , np .complexfloating ):
158- return np .nan
159- # This is madness, but npg checks that fill_value is compatible
160- # with array dtype even if the fill_value is never used.
161- elif np .issubdtype (dtype , np .integer ):
162- return dtypes .get_neg_infinity (dtype , min_for_int = True )
163- elif np .issubdtype (dtype , np .timedelta64 ):
164- return np .timedelta64 ("NaT" )
165- elif np .issubdtype (dtype , np .datetime64 ):
166- return np .datetime64 ("NaT" )
167- else :
168- return None
169- return fill_value
170-
171-
172118def _atleast_1d (inp , min_length : int = 1 ):
173119 if xrutils .is_scalar (inp ):
174120 inp = (inp ,) * min_length
@@ -646,7 +592,7 @@ def last(self) -> AlignedArrays:
646592 # TODO: automate?
647593 engine = "flox" ,
648594 dtype = self .array .dtype ,
649- fill_value = _get_fill_value (self .array .dtype , dtypes .NA ),
595+ fill_value = dtypes . _get_fill_value (self .array .dtype , dtypes .NA ),
650596 expected_groups = None ,
651597 )
652598 return AlignedArrays (array = reduced ["intermediates" ][0 ], group_idx = reduced ["groups" ])
@@ -829,7 +775,9 @@ def _initialize_aggregation(
829775 np .dtype (dtype ) if dtype is not None and not isinstance (dtype , np .dtype ) else dtype
830776 )
831777
832- final_dtype = _normalize_dtype (dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value )
778+ final_dtype = dtypes ._normalize_dtype (
779+ dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value
780+ )
833781 if agg .name not in [
834782 "first" ,
835783 "last" ,
@@ -841,14 +789,14 @@ def _initialize_aggregation(
841789 "nanmax" ,
842790 "topk" ,
843791 ]:
844- final_dtype = _maybe_promote_int (final_dtype )
792+ final_dtype = dtypes . _maybe_promote_int (final_dtype )
845793 agg .dtype = {
846794 "user" : dtype , # Save to automatically choose an engine
847795 "final" : final_dtype ,
848796 "numpy" : (final_dtype ,),
849797 "intermediate" : tuple (
850798 (
851- _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
799+ dtypes . _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
852800 if int_dtype is None
853801 else np .dtype (int_dtype )
854802 )
@@ -863,10 +811,10 @@ def _initialize_aggregation(
863811 # Replace sentinel fill values according to dtype
864812 agg .fill_value ["user" ] = fill_value
865813 agg .fill_value ["intermediate" ] = tuple (
866- _get_fill_value (dt , fv )
814+ dtypes . _get_fill_value (dt , fv )
867815 for dt , fv in zip (agg .dtype ["intermediate" ], agg .fill_value ["intermediate" ])
868816 )
869- agg .fill_value [func ] = _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
817+ agg .fill_value [func ] = dtypes . _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
870818
871819 fv = fill_value if fill_value is not None else agg .fill_value [agg .name ]
872820 if _is_arg_reduction (agg ):
0 commit comments