Skip to content

Commit a1c5b7b

Browse files
authored
Merge pull request #92 from fzyukio/master
Allow provider to be a context manager (sync/async)
2 parents 5be9189 + f5272ca commit a1c5b7b

File tree

2 files changed

+150
-3
lines changed

2 files changed

+150
-3
lines changed

src/inject/__init__.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def my_config(binder):
7373
inject.configure(my_config)
7474
7575
"""
76+
import contextlib
77+
7678
from inject._version import __version__
7779

7880
import inspect
@@ -156,7 +158,10 @@ def bind_to_constructor(self, cls: Binding, constructor: Constructor) -> 'Binder
156158
return self
157159

158160
def bind_to_provider(self, cls: Binding, provider: Provider) -> 'Binder':
159-
"""Bind a class to a callable instance provider executed for each injection."""
161+
"""
162+
Bind a class to a callable instance provider executed for each injection.
163+
A provider can be a normal function or a context manager. Both sync and async are supported.
164+
"""
160165
self._check_class(cls)
161166
if provider is None:
162167
raise InjectorException('Provider cannot be None, key=%s' % cls)
@@ -323,6 +328,35 @@ class _ParametersInjection(Generic[T]):
323328
def __init__(self, **kwargs: Any) -> None:
324329
self._params = kwargs
325330

331+
@staticmethod
332+
def _aggregate_sync_stack(
333+
sync_stack: contextlib.ExitStack,
334+
provided_params: frozenset[str],
335+
kwargs: dict[str, Any]
336+
) -> None:
337+
"""Extracts context managers, aggregate them in an ExitStack and swap out the param value with results of
338+
running __enter__(). The result is equivalent to using `with` multiple times """
339+
executed_kwargs = {
340+
param: sync_stack.enter_context(inst)
341+
for param, inst in kwargs.items()
342+
if param not in provided_params and isinstance(inst, contextlib._GeneratorContextManager)
343+
}
344+
kwargs.update(executed_kwargs)
345+
346+
@staticmethod
347+
async def _aggregate_async_stack(
348+
async_stack: contextlib.AsyncExitStack,
349+
provided_params: frozenset[str],
350+
kwargs: dict[str, Any]
351+
) -> None:
352+
"""Similar to _aggregate_sync_stack, but for async context managers"""
353+
executed_kwargs = {
354+
param: await async_stack.enter_async_context(inst)
355+
for param, inst in kwargs.items()
356+
if param not in provided_params and isinstance(inst, contextlib._AsyncGeneratorContextManager)
357+
}
358+
kwargs.update(executed_kwargs)
359+
326360
def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[..., Union[Awaitable[T], T]]:
327361
if sys.version_info.major == 2:
328362
arg_names = inspect.getargspec(func).args
@@ -340,7 +374,11 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
340374
kwargs[param] = instance(cls)
341375
async_func = cast(Callable[..., Awaitable[T]], func)
342376
try:
343-
return await async_func(*args, **kwargs)
377+
with contextlib.ExitStack() as sync_stack:
378+
async with contextlib.AsyncExitStack() as async_stack:
379+
self._aggregate_sync_stack(sync_stack, provided_params, kwargs)
380+
await self._aggregate_async_stack(async_stack, provided_params, kwargs)
381+
return await async_func(*args, **kwargs)
344382
except TypeError as previous_error:
345383
raise ConstructorTypeError(func, previous_error)
346384

@@ -355,7 +393,9 @@ def injection_wrapper(*args: Any, **kwargs: Any) -> T:
355393
kwargs[param] = instance(cls)
356394
sync_func = cast(Callable[..., T], func)
357395
try:
358-
return sync_func(*args, **kwargs)
396+
with contextlib.ExitStack() as sync_stack:
397+
self._aggregate_sync_stack(sync_stack, provided_params, kwargs)
398+
return sync_func(*args, **kwargs)
359399
except TypeError as previous_error:
360400
raise ConstructorTypeError(func, previous_error)
361401
return injection_wrapper

test/test_context_manager.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import contextlib
2+
3+
import inject
4+
from test import BaseTestInject
5+
6+
7+
class Destroyable:
8+
def __init__(self):
9+
self.started = True
10+
11+
def destroy(self):
12+
self.started = False
13+
14+
15+
class MockFile(Destroyable):
16+
...
17+
18+
19+
class MockConnection(Destroyable):
20+
...
21+
22+
23+
class MockFoo(Destroyable):
24+
...
25+
26+
27+
@contextlib.contextmanager
28+
def get_file_sync():
29+
obj = MockFile()
30+
yield obj
31+
obj.destroy()
32+
33+
34+
@contextlib.contextmanager
35+
def get_conn_sync():
36+
obj = MockConnection()
37+
yield obj
38+
obj.destroy()
39+
40+
41+
@contextlib.contextmanager
42+
def get_foo_sync():
43+
obj = MockFoo()
44+
yield obj
45+
obj.destroy()
46+
47+
48+
@contextlib.asynccontextmanager
49+
async def get_file_async():
50+
obj = MockFile()
51+
yield obj
52+
obj.destroy()
53+
54+
55+
@contextlib.asynccontextmanager
56+
async def get_conn_async():
57+
obj = MockConnection()
58+
yield obj
59+
obj.destroy()
60+
61+
62+
class TestContextManagerFunctional(BaseTestInject):
63+
64+
def test_provider_as_context_manager_sync(self):
65+
def config(binder):
66+
binder.bind_to_provider(MockFile, get_file_sync)
67+
binder.bind(int, 100)
68+
binder.bind_to_provider(str, lambda: "Hello")
69+
binder.bind_to_provider(MockConnection, get_conn_sync)
70+
71+
inject.configure(config)
72+
73+
@inject.autoparams()
74+
def mock_func(conn: MockConnection, name: str, f: MockFile, number: int):
75+
assert f.started
76+
assert conn.started
77+
assert name == "Hello"
78+
assert number == 100
79+
return f, conn
80+
81+
f_, conn_ = mock_func()
82+
assert not f_.started
83+
assert not conn_.started
84+
85+
def test_provider_as_context_manager_async(self):
86+
def config(binder):
87+
binder.bind_to_provider(MockFile, get_file_async)
88+
binder.bind(int, 100)
89+
binder.bind_to_provider(str, lambda: "Hello")
90+
binder.bind_to_provider(MockConnection, get_conn_async)
91+
binder.bind_to_provider(MockFoo, get_foo_sync)
92+
93+
inject.configure(config)
94+
95+
@inject.autoparams()
96+
async def mock_func(conn: MockConnection, name: str, f: MockFile, number: int, foo: MockFoo):
97+
assert f.started
98+
assert conn.started
99+
assert foo.started
100+
assert name == "Hello"
101+
assert number == 100
102+
return f, conn, foo
103+
104+
f_, conn_, foo_ = self.run_async(mock_func())
105+
assert not f_.started
106+
assert not conn_.started
107+
assert not foo_.started

0 commit comments

Comments
 (0)