@@ -150,7 +150,12 @@ def get_session():
150150
151151
152152class AtlanClient (BaseSettings ):
153- _current_client_tls : ClassVar [local ] = local () # Thread-local storage (TLS)
153+ _current_client_ctx : ClassVar [ContextVar ] = ContextVar (
154+ "_current_client_ctx" , default = None
155+ )
156+ _401_has_retried_ctx : ClassVar [ContextVar ] = ContextVar (
157+ "_401_has_retried_ctx" , default = False
158+ )
154159 base_url : Union [Literal ["INTERNAL" ], HttpUrl ]
155160 api_key : str
156161 connect_timeout : float = 30.0 # 30 secs
@@ -190,37 +195,24 @@ class AtlanClient(BaseSettings):
190195 class Config :
191196 env_prefix = "atlan_"
192197
193- @classmethod
194- def init_for_multithreading (cls , client : AtlanClient ):
195- """
196- Prepares the given client for use in multi-threaded environments.
197-
198- This sets the thread-local context and resets internal retry flags
199- to ensure correct behavior when using the client across multiple threads.
200- """
201- AtlanClient .set_current_client (client )
202- client ._401_tls .has_retried = False
203-
204198 @classmethod
205199 def set_current_client (cls , client : AtlanClient ):
206200 """
207- Sets the current client to thread-local storage (TLS)
201+ Sets the current client context
208202 """
209203 if not isinstance (client , AtlanClient ):
210204 raise ErrorCode .MISSING_ATLAN_CLIENT .exception_with_parameters ()
211- cls ._current_client_tls . client = client
205+ cls ._current_client_ctx . set ( client )
212206
213207 @classmethod
214208 def get_current_client (cls ) -> AtlanClient :
215209 """
216- Retrieves the current client
210+ Retrieves the current client context
217211 """
218- if (
219- not hasattr (cls ._current_client_tls , "client" )
220- or not cls ._current_client_tls .client
221- ):
212+ client = cls ._current_client_ctx .get ()
213+ if not client :
222214 raise ErrorCode .NO_ATLAN_CLIENT_AVAILABLE .exception_with_parameters ()
223- return cls . _current_client_tls . client
215+ return client
224216
225217 def __init__ (self , ** data ):
226218 super ().__init__ (** data )
@@ -233,7 +225,8 @@ def __init__(self, **data):
233225 adapter = HTTPAdapter (max_retries = self .retry )
234226 session .mount (HTTPS_PREFIX , adapter )
235227 session .mount (HTTP_PREFIX , adapter )
236- AtlanClient .init_for_multithreading (self )
228+ AtlanClient .set_current_client (self )
229+ self ._401_has_retried_ctx .set (False )
237230
238231 @property
239232 def admin (self ) -> AdminClient :
@@ -482,11 +475,11 @@ def _call_api_internal(
482475 # - But if the next response is != 401 (e.g. 403), and `has_retried = True`,
483476 # then we should reset `has_retried = False` so that future 401s can trigger a new token refresh.
484477 if (
485- self ._401_tls . has_retried
478+ self ._401_has_retried_ctx . get ()
486479 and response .status_code
487480 != ErrorCode .AUTHENTICATION_PASSTHROUGH .http_error_code
488481 ):
489- self ._401_tls . has_retried = False
482+ self ._401_has_retried_ctx . set ( False )
490483
491484 if response .status_code == api .expected_status :
492485 try :
@@ -571,7 +564,7 @@ def _call_api_internal(
571564 # on authentication failure (token may have expired)
572565 if (
573566 self ._user_id
574- and not self ._401_tls . has_retried
567+ and not self ._401_has_retried_ctx . get ()
575568 and response .status_code
576569 == ErrorCode .AUTHENTICATION_PASSTHROUGH .http_error_code
577570 ):
@@ -732,7 +725,7 @@ def _handle_401_token_refresh(
732725 )
733726 raise
734727 self .api_key = new_token
735- self ._401_tls . has_retried = True
728+ self ._401_has_retried_ctx . set ( True )
736729 params ["headers" ]["authorization" ] = f"Bearer { self .api_key } "
737730 self ._request_params ["headers" ]["authorization" ] = f"Bearer { self .api_key } "
738731 LOGGER .debug ("Successfully completed 401 automatic token refresh." )
0 commit comments