Source code for component_injector

import contextvars
import functools
import inspect
from dataclasses import dataclass
from types import TracebackType
from typing import (
    Any,
    Callable,
    Dict,
    Optional,
    Type,
    TypeVar,
    cast,
    Set,
    List,
    Iterable,
    Awaitable,
    Tuple,
)

__all__ = ["Injector"]


T = TypeVar("T")
ComponentMap = Dict[Type[T], T]


@dataclass
class Factory:
    factory: Optional[Callable[[], Any]]
    resolved_types: Set[Type]
    context: Optional["Context"] = None


FactoryMap = Dict[Type[T], Factory]

UNSET = object()


class ComponentStack:
    __slots__ = ["_layers", "layer"]

    def __init__(self, layers: Optional[List[ComponentMap]] = None) -> None:
        if layers is None:
            layers = [{}]
        self._layers = layers
        self.layer = layers[0]

    def stack(self) -> "ComponentStack":
        return ComponentStack([{}, *self._layers])

    def __getitem__(self, key: Type[T]) -> T:
        for layer in self._layers:
            if key in layer:
                value = layer[key]
                if value is UNSET:
                    raise KeyError(key)
                return cast(T, value)
        raise KeyError(key)

    def __setitem__(self, key: Type[T], value: T) -> None:
        self.layer[key] = value

    def __delitem__(self, key: Type[T]) -> None:
        self.layer[key] = UNSET

    def update(self, values: ComponentMap) -> None:
        self.layer.update(values)


@dataclass
class ContextData:
    factories: FactoryMap
    components: ComponentStack

    def stack(self) -> "ContextData":
        return ContextData(self.factories.copy(), self.components.stack())


class Context:
    __slots__ = ["_current_context", "_data", "_tokens"]

    _data: ContextData
    _tokens: List[Any]

    def __init__(self, other: Optional["Context"] = None) -> None:
        if other is None:
            self._current_context = contextvars.ContextVar("Context", default=self)
            self._data = ContextData({}, ComponentStack())
        else:
            self._current_context = other._current_context
            self._data = self.current._data.stack()
        self._tokens = []

    def __enter__(self) -> "Context":
        self._tokens.append(self._current_context.set(self))
        return self

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[Exception],
        traceback: Optional[TracebackType],
    ) -> None:
        self._current_context.reset(self._tokens.pop())

    @property
    def current(self) -> "Context":
        return self._current_context.get()

    @property
    def components(self) -> ComponentStack:
        return self.current._data.components

    @property
    def factories(self) -> FactoryMap:
        return self.current._data.factories


[docs]class Injector: """ Provides a basic injector namespace. It's common to use one injector per project. """ __slots__ = ["_context"] def __init__(self) -> None: self._context = Context() def _register_type_factory( self, type_: Type[T], factory_function: Optional[Callable[[], T]], *, bases: bool = True, overwrite_bases: bool = True, persistent: bool = True, ) -> Factory: factories: FactoryMap = self._context.factories components: ComponentStack = self._context.components if persistent: factory = Factory(factory_function, {type_}, self._context.current) else: factory = Factory(factory_function, {type_}) factories[type_] = factory if bases: types = type_.mro() for type_ in types: apply = overwrite_bases or type_ not in factories if inspect.isclass(type_) and apply: factory.resolved_types.add(type_) factories[type_] = factory if overwrite_bases: del components[type_] return factory
[docs] def register_factory( self, factory: Callable[[], Any], *, bases: bool = True, overwrite_bases: bool = True, persistent: bool = False, ) -> None: """ Register a new factory function with the injector. Not that the factory function's return type annotation should be set to the type of the component you want to inject. :param factory: The factory function. Will be called without arguments and should return the instantiated component. If the factory returns an Awaitable it can only used to inject into coroutine functions. :param bases: Besides registering the exact component type, also register for all of the component's base classes. Defaults to `True`. :param overwrite_bases: If any of the component's base classes are already registered with the injector, overwrite those registrations. Defaults to `True`. :param persistent: When materializing the component using the factory, insert the component into the scope where the factory is registered instead of the current scope. Defaults to `False`. """ if inspect.isclass(factory): type_ = cast(Type[Any], factory) else: type_ = inspect.signature(factory).return_annotation assert ( type_ is not inspect.Signature.empty ), "Please add a return type annotation to your factory function." self._register_type_factory( type_, factory, bases=bases, overwrite_bases=overwrite_bases, persistent=persistent, )
def _get_factory_context(self, factory: Factory) -> Context: if factory.context: return factory.context else: return self._context.current
[docs] def register( self, component: Any, *, bases: bool = True, overwrite_bases: bool = True ) -> None: """ Register a new component with the injector. :param component: The component to register with the injector. :param bases: Besides registering the exact component type, also register for all of the component's base classes. Defaults to `True`. :param overwrite_bases: If any of the component's base classes are already registered with the injector, overwrite those registrations. Defaults to `True`. """ factory = self._register_type_factory( type(component), None, bases=bases, overwrite_bases=overwrite_bases ) with self._get_factory_context(factory) as context: context.components.update( {type_: component for type_ in factory.resolved_types} )
[docs] def get_component(self, type_: Type[T]) -> T: """ Get a component from the injector's current scope. Materialize it using a factory if necessary. Note that it is an error to use this function to get a component that has a factory that returns an `Awaitable`. :param type_: The type of the component to return. :return: The materialized component. """ components = self._context.components try: return components[type_] except KeyError: pass factory = self._context.factories[type_] assert factory.factory is not None with self._get_factory_context(factory) as context: component = factory.factory() assert not inspect.isawaitable( component ), "Using an awaitable factory in synchronous code." context.components.update( {type_: component for type_ in factory.resolved_types} ) return cast(T, component)
[docs] async def get_component_async(self, type_: Type[T]) -> T: """ Get a component from the injector's current scope. Materialize it using a factory if necessary. Use this method if the component's factory function returns an `Awaitable`. :param type_: The type of the component to return. :return: The materialized component. """ components = self._context.components try: return components[type_] except KeyError: pass factory = self._context.factories[type_] assert factory.factory is not None with self._get_factory_context(factory) as context: component_or_awaitable = factory.factory() if inspect.isawaitable(component_or_awaitable): component = await cast(Awaitable[T], component_or_awaitable) else: component = cast(T, component_or_awaitable) context.components.update( {type_: component for type_ in factory.resolved_types} ) return component
[docs] def scope(self) -> Context: """ Return a context manager that you can use to enter a new scpoe. When leaving the scope, any components or factories added to the injector will be forgotten. :return: The scope context object. You can use this to re-enter this scope at a later time if needed. """ return Context(self._context)
[docs] def inject(self, f: Callable[..., T]) -> Callable[..., T]: """ This decorator will connect the injector to a function or method. When the resulting function is called, the provided arguments will be checked against the function's signature and any missing arguments the injector has a component or factory those arguments will be filled in. :param f: The function or method to inject components into. :return: The decorated function. """ sig = inspect.signature(f) def bind_arguments( args: Iterable[Any], kwargs: Dict[str, Any] ) -> Tuple[inspect.BoundArguments, Dict[str, Any]]: factories = self._context.factories bound = sig.bind_partial(*args, **kwargs) components = {} for name, param in sig.parameters.items(): if ( name in bound.arguments or param.annotation is inspect.Parameter.empty ): continue if param.annotation in factories: components[name] = param.annotation return bound, components @functools.wraps(f) def wrapper(*args: Any, **kwargs: Any) -> T: bound, bind_components = bind_arguments(args, kwargs) bound.arguments.update( { name: self.get_component(type_) for name, type_ in bind_components.items() } ) bound.apply_defaults() return f(*bound.args, **bound.kwargs) @functools.wraps(f) async def async_wrapper(*args: Any, **kwargs: Any) -> T: bound, bind_components = bind_arguments(args, kwargs) bound.arguments.update( { name: await self.get_component_async(type_) for name, type_ in bind_components.items() } ) bound.apply_defaults() return await cast(Awaitable[T], f(*bound.args, **bound.kwargs)) if inspect.iscoroutinefunction(f): return async_wrapper else: return wrapper