@@ -65,35 +65,71 @@ def data_for_grouping():
6565 return DecimalArray ([b , b , na , na , a , a , b , c ])
6666
6767
68- class TestDtype (base .BaseDtypeTests ):
69- pass
68+ class TestDecimalArray (base .ExtensionTests ):
69+ def _get_expected_exception (
70+ self , op_name : str , obj , other
71+ ) -> type [Exception ] | None :
72+ return None
7073
74+ def _supports_reduction (self , obj , op_name : str ) -> bool :
75+ return True
7176
72- class TestInterface (base .BaseInterfaceTests ):
73- pass
77+ def check_reduce (self , s , op_name , skipna ):
78+ if op_name == "count" :
79+ return super ().check_reduce (s , op_name , skipna )
80+ else :
81+ result = getattr (s , op_name )(skipna = skipna )
82+ expected = getattr (np .asarray (s ), op_name )()
83+ tm .assert_almost_equal (result , expected )
7484
85+ def test_reduce_series_numeric (self , data , all_numeric_reductions , skipna , request ):
86+ if all_numeric_reductions in ["kurt" , "skew" , "sem" , "median" ]:
87+ mark = pytest .mark .xfail (raises = NotImplementedError )
88+ request .node .add_marker (mark )
89+ super ().test_reduce_series_numeric (data , all_numeric_reductions , skipna )
7590
76- class TestConstructors (base .BaseConstructorsTests ):
77- pass
91+ def test_reduce_frame (self , data , all_numeric_reductions , skipna , request ):
92+ op_name = all_numeric_reductions
93+ if op_name in ["skew" , "median" ]:
94+ mark = pytest .mark .xfail (raises = NotImplementedError )
95+ request .node .add_marker (mark )
7896
97+ return super ().test_reduce_frame (data , all_numeric_reductions , skipna )
7998
80- class TestReshaping (base .BaseReshapingTests ):
81- pass
99+ def test_compare_scalar (self , data , comparison_op ):
100+ ser = pd .Series (data )
101+ self ._compare_other (ser , data , comparison_op , 0.5 )
82102
103+ def test_compare_array (self , data , comparison_op ):
104+ ser = pd .Series (data )
83105
84- class TestGetitem (base .BaseGetitemTests ):
85- def test_take_na_value_other_decimal (self ):
86- arr = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("2.0" )])
87- result = arr .take ([0 , - 1 ], allow_fill = True , fill_value = decimal .Decimal ("-1.0" ))
88- expected = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("-1.0" )])
89- tm .assert_extension_array_equal (result , expected )
106+ alter = np .random .default_rng (2 ).choice ([- 1 , 0 , 1 ], len (data ))
107+ # Randomly double, halve or keep same value
108+ other = pd .Series (data ) * [decimal .Decimal (pow (2.0 , i )) for i in alter ]
109+ self ._compare_other (ser , data , comparison_op , other )
90110
111+ def test_arith_series_with_array (self , data , all_arithmetic_operators ):
112+ op_name = all_arithmetic_operators
113+ ser = pd .Series (data )
114+
115+ context = decimal .getcontext ()
116+ divbyzerotrap = context .traps [decimal .DivisionByZero ]
117+ invalidoptrap = context .traps [decimal .InvalidOperation ]
118+ context .traps [decimal .DivisionByZero ] = 0
119+ context .traps [decimal .InvalidOperation ] = 0
91120
92- class TestIndex (base .BaseIndexTests ):
93- pass
121+ # Decimal supports ops with int, but not float
122+ other = pd .Series ([int (d * 100 ) for d in data ])
123+ self .check_opname (ser , op_name , other )
124+
125+ if "mod" not in op_name :
126+ self .check_opname (ser , op_name , ser * 2 )
94127
128+ self .check_opname (ser , op_name , 0 )
129+ self .check_opname (ser , op_name , 5 )
130+ context .traps [decimal .DivisionByZero ] = divbyzerotrap
131+ context .traps [decimal .InvalidOperation ] = invalidoptrap
95132
96- class TestMissing (base .BaseMissingTests ):
97133 def test_fillna_frame (self , data_missing ):
98134 msg = "ExtensionArray.fillna added a 'copy' keyword"
99135 with tm .assert_produces_warning (
@@ -141,59 +177,6 @@ def test_fillna_series_method(self, data_missing, fillna_method):
141177 ):
142178 super ().test_fillna_series_method (data_missing , fillna_method )
143179
144-
145- class Reduce :
146- def _supports_reduction (self , obj , op_name : str ) -> bool :
147- return True
148-
149- def check_reduce (self , s , op_name , skipna ):
150- if op_name == "count" :
151- return super ().check_reduce (s , op_name , skipna )
152- else :
153- result = getattr (s , op_name )(skipna = skipna )
154- expected = getattr (np .asarray (s ), op_name )()
155- tm .assert_almost_equal (result , expected )
156-
157- def test_reduction_without_keepdims (self ):
158- # GH52788
159- # test _reduce without keepdims
160-
161- class DecimalArray2 (DecimalArray ):
162- def _reduce (self , name : str , * , skipna : bool = True , ** kwargs ):
163- # no keepdims in signature
164- return super ()._reduce (name , skipna = skipna )
165-
166- arr = DecimalArray2 ([decimal .Decimal (2 ) for _ in range (100 )])
167-
168- ser = pd .Series (arr )
169- result = ser .agg ("sum" )
170- expected = decimal .Decimal (200 )
171- assert result == expected
172-
173- df = pd .DataFrame ({"a" : arr , "b" : arr })
174- with tm .assert_produces_warning (FutureWarning ):
175- result = df .agg ("sum" )
176- expected = pd .Series ({"a" : 200 , "b" : 200 }, dtype = object )
177- tm .assert_series_equal (result , expected )
178-
179-
180- class TestReduce (Reduce , base .BaseReduceTests ):
181- def test_reduce_series_numeric (self , data , all_numeric_reductions , skipna , request ):
182- if all_numeric_reductions in ["kurt" , "skew" , "sem" , "median" ]:
183- mark = pytest .mark .xfail (raises = NotImplementedError )
184- request .node .add_marker (mark )
185- super ().test_reduce_series_numeric (data , all_numeric_reductions , skipna )
186-
187- def test_reduce_frame (self , data , all_numeric_reductions , skipna , request ):
188- op_name = all_numeric_reductions
189- if op_name in ["skew" , "median" ]:
190- mark = pytest .mark .xfail (raises = NotImplementedError )
191- request .node .add_marker (mark )
192-
193- return super ().test_reduce_frame (data , all_numeric_reductions , skipna )
194-
195-
196- class TestMethods (base .BaseMethodsTests ):
197180 def test_fillna_copy_frame (self , data_missing , using_copy_on_write ):
198181 warn = FutureWarning if not using_copy_on_write else None
199182 msg = "ExtensionArray.fillna added a 'copy' keyword"
@@ -226,27 +209,31 @@ def test_value_counts(self, all_data, dropna, request):
226209
227210 tm .assert_series_equal (result , expected )
228211
229-
230- class TestCasting (base .BaseCastingTests ):
231- pass
232-
233-
234- class TestGroupby (base .BaseGroupbyTests ):
235- pass
236-
237-
238- class TestSetitem (base .BaseSetitemTests ):
239- pass
240-
241-
242- class TestPrinting (base .BasePrintingTests ):
243212 def test_series_repr (self , data ):
244213 # Overriding this base test to explicitly test that
245214 # the custom _formatter is used
246215 ser = pd .Series (data )
247216 assert data .dtype .name in repr (ser )
248217 assert "Decimal: " in repr (ser )
249218
219+ @pytest .mark .xfail (
220+ reason = "Looks like the test (incorrectly) implicitly assumes int/bool dtype"
221+ )
222+ def test_invert (self , data ):
223+ super ().test_invert (data )
224+
225+ @pytest .mark .xfail (reason = "Inconsistent array-vs-scalar behavior" )
226+ @pytest .mark .parametrize ("ufunc" , [np .positive , np .negative , np .abs ])
227+ def test_unary_ufunc_dunder_equivalence (self , data , ufunc ):
228+ super ().test_unary_ufunc_dunder_equivalence (data , ufunc )
229+
230+
231+ def test_take_na_value_other_decimal ():
232+ arr = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("2.0" )])
233+ result = arr .take ([0 , - 1 ], allow_fill = True , fill_value = decimal .Decimal ("-1.0" ))
234+ expected = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("-1.0" )])
235+ tm .assert_extension_array_equal (result , expected )
236+
250237
251238def test_series_constructor_coerce_data_to_extension_dtype ():
252239 dtype = DecimalDtype ()
@@ -305,53 +292,6 @@ def test_astype_dispatches(frame):
305292 assert result .dtype .context .prec == ctx .prec
306293
307294
308- class TestArithmeticOps (base .BaseArithmeticOpsTests ):
309- series_scalar_exc = None
310- frame_scalar_exc = None
311- series_array_exc = None
312-
313- def _get_expected_exception (
314- self , op_name : str , obj , other
315- ) -> type [Exception ] | None :
316- return None
317-
318- def test_arith_series_with_array (self , data , all_arithmetic_operators ):
319- op_name = all_arithmetic_operators
320- s = pd .Series (data )
321-
322- context = decimal .getcontext ()
323- divbyzerotrap = context .traps [decimal .DivisionByZero ]
324- invalidoptrap = context .traps [decimal .InvalidOperation ]
325- context .traps [decimal .DivisionByZero ] = 0
326- context .traps [decimal .InvalidOperation ] = 0
327-
328- # Decimal supports ops with int, but not float
329- other = pd .Series ([int (d * 100 ) for d in data ])
330- self .check_opname (s , op_name , other )
331-
332- if "mod" not in op_name :
333- self .check_opname (s , op_name , s * 2 )
334-
335- self .check_opname (s , op_name , 0 )
336- self .check_opname (s , op_name , 5 )
337- context .traps [decimal .DivisionByZero ] = divbyzerotrap
338- context .traps [decimal .InvalidOperation ] = invalidoptrap
339-
340-
341- class TestComparisonOps (base .BaseComparisonOpsTests ):
342- def test_compare_scalar (self , data , comparison_op ):
343- s = pd .Series (data )
344- self ._compare_other (s , data , comparison_op , 0.5 )
345-
346- def test_compare_array (self , data , comparison_op ):
347- s = pd .Series (data )
348-
349- alter = np .random .default_rng (2 ).choice ([- 1 , 0 , 1 ], len (data ))
350- # Randomly double, halve or keep same value
351- other = pd .Series (data ) * [decimal .Decimal (pow (2.0 , i )) for i in alter ]
352- self ._compare_other (s , data , comparison_op , other )
353-
354-
355295class DecimalArrayWithoutFromSequence (DecimalArray ):
356296 """Helper class for testing error handling in _from_sequence."""
357297
0 commit comments