Skip to content

Commit 47a16ec

Browse files
committed
implement hyp2f1
1 parent f346fd0 commit 47a16ec

File tree

3 files changed

+334
-0
lines changed

3 files changed

+334
-0
lines changed

jax/_src/scipy/special.py

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,6 +2637,335 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
26372637
)
26382638

26392639

2640+
def _binom(n, k):
2641+
a = lax.lgamma(n + 1.0)
2642+
b = lax.lgamma(n - k + 1.0)
2643+
c = lax.lgamma(k + 1.0)
2644+
2645+
return lax.exp(a - b - c)
2646+
2647+
2648+
def _poch(q, n):
2649+
"""
2650+
`jax.scipy.special.poch` does not allow for non-positive integer q.
2651+
"""
2652+
def body(i, state):
2653+
q, prod = state
2654+
2655+
prod *= q + i
2656+
2657+
return q, prod
2658+
2659+
return lax.cond(
2660+
n == 0,
2661+
lambda: jnp.array(1, dtype=q.dtype),
2662+
lambda: lax.fori_loop(1., n, body, (q, q))[1]
2663+
)
2664+
2665+
2666+
def _hyp2f1_terminal(a, b, c, x):
2667+
"""
2668+
The Taylor series representation of the 2F1 hypergeometric function
2669+
terminates when either a or b is a non-positive integer. See Eq. 4.1 and
2670+
Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014
2671+
https://doi.org/10.48550/arXiv.1407.7786
2672+
"""
2673+
# Ensure that between a and b, the negative integer parameter with the greater
2674+
# absolute value - that still has a magnitude less than the absolute value of
2675+
# c if c is non-positive - is used for the upper limit in the loop.
2676+
temp = a
2677+
a = jnp.where(
2678+
jnp.logical_and(
2679+
b < a,
2680+
jnp.logical_and(
2681+
b % 1 == 0,
2682+
jnp.logical_not(
2683+
jnp.logical_and(
2684+
c % 1 == 0,
2685+
jnp.logical_and(
2686+
c <= 0,
2687+
c > b
2688+
)
2689+
)
2690+
)
2691+
)
2692+
), b, a
2693+
)
2694+
b = jnp.where(
2695+
jnp.logical_and(
2696+
b < temp,
2697+
jnp.logical_and(
2698+
b % 1 == 0,
2699+
jnp.logical_not(
2700+
jnp.logical_and(
2701+
c % 1 == 0,
2702+
jnp.logical_and(
2703+
c <= 0,
2704+
c > b
2705+
)
2706+
)
2707+
)
2708+
)
2709+
), temp, b
2710+
)
2711+
2712+
def body(i, sum):
2713+
sum += (-1) ** i * _binom(jnp.abs(a), i) / _poch(c, i) * _poch(b, i) * x ** i
2714+
2715+
return sum
2716+
2717+
return lax.fori_loop(0., jnp.abs(a) + 1, body, 0.)
2718+
2719+
2720+
def _hyp2f1_serie(a, b, c, x):
2721+
"""
2722+
Compute the 2F1 hypergeometric function using the Taylor expansion.
2723+
See Eq. 4.1 from PEARSON, OLVER & PORTER 2014
2724+
https://doi.org/10.48550/arXiv.1407.7786
2725+
"""
2726+
precision = jnp.finfo(jnp.float32).eps
2727+
2728+
s = 1 - x
2729+
2730+
neg_int_a = jnp.logical_and(a <= 0, a % 1 == 0)
2731+
neg_int_b = jnp.logical_and(b <= 0, b % 1 == 0)
2732+
neg_int_c = jnp.logical_and(c <= 0, c % 1 == 0)
2733+
2734+
def body(state):
2735+
serie, k, term = state
2736+
serie += term
2737+
term = _poch(a, k) / _poch(c, k) * _poch(b, k) / factorial(k) * x ** k
2738+
k += 1
2739+
2740+
return serie, k, term
2741+
2742+
def cond(state):
2743+
serie, k, term = state
2744+
2745+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
2746+
2747+
init = 0., 1., 1.
2748+
2749+
return lax.while_loop(cond, body, init)[0]
2750+
2751+
2752+
def _hyp2f1_terminal_or_serie(a, b, c, x):
2753+
"""
2754+
Check for recurrence relations along with whether or not the series
2755+
terminates. True recursion is not possible; however, the recurrence
2756+
relation may still be approximated.
2757+
See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014
2758+
https://doi.org/10.48550/arXiv.1407.7786
2759+
"""
2760+
neg_int_a = jnp.logical_and(a <= 0, a % 1 == 0)
2761+
neg_int_b = jnp.logical_and(b <= 0, b % 1 == 0)
2762+
neg_int_c = jnp.logical_and(c <= 0, c % 1 == 0)
2763+
neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b)
2764+
not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b)
2765+
2766+
s = 1 - x
2767+
d = c - a - b
2768+
2769+
index = jnp.where(
2770+
jnp.logical_and(
2771+
neg_int_c,
2772+
jnp.logical_and(
2773+
jnp.logical_not(jnp.logical_and(neg_int_a, a > c)),
2774+
jnp.logical_not(jnp.logical_and(neg_int_b, b > c))
2775+
)
2776+
), 0,
2777+
jnp.where(jnp.logical_and(x < -0.5, not_neg_int_a_or_b),
2778+
jnp.where(b > a, 1, 2),
2779+
jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b),
2780+
jnp.where(d % 1 != 0, 3, 4),
2781+
jnp.where(jnp.logical_and(jnp.logical_not(neg_int_c), neg_int_a_or_b), 5, 3))))
2782+
2783+
return lax.select_n(index,
2784+
jnp.array(jnp.inf, dtype=x.dtype),
2785+
s ** (-a) * _hyp2f1_serie(a, c - b, c, -x / s),
2786+
s ** (-b) * _hyp2f1_serie(c - a, b, c, -x / s),
2787+
_hyp2f1_serie(a, b, c, x),
2788+
_hyp2f1_digamma_transform(a, b, c, x),
2789+
_hyp2f1_terminal(a, b, c, x))
2790+
2791+
2792+
def _hyp2f1_gamma_transform(a, b, c, x):
2793+
"""
2794+
Gamma transformations of the 2F1 hypergeometric function.
2795+
"""
2796+
2797+
def transform_1():
2798+
"""
2799+
See Eq. 4.10 and Analytic Continuation Formulas from PEARSON, OLVER & PORTER 2014
2800+
https://doi.org/10.48550/arXiv.1407.7786
2801+
"""
2802+
p = _hyp2f1_serie(a, 1 - c + a, 1 - b + a, 1 / x)
2803+
q = _hyp2f1_serie(b, 1 - c + b, 1 - a + b, 1 / x)
2804+
p *= (-x) ** (-a)
2805+
q *= (-x) ** (-b)
2806+
t1 = gamma(c)
2807+
s = t1 * gamma(b - a) / (gamma(b) * gamma(c - a))
2808+
y = t1 * gamma(a - b) / (gamma(a) * gamma(c - b))
2809+
2810+
return s * p + y * q
2811+
2812+
def transform_2():
2813+
"""
2814+
See 4.1 Properties of F from PEARSON, OLVER & PORTER 2014
2815+
https://doi.org/10.48550/arXiv.1407.7786
2816+
"""
2817+
return gamma(c) * gamma(c - a - b) / (gamma(c - a) * gamma(c - b))
2818+
2819+
return jnp.where(
2820+
x < -2,
2821+
transform_1(),
2822+
transform_2()
2823+
)
2824+
2825+
2826+
def _hyp2f1_digamma_transform(a, b, c, x):
2827+
"""
2828+
Digamma transformation of the 2F1 hypergeometric function.
2829+
See AMS55 #15.3.10, #15.3.11, #15.3.12
2830+
"""
2831+
precision = jnp.finfo(jnp.float32).eps
2832+
2833+
d = c - a - b
2834+
s = 1 - x
2835+
id = jnp.round(d)
2836+
2837+
e = jnp.where(id >= 0, d, -d)
2838+
d1 = jnp.where(id >= 0, d, 0.)
2839+
d2 = jnp.where(id >= 0, 0., d)
2840+
aid = jnp.where(id >= 0, id, -id).astype('int32')
2841+
2842+
ax = jnp.log(s)
2843+
2844+
y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax
2845+
y /= gamma(e + 1.0)
2846+
2847+
p = (a + d1) * (b + d1) * s / gamma(e + 2.0)
2848+
2849+
def cond(state):
2850+
_, _, _, _, _, _, q, _, _, t, y = state
2851+
2852+
return jnp.logical_and(
2853+
t < 250,
2854+
jnp.logical_or(y == 0, jnp.abs(q / y) > precision)
2855+
)
2856+
2857+
def body(state):
2858+
a, ax, b, d1, e, p, q, r, s, t, y = state
2859+
2860+
r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \
2861+
- digamma(b + t + d1) - ax
2862+
q = p * r
2863+
y += q
2864+
p *= s * (a + t + d1) / (t + 1.0)
2865+
p *= (b + t + d1) / (t + 1.0 + e)
2866+
t += 1.0
2867+
2868+
return a, ax, b, d1, e, p, q, r, s, t, y
2869+
2870+
init = a, ax, b, d1, e, p, y, 0.0, s, 1.0, y
2871+
_, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init)
2872+
2873+
def compute_sum(y):
2874+
y1 = 1.0
2875+
t = 0.0
2876+
p = 1.0
2877+
2878+
def for_body(i, state):
2879+
a, b, d2, e, p, s, t, y1 = state
2880+
2881+
r = 1.0 - e + t
2882+
p *= s * (a + t + d2) * (b + t + d2) / r
2883+
t += 1.0
2884+
p /= t
2885+
y1 += p
2886+
2887+
return a, b, d2, e, p, s, t, y1
2888+
2889+
init_val = a, b, d2, e, p, s, t, y1
2890+
y1 = lax.fori_loop(1, aid, for_body, init_val)[-1]
2891+
2892+
p = gamma(c)
2893+
y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1))
2894+
y *= p / (gamma(a + d2) * gamma(b + d2))
2895+
2896+
y = jnp.where((aid & 1) != 0, -y, y)
2897+
q = s ** id
2898+
2899+
return jnp.where(id > 0, y * q + y1, y + y1 * q)
2900+
2901+
return jnp.where(
2902+
id == 0,
2903+
y * gamma(c) / (gamma(a) * gamma(b)),
2904+
compute_sum(y)
2905+
)
2906+
2907+
2908+
@jit
2909+
@jnp.vectorize
2910+
def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array:
2911+
r"""The 2F1 hypergeometric function.
2912+
2913+
JAX implementation of :obj:`scipy.special.hyp2f1`.
2914+
2915+
.. math::
2916+
2917+
\mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!}
2918+
2919+
where :math:`(\cdot)_k` is the Pochammer symbol.
2920+
2921+
The JAX version only accepts positive and real inputs. Values of
2922+
``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may
2923+
lead to erroneous results; consider enabling double precision in this case.
2924+
2925+
Args:
2926+
a: arraylike, real-valued
2927+
b: arraylike, real-valued
2928+
c: arraylike, real-valued
2929+
x: arraylike, real-valued
2930+
2931+
Returns:
2932+
array of 2F1 values.
2933+
"""
2934+
# This is backed by https://doi.org/10.48550/arXiv.1407.7786
2935+
a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x)
2936+
2937+
d = c - a - b
2938+
s = 1 - x
2939+
2940+
neg_int_ca = jnp.logical_and(c - a <= 0, (c - a) % 1 == 0)
2941+
neg_int_cb = jnp.logical_and(c - b <= 0, (c - b) % 1 == 0)
2942+
neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb)
2943+
2944+
index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0,
2945+
jnp.where(c == 0, 2,
2946+
jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(d % 1 != 0, s < 0))), 1,
2947+
jnp.where(jnp.logical_and(d <= 0, x == 1), 2,
2948+
jnp.where(jnp.logical_and(x < 1, b == c), 3,
2949+
jnp.where(jnp.logical_and(x < 1, a == c), 4,
2950+
jnp.where(x > 1, 2,
2951+
jnp.where(x == 1,
2952+
jnp.where(neg_int_ca_or_cb,
2953+
jnp.where(d >= 0, 5, 2),
2954+
jnp.where(d <= 0, 2, 6)),
2955+
jnp.where(d < 0, 7,
2956+
jnp.where(neg_int_ca_or_cb, 5, 7))))))))))
2957+
2958+
return lax.select_n(index,
2959+
jnp.array(1, dtype=x.dtype),
2960+
s ** d * _hyp2f1_terminal_or_serie(c - a, c - b, c, x),
2961+
jnp.array(jnp.inf, dtype=x.dtype),
2962+
s ** (-a),
2963+
s ** (-b),
2964+
s ** d * _hyp2f1_serie(c - a, c - b, c, x),
2965+
_hyp2f1_gamma_transform(a, b, c, x),
2966+
_hyp2f1_terminal_or_serie(a, b, c, x))
2967+
2968+
26402969
def softmax(x: ArrayLike,
26412970
/,
26422971
*,

jax/scipy/special.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
gammaln as gammaln,
3838
gammasgn as gammasgn,
3939
hyp1f1 as hyp1f1,
40+
hyp2f1 as hyp2f1,
4041
i0 as i0,
4142
i0e as i0e,
4243
i1 as i1,

tests/lax_scipy_special_functions_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
157157
"hyp1f1", 3, float_dtypes,
158158
functools.partial(jtu.rand_uniform, low=0.5, high=30), True
159159
),
160+
op_record(
161+
"hyp2f1", 4, float_dtypes,
162+
functools.partial(jtu.rand_uniform, low=0.5, high=30), False
163+
),
160164
op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True),
161165
op_record("softmax", 1, float_dtypes, jtu.rand_default, True),
162166
]

0 commit comments

Comments
 (0)