Skip to content

Commit 1f1e8d8

Browse files
authored
refactor: Instantiation depends on the module requesting the instance
1 parent 2d699e4 commit 1f1e8d8

8 files changed

Lines changed: 368 additions & 278 deletions

File tree

injection/__init__.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ from ._core.asfunction import AsFunctionWrappedType as _AsFunctionWrappedType
99
from ._core.common.invertible import Invertible as _Invertible
1010
from ._core.common.type import InputType as _InputType
1111
from ._core.common.type import TypeInfo as _TypeInfo
12-
from ._core.module import InjectableFactory as _InjectableFactory
13-
from ._core.module import ModeStr, PriorityStr
12+
from ._core.locator import InjectableFactory as _InjectableFactory
13+
from ._core.locator import ModeStr
14+
from ._core.module import PriorityStr
1415
from ._core.scope import ScopeKindStr
1516

1617
type _Decorator[T] = Callable[[T], T]

injection/_core/asfunction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def decorator(wp: AsFunctionWrappedType[P, T]) -> Callable[P, T]:
2323
factory: Caller[..., Callable[P, T]] = module.make_injected_function(
2424
wp,
2525
threadsafe=threadsafe,
26-
).__inject_metadata__
26+
).__injection_metadata__
2727

2828
wrapper: Callable[P, T]
2929

injection/_core/common/asynchronous.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,17 @@ async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
6464

6565
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
6666
return self.callable(*args, **kwargs)
67+
68+
69+
@runtime_checkable
70+
class HiddenCaller[**P, T](Protocol):
71+
__slots__ = ()
72+
73+
@property
74+
@abstractmethod
75+
def __injection_hidden_caller__(self) -> Caller[P, T]:
76+
raise NotImplementedError
77+
78+
@abstractmethod
79+
def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
80+
raise NotImplementedError

injection/_core/locator.py

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Awaitable, Callable, Collection, Iterable, Iterator
5+
from contextlib import suppress
6+
from dataclasses import dataclass, field
7+
from enum import StrEnum
8+
from inspect import iscoroutinefunction
9+
from typing import (
10+
Any,
11+
ContextManager,
12+
Literal,
13+
NamedTuple,
14+
Protocol,
15+
Self,
16+
runtime_checkable,
17+
)
18+
from weakref import WeakKeyDictionary
19+
20+
from injection._core.common.asynchronous import (
21+
AsyncCaller,
22+
Caller,
23+
HiddenCaller,
24+
SyncCaller,
25+
)
26+
from injection._core.common.event import Event, EventChannel, EventListener
27+
from injection._core.common.type import InputType
28+
from injection._core.injectables import Injectable
29+
from injection.exceptions import NoInjectable, SkipInjectable
30+
31+
32+
@dataclass(frozen=True, slots=True)
33+
class LocatorEvent(Event, ABC):
34+
locator: Locator
35+
36+
37+
@dataclass(frozen=True, slots=True)
38+
class LocatorDependenciesUpdated[T](LocatorEvent):
39+
classes: Collection[InputType[T]]
40+
mode: Mode
41+
42+
def __str__(self) -> str:
43+
length = len(self.classes)
44+
formatted_types = ", ".join(f"`{cls}`" for cls in self.classes)
45+
return (
46+
f"{length} dependenc{'ies' if length > 1 else 'y'} have been "
47+
f"updated{f': {formatted_types}' if formatted_types else ''}."
48+
)
49+
50+
51+
type InjectableFactory[T] = Callable[[Caller[..., T]], Injectable[T]]
52+
53+
type Recipe[**P, T] = Callable[P, T] | Callable[P, Awaitable[T]]
54+
55+
56+
class InjectionProvider(ABC):
57+
__slots__ = ("__weakref__",)
58+
59+
@abstractmethod
60+
def make_injected_function[**P, T](
61+
self,
62+
wrapped: Callable[P, T],
63+
/,
64+
) -> Callable[P, T]:
65+
raise NotImplementedError
66+
67+
68+
@runtime_checkable
69+
class InjectableBroker[T](Protocol):
70+
__slots__ = ()
71+
72+
@abstractmethod
73+
def get(self, provider: InjectionProvider) -> Injectable[T] | None:
74+
raise NotImplementedError
75+
76+
@abstractmethod
77+
def request(self, provider: InjectionProvider) -> Injectable[T]:
78+
raise NotImplementedError
79+
80+
81+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
82+
class DynamicInjectableBroker[T](InjectableBroker[T]):
83+
injectable_factory: InjectableFactory[T]
84+
recipe: Recipe[..., T]
85+
cache: WeakKeyDictionary[InjectionProvider, Injectable[T]] = field(
86+
default_factory=WeakKeyDictionary,
87+
init=False,
88+
)
89+
90+
def get(self, provider: InjectionProvider) -> Injectable[T] | None:
91+
return self.cache.get(provider)
92+
93+
def request(self, provider: InjectionProvider) -> Injectable[T]:
94+
with suppress(KeyError):
95+
return self.cache[provider]
96+
97+
injectable = _make_injectable(
98+
self.injectable_factory,
99+
provider.make_injected_function(self.recipe), # type: ignore[misc]
100+
)
101+
self.cache[provider] = injectable
102+
return injectable
103+
104+
105+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
106+
class StaticInjectableBroker[T](InjectableBroker[T]):
107+
value: Injectable[T]
108+
109+
def get(self, provider: InjectionProvider) -> Injectable[T] | None:
110+
return self.value
111+
112+
def request(self, provider: InjectionProvider) -> Injectable[T]:
113+
return self.value
114+
115+
@classmethod
116+
def from_factory(
117+
cls,
118+
injectable_factory: InjectableFactory[T],
119+
recipe: Recipe[..., T],
120+
) -> Self:
121+
return cls(_make_injectable(injectable_factory, recipe))
122+
123+
124+
class Mode(StrEnum):
125+
FALLBACK = "fallback"
126+
NORMAL = "normal"
127+
OVERRIDE = "override"
128+
129+
@property
130+
def rank(self) -> int:
131+
return tuple(type(self)).index(self)
132+
133+
@classmethod
134+
def get_default(cls) -> Mode:
135+
return cls.NORMAL
136+
137+
138+
type ModeStr = Literal["fallback", "normal", "override"]
139+
140+
141+
class Record[T](NamedTuple):
142+
broker: InjectableBroker[T]
143+
mode: Mode
144+
145+
146+
@dataclass(repr=False, eq=False, frozen=True, kw_only=True, slots=True)
147+
class Updater[T]:
148+
classes: Collection[InputType[T]]
149+
broker: InjectableBroker[T]
150+
mode: Mode
151+
152+
def make_record(self) -> Record[T]:
153+
return Record(self.broker, self.mode)
154+
155+
156+
@dataclass(repr=False, frozen=True, slots=True)
157+
class Locator:
158+
__records: dict[InputType[Any], Record[Any]] = field(
159+
default_factory=dict,
160+
init=False,
161+
)
162+
__channel: EventChannel = field(
163+
default_factory=EventChannel,
164+
init=False,
165+
)
166+
167+
def __contains__(self, cls: InputType[Any], /) -> bool:
168+
return cls in self.__records
169+
170+
@property
171+
def __brokers(self) -> frozenset[InjectableBroker[Any]]:
172+
return frozenset(record.broker for record in self.__records.values())
173+
174+
def is_locked(self, provider: InjectionProvider) -> bool:
175+
return any(
176+
injectable.is_locked for injectable in self.__iter_injectables(provider)
177+
)
178+
179+
def request[T](
180+
self,
181+
cls: InputType[T],
182+
/,
183+
provider: InjectionProvider,
184+
) -> Injectable[T]:
185+
try:
186+
record = self.__records[cls]
187+
except KeyError as exc:
188+
raise NoInjectable(cls) from exc
189+
else:
190+
return record.broker.request(provider)
191+
192+
def update[T](self, updater: Updater[T]) -> Self:
193+
record = updater.make_record()
194+
records = dict(self.__prepare_for_updating(updater.classes, record))
195+
196+
if records:
197+
event = LocatorDependenciesUpdated(self, records.keys(), record.mode)
198+
199+
with self.dispatch(event):
200+
self.__records.update(records)
201+
202+
return self
203+
204+
def unlock(self, provider: InjectionProvider) -> None:
205+
for injectable in self.__iter_injectables(provider):
206+
injectable.unlock()
207+
208+
async def all_ready(self, provider: InjectionProvider) -> None:
209+
for injectable in self.__iter_injectables(provider):
210+
if injectable.is_locked:
211+
continue
212+
213+
with suppress(SkipInjectable):
214+
await injectable.aget_instance()
215+
216+
def add_listener(self, listener: EventListener) -> Self:
217+
self.__channel.add_listener(listener)
218+
return self
219+
220+
def dispatch(self, event: Event) -> ContextManager[None]:
221+
return self.__channel.dispatch(event)
222+
223+
def __iter_injectables(
224+
self,
225+
provider: InjectionProvider,
226+
) -> Iterator[Injectable[Any]]:
227+
for broker in self.__brokers:
228+
injectable = broker.get(provider)
229+
230+
if injectable is None:
231+
continue
232+
233+
yield injectable
234+
235+
def __prepare_for_updating[T](
236+
self,
237+
classes: Iterable[InputType[T]],
238+
record: Record[T],
239+
) -> Iterator[tuple[InputType[T], Record[T]]]:
240+
for cls in classes:
241+
try:
242+
existing = self.__records[cls]
243+
except KeyError:
244+
...
245+
else:
246+
if not self.__keep_new_record(record, existing, cls):
247+
continue
248+
249+
yield cls, record
250+
251+
@staticmethod
252+
def __keep_new_record[T](
253+
new: Record[T],
254+
existing: Record[T],
255+
cls: InputType[T],
256+
) -> bool:
257+
new_mode, existing_mode = new.mode, existing.mode
258+
259+
if new_mode == Mode.OVERRIDE:
260+
return True
261+
262+
elif new_mode == existing_mode:
263+
raise RuntimeError(f"An injectable already exists for the class `{cls}`.")
264+
265+
return new_mode.rank > existing_mode.rank
266+
267+
268+
def _extract_caller[**P, T](
269+
function: Callable[P, T] | Callable[P, Awaitable[T]],
270+
) -> Caller[P, T]:
271+
if iscoroutinefunction(function):
272+
return AsyncCaller(function)
273+
274+
elif isinstance(function, HiddenCaller):
275+
return function.__injection_hidden_caller__
276+
277+
return SyncCaller(function) # type: ignore[arg-type]
278+
279+
280+
def _make_injectable[T](
281+
injectable_factory: InjectableFactory[T],
282+
recipe: Recipe[..., T],
283+
) -> Injectable[T]:
284+
return injectable_factory(_extract_caller(recipe))

0 commit comments

Comments
 (0)