1818 Sequence ,
1919 TypeVar ,
2020 cast ,
21+ overload ,
2122)
2223
2324import numpy as np
2425
2526from pandas ._libs import lib
2627from pandas ._typing import (
2728 ArrayLike ,
29+ AstypeArg ,
2830 Dtype ,
2931 FillnaOptions ,
3032 PositionalIndexer ,
@@ -520,9 +522,21 @@ def nbytes(self) -> int:
520522 # Additional Methods
521523 # ------------------------------------------------------------------------
522524
523- def astype (self , dtype , copy = True ):
525+ @overload
526+ def astype (self , dtype : npt .DTypeLike , copy : bool = ...) -> np .ndarray :
527+ ...
528+
529+ @overload
530+ def astype (self , dtype : ExtensionDtype , copy : bool = ...) -> ExtensionArray :
531+ ...
532+
533+ @overload
534+ def astype (self , dtype : AstypeArg , copy : bool = ...) -> ArrayLike :
535+ ...
536+
537+ def astype (self , dtype : AstypeArg , copy : bool = True ) -> ArrayLike :
524538 """
525- Cast to a NumPy array with 'dtype'.
539+ Cast to a NumPy array or ExtensionArray with 'dtype'.
526540
527541 Parameters
528542 ----------
@@ -535,8 +549,10 @@ def astype(self, dtype, copy=True):
535549
536550 Returns
537551 -------
538- array : ndarray
539- NumPy ndarray with 'dtype' for its dtype.
552+ array : np.ndarray or ExtensionArray
553+ An ExtensionArray if dtype is StringDtype,
554+ or same as that of underlying array.
555+ Otherwise a NumPy ndarray with 'dtype' for its dtype.
540556 """
541557 from pandas .core .arrays .string_ import StringDtype
542558
@@ -552,7 +568,11 @@ def astype(self, dtype, copy=True):
552568 # allow conversion to StringArrays
553569 return dtype .construct_array_type ()._from_sequence (self , copy = False )
554570
555- return np .array (self , dtype = dtype , copy = copy )
571+ # error: Argument "dtype" to "array" has incompatible type
572+ # "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None, type,
573+ # _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int,
574+ # Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
575+ return np .array (self , dtype = dtype , copy = copy ) # type: ignore[arg-type]
556576
557577 def isna (self ) -> np .ndarray | ExtensionArraySupportsAnyAll :
558578 """
@@ -863,6 +883,8 @@ def searchsorted(
863883 # 2. Values between the values in the `data_for_sorting` fixture
864884 # 3. Missing values.
865885 arr = self .astype (object )
886+ if isinstance (value , ExtensionArray ):
887+ value = value .astype (object )
866888 return arr .searchsorted (value , side = side , sorter = sorter )
867889
868890 def equals (self , other : object ) -> bool :
0 commit comments