Skip to content

Commit 4196dd4

Browse files
feat: Add Index.__eq__ for consts, aligned objects
1 parent 56e5033 commit 4196dd4

File tree

4 files changed

+88
-0
lines changed

4 files changed

+88
-0
lines changed

bigframes/core/indexes/base.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,58 @@ def item(self):
754754
# Docstring is in third_party/bigframes_vendored/pandas/core/indexes/base.py
755755
return self.to_series().peek(2).item()
756756

757+
def __eq__(self, other) -> Index:
758+
return self._apply_binop(other, ops.eq_op)
759+
760+
def _apply_binop(self, other, op: ops.BinaryOp) -> Index:
761+
# TODO: Handle local objects, or objects not implicitly alignable? Gets ambiguous with partial ordering though
762+
if isinstance(other, (bigframes.series.Series, Index)):
763+
other = Index(other)
764+
if other.nlevels != self.nlevels:
765+
raise ValueError("Dimensions do not match")
766+
767+
lexpr = self._block.expr
768+
rexpr = other._block.expr
769+
join_result = lexpr.try_row_join(rexpr)
770+
if join_result is None:
771+
raise ValueError("Cannot align objects")
772+
773+
expr, (lmap, rmap) = join_result
774+
775+
expr, res_ids = expr.compute_values(
776+
[
777+
op.as_expr(lmap[lid], rmap[rid])
778+
for lid, rid in zip(lexpr.column_ids, rexpr.column_ids)
779+
]
780+
)
781+
return Index(
782+
blocks.Block(
783+
expr.select_columns(res_ids),
784+
index_columns=res_ids,
785+
column_labels=[],
786+
index_labels=[None] * len(res_ids),
787+
)
788+
)
789+
elif (
790+
isinstance(other, bigframes.dtypes.LOCAL_SCALAR_TYPES) and self.nlevels == 1
791+
):
792+
block, id = self._block.project_expr(
793+
op.as_expr(self._block.index_columns[0], ex.const(other))
794+
)
795+
return Index(block.select_column(id))
796+
elif isinstance(other, tuple) and len(other) == self.nlevels:
797+
block = self._block.project_exprs(
798+
[
799+
op.as_expr(self._block.index_columns[i], ex.const(other[i]))
800+
for i in range(self.nlevels)
801+
],
802+
labels=[None] * self.nlevels,
803+
drop=True,
804+
)
805+
return Index(block)
806+
else:
807+
return NotImplemented
808+
757809

758810
def _should_create_datetime_index(block: blocks.Block) -> bool:
759811
if len(block.index.dtypes) != 1:

bigframes/core/indexes/multi.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,14 @@ def from_arrays(
4646
pd_index = pandas.MultiIndex.from_arrays(arrays, sortorder, names)
4747
# Index.__new__ should detect multiple levels and properly create a multiindex
4848
return cast(MultiIndex, Index(pd_index))
49+
50+
def __eg__(self, other) -> Index:
51+
import bigframes.operations as ops
52+
import bigframes.operations.aggregations as agg_ops
53+
54+
eq_result = self._apply_binop(other, ops.eq_op)
55+
return Index(
56+
eq_result._block.aggregate_all_and_stack(
57+
agg_ops.all_op, axis=1, dropna=False
58+
)
59+
)

tests/system/small/test_index.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,3 +668,20 @@ def test_custom_index_setitem_error():
668668

669669
with pytest.raises(TypeError, match="Index does not support mutable operations"):
670670
custom_index[2] = 999
671+
672+
673+
def test_index_eq_const(scalars_df_index, scalars_pandas_df_index):
674+
bf_result = (scalars_df_index.index == 3).to_pandas()
675+
pd_result = scalars_pandas_df_index.index == 3
676+
assert bf_result == pd.Index(pd_result)
677+
678+
679+
def test_index_eq_aligned_index(scalars_df_index, scalars_pandas_df_index):
680+
bf_result = (
681+
bpd.Index(scalars_df_index.int64_col)
682+
== bpd.Index(scalars_df_index.int64_col.abs())
683+
).to_pandas()
684+
pd_result = pd.Index(scalars_pandas_df_index.int64_col) == pd.Index(
685+
scalars_pandas_df_index.int64_col.abs()
686+
)
687+
assert bf_result == pd.Index(pd_result)

tests/system/small/test_multiindex.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,3 +1474,11 @@ def test_multi_index_contains(scalars_df_index, scalars_pandas_df_index, key):
14741474
pd_result = key in scalars_pandas_df_index.set_index(col_name).index
14751475

14761476
assert bf_result == pd_result
1477+
1478+
1479+
def test_multiindex_eq_const(scalars_df_index, scalars_pandas_df_index):
1480+
col_name = ["int64_col", "bool_col"]
1481+
bf_result = scalars_df_index.set_index(col_name).index == (2, False)
1482+
pd_result = scalars_pandas_df_index.set_index(col_name).index == (2, False)
1483+
1484+
assert bf_result == pd_result

0 commit comments

Comments
 (0)