diff --git a/pytensor/scalar/c_code/incbet.c b/pytensor/scalar/c_code/incbet.c new file mode 100644 index 0000000000..491da20089 --- /dev/null +++ b/pytensor/scalar/c_code/incbet.c @@ -0,0 +1,409 @@ +/* incbet.c + * + * Incomplete beta integral + * + * + * SYNOPSIS: + * + * double a, b, x, y, incbet(); + * + * y = incbet( a, b, x ); + * + * + * DESCRIPTION: + * + * Returns incomplete beta integral of the arguments, evaluated + * from zero to x. The function is defined as + * + * x + * - - + * | (a+b) | | a-1 b-1 + * ----------- | t (1-t) dt. + * - - | | + * | (a) | (b) - + * 0 + * + * The domain of definition is 0 <= x <= 1. In this + * implementation a and b are restricted to positive values. + * The integral from x to 1 may be obtained by the symmetry + * relation + * + * 1 - incbet( a, b, x ) = incbet( b, a, 1-x ). + * + * The integral is evaluated by a continued fraction expansion + * or, when b*x is small, by a power series. + * + * ACCURACY: + * + * Tested at uniformly distributed random points (a,b,x) with a and b + * in "domain" and x between 0 and 1. + * Relative error + * arithmetic domain # trials peak rms + * IEEE 0,5 10000 6.9e-15 4.5e-16 + * IEEE 0,85 250000 2.2e-13 1.7e-14 + * IEEE 0,1000 30000 5.3e-12 6.3e-13 + * IEEE 0,10000 250000 9.3e-11 7.1e-12 + * IEEE 0,100000 10000 8.7e-10 4.8e-11 + * Outputs smaller than the IEEE gradual underflow threshold + * were excluded from these statistics. + * + * ERROR MESSAGES: + * message condition value returned + * incbet domain x<0, x>1 0.0 + * incbet underflow 0.0 + */ + + +/* +Cephes Math Library, Release 2.8: June, 2000 +Copyright 1984, 1995, 2000 by Stephen L. Moshier +*/ + +#include "mconf.h" + +#ifdef DEC +#define MAXGAM 34.84425627277176174 +#else +#define MAXGAM 171.624376956302725 +#endif + +extern double MACHEP, MINLOG, MAXLOG; +#ifdef ANSIPROT +extern double gamma ( double ); +extern double lgam ( double ); +extern double exp ( double ); +extern double log ( double ); +extern double pow ( double, double ); +extern double fabs ( double ); +static double incbcf(double, double, double); +static double incbd(double, double, double); +static double pseries(double, double, double); +#else +double gamma(), lgam(), exp(), log(), pow(), fabs(); +static double incbcf(), incbd(), pseries(); +#endif + +static double big = 4.503599627370496e15; +static double biginv = 2.22044604925031308085e-16; + + +double incbet( aa, bb, xx ) +double aa, bb, xx; +{ +double a, b, t, x, xc, w, y; +int flag; + +if( aa <= 0.0 || bb <= 0.0 ) + goto domerr; + +if( (xx <= 0.0) || ( xx >= 1.0) ) + { + if( xx == 0.0 ) + return(0.0); + if( xx == 1.0 ) + return( 1.0 ); +domerr: + mtherr( "incbet", DOMAIN ); + return( 0.0 ); + } + +flag = 0; +if( (bb * xx) <= 1.0 && xx <= 0.95) + { + t = pseries(aa, bb, xx); + goto done; + } + +w = 1.0 - xx; + +/* Reverse a and b if x is greater than the mean. */ +if( xx > (aa/(aa+bb)) ) + { + flag = 1; + a = bb; + b = aa; + xc = xx; + x = w; + } +else + { + a = aa; + b = bb; + xc = w; + x = xx; + } + +if( flag == 1 && (b * x) <= 1.0 && x <= 0.95) + { + t = pseries(a, b, x); + goto done; + } + +/* Choose expansion for better convergence. */ +y = x * (a+b-2.0) - (a-1.0); +if( y < 0.0 ) + w = incbcf( a, b, x ); +else + w = incbd( a, b, x ) / xc; + +/* Multiply w by the factor + a b _ _ _ + x (1-x) | (a+b) / ( a | (a) | (b) ) . */ + +y = a * log(x); +t = b * log(xc); +if( (a+b) < MAXGAM && fabs(y) < MAXLOG && fabs(t) < MAXLOG ) + { + t = pow(xc,b); + t *= pow(x,a); + t /= a; + t *= w; + t *= gamma(a+b) / (gamma(a) * gamma(b)); + goto done; + } +/* Resort to logarithms. */ +y += t + lgam(a+b) - lgam(a) - lgam(b); +y += log(w/a); +if( y < MINLOG ) + t = 0.0; +else + t = exp(y); + +done: + +if( flag == 1 ) + { + if( t <= MACHEP ) + t = 1.0 - MACHEP; + else + t = 1.0 - t; + } +return( t ); +} + +/* Continued fraction expansion #1 + * for incomplete beta integral + */ + +static double incbcf( a, b, x ) +double a, b, x; +{ +double xk, pk, pkm1, pkm2, qk, qkm1, qkm2; +double k1, k2, k3, k4, k5, k6, k7, k8; +double r, t, ans, thresh; +int n; + +k1 = a; +k2 = a + b; +k3 = a; +k4 = a + 1.0; +k5 = 1.0; +k6 = b - 1.0; +k7 = k4; +k8 = a + 2.0; + +pkm2 = 0.0; +qkm2 = 1.0; +pkm1 = 1.0; +qkm1 = 1.0; +ans = 1.0; +r = 1.0; +n = 0; +thresh = 3.0 * MACHEP; +do + { + + xk = -( x * k1 * k2 )/( k3 * k4 ); + pk = pkm1 + pkm2 * xk; + qk = qkm1 + qkm2 * xk; + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + + xk = ( x * k5 * k6 )/( k7 * k8 ); + pk = pkm1 + pkm2 * xk; + qk = qkm1 + qkm2 * xk; + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + + if( qk != 0 ) + r = pk/qk; + if( r != 0 ) + { + t = fabs( (ans - r)/r ); + ans = r; + } + else + t = 1.0; + + if( t < thresh ) + goto cdone; + + k1 += 1.0; + k2 += 1.0; + k3 += 2.0; + k4 += 2.0; + k5 += 1.0; + k6 -= 1.0; + k7 += 2.0; + k8 += 2.0; + + if( (fabs(qk) + fabs(pk)) > big ) + { + pkm2 *= biginv; + pkm1 *= biginv; + qkm2 *= biginv; + qkm1 *= biginv; + } + if( (fabs(qk) < biginv) || (fabs(pk) < biginv) ) + { + pkm2 *= big; + pkm1 *= big; + qkm2 *= big; + qkm1 *= big; + } + } +while( ++n < 300 ); + +cdone: +return(ans); +} + + +/* Continued fraction expansion #2 + * for incomplete beta integral + */ + +static double incbd( a, b, x ) +double a, b, x; +{ +double xk, pk, pkm1, pkm2, qk, qkm1, qkm2; +double k1, k2, k3, k4, k5, k6, k7, k8; +double r, t, ans, z, thresh; +int n; + +k1 = a; +k2 = b - 1.0; +k3 = a; +k4 = a + 1.0; +k5 = 1.0; +k6 = a + b; +k7 = a + 1.0;; +k8 = a + 2.0; + +pkm2 = 0.0; +qkm2 = 1.0; +pkm1 = 1.0; +qkm1 = 1.0; +z = x / (1.0-x); +ans = 1.0; +r = 1.0; +n = 0; +thresh = 3.0 * MACHEP; +do + { + + xk = -( z * k1 * k2 )/( k3 * k4 ); + pk = pkm1 + pkm2 * xk; + qk = qkm1 + qkm2 * xk; + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + + xk = ( z * k5 * k6 )/( k7 * k8 ); + pk = pkm1 + pkm2 * xk; + qk = qkm1 + qkm2 * xk; + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + + if( qk != 0 ) + r = pk/qk; + if( r != 0 ) + { + t = fabs( (ans - r)/r ); + ans = r; + } + else + t = 1.0; + + if( t < thresh ) + goto cdone; + + k1 += 1.0; + k2 -= 1.0; + k3 += 2.0; + k4 += 2.0; + k5 += 1.0; + k6 += 1.0; + k7 += 2.0; + k8 += 2.0; + + if( (fabs(qk) + fabs(pk)) > big ) + { + pkm2 *= biginv; + pkm1 *= biginv; + qkm2 *= biginv; + qkm1 *= biginv; + } + if( (fabs(qk) < biginv) || (fabs(pk) < biginv) ) + { + pkm2 *= big; + pkm1 *= big; + qkm2 *= big; + qkm1 *= big; + } + } +while( ++n < 300 ); +cdone: +return(ans); +} + +/* Power series for incomplete beta integral. + Use when b*x is small and x not too close to 1. */ + +static double pseries( a, b, x ) +double a, b, x; +{ +double s, t, u, v, n, t1, z, ai; + +ai = 1.0 / a; +u = (1.0 - b) * x; +v = u / (a + 1.0); +t1 = v; +t = u; +n = 2.0; +s = 0.0; +z = MACHEP * ai; +while( fabs(v) > z ) + { + u = (n - b) * x / n; + t *= u; + v = t / (a + n); + s += v; + n += 1.0; + } +s += t1; +s += ai; + +u = a * log(x); +if( (a+b) < MAXGAM && fabs(u) < MAXLOG ) + { + t = gamma(a+b)/(gamma(a)*gamma(b)); + s = s * t * pow(x,a); + } +else + { + t = lgam(a+b) - lgam(a) - lgam(b) + u + log(s); + if( t < MINLOG ) + s = 0.0; + else + s = exp(t); + } +return(s); +} diff --git a/pytensor/scalar/c_code/mconf.h b/pytensor/scalar/c_code/mconf.h new file mode 100644 index 0000000000..ab31ae787e --- /dev/null +++ b/pytensor/scalar/c_code/mconf.h @@ -0,0 +1,199 @@ +/* mconf.h + * + * Common include file for math routines + * + * + * + * SYNOPSIS: + * + * #include "mconf.h" + * + * + * + * DESCRIPTION: + * + * This file contains definitions for error codes that are + * passed to the common error handling routine mtherr() + * (which see). + * + * The file also includes a conditional assembly definition + * for the type of computer arithmetic (IEEE, DEC, Motorola + * IEEE, or UNKnown). + * + * For Digital Equipment PDP-11 and VAX computers, certain + * IBM systems, and others that use numbers with a 56-bit + * significand, the symbol DEC should be defined. In this + * mode, most floating point constants are given as arrays + * of octal integers to eliminate decimal to binary conversion + * errors that might be introduced by the compiler. + * + * For little-endian computers, such as IBM PC, that follow the + * IEEE Standard for Binary Floating Point Arithmetic (ANSI/IEEE + * Std 754-1985), the symbol IBMPC should be defined. These + * numbers have 53-bit significands. In this mode, constants + * are provided as arrays of hexadecimal 16 bit integers. + * + * Big-endian IEEE format is denoted MIEEE. On some RISC + * systems such as Sun SPARC, double precision constants + * must be stored on 8-byte address boundaries. Since integer + * arrays may be aligned differently, the MIEEE configuration + * may fail on such machines. + * + * To accommodate other types of computer arithmetic, all + * constants are also provided in a normal decimal radix + * which one can hope are correctly converted to a suitable + * format by the available C language compiler. To invoke + * this mode, define the symbol UNK. + * + * An important difference among these modes is a predefined + * set of machine arithmetic constants for each. The numbers + * MACHEP (the machine roundoff error), MAXNUM (largest number + * represented), and several other parameters are preset by + * the configuration symbol. Check the file const.c to + * ensure that these values are correct for your computer. + * + * Configurations NANS, INFINITIES, MINUSZERO, and DENORMAL + * may fail on many systems. Verify that they are supposed + * to work on your computer. + */ +/* +Cephes Math Library Release 2.3: June, 1995 +Copyright 1984, 1987, 1989, 1995 by Stephen L. Moshier +*/ + + +/* Define if the `long double' type works. */ +#define HAVE_LONG_DOUBLE 1 + +/* Define as the return type of signal handlers (int or void). */ +#define RETSIGTYPE void + +/* Define if you have the ANSI C header files. */ +#define STDC_HEADERS 1 + +/* Define if your processor stores words with the most significant + byte first (like Motorola and SPARC, unlike Intel and VAX). */ +/* #undef WORDS_BIGENDIAN */ + +/* Define if floating point words are bigendian. */ +/* #undef FLOAT_WORDS_BIGENDIAN */ + +/* The number of bytes in a int. */ +#define SIZEOF_INT 4 + +/* Define if you have the header file. */ +#define HAVE_STRING_H 1 + +/* Name of package */ +#define PACKAGE "cephes" + +/* Version number of package */ +#define VERSION "2.7" + +/* Constant definitions for math error conditions + */ + +#define DOMAIN 1 /* argument domain error */ +#define SING 2 /* argument singularity */ +#define OVERFLOW 3 /* overflow range error */ +#define UNDERFLOW 4 /* underflow range error */ +#define TLOSS 5 /* total loss of precision */ +#define PLOSS 6 /* partial loss of precision */ + +#define EDOM 33 +#define ERANGE 34 +/* Complex numeral. */ +typedef struct + { + double r; + double i; + } cmplx; + +#ifdef HAVE_LONG_DOUBLE +/* Long double complex numeral. */ +typedef struct + { + long double r; + long double i; + } cmplxl; +#endif + + +/* Type of computer arithmetic */ + +/* PDP-11, Pro350, VAX: + */ +/* #define DEC 1 */ + +/* Intel IEEE, low order words come first: + */ +/* #define IBMPC 1 */ + +/* Motorola IEEE, high order words come first + * (Sun 680x0 workstation): + */ +/* #define MIEEE 1 */ + +/* UNKnown arithmetic, invokes coefficients given in + * normal decimal format. Beware of range boundary + * problems (MACHEP, MAXLOG, etc. in const.c) and + * roundoff problems in pow.c: + * (Sun SPARCstation) + */ +#define UNK 1 + +/* If you define UNK, then be sure to set BIGENDIAN properly. */ +#ifdef FLOAT_WORDS_BIGENDIAN +#define BIGENDIAN 1 +#else +#define BIGENDIAN 0 +#endif +/* Define this `volatile' if your compiler thinks + * that floating point arithmetic obeys the associative + * and distributive laws. It will defeat some optimizations + * (but probably not enough of them). + * + * #define VOLATILE volatile + */ +#define VOLATILE + +/* For 12-byte long doubles on an i386, pad a 16-bit short 0 + * to the end of real constants initialized by integer arrays. + * + * #define XPD 0, + * + * Otherwise, the type is 10 bytes long and XPD should be + * defined blank (e.g., Microsoft C). + * + * #define XPD + */ +#define XPD 0, + +/* Define to support tiny denormal numbers, else undefine. */ +#define DENORMAL 1 + +/* Define to ask for infinity support, else undefine. */ +#define INFINITIES 1 + +/* Define to ask for support of numbers that are Not-a-Number, + else undefine. This may automatically define INFINITIES in some files. */ +#define NANS 1 + +/* Define to distinguish between -0.0 and +0.0. */ +#define MINUSZERO 1 + +/* Define 1 for ANSI C atan2() function + See atan.c and clog.c. */ +#define ANSIC 1 + +/* Get ANSI function prototypes, if you want them. */ +#if 1 +/* #ifdef __STDC__ */ +#define ANSIPROT 1 +int mtherr ( char *, int ); +#else +int mtherr(); +#endif + +/* Variable for error reporting. See mtherr.c. */ +extern int merror; diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index edf03a393d..59ac158a83 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1495,8 +1495,28 @@ def grad(self, inp, grads): ), ] - def c_code(self, *args, **kwargs): - raise NotImplementedError() + def c_header_dirs(self, **kwargs): + # Using the mconf_.h header for betainc.c + res = [ + *super().c_header_dirs(**kwargs), + os.path.join(os.path.dirname(__file__), "c_code"), + ] + return res + + + def c_support_code(self, **kwargs): + with open(os.path.join(os.path.dirname(__file__), "c_code", "incbet.c")) as f: + raw = f.read() + return raw + + def c_code(self, node, name, inp, out, sub): + (a, b, x) = inp + (z,) = out + if node.inputs[0].type in float_types: + dtype = "npy_" + node.outputs[0].dtype + return f"{z} = ({dtype}) incbet({a}, {b}, {x});" + + raise NotImplementedError("type not supported", type) betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")