-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
The scipy.special.hyp2f1 function was added here, but because of its use of lax.while_loop, the implementation does not work with autodiff. I think it would be beneficial to implement a custom derivative, similarly to how it's already been done for the scipy.special.hyp1f1 function.
Lines 2642 to 2646 in d39f29c
| hyp1f1.defjvps( | |
| lambda a_dot, primal_out, a, b, x: _hyp1f1_a_derivative(a, b, x) * a_dot, | |
| lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot, | |
| lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot | |
| ) |
Based on this open issue with PyTensor, it seems that there's a need.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request