File tree Expand file tree Collapse file tree 2 files changed +37
-2
lines changed Expand file tree Collapse file tree 2 files changed +37
-2
lines changed Original file line number Diff line number Diff line change @@ -19,7 +19,7 @@ def is_array_api_obj(x):
1919 """
2020 return _is_numpy_array (x ) or hasattr (x , '__array_namespace__' )
2121
22- def get_namespace (* xs ):
22+ def get_namespace (* xs , _use_compat = True ):
2323 """
2424 Get the array API compatible namespace for the arrays `xs`.
2525
@@ -30,7 +30,10 @@ def get_namespace(*xs):
3030 if hasattr (x , '__array_namespace__' ):
3131 namespaces .add (x .__array_namespace__ )
3232 elif _is_numpy_array (x ):
33- namespaces .add (compat_namespace )
33+ if _use_compat :
34+ namespaces .add (compat_namespace )
35+ else :
36+ namespaces .add (np )
3437 else :
3538 # TODO: Support Python scalars?
3639 raise ValueError ("The input is not a supported array type" )
Original file line number Diff line number Diff line change 1+ """
2+ Internal helpers
3+ """
4+
5+ from functools import wraps
6+ from inspect import signature
7+
8+ from ._helpers import get_namespace
9+
10+ def get_xp (f ):
11+ """
12+ Decorator to automatically replace xp with the corresponding array module
13+
14+ Use like
15+
16+ @get_xp
17+ def func(x, /, xp, kwarg=None):
18+ return xp.func(x, kwarg=kwarg)
19+
20+ Note that xp must be able to be passed as a keyword argument.
21+ """
22+ @wraps (f )
23+ def inner (* args , ** kwargs ):
24+ xp = get_namespace (* args , _use_compat = False )
25+ return f (* args , xp = xp , ** kwargs )
26+
27+ sig = signature (f )
28+ new_sig = sig .replace (parameters = [sig .parameters [i ] for i in sig .parameters if i != 'xp' ])
29+
30+ inner .__signature__ = new_sig
31+
32+ return inner
You can’t perform that action at this time.
0 commit comments