Skip to content

Add custom derivative for scipy.special.hyp2f1 #30195

@mattbahr

Description

@mattbahr

The scipy.special.hyp2f1 function was added here, but because of its use of lax.while_loop, the implementation does not work with autodiff. I think it would be beneficial to implement a custom derivative, similarly to how it's already been done for the scipy.special.hyp1f1 function.

jax/jax/_src/scipy/special.py

Lines 2642 to 2646 in d39f29c

hyp1f1.defjvps(
lambda a_dot, primal_out, a, b, x: _hyp1f1_a_derivative(a, b, x) * a_dot,
lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot,
lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot
)

Based on this open issue with PyTensor, it seems that there's a need.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions