22from functools import wraps
33import re
44import textwrap
5- from typing import TYPE_CHECKING , Any , Callable , Dict , List
5+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Type , Union
66import warnings
77
88import numpy as np
@@ -142,7 +142,7 @@ def _map_stringarray(
142142 The value to use for missing values. By default, this is
143143 the original value (NA).
144144 dtype : Dtype
145- The result dtype to use. Specifying this aviods an intermediate
145+ The result dtype to use. Specifying this avoids an intermediate
146146 object-dtype allocation.
147147
148148 Returns
@@ -152,14 +152,20 @@ def _map_stringarray(
152152 an ndarray.
153153
154154 """
155- from pandas .arrays import IntegerArray , StringArray
155+ from pandas .arrays import IntegerArray , StringArray , BooleanArray
156156
157157 mask = isna (arr )
158158
159159 assert isinstance (arr , StringArray )
160160 arr = np .asarray (arr )
161161
162- if is_integer_dtype (dtype ):
162+ if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
163+ constructor : Union [Type [IntegerArray ], Type [BooleanArray ]]
164+ if is_integer_dtype (dtype ):
165+ constructor = IntegerArray
166+ else :
167+ constructor = BooleanArray
168+
163169 na_value_is_na = isna (na_value )
164170 if na_value_is_na :
165171 na_value = 1
@@ -169,21 +175,20 @@ def _map_stringarray(
169175 mask .view ("uint8" ),
170176 convert = False ,
171177 na_value = na_value ,
172- dtype = np .dtype ("int64" ),
178+ dtype = np .dtype (dtype ),
173179 )
174180
175181 if not na_value_is_na :
176182 mask [:] = False
177183
178- return IntegerArray (result , mask )
184+ return constructor (result , mask )
179185
180186 elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
181187 # i.e. StringDtype
182188 result = lib .map_infer_mask (
183189 arr , func , mask .view ("uint8" ), convert = False , na_value = na_value
184190 )
185191 return StringArray (result )
186- # TODO: BooleanArray
187192 else :
188193 # This is when the result type is object. We reach this when
189194 # -> We know the result type is truly object (e.g. .encode returns bytes
@@ -299,7 +304,7 @@ def str_count(arr, pat, flags=0):
299304 """
300305 regex = re .compile (pat , flags = flags )
301306 f = lambda x : len (regex .findall (x ))
302- return _na_map (f , arr , dtype = int )
307+ return _na_map (f , arr , dtype = "int64" )
303308
304309
305310def str_contains (arr , pat , case = True , flags = 0 , na = np .nan , regex = True ):
@@ -1365,7 +1370,7 @@ def str_find(arr, sub, start=0, end=None, side="left"):
13651370 else :
13661371 f = lambda x : getattr (x , method )(sub , start , end )
13671372
1368- return _na_map (f , arr , dtype = int )
1373+ return _na_map (f , arr , dtype = "int64" )
13691374
13701375
13711376def str_index (arr , sub , start = 0 , end = None , side = "left" ):
@@ -1385,7 +1390,7 @@ def str_index(arr, sub, start=0, end=None, side="left"):
13851390 else :
13861391 f = lambda x : getattr (x , method )(sub , start , end )
13871392
1388- return _na_map (f , arr , dtype = int )
1393+ return _na_map (f , arr , dtype = "int64" )
13891394
13901395
13911396def str_pad (arr , width , side = "left" , fillchar = " " ):
@@ -3210,7 +3215,7 @@ def rindex(self, sub, start=0, end=None):
32103215 len ,
32113216 docstring = _shared_docs ["len" ],
32123217 forbidden_types = None ,
3213- dtype = int ,
3218+ dtype = "int64" ,
32143219 returns_string = False ,
32153220 )
32163221
0 commit comments