Skip to content

Commit 6bd8de7

Browse files
authored
Merge pull request #767 from atlanhq/APP-8642
APP-8642 : Added support for OAuth client to pyatlan
2 parents a31cfb0 + fb69b17 commit 6bd8de7

12 files changed

Lines changed: 1942 additions & 14 deletions

File tree

.env.example

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
ATLAN_BASE_URL=your_tenant_base_url
2+
3+
#API KEY based authentication
24
ATLAN_API_KEY=your_api_key
5+
6+
#OAuth based authentication
7+
ATLAN_OAUTH_CLIENT_ID=your_oauth_client_id
8+
ATLAN_OAUTH_CLIENT_SECRET=your_oauth_client_secret

pyatlan/client/aio/client.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from contextlib import _AsyncGeneratorContextManager
1717
from http import HTTPStatus
1818
from types import SimpleNamespace
19-
from typing import Optional
19+
from typing import Any, Optional
2020

2121
import httpx
2222
from httpx_retries.retry import Retry
@@ -41,6 +41,7 @@
4141
from pyatlan.client.aio.file import AsyncFileClient
4242
from pyatlan.client.aio.group import AsyncGroupClient
4343
from pyatlan.client.aio.impersonate import AsyncImpersonationClient
44+
from pyatlan.client.aio.oauth import AsyncOAuthTokenManager
4445
from pyatlan.client.aio.open_lineage import AsyncOpenLineageClient
4546
from pyatlan.client.aio.query import AsyncQueryClient
4647
from pyatlan.client.aio.role import AsyncRoleClient
@@ -90,6 +91,7 @@ class AsyncAtlanClient(AtlanClient):
9091
"""
9192

9293
_async_session: Optional[httpx.AsyncClient] = PrivateAttr(default=None)
94+
_async_oauth_token_manager: Optional[Any] = PrivateAttr(default=None)
9395
_async_admin_client: Optional[AsyncAdminClient] = PrivateAttr(default=None)
9496
_async_asset_client: Optional[AsyncAssetClient] = PrivateAttr(default=None)
9597
_async_audit_client: Optional[AsyncAuditClient] = PrivateAttr(default=None)
@@ -133,6 +135,31 @@ class AsyncAtlanClient(AtlanClient):
133135
def __init__(self, **kwargs):
134136
# Initialize sync client (handles all validation, env vars, etc.)
135137
super().__init__(**kwargs)
138+
if self.oauth_client_id and self.oauth_client_secret and self.api_key is None:
139+
LOGGER.debug(
140+
"API Key not provided. Using Async OAuth flow for authentication"
141+
)
142+
if self._oauth_token_manager:
143+
LOGGER.debug("Sync oauth flow open. Closing it for Async oauth flow")
144+
self._oauth_token_manager.close()
145+
self._oauth_token_manager = None
146+
147+
final_base_url = self.base_url or os.environ.get(
148+
"ATLAN_BASE_URL", "INTERNAL"
149+
)
150+
final_oauth_client_id = self.oauth_client_id or os.environ.get(
151+
"ATLAN_OAUTH_CLIENT_ID"
152+
)
153+
final_oauth_client_secret = self.oauth_client_secret or os.environ.get(
154+
"ATLAN_OAUTH_CLIENT_SECRET"
155+
)
156+
self._async_oauth_token_manager = AsyncOAuthTokenManager(
157+
base_url=final_base_url,
158+
client_id=final_oauth_client_id,
159+
client_secret=final_oauth_client_secret,
160+
connect_timeout=self.connect_timeout,
161+
read_timeout=self.read_timeout,
162+
)
136163

137164
# Build proxy/SSL configuration (reuse from sync client)
138165
transport_kwargs = self._build_transport_proxy_config(kwargs)
@@ -438,6 +465,9 @@ async def _create_params(
438465
Async version of _create_params that uses AsyncAtlanRequest for AtlanObject instances.
439466
"""
440467
params = copy.deepcopy(self._request_params)
468+
if self._async_oauth_token_manager:
469+
token = await self._async_oauth_token_manager.get_token()
470+
params["headers"]["authorization"] = f"Bearer {token}"
441471
params["headers"]["Accept"] = api.consumes
442472
params["headers"]["content-type"] = api.produces
443473
if query_params is not None:
@@ -687,7 +717,7 @@ async def _handle_error_response(
687717

688718
# Retry with impersonation (if _user_id is present) on authentication failure
689719
if (
690-
self._user_id
720+
(self._user_id or self._async_oauth_token_manager)
691721
and not self._401_has_retried.get()
692722
and response.status_code
693723
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
@@ -746,6 +776,21 @@ async def _handle_401_token_refresh(
746776
Async version of token refresh and retry logic.
747777
Handles token refresh and retries the API request upon a 401 Unauthorized response.
748778
"""
779+
if self._async_oauth_token_manager:
780+
await self._async_oauth_token_manager.invalidate_token()
781+
token = await self._async_oauth_token_manager.get_token()
782+
params["headers"]["authorization"] = f"Bearer {token}"
783+
self._401_has_retried.set(True)
784+
LOGGER.debug("Successfully refreshed OAuth token after 401.")
785+
return await self._call_api_internal(
786+
api,
787+
path,
788+
params,
789+
binary_data=binary_data,
790+
download_file_path=download_file_path,
791+
text_response=text_response,
792+
)
793+
749794
try:
750795
# Use sync impersonation call since it's a quick API call
751796
new_token = await self.impersonate.user(user_id=self._user_id)

pyatlan/client/aio/oauth.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright 2025 Atlan Pte. Ltd.
3+
import asyncio
4+
import time
5+
from typing import Optional
6+
from urllib.parse import urljoin
7+
8+
import httpx
9+
from authlib.oauth2.rfc6749 import OAuth2Token
10+
11+
from pyatlan.client.constants import GET_OAUTH_CLIENT
12+
from pyatlan.utils import API
13+
14+
15+
class AsyncOAuthTokenManager:
16+
"""
17+
Manages OAuth tokens for asynchronous HTTP clients.
18+
19+
:param base_url: Base URL of the Atlan tenant.
20+
:param client_id: OAuth client ID.
21+
:param client_secret: OAuth client secret.
22+
:param http_client: Optional asynchronous HTTP client to use.
23+
:param connect_timeout: Timeout for establishing connections.
24+
:param read_timeout: Timeout for reading data.
25+
:param write_timeout: Timeout for writing data.
26+
:param pool_timeout: Timeout for acquiring a connection from the pool.
27+
"""
28+
29+
def __init__(
30+
self,
31+
base_url: str,
32+
client_id: str,
33+
client_secret: str,
34+
http_client: Optional[httpx.AsyncClient] = None,
35+
connect_timeout: float = 30.0,
36+
read_timeout: float = 900.0,
37+
write_timeout: float = 30.0,
38+
pool_timeout: float = 30.0,
39+
):
40+
self.base_url = base_url
41+
self.client_id = client_id
42+
self.client_secret = client_secret
43+
self.token_url = self._create_path(GET_OAUTH_CLIENT)
44+
self._lock = asyncio.Lock()
45+
self._http_client = http_client or httpx.AsyncClient(
46+
timeout=httpx.Timeout(
47+
connect=connect_timeout,
48+
read=read_timeout,
49+
write=write_timeout,
50+
pool=pool_timeout,
51+
)
52+
)
53+
self._token: Optional[OAuth2Token] = None
54+
self._owns_client = http_client is None
55+
56+
async def get_token(self) -> str:
57+
"""
58+
Retrieves a valid OAuth token, refreshing it if necessary.
59+
"""
60+
async with self._lock:
61+
if self._token and not self._token.is_expired():
62+
return str(self._token["access_token"])
63+
64+
response = await self._http_client.post(
65+
self.token_url,
66+
json={
67+
"clientId": self.client_id,
68+
"clientSecret": self.client_secret,
69+
},
70+
headers={"Content-Type": "application/json"},
71+
)
72+
response.raise_for_status()
73+
74+
data = response.json()
75+
access_token = data.get("accessToken") or data.get("access_token")
76+
77+
if not access_token:
78+
raise ValueError(
79+
f"OAuth token response missing 'accessToken' field. "
80+
f"Response keys: {list(data.keys())}"
81+
)
82+
83+
expires_in = data.get("expiresIn") or data.get("expires_in", 600)
84+
85+
self._token = OAuth2Token(
86+
{
87+
"access_token": access_token,
88+
"token_type": data.get("tokenType")
89+
or data.get("token_type", "Bearer"),
90+
"expires_in": expires_in,
91+
"expires_at": int(time.time()) + expires_in,
92+
}
93+
)
94+
95+
return access_token
96+
97+
async def invalidate_token(self):
98+
"""
99+
Invalidates the current OAuth token.
100+
"""
101+
async with self._lock:
102+
self._token = None
103+
104+
def _create_path(self, api: API):
105+
"""
106+
Creates the full URL for the given API endpoint.
107+
"""
108+
if self.base_url == "INTERNAL":
109+
return urljoin(api.endpoint.service, api.path)
110+
else:
111+
base_with_prefix = urljoin(self.base_url, api.endpoint.prefix)
112+
return urljoin(base_with_prefix, api.path)
113+
114+
async def aclose(self):
115+
"""
116+
Closes the underlying HTTP client if owned by this manager.
117+
"""
118+
if self._owns_client:
119+
await self._http_client.aclose()

pyatlan/client/atlan.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pyatlan.client.file import FileClient
4949
from pyatlan.client.group import GroupClient
5050
from pyatlan.client.impersonate import ImpersonationClient
51+
from pyatlan.client.oauth import OAuthTokenManager
5152
from pyatlan.client.open_lineage import OpenLineageClient
5253
from pyatlan.client.query import QueryClient
5354
from pyatlan.client.role import RoleClient
@@ -127,7 +128,9 @@ def log_response(response, *args, **kwargs):
127128

128129
class AtlanClient(BaseSettings):
129130
base_url: Union[Literal["INTERNAL"], HttpUrl]
130-
api_key: str
131+
api_key: Optional[str] = None
132+
oauth_client_id: Optional[str] = None
133+
oauth_client_secret: Optional[str] = None
131134
connect_timeout: float = 30.0 # 30 secs
132135
read_timeout: float = 900.0 # 15 mins
133136
retry: Retry = DEFAULT_RETRY
@@ -137,6 +140,7 @@ class AtlanClient(BaseSettings):
137140
_session: httpx.Client = PrivateAttr()
138141
_request_params: dict = PrivateAttr()
139142
_user_id: Optional[str] = PrivateAttr(default=None)
143+
_oauth_token_manager: Optional[Any] = PrivateAttr(default=None)
140144
_workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None)
141145
_credential_client: Optional[CredentialClient] = PrivateAttr(default=None)
142146
_admin_client: Optional[AdminClient] = PrivateAttr(default=None)
@@ -172,11 +176,33 @@ class Config:
172176

173177
def __init__(self, **data):
174178
super().__init__(**data)
175-
self._request_params = (
176-
{"headers": {"authorization": f"Bearer {self.api_key}"}}
177-
if self.api_key and self.api_key.strip()
178-
else {"headers": {}}
179-
)
179+
180+
if self.oauth_client_id and self.oauth_client_secret and self.api_key is None:
181+
LOGGER.debug("API KEY not provided. Using OAuth flow for authentication")
182+
183+
final_base_url = self.base_url or os.environ.get(
184+
"ATLAN_BASE_URL", "INTERNAL"
185+
)
186+
final_oauth_client_id = self.oauth_client_id or os.environ.get(
187+
"ATLAN_OAUTH_CLIENT_ID"
188+
)
189+
final_oauth_client_secret = self.oauth_client_secret or os.environ.get(
190+
"ATLAN_OAUTH_CLIENT_SECRET"
191+
)
192+
self._oauth_token_manager = OAuthTokenManager(
193+
base_url=final_base_url,
194+
client_id=final_oauth_client_id,
195+
client_secret=final_oauth_client_secret,
196+
connect_timeout=self.connect_timeout,
197+
read_timeout=self.read_timeout,
198+
)
199+
self._request_params = {"headers": {}}
200+
else:
201+
self._request_params = (
202+
{"headers": {"authorization": f"Bearer {self.api_key}"}}
203+
if self.api_key and self.api_key.strip()
204+
else {"headers": {}}
205+
)
180206

181207
# Build proxy/SSL configuration with environment variable fallback
182208
transport_kwargs = self._build_transport_proxy_config(data)
@@ -691,7 +717,7 @@ def _call_api_internal(
691717
# Retry with impersonation (if _user_id is present)
692718
# on authentication failure (token may have expired)
693719
if (
694-
self._user_id
720+
(self._user_id or self._oauth_token_manager)
695721
and not self._401_has_retried.get()
696722
and response.status_code
697723
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
@@ -813,6 +839,9 @@ def _create_params(
813839
self, api: API, query_params, request_obj, exclude_unset: bool = True
814840
):
815841
params = copy.deepcopy(self._request_params)
842+
if self._oauth_token_manager:
843+
token = self._oauth_token_manager.get_token()
844+
params["headers"]["authorization"] = f"Bearer {token}"
816845
params["headers"]["Accept"] = api.consumes
817846
params["headers"]["content-type"] = api.produces
818847
if query_params is not None:
@@ -846,6 +875,21 @@ def _handle_401_token_refresh(
846875
847876
returns: HTTP response received after retrying the request with the refreshed token
848877
"""
878+
if self._oauth_token_manager:
879+
self._oauth_token_manager.invalidate_token()
880+
token = self._oauth_token_manager.get_token()
881+
params["headers"]["authorization"] = f"Bearer {token}"
882+
self._401_has_retried.set(True)
883+
LOGGER.debug("Successfully refreshed OAuth token after 401.")
884+
return self._call_api_internal(
885+
api,
886+
path,
887+
params,
888+
binary_data=binary_data,
889+
download_file_path=download_file_path,
890+
text_response=text_response,
891+
)
892+
849893
try:
850894
new_token = self.impersonate.user(user_id=self._user_id)
851895
except Exception as e:

pyatlan/client/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@
8888
GET_WHOAMI_USER = API(
8989
WHOAMI_API, HTTPMethod.GET, HTTPStatus.OK, endpoint=EndPoint.HERACLES
9090
)
91+
92+
# oauth client authentication
93+
GET_OAUTH_CLIENT = API(
94+
"oauth-clients/token",
95+
HTTPMethod.POST,
96+
HTTPStatus.OK,
97+
endpoint=EndPoint.HERACLES,
98+
)
9199
# SQL parsing APIs
92100
PARSE_QUERY = API(
93101
f"{QUERY_API}/parse", HTTPMethod.POST, HTTPStatus.OK, endpoint=EndPoint.HEKA

0 commit comments

Comments
 (0)