Skip to content

Commit d88d2a3

Browse files
authored
refactoring: ♻️ injectables.py
1 parent 3709268 commit d88d2a3

2 files changed

Lines changed: 94 additions & 90 deletions

File tree

injection/_core/injectables.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
runtime_checkable,
1414
)
1515

16-
from injection._core.common.asynchronous import Caller, create_semaphore
16+
from injection._core.common.asynchronous import Caller
17+
from injection._core.common.asynchronous import (
18+
create_semaphore as _create_async_semaphore,
19+
)
1720
from injection._core.scope import Scope, get_active_scopes, get_scope
1821
from injection.exceptions import InjectionError
1922

@@ -39,12 +42,8 @@ def get_instance(self) -> T:
3942

4043

4144
@dataclass(repr=False, eq=False, frozen=True, slots=True)
42-
class BaseInjectable[R, T](Injectable[T], ABC):
43-
factory: Caller[..., R]
44-
45-
46-
class SimpleInjectable[T](BaseInjectable[T, T]):
47-
__slots__ = ()
45+
class SimpleInjectable[T](Injectable[T]):
46+
factory: Caller[..., T]
4847

4948
async def aget_instance(self) -> T:
5049
return await self.factory.acall()
@@ -53,13 +52,13 @@ def get_instance(self) -> T:
5352
return self.factory.call()
5453

5554

56-
@dataclass(repr=False, eq=False, frozen=True, slots=True)
57-
class CachedInjectable[R, T](BaseInjectable[R, T], ABC):
58-
__semaphore: AsyncContextManager[Any] = field(
59-
default_factory=partial(create_semaphore, 1),
60-
init=False,
61-
hash=False,
62-
)
55+
class CacheLogic[T]:
56+
__slots__ = ("__semaphore",)
57+
58+
__semaphore: AsyncContextManager[Any]
59+
60+
def __init__(self) -> None:
61+
self.__semaphore = _create_async_semaphore(1)
6362

6463
async def aget_or_create[K](
6564
self,
@@ -90,32 +89,37 @@ def get_or_create[K](
9089
return instance
9190

9291

93-
class SingletonInjectable[T](CachedInjectable[T, T]):
94-
__slots__ = ("__dict__",)
92+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
93+
class SingletonInjectable[T](Injectable[T]):
94+
factory: Caller[..., T]
95+
cache: MutableMapping[str, T] = field(default_factory=dict)
96+
logic: CacheLogic[T] = field(default_factory=CacheLogic)
9597

9698
__key: ClassVar[str] = "$instance"
9799

98100
@property
99101
def is_locked(self) -> bool:
100-
return self.__key in self.__cache
101-
102-
@property
103-
def __cache(self) -> MutableMapping[str, Any]:
104-
return self.__dict__
102+
return self.__key in self.cache
105103

106104
async def aget_instance(self) -> T:
107-
return await self.aget_or_create(self.__cache, self.__key, self.factory.acall)
105+
return await self.logic.aget_or_create(
106+
self.cache,
107+
self.__key,
108+
self.factory.acall,
109+
)
108110

109111
def get_instance(self) -> T:
110-
return self.get_or_create(self.__cache, self.__key, self.factory.call)
112+
return self.logic.get_or_create(self.cache, self.__key, self.factory.call)
111113

112114
def unlock(self) -> None:
113-
self.__cache.pop(self.__key, None)
115+
self.cache.pop(self.__key, None)
114116

115117

116118
@dataclass(repr=False, eq=False, frozen=True, slots=True)
117-
class ScopedInjectable[R, T](CachedInjectable[R, T], ABC):
119+
class ScopedInjectable[R, T](Injectable[T], ABC):
120+
factory: Caller[..., R]
118121
scope_name: str
122+
logic: CacheLogic[T] = field(default_factory=CacheLogic)
119123

120124
@property
121125
def is_locked(self) -> bool:
@@ -130,26 +134,26 @@ def build(self, scope: Scope) -> T:
130134
raise NotImplementedError
131135

132136
async def aget_instance(self) -> T:
133-
scope = self.get_scope()
137+
scope = self.__get_scope()
134138
factory = partial(self.abuild, scope)
135-
return await self.aget_or_create(scope.cache, self, factory)
139+
return await self.logic.aget_or_create(scope.cache, self, factory)
136140

137141
def get_instance(self) -> T:
138-
scope = self.get_scope()
142+
scope = self.__get_scope()
139143
factory = partial(self.build, scope)
140-
return self.get_or_create(scope.cache, self, factory)
141-
142-
def get_scope(self) -> Scope:
143-
return get_scope(self.scope_name)
144+
return self.logic.get_or_create(scope.cache, self, factory)
144145

145146
def setdefault(self, instance: T) -> T:
146-
scope = self.get_scope()
147-
return self.get_or_create(scope.cache, self, lambda: instance)
147+
scope = self.__get_scope()
148+
return self.logic.get_or_create(scope.cache, self, lambda: instance)
148149

149150
def unlock(self) -> None:
150151
if self.is_locked:
151152
raise RuntimeError(f"To unlock, close the `{self.scope_name}` scope.")
152153

154+
def __get_scope(self) -> Scope:
155+
return get_scope(self.scope_name)
156+
153157

154158
class AsyncCMScopedInjectable[T](ScopedInjectable[AsyncContextManager[T], T]):
155159
__slots__ = ()

0 commit comments

Comments
 (0)