|
8 | 8 | from textwrap import dedent |
9 | 9 | from typing import ( |
10 | 10 | TYPE_CHECKING, |
11 | | - Hashable, |
12 | 11 | Literal, |
13 | | - Sequence, |
14 | 12 | cast, |
15 | | - final, |
16 | 13 | ) |
17 | 14 | import warnings |
18 | 15 |
|
|
29 | 26 | ArrayLike, |
30 | 27 | AxisInt, |
31 | 28 | DtypeObj, |
32 | | - IndexLabel, |
33 | 29 | TakeIndexer, |
34 | 30 | npt, |
35 | 31 | ) |
|
97 | 93 |
|
98 | 94 | from pandas import ( |
99 | 95 | Categorical, |
100 | | - DataFrame, |
101 | 96 | Index, |
102 | 97 | Series, |
103 | 98 | ) |
@@ -1167,227 +1162,6 @@ def checked_add_with_arr( |
1167 | 1162 | return result |
1168 | 1163 |
|
1169 | 1164 |
|
1170 | | -# --------------- # |
1171 | | -# select n # |
1172 | | -# --------------- # |
1173 | | - |
1174 | | - |
1175 | | -class SelectN: |
1176 | | - def __init__(self, obj, n: int, keep: str) -> None: |
1177 | | - self.obj = obj |
1178 | | - self.n = n |
1179 | | - self.keep = keep |
1180 | | - |
1181 | | - if self.keep not in ("first", "last", "all"): |
1182 | | - raise ValueError('keep must be either "first", "last" or "all"') |
1183 | | - |
1184 | | - def compute(self, method: str) -> DataFrame | Series: |
1185 | | - raise NotImplementedError |
1186 | | - |
1187 | | - @final |
1188 | | - def nlargest(self): |
1189 | | - return self.compute("nlargest") |
1190 | | - |
1191 | | - @final |
1192 | | - def nsmallest(self): |
1193 | | - return self.compute("nsmallest") |
1194 | | - |
1195 | | - @final |
1196 | | - @staticmethod |
1197 | | - def is_valid_dtype_n_method(dtype: DtypeObj) -> bool: |
1198 | | - """ |
1199 | | - Helper function to determine if dtype is valid for |
1200 | | - nsmallest/nlargest methods |
1201 | | - """ |
1202 | | - return ( |
1203 | | - not is_complex_dtype(dtype) |
1204 | | - if is_numeric_dtype(dtype) |
1205 | | - else needs_i8_conversion(dtype) |
1206 | | - ) |
1207 | | - |
1208 | | - |
1209 | | -class SelectNSeries(SelectN): |
1210 | | - """ |
1211 | | - Implement n largest/smallest for Series |
1212 | | -
|
1213 | | - Parameters |
1214 | | - ---------- |
1215 | | - obj : Series |
1216 | | - n : int |
1217 | | - keep : {'first', 'last'}, default 'first' |
1218 | | -
|
1219 | | - Returns |
1220 | | - ------- |
1221 | | - nordered : Series |
1222 | | - """ |
1223 | | - |
1224 | | - def compute(self, method: str) -> Series: |
1225 | | - from pandas.core.reshape.concat import concat |
1226 | | - |
1227 | | - n = self.n |
1228 | | - dtype = self.obj.dtype |
1229 | | - if not self.is_valid_dtype_n_method(dtype): |
1230 | | - raise TypeError(f"Cannot use method '{method}' with dtype {dtype}") |
1231 | | - |
1232 | | - if n <= 0: |
1233 | | - return self.obj[[]] |
1234 | | - |
1235 | | - dropped = self.obj.dropna() |
1236 | | - nan_index = self.obj.drop(dropped.index) |
1237 | | - |
1238 | | - # slow method |
1239 | | - if n >= len(self.obj): |
1240 | | - ascending = method == "nsmallest" |
1241 | | - return self.obj.sort_values(ascending=ascending).head(n) |
1242 | | - |
1243 | | - # fast method |
1244 | | - new_dtype = dropped.dtype |
1245 | | - arr = _ensure_data(dropped.values) |
1246 | | - if method == "nlargest": |
1247 | | - arr = -arr |
1248 | | - if is_integer_dtype(new_dtype): |
1249 | | - # GH 21426: ensure reverse ordering at boundaries |
1250 | | - arr -= 1 |
1251 | | - |
1252 | | - elif is_bool_dtype(new_dtype): |
1253 | | - # GH 26154: ensure False is smaller than True |
1254 | | - arr = 1 - (-arr) |
1255 | | - |
1256 | | - if self.keep == "last": |
1257 | | - arr = arr[::-1] |
1258 | | - |
1259 | | - nbase = n |
1260 | | - narr = len(arr) |
1261 | | - n = min(n, narr) |
1262 | | - |
1263 | | - # arr passed into kth_smallest must be contiguous. We copy |
1264 | | - # here because kth_smallest will modify its input |
1265 | | - kth_val = algos.kth_smallest(arr.copy(order="C"), n - 1) |
1266 | | - (ns,) = np.nonzero(arr <= kth_val) |
1267 | | - inds = ns[arr[ns].argsort(kind="mergesort")] |
1268 | | - |
1269 | | - if self.keep != "all": |
1270 | | - inds = inds[:n] |
1271 | | - findex = nbase |
1272 | | - else: |
1273 | | - if len(inds) < nbase <= len(nan_index) + len(inds): |
1274 | | - findex = len(nan_index) + len(inds) |
1275 | | - else: |
1276 | | - findex = len(inds) |
1277 | | - |
1278 | | - if self.keep == "last": |
1279 | | - # reverse indices |
1280 | | - inds = narr - 1 - inds |
1281 | | - |
1282 | | - return concat([dropped.iloc[inds], nan_index]).iloc[:findex] |
1283 | | - |
1284 | | - |
1285 | | -class SelectNFrame(SelectN): |
1286 | | - """ |
1287 | | - Implement n largest/smallest for DataFrame |
1288 | | -
|
1289 | | - Parameters |
1290 | | - ---------- |
1291 | | - obj : DataFrame |
1292 | | - n : int |
1293 | | - keep : {'first', 'last'}, default 'first' |
1294 | | - columns : list or str |
1295 | | -
|
1296 | | - Returns |
1297 | | - ------- |
1298 | | - nordered : DataFrame |
1299 | | - """ |
1300 | | - |
1301 | | - def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None: |
1302 | | - super().__init__(obj, n, keep) |
1303 | | - if not is_list_like(columns) or isinstance(columns, tuple): |
1304 | | - columns = [columns] |
1305 | | - |
1306 | | - columns = cast(Sequence[Hashable], columns) |
1307 | | - columns = list(columns) |
1308 | | - self.columns = columns |
1309 | | - |
1310 | | - def compute(self, method: str) -> DataFrame: |
1311 | | - from pandas.core.api import Index |
1312 | | - |
1313 | | - n = self.n |
1314 | | - frame = self.obj |
1315 | | - columns = self.columns |
1316 | | - |
1317 | | - for column in columns: |
1318 | | - dtype = frame[column].dtype |
1319 | | - if not self.is_valid_dtype_n_method(dtype): |
1320 | | - raise TypeError( |
1321 | | - f"Column {repr(column)} has dtype {dtype}, " |
1322 | | - f"cannot use method {repr(method)} with this dtype" |
1323 | | - ) |
1324 | | - |
1325 | | - def get_indexer(current_indexer, other_indexer): |
1326 | | - """ |
1327 | | - Helper function to concat `current_indexer` and `other_indexer` |
1328 | | - depending on `method` |
1329 | | - """ |
1330 | | - if method == "nsmallest": |
1331 | | - return current_indexer.append(other_indexer) |
1332 | | - else: |
1333 | | - return other_indexer.append(current_indexer) |
1334 | | - |
1335 | | - # Below we save and reset the index in case index contains duplicates |
1336 | | - original_index = frame.index |
1337 | | - cur_frame = frame = frame.reset_index(drop=True) |
1338 | | - cur_n = n |
1339 | | - indexer = Index([], dtype=np.int64) |
1340 | | - |
1341 | | - for i, column in enumerate(columns): |
1342 | | - # For each column we apply method to cur_frame[column]. |
1343 | | - # If it's the last column or if we have the number of |
1344 | | - # results desired we are done. |
1345 | | - # Otherwise there are duplicates of the largest/smallest |
1346 | | - # value and we need to look at the rest of the columns |
1347 | | - # to determine which of the rows with the largest/smallest |
1348 | | - # value in the column to keep. |
1349 | | - series = cur_frame[column] |
1350 | | - is_last_column = len(columns) - 1 == i |
1351 | | - values = getattr(series, method)( |
1352 | | - cur_n, keep=self.keep if is_last_column else "all" |
1353 | | - ) |
1354 | | - |
1355 | | - if is_last_column or len(values) <= cur_n: |
1356 | | - indexer = get_indexer(indexer, values.index) |
1357 | | - break |
1358 | | - |
1359 | | - # Now find all values which are equal to |
1360 | | - # the (nsmallest: largest)/(nlargest: smallest) |
1361 | | - # from our series. |
1362 | | - border_value = values == values[values.index[-1]] |
1363 | | - |
1364 | | - # Some of these values are among the top-n |
1365 | | - # some aren't. |
1366 | | - unsafe_values = values[border_value] |
1367 | | - |
1368 | | - # These values are definitely among the top-n |
1369 | | - safe_values = values[~border_value] |
1370 | | - indexer = get_indexer(indexer, safe_values.index) |
1371 | | - |
1372 | | - # Go on and separate the unsafe_values on the remaining |
1373 | | - # columns. |
1374 | | - cur_frame = cur_frame.loc[unsafe_values.index] |
1375 | | - cur_n = n - len(indexer) |
1376 | | - |
1377 | | - frame = frame.take(indexer) |
1378 | | - |
1379 | | - # Restore the index on frame |
1380 | | - frame.index = original_index.take(indexer) |
1381 | | - |
1382 | | - # If there is only one column, the frame is already sorted. |
1383 | | - if len(columns) == 1: |
1384 | | - return frame |
1385 | | - |
1386 | | - ascending = method == "nsmallest" |
1387 | | - |
1388 | | - return frame.sort_values(columns, ascending=ascending, kind="mergesort") |
1389 | | - |
1390 | | - |
1391 | 1165 | # ---- # |
1392 | 1166 | # take # |
1393 | 1167 | # ---- # |
|
0 commit comments