Skip to content

Commit 7ced97d

Browse files
author
T. Koskamp
committed
Update indices property from groupby
1 parent a709b22 commit 7ced97d

File tree

2 files changed

+32
-39
lines changed

2 files changed

+32
-39
lines changed

pandas/core/groupby/groupby.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
636636
return self._grouper.indices
637637

638638
@final
639-
def _get_indices(self, names):
639+
def _get_indices(self, name):
640640
"""
641641
Safe get multiple indices, translate keys for
642642
datelike to underlying repr.
@@ -649,28 +649,27 @@ def get_converter(s):
649649
return lambda key: Timestamp(key)
650650
elif isinstance(s, np.datetime64):
651651
return lambda key: Timestamp(key).asm8
652-
elif isna(s):
653-
return lambda key: np.nan
654652
else:
655653
return lambda key: key
656654

657-
if len(names) == 0:
658-
return []
655+
if isna(name):
656+
return self.indices.get(np.nan, [])
657+
if isinstance(name, tuple):
658+
name = tuple(np.nan if isna(comp) else comp for comp in name)
659659

660660
if len(self.indices) > 0:
661661
index_sample = next(iter(self.indices))
662662
else:
663663
index_sample = None # Dummy sample
664664

665-
name_sample = names[0]
666665
if isinstance(index_sample, tuple):
667-
if not isinstance(name_sample, tuple):
666+
if not isinstance(name, tuple):
668667
msg = "must supply a tuple to get_group with multiple grouping keys"
669668
raise ValueError(msg)
670-
if not len(name_sample) == len(index_sample):
669+
if not len(name) == len(index_sample):
671670
try:
672671
# If the original grouper was a tuple
673-
return [self.indices[name] for name in names]
672+
return self.indices[name]
674673
except KeyError as err:
675674
# turns out it wasn't a tuple
676675
msg = (
@@ -679,41 +678,20 @@ def get_converter(s):
679678
)
680679
raise ValueError(msg) from err
681680

682-
has_nan = any(isna(n) for n in name_sample)
683-
684-
sample = name_sample if has_nan else index_sample
685-
converters = (get_converter(s) for s in sample)
686-
687-
names = (
688-
tuple(f(n) for f, n in zip(converters, name, strict=True))
689-
for name in names
690-
)
691-
692-
indices = self.indices
693-
if not self.dropna and has_nan:
694-
indices = {}
695-
for k, v in self.indices.items():
696-
k = tuple(np.nan if isna(e) else e for e in k)
697-
indices[k] = v
681+
converters = (get_converter(s) for s in index_sample)
682+
name = tuple(f(n) for f, n in zip(converters, name, strict=True))
698683
else:
699-
has_nan = isna(name_sample)
700-
701-
convert_sample = name_sample if has_nan else index_sample
702-
converter = get_converter(convert_sample)
703-
names = (converter(name) for name in names)
704-
705-
indices = self.indices
706-
if not self.dropna and has_nan:
707-
indices = {np.nan if isna(k) else k: v for k, v in indices.items()}
684+
converter = get_converter(index_sample)
685+
name = converter(name)
708686

709-
return [indices.get(name, []) for name in names]
687+
return self.indices.get(name, [])
710688

711689
@final
712690
def _get_index(self, name):
713691
"""
714692
Safe get index, translate keys for datelike to underlying repr.
715693
"""
716-
return self._get_indices([name])[0]
694+
return self._get_indices(name)
717695

718696
@final
719697
@cache_readonly

pandas/core/groupby/ops.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,24 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
652652
"""dict {group name -> group indices}"""
653653
if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
654654
# This shows unused categories in indices GH#38642
655-
return self.groupings[0].indices
656-
codes_list = [ping.codes for ping in self.groupings]
657-
return get_indexer_dict(codes_list, self.levels)
655+
result = self.groupings[0].indices
656+
else:
657+
codes_list = [ping.codes for ping in self.groupings]
658+
result = get_indexer_dict(codes_list, self.levels)
659+
if not self.dropna:
660+
has_mi = isinstance(self.result_index, MultiIndex)
661+
if not has_mi and self.result_index.hasnans:
662+
result = {
663+
np.nan if isna(key) else key: value for key, value in result.items()
664+
}
665+
elif has_mi:
666+
# MultiIndex has no efficient way to tell if there are NAs
667+
result = {
668+
tuple(np.nan if isna(comp) else comp for comp in key): value
669+
for key, value in result.items()
670+
}
671+
672+
return result
658673

659674
@final
660675
@cache_readonly

0 commit comments

Comments
 (0)