33import asyncio
44import time
55from typing import Optional
6+ from urllib .parse import urljoin
67
78import httpx
89from authlib .oauth2 .rfc6749 import OAuth2Token
1213
1314
1415class AsyncOAuthTokenManager :
16+ """
17+ Manages OAuth tokens for asynchronous HTTP clients.
18+
19+ :param base_url: Base URL of the Atlan tenant.
20+ :param client_id: OAuth client ID.
21+ :param client_secret: OAuth client secret.
22+ :param http_client: Optional asynchronous HTTP client to use.
23+ :param connect_timeout: Timeout for establishing connections.
24+ :param read_timeout: Timeout for reading data.
25+ :param write_timeout: Timeout for writing data.
26+ :param pool_timeout: Timeout for acquiring a connection from the pool.
27+ """
28+
1529 def __init__ (
1630 self ,
1731 base_url : str ,
@@ -20,6 +34,8 @@ def __init__(
2034 http_client : Optional [httpx .AsyncClient ] = None ,
2135 connect_timeout : float = 30.0 ,
2236 read_timeout : float = 900.0 ,
37+ write_timeout : float = 30.0 ,
38+ pool_timeout : float = 30.0 ,
2339 ):
2440 self .base_url = base_url
2541 self .client_id = client_id
@@ -28,13 +44,19 @@ def __init__(
2844 self ._lock = asyncio .Lock ()
2945 self ._http_client = http_client or httpx .AsyncClient (
3046 timeout = httpx .Timeout (
31- connect = connect_timeout , read = read_timeout , write = 30.0 , pool = 30.0
47+ connect = connect_timeout ,
48+ read = read_timeout ,
49+ write = write_timeout ,
50+ pool = pool_timeout ,
3251 )
3352 )
3453 self ._token : Optional [OAuth2Token ] = None
3554 self ._owns_client = http_client is None
3655
3756 async def get_token (self ) -> str :
57+ """
58+ Retrieves a valid OAuth token, refreshing it if necessary.
59+ """
3860 async with self ._lock :
3961 if self ._token and not self ._token .is_expired ():
4062 return str (self ._token ["access_token" ])
@@ -73,18 +95,25 @@ async def get_token(self) -> str:
7395 return access_token
7496
7597 async def invalidate_token (self ):
98+ """
99+ Invalidates the current OAuth token.
100+ """
76101 async with self ._lock :
77102 self ._token = None
78103
79104 def _create_path (self , api : API ):
80- from urllib .parse import urljoin
81-
105+ """
106+ Creates the full URL for the given API endpoint.
107+ """
82108 if self .base_url == "INTERNAL" :
83109 return urljoin (api .endpoint .service , api .path )
84110 else :
85111 base_with_prefix = urljoin (self .base_url , api .endpoint .prefix )
86112 return urljoin (base_with_prefix , api .path )
87113
88114 async def aclose (self ):
115+ """
116+ Closes the underlying HTTP client if owned by this manager.
117+ """
89118 if self ._owns_client :
90119 await self ._http_client .aclose ()
0 commit comments