1313 runtime_checkable ,
1414)
1515
16- from injection ._core .common .asynchronous import Caller , create_semaphore
16+ from injection ._core .common .asynchronous import Caller
17+ from injection ._core .common .asynchronous import (
18+ create_semaphore as _create_async_semaphore ,
19+ )
1720from injection ._core .scope import Scope , get_active_scopes , get_scope
1821from injection .exceptions import InjectionError
1922
@@ -39,12 +42,8 @@ def get_instance(self) -> T:
3942
4043
4144@dataclass (repr = False , eq = False , frozen = True , slots = True )
42- class BaseInjectable [R , T ](Injectable [T ], ABC ):
43- factory : Caller [..., R ]
44-
45-
46- class SimpleInjectable [T ](BaseInjectable [T , T ]):
47- __slots__ = ()
45+ class SimpleInjectable [T ](Injectable [T ]):
46+ factory : Caller [..., T ]
4847
4948 async def aget_instance (self ) -> T :
5049 return await self .factory .acall ()
@@ -53,13 +52,13 @@ def get_instance(self) -> T:
5352 return self .factory .call ()
5453
5554
56- @ dataclass ( repr = False , eq = False , frozen = True , slots = True )
57- class CachedInjectable [ R , T ]( BaseInjectable [ R , T ], ABC ):
58- __semaphore : AsyncContextManager [ Any ] = field (
59- default_factory = partial ( create_semaphore , 1 ),
60- init = False ,
61- hash = False ,
62- )
55+ class CacheLogic [ T ]:
56+ __slots__ = ( "__semaphore" ,)
57+
58+ __semaphore : AsyncContextManager [ Any ]
59+
60+ def __init__ ( self ) -> None :
61+ self . __semaphore = _create_async_semaphore ( 1 )
6362
6463 async def aget_or_create [K ](
6564 self ,
@@ -90,32 +89,37 @@ def get_or_create[K](
9089 return instance
9190
9291
93- class SingletonInjectable [T ](CachedInjectable [T , T ]):
94- __slots__ = ("__dict__" ,)
92+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
93+ class SingletonInjectable [T ](Injectable [T ]):
94+ factory : Caller [..., T ]
95+ cache : MutableMapping [str , T ] = field (default_factory = dict )
96+ logic : CacheLogic [T ] = field (default_factory = CacheLogic )
9597
9698 __key : ClassVar [str ] = "$instance"
9799
98100 @property
99101 def is_locked (self ) -> bool :
100- return self .__key in self .__cache
101-
102- @property
103- def __cache (self ) -> MutableMapping [str , Any ]:
104- return self .__dict__
102+ return self .__key in self .cache
105103
106104 async def aget_instance (self ) -> T :
107- return await self .aget_or_create (self .__cache , self .__key , self .factory .acall )
105+ return await self .logic .aget_or_create (
106+ self .cache ,
107+ self .__key ,
108+ self .factory .acall ,
109+ )
108110
109111 def get_instance (self ) -> T :
110- return self .get_or_create (self .__cache , self .__key , self .factory .call )
112+ return self .logic . get_or_create (self .cache , self .__key , self .factory .call )
111113
112114 def unlock (self ) -> None :
113- self .__cache .pop (self .__key , None )
115+ self .cache .pop (self .__key , None )
114116
115117
116118@dataclass (repr = False , eq = False , frozen = True , slots = True )
117- class ScopedInjectable [R , T ](CachedInjectable [R , T ], ABC ):
119+ class ScopedInjectable [R , T ](Injectable [T ], ABC ):
120+ factory : Caller [..., R ]
118121 scope_name : str
122+ logic : CacheLogic [T ] = field (default_factory = CacheLogic )
119123
120124 @property
121125 def is_locked (self ) -> bool :
@@ -130,26 +134,26 @@ def build(self, scope: Scope) -> T:
130134 raise NotImplementedError
131135
132136 async def aget_instance (self ) -> T :
133- scope = self .get_scope ()
137+ scope = self .__get_scope ()
134138 factory = partial (self .abuild , scope )
135- return await self .aget_or_create (scope .cache , self , factory )
139+ return await self .logic . aget_or_create (scope .cache , self , factory )
136140
137141 def get_instance (self ) -> T :
138- scope = self .get_scope ()
142+ scope = self .__get_scope ()
139143 factory = partial (self .build , scope )
140- return self .get_or_create (scope .cache , self , factory )
141-
142- def get_scope (self ) -> Scope :
143- return get_scope (self .scope_name )
144+ return self .logic .get_or_create (scope .cache , self , factory )
144145
145146 def setdefault (self , instance : T ) -> T :
146- scope = self .get_scope ()
147- return self .get_or_create (scope .cache , self , lambda : instance )
147+ scope = self .__get_scope ()
148+ return self .logic . get_or_create (scope .cache , self , lambda : instance )
148149
149150 def unlock (self ) -> None :
150151 if self .is_locked :
151152 raise RuntimeError (f"To unlock, close the `{ self .scope_name } ` scope." )
152153
154+ def __get_scope (self ) -> Scope :
155+ return get_scope (self .scope_name )
156+
153157
154158class AsyncCMScopedInjectable [T ](ScopedInjectable [AsyncContextManager [T ], T ]):
155159 __slots__ = ()
0 commit comments