@@ -73,6 +73,8 @@ def my_config(binder):
7373    inject.configure(my_config) 
7474
7575""" 
76+ import  contextlib 
77+ 
7678from  inject ._version  import  __version__ 
7779
7880import  inspect 
@@ -156,7 +158,10 @@ def bind_to_constructor(self, cls: Binding, constructor: Constructor) -> 'Binder
156158        return  self 
157159
158160    def  bind_to_provider (self , cls : Binding , provider : Provider ) ->  'Binder' :
159-         """Bind a class to a callable instance provider executed for each injection.""" 
161+         """ 
162+         Bind a class to a callable instance provider executed for each injection. 
163+         A provider can be a normal function or a context manager. Both sync and async are supported. 
164+         """ 
160165        self ._check_class (cls )
161166        if  provider  is  None :
162167            raise  InjectorException ('Provider cannot be None, key=%s'  %  cls )
@@ -323,6 +328,35 @@ class _ParametersInjection(Generic[T]):
323328    def  __init__ (self , ** kwargs : Any ) ->  None :
324329        self ._params  =  kwargs 
325330
331+     @staticmethod  
332+     def  _aggregate_sync_stack (
333+             sync_stack : contextlib .ExitStack ,
334+             provided_params : frozenset [str ],
335+             kwargs : dict [str , Any ]
336+     ) ->  None :
337+         """Extracts context managers, aggregate them in an ExitStack and swap out the param value with results of 
338+         running __enter__(). The result is equivalent to using `with` multiple times """ 
339+         executed_kwargs  =  {
340+             param : sync_stack .enter_context (inst )
341+             for  param , inst  in  kwargs .items ()
342+             if  param  not  in provided_params  and  isinstance (inst , contextlib ._GeneratorContextManager )
343+         }
344+         kwargs .update (executed_kwargs )
345+ 
346+     @staticmethod  
347+     async  def  _aggregate_async_stack (
348+             async_stack : contextlib .AsyncExitStack ,
349+             provided_params : frozenset [str ],
350+             kwargs : dict [str , Any ]
351+     ) ->  None :
352+         """Similar to _aggregate_sync_stack, but for async context managers""" 
353+         executed_kwargs  =  {
354+             param : await  async_stack .enter_async_context (inst )
355+             for  param , inst  in  kwargs .items ()
356+             if  param  not  in provided_params  and  isinstance (inst , contextlib ._AsyncGeneratorContextManager )
357+         }
358+         kwargs .update (executed_kwargs )
359+ 
326360    def  __call__ (self , func : Callable [..., Union [Awaitable [T ], T ]]) ->  Callable [..., Union [Awaitable [T ], T ]]:
327361        if  sys .version_info .major  ==  2 :
328362            arg_names  =  inspect .getargspec (func ).args 
@@ -340,7 +374,11 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
340374                        kwargs [param ] =  instance (cls )
341375                async_func  =  cast (Callable [..., Awaitable [T ]], func )
342376                try :
343-                     return  await  async_func (* args , ** kwargs )
377+                     with  contextlib .ExitStack () as  sync_stack :
378+                         async  with  contextlib .AsyncExitStack () as  async_stack :
379+                             self ._aggregate_sync_stack (sync_stack , provided_params , kwargs )
380+                             await  self ._aggregate_async_stack (async_stack , provided_params , kwargs )
381+                             return  await  async_func (* args , ** kwargs )
344382                except  TypeError  as  previous_error :
345383                    raise  ConstructorTypeError (func , previous_error )
346384
@@ -355,7 +393,9 @@ def injection_wrapper(*args: Any, **kwargs: Any) -> T:
355393                    kwargs [param ] =  instance (cls )
356394            sync_func  =  cast (Callable [..., T ], func )
357395            try :
358-                 return  sync_func (* args , ** kwargs )
396+                 with  contextlib .ExitStack () as  sync_stack :
397+                     self ._aggregate_sync_stack (sync_stack , provided_params , kwargs )
398+                     return  sync_func (* args , ** kwargs )
359399            except  TypeError  as  previous_error :
360400                raise  ConstructorTypeError (func , previous_error )
361401        return  injection_wrapper 
0 commit comments