4848from pyatlan .client .file import FileClient
4949from pyatlan .client .group import GroupClient
5050from pyatlan .client .impersonate import ImpersonationClient
51+ from pyatlan .client .oauth import OAuthTokenManager
5152from pyatlan .client .open_lineage import OpenLineageClient
5253from pyatlan .client .query import QueryClient
5354from pyatlan .client .role import RoleClient
@@ -127,7 +128,9 @@ def log_response(response, *args, **kwargs):
127128
128129class AtlanClient (BaseSettings ):
129130 base_url : Union [Literal ["INTERNAL" ], HttpUrl ]
130- api_key : str
131+ api_key : Optional [str ] = None
132+ oauth_client_id : Optional [str ] = None
133+ oauth_client_secret : Optional [str ] = None
131134 connect_timeout : float = 30.0 # 30 secs
132135 read_timeout : float = 900.0 # 15 mins
133136 retry : Retry = DEFAULT_RETRY
@@ -137,6 +140,7 @@ class AtlanClient(BaseSettings):
137140 _session : httpx .Client = PrivateAttr ()
138141 _request_params : dict = PrivateAttr ()
139142 _user_id : Optional [str ] = PrivateAttr (default = None )
143+ _oauth_token_manager : Optional [Any ] = PrivateAttr (default = None )
140144 _workflow_client : Optional [WorkflowClient ] = PrivateAttr (default = None )
141145 _credential_client : Optional [CredentialClient ] = PrivateAttr (default = None )
142146 _admin_client : Optional [AdminClient ] = PrivateAttr (default = None )
@@ -172,11 +176,33 @@ class Config:
172176
173177 def __init__ (self , ** data ):
174178 super ().__init__ (** data )
175- self ._request_params = (
176- {"headers" : {"authorization" : f"Bearer { self .api_key } " }}
177- if self .api_key and self .api_key .strip ()
178- else {"headers" : {}}
179- )
179+
180+ if self .oauth_client_id and self .oauth_client_secret and self .api_key is None :
181+ LOGGER .debug ("API KEY not provided. Using OAuth flow for authentication" )
182+
183+ final_base_url = self .base_url or os .environ .get (
184+ "ATLAN_BASE_URL" , "INTERNAL"
185+ )
186+ final_oauth_client_id = self .oauth_client_id or os .environ .get (
187+ "ATLAN_OAUTH_CLIENT_ID"
188+ )
189+ final_oauth_client_secret = self .oauth_client_secret or os .environ .get (
190+ "ATLAN_OAUTH_CLIENT_SECRET"
191+ )
192+ self ._oauth_token_manager = OAuthTokenManager (
193+ base_url = final_base_url ,
194+ client_id = final_oauth_client_id ,
195+ client_secret = final_oauth_client_secret ,
196+ connect_timeout = self .connect_timeout ,
197+ read_timeout = self .read_timeout ,
198+ )
199+ self ._request_params = {"headers" : {}}
200+ else :
201+ self ._request_params = (
202+ {"headers" : {"authorization" : f"Bearer { self .api_key } " }}
203+ if self .api_key and self .api_key .strip ()
204+ else {"headers" : {}}
205+ )
180206
181207 # Build proxy/SSL configuration with environment variable fallback
182208 transport_kwargs = self ._build_transport_proxy_config (data )
@@ -691,7 +717,7 @@ def _call_api_internal(
691717 # Retry with impersonation (if _user_id is present)
692718 # on authentication failure (token may have expired)
693719 if (
694- self ._user_id
720+ ( self ._user_id or self . _oauth_token_manager )
695721 and not self ._401_has_retried .get ()
696722 and response .status_code
697723 == ErrorCode .AUTHENTICATION_PASSTHROUGH .http_error_code
@@ -813,6 +839,9 @@ def _create_params(
813839 self , api : API , query_params , request_obj , exclude_unset : bool = True
814840 ):
815841 params = copy .deepcopy (self ._request_params )
842+ if self ._oauth_token_manager :
843+ token = self ._oauth_token_manager .get_token ()
844+ params ["headers" ]["authorization" ] = f"Bearer { token } "
816845 params ["headers" ]["Accept" ] = api .consumes
817846 params ["headers" ]["content-type" ] = api .produces
818847 if query_params is not None :
@@ -846,6 +875,21 @@ def _handle_401_token_refresh(
846875
847876 returns: HTTP response received after retrying the request with the refreshed token
848877 """
878+ if self ._oauth_token_manager :
879+ self ._oauth_token_manager .invalidate_token ()
880+ token = self ._oauth_token_manager .get_token ()
881+ params ["headers" ]["authorization" ] = f"Bearer { token } "
882+ self ._401_has_retried .set (True )
883+ LOGGER .debug ("Successfully refreshed OAuth token after 401." )
884+ return self ._call_api_internal (
885+ api ,
886+ path ,
887+ params ,
888+ binary_data = binary_data ,
889+ download_file_path = download_file_path ,
890+ text_response = text_response ,
891+ )
892+
849893 try :
850894 new_token = self .impersonate .user (user_id = self ._user_id )
851895 except Exception as e :
0 commit comments