@@ -2637,6 +2637,335 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
26372637)
26382638
26392639
2640+ def _binom (n , k ):
2641+ a = lax .lgamma (n + 1.0 )
2642+ b = lax .lgamma (n - k + 1.0 )
2643+ c = lax .lgamma (k + 1.0 )
2644+
2645+ return lax .exp (a - b - c )
2646+
2647+
2648+ def _poch (q , n ):
2649+ """
2650+ `jax.scipy.special.poch` does not allow for non-positive integer q.
2651+ """
2652+ def body (i , state ):
2653+ q , prod = state
2654+
2655+ prod *= q + i
2656+
2657+ return q , prod
2658+
2659+ return lax .cond (
2660+ n == 0 ,
2661+ lambda : jnp .array (1 , dtype = q .dtype ),
2662+ lambda : lax .fori_loop (1. , n , body , (q , q ))[1 ]
2663+ )
2664+
2665+
2666+ def _hyp2f1_terminal (a , b , c , x ):
2667+ """
2668+ The Taylor series representation of the 2F1 hypergeometric function
2669+ terminates when either a or b is a non-positive integer. See Eq. 4.1 and
2670+ Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014
2671+ https://doi.org/10.48550/arXiv.1407.7786
2672+ """
2673+ # Ensure that between a and b, the negative integer parameter with the greater
2674+ # absolute value - that still has a magnitude less than the absolute value of
2675+ # c if c is non-positive - is used for the upper limit in the loop.
2676+ temp = a
2677+ a = jnp .where (
2678+ jnp .logical_and (
2679+ b < a ,
2680+ jnp .logical_and (
2681+ b % 1 == 0 ,
2682+ jnp .logical_not (
2683+ jnp .logical_and (
2684+ c % 1 == 0 ,
2685+ jnp .logical_and (
2686+ c <= 0 ,
2687+ c > b
2688+ )
2689+ )
2690+ )
2691+ )
2692+ ), b , a
2693+ )
2694+ b = jnp .where (
2695+ jnp .logical_and (
2696+ b < temp ,
2697+ jnp .logical_and (
2698+ b % 1 == 0 ,
2699+ jnp .logical_not (
2700+ jnp .logical_and (
2701+ c % 1 == 0 ,
2702+ jnp .logical_and (
2703+ c <= 0 ,
2704+ c > b
2705+ )
2706+ )
2707+ )
2708+ )
2709+ ), temp , b
2710+ )
2711+
2712+ def body (i , sum ):
2713+ sum += (- 1 ) ** i * _binom (jnp .abs (a ), i ) / _poch (c , i ) * _poch (b , i ) * x ** i
2714+
2715+ return sum
2716+
2717+ return lax .fori_loop (0. , jnp .abs (a ) + 1 , body , 0. )
2718+
2719+
2720+ def _hyp2f1_serie (a , b , c , x ):
2721+ """
2722+ Compute the 2F1 hypergeometric function using the Taylor expansion.
2723+ See Eq. 4.1 from PEARSON, OLVER & PORTER 2014
2724+ https://doi.org/10.48550/arXiv.1407.7786
2725+ """
2726+ precision = jnp .finfo (jnp .float32 ).eps
2727+
2728+ s = 1 - x
2729+
2730+ neg_int_a = jnp .logical_and (a <= 0 , a % 1 == 0 )
2731+ neg_int_b = jnp .logical_and (b <= 0 , b % 1 == 0 )
2732+ neg_int_c = jnp .logical_and (c <= 0 , c % 1 == 0 )
2733+
2734+ def body (state ):
2735+ serie , k , term = state
2736+ serie += term
2737+ term = _poch (a , k ) / _poch (c , k ) * _poch (b , k ) / factorial (k ) * x ** k
2738+ k += 1
2739+
2740+ return serie , k , term
2741+
2742+ def cond (state ):
2743+ serie , k , term = state
2744+
2745+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > precision )
2746+
2747+ init = 0. , 1. , 1.
2748+
2749+ return lax .while_loop (cond , body , init )[0 ]
2750+
2751+
2752+ def _hyp2f1_terminal_or_serie (a , b , c , x ):
2753+ """
2754+ Check for recurrence relations along with whether or not the series
2755+ terminates. True recursion is not possible; however, the recurrence
2756+ relation may still be approximated.
2757+ See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014
2758+ https://doi.org/10.48550/arXiv.1407.7786
2759+ """
2760+ neg_int_a = jnp .logical_and (a <= 0 , a % 1 == 0 )
2761+ neg_int_b = jnp .logical_and (b <= 0 , b % 1 == 0 )
2762+ neg_int_c = jnp .logical_and (c <= 0 , c % 1 == 0 )
2763+ neg_int_a_or_b = jnp .logical_or (neg_int_a , neg_int_b )
2764+ not_neg_int_a_or_b = jnp .logical_not (neg_int_a_or_b )
2765+
2766+ s = 1 - x
2767+ d = c - a - b
2768+
2769+ index = jnp .where (
2770+ jnp .logical_and (
2771+ neg_int_c ,
2772+ jnp .logical_and (
2773+ jnp .logical_not (jnp .logical_and (neg_int_a , a > c )),
2774+ jnp .logical_not (jnp .logical_and (neg_int_b , b > c ))
2775+ )
2776+ ), 0 ,
2777+ jnp .where (jnp .logical_and (x < - 0.5 , not_neg_int_a_or_b ),
2778+ jnp .where (b > a , 1 , 2 ),
2779+ jnp .where (jnp .logical_and (x > 0.9 , not_neg_int_a_or_b ),
2780+ jnp .where (d % 1 != 0 , 3 , 4 ),
2781+ jnp .where (jnp .logical_and (jnp .logical_not (neg_int_c ), neg_int_a_or_b ), 5 , 3 ))))
2782+
2783+ return lax .select_n (index ,
2784+ jnp .array (jnp .inf , dtype = x .dtype ),
2785+ s ** (- a ) * _hyp2f1_serie (a , c - b , c , - x / s ),
2786+ s ** (- b ) * _hyp2f1_serie (c - a , b , c , - x / s ),
2787+ _hyp2f1_serie (a , b , c , x ),
2788+ _hyp2f1_digamma_transform (a , b , c , x ),
2789+ _hyp2f1_terminal (a , b , c , x ))
2790+
2791+
2792+ def _hyp2f1_gamma_transform (a , b , c , x ):
2793+ """
2794+ Gamma transformations of the 2F1 hypergeometric function.
2795+ """
2796+
2797+ def transform_1 ():
2798+ """
2799+ See Eq. 4.10 and Analytic Continuation Formulas from PEARSON, OLVER & PORTER 2014
2800+ https://doi.org/10.48550/arXiv.1407.7786
2801+ """
2802+ p = _hyp2f1_serie (a , 1 - c + a , 1 - b + a , 1 / x )
2803+ q = _hyp2f1_serie (b , 1 - c + b , 1 - a + b , 1 / x )
2804+ p *= (- x ) ** (- a )
2805+ q *= (- x ) ** (- b )
2806+ t1 = gamma (c )
2807+ s = t1 * gamma (b - a ) / (gamma (b ) * gamma (c - a ))
2808+ y = t1 * gamma (a - b ) / (gamma (a ) * gamma (c - b ))
2809+
2810+ return s * p + y * q
2811+
2812+ def transform_2 ():
2813+ """
2814+ See 4.1 Properties of F from PEARSON, OLVER & PORTER 2014
2815+ https://doi.org/10.48550/arXiv.1407.7786
2816+ """
2817+ return gamma (c ) * gamma (c - a - b ) / (gamma (c - a ) * gamma (c - b ))
2818+
2819+ return jnp .where (
2820+ x < - 2 ,
2821+ transform_1 (),
2822+ transform_2 ()
2823+ )
2824+
2825+
2826+ def _hyp2f1_digamma_transform (a , b , c , x ):
2827+ """
2828+ Digamma transformation of the 2F1 hypergeometric function.
2829+ See AMS55 #15.3.10, #15.3.11, #15.3.12
2830+ """
2831+ precision = jnp .finfo (jnp .float32 ).eps
2832+
2833+ d = c - a - b
2834+ s = 1 - x
2835+ id = jnp .round (d )
2836+
2837+ e = jnp .where (id >= 0 , d , - d )
2838+ d1 = jnp .where (id >= 0 , d , 0. )
2839+ d2 = jnp .where (id >= 0 , 0. , d )
2840+ aid = jnp .where (id >= 0 , id , - id ).astype ('int32' )
2841+
2842+ ax = jnp .log (s )
2843+
2844+ y = digamma (1.0 ) + digamma (1.0 + e ) - digamma (a + d1 ) - digamma (b + d1 ) - ax
2845+ y /= gamma (e + 1.0 )
2846+
2847+ p = (a + d1 ) * (b + d1 ) * s / gamma (e + 2.0 )
2848+
2849+ def cond (state ):
2850+ _ , _ , _ , _ , _ , _ , q , _ , _ , t , y = state
2851+
2852+ return jnp .logical_and (
2853+ t < 250 ,
2854+ jnp .logical_or (y == 0 , jnp .abs (q / y ) > precision )
2855+ )
2856+
2857+ def body (state ):
2858+ a , ax , b , d1 , e , p , q , r , s , t , y = state
2859+
2860+ r = digamma (1.0 + t ) + digamma (1.0 + t + e ) - digamma (a + t + d1 ) \
2861+ - digamma (b + t + d1 ) - ax
2862+ q = p * r
2863+ y += q
2864+ p *= s * (a + t + d1 ) / (t + 1.0 )
2865+ p *= (b + t + d1 ) / (t + 1.0 + e )
2866+ t += 1.0
2867+
2868+ return a , ax , b , d1 , e , p , q , r , s , t , y
2869+
2870+ init = a , ax , b , d1 , e , p , y , 0.0 , s , 1.0 , y
2871+ _ , _ , _ , _ , _ , _ , q , r , _ , _ , y = lax .while_loop (cond , body , init )
2872+
2873+ def compute_sum (y ):
2874+ y1 = 1.0
2875+ t = 0.0
2876+ p = 1.0
2877+
2878+ def for_body (i , state ):
2879+ a , b , d2 , e , p , s , t , y1 = state
2880+
2881+ r = 1.0 - e + t
2882+ p *= s * (a + t + d2 ) * (b + t + d2 ) / r
2883+ t += 1.0
2884+ p /= t
2885+ y1 += p
2886+
2887+ return a , b , d2 , e , p , s , t , y1
2888+
2889+ init_val = a , b , d2 , e , p , s , t , y1
2890+ y1 = lax .fori_loop (1 , aid , for_body , init_val )[- 1 ]
2891+
2892+ p = gamma (c )
2893+ y1 *= gamma (e ) * p / (gamma (a + d1 ) * gamma (b + d1 ))
2894+ y *= p / (gamma (a + d2 ) * gamma (b + d2 ))
2895+
2896+ y = jnp .where ((aid & 1 ) != 0 , - y , y )
2897+ q = s ** id
2898+
2899+ return jnp .where (id > 0 , y * q + y1 , y + y1 * q )
2900+
2901+ return jnp .where (
2902+ id == 0 ,
2903+ y * gamma (c ) / (gamma (a ) * gamma (b )),
2904+ compute_sum (y )
2905+ )
2906+
2907+
2908+ @jit
2909+ @jnp .vectorize
2910+ def hyp2f1 (a : ArrayLike , b : ArrayLike , c : ArrayLike , x : ArrayLike ) -> Array :
2911+ r"""The 2F1 hypergeometric function.
2912+
2913+ JAX implementation of :obj:`scipy.special.hyp2f1`.
2914+
2915+ .. math::
2916+
2917+ \mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!}
2918+
2919+ where :math:`(\cdot)_k` is the Pochammer symbol.
2920+
2921+ The JAX version only accepts positive and real inputs. Values of
2922+ ``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may
2923+ lead to erroneous results; consider enabling double precision in this case.
2924+
2925+ Args:
2926+ a: arraylike, real-valued
2927+ b: arraylike, real-valued
2928+ c: arraylike, real-valued
2929+ x: arraylike, real-valued
2930+
2931+ Returns:
2932+ array of 2F1 values.
2933+ """
2934+ # This is backed by https://doi.org/10.48550/arXiv.1407.7786
2935+ a , b , c , x = promote_args_inexact ('hyp2f1' , a , b , c , x )
2936+
2937+ d = c - a - b
2938+ s = 1 - x
2939+
2940+ neg_int_ca = jnp .logical_and (c - a <= 0 , (c - a ) % 1 == 0 )
2941+ neg_int_cb = jnp .logical_and (c - b <= 0 , (c - b ) % 1 == 0 )
2942+ neg_int_ca_or_cb = jnp .logical_or (neg_int_ca , neg_int_cb )
2943+
2944+ index = jnp .where (jnp .logical_or (x == 0 , jnp .logical_and (jnp .logical_or (a == 0 , b == 0 ), c != 0 )), 0 ,
2945+ jnp .where (c == 0 , 2 ,
2946+ jnp .where (jnp .logical_and (d <= - 1 , jnp .logical_not (jnp .logical_and (d % 1 != 0 , s < 0 ))), 1 ,
2947+ jnp .where (jnp .logical_and (d <= 0 , x == 1 ), 2 ,
2948+ jnp .where (jnp .logical_and (x < 1 , b == c ), 3 ,
2949+ jnp .where (jnp .logical_and (x < 1 , a == c ), 4 ,
2950+ jnp .where (x > 1 , 2 ,
2951+ jnp .where (x == 1 ,
2952+ jnp .where (neg_int_ca_or_cb ,
2953+ jnp .where (d >= 0 , 5 , 2 ),
2954+ jnp .where (d <= 0 , 2 , 6 )),
2955+ jnp .where (d < 0 , 7 ,
2956+ jnp .where (neg_int_ca_or_cb , 5 , 7 ))))))))))
2957+
2958+ return lax .select_n (index ,
2959+ jnp .array (1 , dtype = x .dtype ),
2960+ s ** d * _hyp2f1_terminal_or_serie (c - a , c - b , c , x ),
2961+ jnp .array (jnp .inf , dtype = x .dtype ),
2962+ s ** (- a ),
2963+ s ** (- b ),
2964+ s ** d * _hyp2f1_serie (c - a , c - b , c , x ),
2965+ _hyp2f1_gamma_transform (a , b , c , x ),
2966+ _hyp2f1_terminal_or_serie (a , b , c , x ))
2967+
2968+
26402969def softmax (x : ArrayLike ,
26412970 / ,
26422971 * ,
0 commit comments