Skip to content

Commit b77e9ce

Browse files
committed
APP-8817: Fixed get_client() (for both sync and async client errors)
1 parent e22558d commit b77e9ce

5 files changed

Lines changed: 217 additions & 19 deletions

File tree

pyatlan/client/aio/client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,7 @@ async def from_token_guid( # type: ignore[override]
174174

175175
# Step 1: Initialize base client and get Atlan-Argo credentials
176176
# Note: Using empty api_key as we're bootstrapping authentication
177-
client = cls(base_url=final_base_url)
178-
# Explicitly set api_key to empty string to avoid
179-
# httpx.LocalProtocolError: Illegal header value b'Bearer '
180-
client.api_key = ""
177+
client = cls(base_url=final_base_url, api_key="")
181178
client_info = ImpersonateUser.get_client_info(
182179
client_id=client_id, client_secret=client_secret
183180
)
@@ -873,6 +870,11 @@ async def _upload_file(self, api, file=None, filename=None):
873870
self._api_logger(api, path)
874871
return await self._call_api_internal(api, path, params, binary_data=post_data)
875872

873+
def update_headers(self, header: dict[str, str]):
874+
"""Update headers for the async session."""
875+
if self._async_session:
876+
self._async_session.headers.update(header)
877+
876878
async def aclose(self):
877879
"""Close async resources"""
878880
if self._async_session:

pyatlan/client/atlan.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,11 @@ class Config:
167167

168168
def __init__(self, **data):
169169
super().__init__(**data)
170-
self._request_params = {
171-
"headers": {
172-
"authorization": f"Bearer {self.api_key}",
173-
}
174-
}
170+
self._request_params = (
171+
{"headers": {"authorization": f"Bearer {self.api_key}"}}
172+
if self.api_key
173+
else {"headers": {}}
174+
)
175175
# Configure httpx client with the provided retry settings
176176
self._session = httpx.Client(
177177
transport=RetryTransport(retry=self.retry),
@@ -377,10 +377,7 @@ def from_token_guid(
377377

378378
# Step 1: Initialize base client and get Atlan-Argo credentials
379379
# Note: Using empty api_key as we're bootstrapping authentication
380-
client = AtlanClient(base_url=final_base_url)
381-
# Explicitly set api_key to empty string to avoid
382-
# httpx.LocalProtocolError: Illegal header value b'Bearer '
383-
client.api_key = ""
380+
client = AtlanClient(base_url=final_base_url, api_key="")
384381
client_info = ImpersonateUser.get_client_info(
385382
client_id=client_id, client_secret=client_secret
386383
)

pyatlan/pkg/utils.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44
import logging
55
import os
66
import sys
7-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
7+
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypeVar, Union
88

99
from pydantic.v1 import parse_obj_as, parse_raw_as
1010

11+
from pyatlan.client.aio import AsyncAtlanClient
1112
from pyatlan.client.atlan import AtlanClient
1213
from pyatlan.pkg.models import RuntimeConfig
1314

1415
LOGGER = logging.getLogger(__name__)
1516

17+
# Type variable for client types
18+
ClientType = TypeVar("ClientType", AtlanClient, AsyncAtlanClient)
19+
1620
# Try to import OpenTelemetry libraries
1721
try:
1822
from opentelemetry.exporter.otlp.proto.grpc._log_exporter import ( # type:ignore
@@ -64,13 +68,16 @@ def _is_valid_type(self, value: Any) -> bool:
6468
OTEL_IMPORTS_AVAILABLE = False
6569

6670

67-
def get_client(impersonate_user_id: str) -> AtlanClient:
71+
def get_client(
72+
impersonate_user_id: str, set_pkg_headers: Optional[bool] = False
73+
) -> AtlanClient:
6874
"""
6975
Set up the default Atlan client, based on environment variables.
7076
This will use an API token if found in ATLAN_API_KEY, and will fallback to attempting to impersonate a user if
7177
ATLAN_API_KEY is empty.
7278
7379
:param impersonate_user_id: unique identifier (GUID) of a user or API token to impersonate
80+
:param set_pkg_headers: whether to set package headers on the client (default is False)
7481
:returns: an initialized client
7582
"""
7683
base_url = os.environ.get("ATLAN_BASE_URL", "INTERNAL")
@@ -94,6 +101,46 @@ def get_client(impersonate_user_id: str) -> AtlanClient:
94101
client = AtlanClient(base_url=base_url, api_key=api_key)
95102
if user_id:
96103
client._user_id = user_id
104+
if set_pkg_headers:
105+
client = set_package_headers(client)
106+
return client
107+
108+
109+
async def get_client_async(
110+
impersonate_user_id: str, set_pkg_headers: Optional[bool] = False
111+
):
112+
"""
113+
Set up the default async Atlan client, based on environment variables.
114+
This will use an API token if found in ATLAN_API_KEY, and will fallback to attempting to impersonate a user if
115+
ATLAN_API_KEY is empty.
116+
117+
:param impersonate_user_id: unique identifier (GUID) of a user or API token to impersonate
118+
:param set_pkg_headers: whether to set package headers on the client (default is False)
119+
:returns: an initialized async client
120+
"""
121+
base_url = os.environ.get("ATLAN_BASE_URL", "INTERNAL")
122+
api_token = os.environ.get("ATLAN_API_KEY", "")
123+
user_id = os.environ.get("ATLAN_USER_ID", impersonate_user_id)
124+
125+
if api_token:
126+
LOGGER.info("Using provided API token for authentication.")
127+
api_key = api_token
128+
elif user_id:
129+
LOGGER.info("No API token found, attempting to impersonate user: %s", user_id)
130+
client = AsyncAtlanClient(base_url=base_url, api_key="")
131+
api_key = await client.impersonate.user(user_id=user_id)
132+
else:
133+
LOGGER.info(
134+
"No API token or impersonation user, attempting short-lived escalation."
135+
)
136+
client = AsyncAtlanClient(base_url=base_url, api_key="")
137+
api_key = await client.impersonate.escalate()
138+
139+
client = AsyncAtlanClient(base_url=base_url, api_key=api_key)
140+
if user_id:
141+
client._user_id = user_id
142+
if set_pkg_headers:
143+
client = set_package_headers(client)
97144
return client
98145

99146

@@ -110,12 +157,12 @@ def set_package_ops(run_time_config: RuntimeConfig) -> AtlanClient:
110157
return client
111158

112159

113-
def set_package_headers(client: AtlanClient) -> AtlanClient:
160+
def set_package_headers(client: ClientType) -> ClientType:
114161
"""
115-
Configure the AtlanClient with package headers from environment variables.
162+
Configure the AtlanClient or AsyncAtlanClient with package headers from environment variables.
116163
117-
:param client: AtlanClient instance to configure
118-
:returns: updated AtlanClient instance.
164+
:param client: AtlanClient or AsyncAtlanClient instance to configure
165+
:returns: updated client instance of the same type.
119166
"""
120167

121168
if (agent := os.environ.get("X_ATLAN_AGENT")) and (

tests/integration/aio/test_client.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
import pytest
77
import pytest_asyncio
8+
from httpx import Headers
89
from pydantic.v1 import StrictStr
910

11+
from pyatlan import __version__ as VERSION
1012
from pyatlan.client.aio.client import AsyncAtlanClient
1113
from pyatlan.client.atlan import DEFAULT_RETRY
1214
from pyatlan.client.common.audit import LOGGER as AUDIT_LOGGER
@@ -42,6 +44,8 @@
4244
Term,
4345
)
4446
from pyatlan.model.user import UserMinimalResponse
47+
from pyatlan.pkg.utils import get_client_async
48+
from pyatlan.utils import get_python_version
4549
from tests.integration.aio.utils import (
4650
async_search_with_retry,
4751
create_database_async,
@@ -1535,6 +1539,80 @@ async def test_client_401_token_refresh(
15351539
assert client.api_key != expired_api_token
15361540
assert results and results.count >= 1
15371541

1542+
# Verify similar results with get_client_async()
1543+
# Setting ATLAN_API_KEY to empty string to force impersonation
1544+
monkeypatch.setenv("ATLAN_API_KEY", "")
1545+
assert expired_token_user_id
1546+
client = await get_client_async(impersonate_user_id=expired_token_user_id)
1547+
results = await (
1548+
FluentSearch()
1549+
.where(CompoundQuery.active_assets())
1550+
.where(CompoundQuery.asset_type(AtlasGlossary))
1551+
.page_size(100)
1552+
.execute_async(client=client)
1553+
)
1554+
1555+
# Confirm the API key has been updated and results are returned
1556+
assert client.api_key != expired_api_token
1557+
assert results and results.count >= 1
1558+
1559+
# Verify package headers are set correctly
1560+
expected_common_headers = Headers(
1561+
{
1562+
"User-Agent": f"Atlan-PythonSDK/{VERSION}",
1563+
"Accept-Encoding": "gzip, deflate",
1564+
"Accept": "*/*",
1565+
"Connection": "keep-alive",
1566+
"x-atlan-agent": "sdk",
1567+
"x-atlan-agent-id": "python",
1568+
"x-atlan-client-origin": "product_sdk",
1569+
"x-atlan-python-version": get_python_version(),
1570+
"x-atlan-client-type": "async",
1571+
}
1572+
)
1573+
1574+
# Clear package environment variables to test default headers
1575+
for var in [
1576+
"X_ATLAN_AGENT",
1577+
"X_ATLAN_AGENT_ID",
1578+
"X_ATLAN_AGENT_PACKAGE_NAME",
1579+
"X_ATLAN_AGENT_WORKFLOW_ID",
1580+
]:
1581+
monkeypatch.delenv(var, raising=False)
1582+
1583+
client = await get_client_async(
1584+
impersonate_user_id=expired_token_user_id, set_pkg_headers=False
1585+
)
1586+
assert client._async_session is not None
1587+
assert expected_common_headers == client._async_session.headers
1588+
1589+
# Set package environment variables to test package headers
1590+
monkeypatch.setenv("X_ATLAN_AGENT", "agent_value")
1591+
monkeypatch.setenv("X_ATLAN_AGENT_ID", "agent_id_value")
1592+
monkeypatch.setenv("X_ATLAN_AGENT_PACKAGE_NAME", "package_name_value")
1593+
monkeypatch.setenv("X_ATLAN_AGENT_WORKFLOW_ID", "workflow_id_value")
1594+
1595+
expected = Headers(
1596+
{
1597+
"User-Agent": f"Atlan-PythonSDK/{VERSION}",
1598+
"Accept-Encoding": "gzip, deflate",
1599+
"Accept": "*/*",
1600+
"Connection": "keep-alive",
1601+
"x-atlan-client-origin": "product_sdk",
1602+
"x-atlan-python-version": get_python_version(),
1603+
"x-atlan-client-type": "async",
1604+
"x-atlan-agent": "agent_value",
1605+
"x-atlan-agent-id": "agent_id_value",
1606+
"x-atlan-agent-package-name": "package_name_value",
1607+
"x-atlan-agent-workflow-id": "workflow_id_value",
1608+
}
1609+
)
1610+
client = await get_client_async(
1611+
impersonate_user_id=expired_token_user_id, set_pkg_headers=True
1612+
)
1613+
assert client._async_session is not None
1614+
assert expected == client._async_session.headers
1615+
15381616

15391617
async def test_client_init_from_token_guid(
15401618
client: AsyncAtlanClient, token: ApiToken, argo_fake_token: ApiToken, monkeypatch

tests/integration/test_client.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from unittest.mock import patch
55

66
import pytest
7+
from httpx import Headers
78
from pydantic.v1 import StrictStr
89

10+
from pyatlan import __version__ as VERSION
911
from pyatlan.client.atlan import DEFAULT_RETRY, AtlanClient
1012
from pyatlan.client.common.audit import LOGGER as AUDIT_LOGGER
1113
from pyatlan.client.common.search_log import LOGGER as SEARCH_LOG_LOGGER
@@ -46,6 +48,8 @@
4648
Term,
4749
)
4850
from pyatlan.model.user import UserMinimalResponse
51+
from pyatlan.pkg.utils import get_client
52+
from pyatlan.utils import get_python_version
4953
from tests.integration.client import TestId
5054
from tests.integration.lineage_test import create_database, delete_asset
5155
from tests.integration.requests_test import create_token, delete_token
@@ -1627,6 +1631,76 @@ def test_client_401_token_refresh(
16271631
assert client.api_key != expired_api_token
16281632
assert results and results.count >= 1
16291633

1634+
# Verify similar results with get_client()
1635+
# Setting ATLAN_API_KEY to empty string to force impersonation
1636+
monkeypatch.setenv("ATLAN_API_KEY", "")
1637+
assert expired_token_user_id
1638+
client = get_client(impersonate_user_id=expired_token_user_id)
1639+
results = (
1640+
FluentSearch()
1641+
.where(CompoundQuery.active_assets())
1642+
.where(CompoundQuery.asset_type(AtlasGlossary))
1643+
.page_size(100)
1644+
.execute(client=client)
1645+
)
1646+
1647+
# Confirm the API key has been updated and results are returned
1648+
assert client.api_key != expired_api_token
1649+
assert results and results.count >= 1
1650+
1651+
# Verify package headers are set correctly
1652+
expected_common_headers = Headers(
1653+
{
1654+
"User-Agent": f"Atlan-PythonSDK/{VERSION}",
1655+
"Accept-Encoding": "gzip, deflate",
1656+
"Accept": "*/*",
1657+
"Connection": "keep-alive",
1658+
"x-atlan-agent": "sdk",
1659+
"x-atlan-agent-id": "python",
1660+
"x-atlan-client-origin": "product_sdk",
1661+
"x-atlan-python-version": get_python_version(),
1662+
"x-atlan-client-type": "sync",
1663+
}
1664+
)
1665+
1666+
# Clear package environment variables to test default headers
1667+
for var in [
1668+
"X_ATLAN_AGENT",
1669+
"X_ATLAN_AGENT_ID",
1670+
"X_ATLAN_AGENT_PACKAGE_NAME",
1671+
"X_ATLAN_AGENT_WORKFLOW_ID",
1672+
]:
1673+
monkeypatch.delenv(var, raising=False)
1674+
1675+
client = get_client(
1676+
impersonate_user_id=expired_token_user_id, set_pkg_headers=False
1677+
)
1678+
assert expected_common_headers == client._session.headers
1679+
1680+
# Set package environment variables to test package headers
1681+
monkeypatch.setenv("X_ATLAN_AGENT", "agent_value")
1682+
monkeypatch.setenv("X_ATLAN_AGENT_ID", "agent_id_value")
1683+
monkeypatch.setenv("X_ATLAN_AGENT_PACKAGE_NAME", "package_name_value")
1684+
monkeypatch.setenv("X_ATLAN_AGENT_WORKFLOW_ID", "workflow_id_value")
1685+
1686+
expected = Headers(
1687+
{
1688+
"User-Agent": f"Atlan-PythonSDK/{VERSION}",
1689+
"Accept-Encoding": "gzip, deflate",
1690+
"Accept": "*/*",
1691+
"Connection": "keep-alive",
1692+
"x-atlan-client-origin": "product_sdk",
1693+
"x-atlan-python-version": get_python_version(),
1694+
"x-atlan-client-type": "sync",
1695+
"x-atlan-agent": "agent_value",
1696+
"x-atlan-agent-id": "agent_id_value",
1697+
"x-atlan-agent-package-name": "package_name_value",
1698+
"x-atlan-agent-workflow-id": "workflow_id_value",
1699+
}
1700+
)
1701+
client = get_client(impersonate_user_id=expired_token_user_id, set_pkg_headers=True)
1702+
assert expected == client._session.headers
1703+
16301704

16311705
def test_client_init_from_token_guid(
16321706
client: AtlanClient, token: ApiToken, argo_fake_token: ApiToken, monkeypatch

0 commit comments

Comments
 (0)