1212
1313
1414def _get_array_list (arr , others ):
15- if isinstance (others [0 ], (list , np .ndarray )):
15+ if len ( others ) and isinstance (others [0 ], (list , np .ndarray )):
1616 arrays = [arr ] + list (others )
1717 else :
1818 arrays = [arr , others ]
@@ -88,12 +88,15 @@ def _length_check(others):
8888 return n
8989
9090
91- def _na_map (f , arr , na_result = np .nan ):
91+ def _na_map (f , arr , na_result = np .nan , dtype = object ):
9292 # should really _check_ for NA
93- return _map (f , arr , na_mask = True , na_value = na_result )
93+ return _map (f , arr , na_mask = True , na_value = na_result , dtype = dtype )
9494
9595
96- def _map (f , arr , na_mask = False , na_value = np .nan ):
96+ def _map (f , arr , na_mask = False , na_value = np .nan , dtype = object ):
97+ if not len (arr ):
98+ return np .ndarray (0 , dtype = dtype )
99+
97100 if isinstance (arr , Series ):
98101 arr = arr .values
99102 if not isinstance (arr , np .ndarray ):
@@ -108,7 +111,7 @@ def g(x):
108111 return f (x )
109112 except (TypeError , AttributeError ):
110113 return na_value
111- return _map (g , arr )
114+ return _map (g , arr , dtype = dtype )
112115 if na_value is not np .nan :
113116 np .putmask (result , mask , na_value )
114117 if result .dtype == object :
@@ -146,7 +149,7 @@ def str_count(arr, pat, flags=0):
146149 """
147150 regex = re .compile (pat , flags = flags )
148151 f = lambda x : len (regex .findall (x ))
149- return _na_map (f , arr )
152+ return _na_map (f , arr , dtype = int )
150153
151154
152155def str_contains (arr , pat , case = True , flags = 0 , na = np .nan , regex = True ):
@@ -187,7 +190,7 @@ def str_contains(arr, pat, case=True, flags=0, na=np.nan, regex=True):
187190 f = lambda x : bool (regex .search (x ))
188191 else :
189192 f = lambda x : pat in x
190- return _na_map (f , arr , na )
193+ return _na_map (f , arr , na , dtype = bool )
191194
192195
193196def str_startswith (arr , pat , na = np .nan ):
@@ -206,7 +209,7 @@ def str_startswith(arr, pat, na=np.nan):
206209 startswith : array (boolean)
207210 """
208211 f = lambda x : x .startswith (pat )
209- return _na_map (f , arr , na )
212+ return _na_map (f , arr , na , dtype = bool )
210213
211214
212215def str_endswith (arr , pat , na = np .nan ):
@@ -225,7 +228,7 @@ def str_endswith(arr, pat, na=np.nan):
225228 endswith : array (boolean)
226229 """
227230 f = lambda x : x .endswith (pat )
228- return _na_map (f , arr , na )
231+ return _na_map (f , arr , na , dtype = bool )
229232
230233
231234def str_lower (arr ):
@@ -375,6 +378,7 @@ def str_match(arr, pat, case=True, flags=0, na=np.nan, as_indexer=False):
375378 # and is basically useless, so we will not warn.
376379
377380 if (not as_indexer ) and regex .groups > 0 :
381+ dtype = object
378382 def f (x ):
379383 m = regex .match (x )
380384 if m :
@@ -383,9 +387,10 @@ def f(x):
383387 return []
384388 else :
385389 # This is the new behavior of str_match.
390+ dtype = bool
386391 f = lambda x : bool (regex .match (x ))
387392
388- return _na_map (f , arr , na )
393+ return _na_map (f , arr , na , dtype = dtype )
389394
390395
391396def _get_single_group_name (rx ):
@@ -409,6 +414,9 @@ def str_extract(arr, pat, flags=0):
409414 Returns
410415 -------
411416 extracted groups : Series (one group) or DataFrame (multiple groups)
417+ Note that dtype of the result is always object, even when no match is
418+ found and the result is a Series or DataFrame containing only NaN
419+ values.
412420
413421 Examples
414422 --------
@@ -461,13 +469,17 @@ def f(x):
461469 if regex .groups == 1 :
462470 result = Series ([f (val )[0 ] for val in arr ],
463471 name = _get_single_group_name (regex ),
464- index = arr .index )
472+ index = arr .index , dtype = object )
465473 else :
466474 names = dict (zip (regex .groupindex .values (), regex .groupindex .keys ()))
467475 columns = [names .get (1 + i , i ) for i in range (regex .groups )]
468- result = DataFrame ([f (val ) for val in arr ],
469- columns = columns ,
470- index = arr .index )
476+ if arr .empty :
477+ result = DataFrame (columns = columns , dtype = object )
478+ else :
479+ result = DataFrame ([f (val ) for val in arr ],
480+ columns = columns ,
481+ index = arr .index ,
482+ dtype = object )
471483 return result
472484
473485
@@ -536,7 +548,7 @@ def str_len(arr):
536548 -------
537549 lengths : array
538550 """
539- return _na_map (len , arr )
551+ return _na_map (len , arr , dtype = int )
540552
541553
542554def str_findall (arr , pat , flags = 0 ):
0 commit comments