@@ -279,8 +279,6 @@ class ScalarType(CType, HasDataType, HasShape):
279279 Analogous to TensorType, but for zero-dimensional objects.
280280 Maps directly to C primitives.
281281
282- TODO: refactor to be named ScalarType for consistency with TensorType.
283-
284282 """
285283
286284 __props__ = ("dtype" ,)
@@ -350,11 +348,14 @@ def c_element_type(self):
350348 return self .dtype_specs ()[1 ]
351349
352350 def c_headers (self , c_compiler = None , ** kwargs ):
353- l = ["<math.h>" ]
354- # These includes are needed by ScalarType and TensorType,
355- # we declare them here and they will be re-used by TensorType
356- l .append ("<numpy/arrayobject.h>" )
357- l .append ("<numpy/arrayscalars.h>" )
351+ l = [
352+ "<math.h>" ,
353+ # These includes are needed by ScalarType and TensorType,
354+ # we declare them here and they will be re-used by TensorType
355+ "<numpy/arrayobject.h>" ,
356+ "<numpy/arrayscalars.h>" ,
357+ "<numpy/npy_2_complexcompat.h>" ,
358+ ]
358359 if config .lib__amblibm and c_compiler .supports_amdlibm :
359360 l += ["<amdlibm.h>" ]
360361 return l
@@ -396,8 +397,8 @@ def dtype_specs(self):
396397 "float16" : (np .float16 , "npy_float16" , "Float16" ),
397398 "float32" : (np .float32 , "npy_float32" , "Float32" ),
398399 "float64" : (np .float64 , "npy_float64" , "Float64" ),
399- "complex128" : (np .complex128 , "pytensor_complex128 " , "Complex128" ),
400- "complex64" : (np .complex64 , "pytensor_complex64 " , "Complex64" ),
400+ "complex128" : (np .complex128 , "npy_complex128 " , "Complex128" ),
401+ "complex64" : (np .complex64 , "npy_complex64 " , "Complex64" ),
401402 "bool" : (np .bool_ , "npy_bool" , "Bool" ),
402403 "uint8" : (np .uint8 , "npy_uint8" , "UInt8" ),
403404 "int8" : (np .int8 , "npy_int8" , "Int8" ),
@@ -506,171 +507,11 @@ def c_sync(self, name, sub):
506507 def c_cleanup (self , name , sub ):
507508 return ""
508509
509- def c_support_code (self , ** kwargs ):
510- if self .dtype .startswith ("complex" ):
511- cplx_types = ["pytensor_complex64" , "pytensor_complex128" ]
512- real_types = [
513- "npy_int8" ,
514- "npy_int16" ,
515- "npy_int32" ,
516- "npy_int64" ,
517- "npy_float32" ,
518- "npy_float64" ,
519- ]
520- # If the 'int' C type is not exactly the same as an existing
521- # 'npy_intX', some C code may not compile, e.g. when assigning
522- # the value 0 (cast to 'int' in C) to an PyTensor_complex64.
523- if np .dtype ("intc" ).num not in [np .dtype (d [4 :]).num for d in real_types ]:
524- # In that case we add the 'int' type to the real types.
525- real_types .append ("int" )
526-
527- template = """
528- struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
529- {
530- typedef pytensor_complex%(nbits)s complex_type;
531- typedef npy_float%(half_nbits)s scalar_type;
532-
533- complex_type operator +(const complex_type &y) const {
534- complex_type ret;
535- ret.real = this->real + y.real;
536- ret.imag = this->imag + y.imag;
537- return ret;
538- }
539-
540- complex_type operator -() const {
541- complex_type ret;
542- ret.real = -this->real;
543- ret.imag = -this->imag;
544- return ret;
545- }
546- bool operator ==(const complex_type &y) const {
547- return (this->real == y.real) && (this->imag == y.imag);
548- }
549- bool operator ==(const scalar_type &y) const {
550- return (this->real == y) && (this->imag == 0);
551- }
552- complex_type operator -(const complex_type &y) const {
553- complex_type ret;
554- ret.real = this->real - y.real;
555- ret.imag = this->imag - y.imag;
556- return ret;
557- }
558- complex_type operator *(const complex_type &y) const {
559- complex_type ret;
560- ret.real = this->real * y.real - this->imag * y.imag;
561- ret.imag = this->real * y.imag + this->imag * y.real;
562- return ret;
563- }
564- complex_type operator /(const complex_type &y) const {
565- complex_type ret;
566- scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
567- ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
568- ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
569- return ret;
570- }
571- template <typename T>
572- complex_type& operator =(const T& y);
573-
574- pytensor_complex%(nbits)s() {}
575-
576- template <typename T>
577- pytensor_complex%(nbits)s(const T& y) { *this = y; }
578-
579- template <typename TR, typename TI>
580- pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
581- };
582- """
583-
584- def operator_eq_real (mytype , othertype ):
585- return f"""
586- template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
587- {{ this->real=y; this->imag=0; return *this; }}
588- """
589-
590- def operator_eq_cplx (mytype , othertype ):
591- return f"""
592- template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
593- {{ this->real=y.real; this->imag=y.imag; return *this; }}
594- """
595-
596- operator_eq = "" .join (
597- operator_eq_real (ctype , rtype )
598- for ctype in cplx_types
599- for rtype in real_types
600- ) + "" .join (
601- operator_eq_cplx (ctype1 , ctype2 )
602- for ctype1 in cplx_types
603- for ctype2 in cplx_types
604- )
605-
606- # We are not using C++ generic templating here, because this would
607- # generate two different functions for adding a complex64 and a
608- # complex128, one returning a complex64, the other a complex128,
609- # and the compiler complains it is ambiguous.
610- # Instead, we generate code for known and safe types only.
611-
612- def operator_plus_real (mytype , othertype ):
613- return f"""
614- const { mytype } operator+(const { mytype } &x, const { othertype } &y)
615- {{ return { mytype } (x.real+y, x.imag); }}
616-
617- const { mytype } operator+(const { othertype } &y, const { mytype } &x)
618- {{ return { mytype } (x.real+y, x.imag); }}
619- """
620-
621- operator_plus = "" .join (
622- operator_plus_real (ctype , rtype )
623- for ctype in cplx_types
624- for rtype in real_types
625- )
626-
627- def operator_minus_real (mytype , othertype ):
628- return f"""
629- const { mytype } operator-(const { mytype } &x, const { othertype } &y)
630- {{ return { mytype } (x.real-y, x.imag); }}
631-
632- const { mytype } operator-(const { othertype } &y, const { mytype } &x)
633- {{ return { mytype } (y-x.real, -x.imag); }}
634- """
635-
636- operator_minus = "" .join (
637- operator_minus_real (ctype , rtype )
638- for ctype in cplx_types
639- for rtype in real_types
640- )
641-
642- def operator_mul_real (mytype , othertype ):
643- return f"""
644- const { mytype } operator*(const { mytype } &x, const { othertype } &y)
645- {{ return { mytype } (x.real*y, x.imag*y); }}
646-
647- const { mytype } operator*(const { othertype } &y, const { mytype } &x)
648- {{ return { mytype } (x.real*y, x.imag*y); }}
649- """
650-
651- operator_mul = "" .join (
652- operator_mul_real (ctype , rtype )
653- for ctype in cplx_types
654- for rtype in real_types
655- )
656-
657- return (
658- template % dict (nbits = 64 , half_nbits = 32 )
659- + template % dict (nbits = 128 , half_nbits = 64 )
660- + operator_eq
661- + operator_plus
662- + operator_minus
663- + operator_mul
664- )
665-
666- else :
667- return ""
668-
669510 def c_init_code (self , ** kwargs ):
670511 return ["import_array();" ]
671512
672513 def c_code_cache_version (self ):
673- return (13 , np .__version__ )
514+ return (14 , np .__version__ )
674515
675516 def get_shape_info (self , obj ):
676517 return obj .itemsize
0 commit comments