@@ -15,6 +15,17 @@ class DecimalDtype(ExtensionDtype):
1515 name = 'decimal'
1616 na_value = decimal .Decimal ('NaN' )
1717
18+ def __init__ (self , context = None ):
19+ self .context = context or decimal .getcontext ()
20+
21+ def __eq__ (self , other ):
22+ if isinstance (other , type (self )):
23+ return self .context == other .context
24+ return super (DecimalDtype , self ).__eq__ (other )
25+
26+ def __repr__ (self ):
27+ return 'DecimalDtype(context={})' .format (self .context )
28+
1829 @classmethod
1930 def construct_array_type (cls ):
2031 """Return the array type associated with this dtype
@@ -35,13 +46,12 @@ def construct_from_string(cls, string):
3546
3647
3748class DecimalArray (ExtensionArray , ExtensionScalarOpsMixin ):
38- dtype = DecimalDtype ()
3949
40- def __init__ (self , values , dtype = None , copy = False ):
50+ def __init__ (self , values , dtype = None , copy = False , context = None ):
4151 for val in values :
42- if not isinstance (val , self . dtype . type ):
52+ if not isinstance (val , decimal . Decimal ):
4353 raise TypeError ("All values must be of type " +
44- str (self . dtype . type ))
54+ str (decimal . Decimal ))
4555 values = np .asarray (values , dtype = object )
4656
4757 self ._data = values
@@ -51,6 +61,11 @@ def __init__(self, values, dtype=None, copy=False):
5161 # those aliases are currently not working due to assumptions
5262 # in internal code (GH-20735)
5363 # self._values = self.values = self.data
64+ self ._dtype = DecimalDtype (context )
65+
66+ @property
67+ def dtype (self ):
68+ return self ._dtype
5469
5570 @classmethod
5671 def _from_sequence (cls , scalars , dtype = None , copy = False ):
@@ -82,6 +97,11 @@ def copy(self, deep=False):
8297 return type (self )(self ._data .copy ())
8398 return type (self )(self )
8499
100+ def astype (self , dtype , copy = True ):
101+ if isinstance (dtype , type (self .dtype )):
102+ return type (self )(self ._data , context = dtype .context )
103+ return super (DecimalArray , self ).astype (dtype , copy )
104+
85105 def __setitem__ (self , key , value ):
86106 if pd .api .types .is_list_like (value ):
87107 value = [decimal .Decimal (v ) for v in value ]
0 commit comments