Skip to content

Commit 3575b88

Browse files
committed
V9 client and model improvements
Adds validate_arguments decorators to async SSO client, fixes SSO model fields, improves asset/batch client v9 compatibility, and various model enhancements (badge, entity, persona, schema, etc.). Includes v9 validate module and pkg utilities. Made-with: Cursor
1 parent 339127f commit 3575b88

65 files changed

Lines changed: 548 additions & 408 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

pyatlan_v9/client/admin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from pyatlan.client.common import AdminGetAdminEvents, AdminGetKeycloakEvents, ApiCaller
99
from pyatlan.errors import ErrorCode
10-
from pyatlan.validate import validate_arguments
1110
from pyatlan_v9.model.keycloak_events import (
1211
AdminEvent,
1312
AdminEventRequest,
@@ -16,6 +15,7 @@
1615
KeycloakEventRequest,
1716
KeycloakEventResponse,
1817
)
18+
from pyatlan_v9.validate import validate_arguments
1919

2020

2121
class V9AdminClient:
@@ -72,9 +72,7 @@ def get_admin_events(self, admin_request: AdminEventRequest) -> AdminEventRespon
7272
:raises AtlanError: on any API communication issue
7373
"""
7474
endpoint, query_params = AdminGetAdminEvents.prepare_request(admin_request)
75-
raw_json = self._client._call_api(
76-
endpoint, query_params=query_params
77-
)
75+
raw_json = self._client._call_api(endpoint, query_params=query_params)
7876
if raw_json:
7977
events = msgspec.convert(raw_json, list[AdminEvent], strict=False)
8078
else:

pyatlan_v9/client/aio/admin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
AsyncApiCaller,
1212
)
1313
from pyatlan.errors import ErrorCode
14-
from pyatlan.validate import validate_arguments
1514
from pyatlan_v9.model.aio.keycloak_events import (
1615
AsyncAdminEventResponse,
1716
AsyncKeycloakEventResponse,
@@ -22,6 +21,7 @@
2221
KeycloakEvent,
2322
KeycloakEventRequest,
2423
)
24+
from pyatlan_v9.validate import validate_arguments
2525

2626

2727
class V9AsyncAdminClient:
@@ -80,9 +80,7 @@ async def get_admin_events(
8080
:raises AtlanError: on any API communication issue
8181
"""
8282
endpoint, query_params = AdminGetAdminEvents.prepare_request(admin_request)
83-
raw_json = await self._client._call_api(
84-
endpoint, query_params=query_params
85-
)
83+
raw_json = await self._client._call_api(endpoint, query_params=query_params)
8684
if raw_json:
8785
events = msgspec.convert(raw_json, list[AdminEvent], strict=False)
8886
else:

pyatlan_v9/client/aio/asset.py

Lines changed: 120 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
TYPE_CHECKING,
1111
Awaitable,
1212
Callable,
13-
Dict,
1413
List,
1514
Optional,
1615
Type,
@@ -22,14 +21,13 @@
2221

2322
import msgspec
2423
from tenacity import (
25-
RetryError,
2624
retry,
2725
retry_if_exception_type,
2826
stop_after_attempt,
2927
wait_exponential,
30-
wait_fixed,
3128
)
3229

30+
from pyatlan.client.asset import CategoryHierarchy
3331
from pyatlan.client.common import (
3432
DeleteByGuid,
3533
FindCategoryFastByName,
@@ -60,16 +58,13 @@
6058
UpdateCertificate,
6159
UpdateCustomMetadataAttributes,
6260
)
63-
from pyatlan.client.constants import (
64-
BULK_UPDATE,
65-
DELETE_ENTITIES_BY_GUIDS,
66-
)
67-
from pyatlan.errors import AtlanError, ErrorCode, NotFoundError, PermissionError
61+
from pyatlan.client.constants import BULK_UPDATE, DELETE_ENTITIES_BY_GUIDS
62+
from pyatlan.errors import ErrorCode, NotFoundError, PermissionError
6863
from pyatlan.model.aio import AsyncIndexSearchResults, AsyncLineageListResults
6964
from pyatlan.model.fields.atlan_fields import AtlanField
7065
from pyatlan.utils import unflatten_custom_metadata_for_entity
71-
from pyatlan.validate import validate_arguments
72-
66+
from pyatlan_v9.model.aggregation import Aggregations
67+
from pyatlan_v9.model.aio.core import AsyncAtlanRequest
7368
from pyatlan_v9.model.assets import (
7469
Asset,
7570
AtlasGlossary,
@@ -81,34 +76,31 @@
8176
Persona,
8277
Purpose,
8378
)
84-
from pyatlan_v9.model.enums import (
85-
AtlanConnectorType,
86-
AtlanDeleteType,
87-
CertificateStatus,
88-
DataQualityScheduleType,
89-
EntityStatus,
90-
SaveSemantic,
91-
)
92-
from pyatlan_v9.model.aggregation import Aggregations
93-
from pyatlan_v9.model.aio.core import AsyncAtlanRequest
9479
from pyatlan_v9.model.core import (
9580
Announcement,
96-
AssetRequest,
9781
AtlanRequest,
9882
AtlanTag,
9983
AtlanTagName,
10084
BulkRequest,
10185
)
10286
from pyatlan_v9.model.custom_metadata import CustomMetadataDict
87+
from pyatlan_v9.model.enums import (
88+
AtlanConnectorType,
89+
AtlanDeleteType,
90+
CertificateStatus,
91+
DataQualityScheduleType,
92+
EntityStatus,
93+
SaveSemantic,
94+
SortOrder,
95+
)
10396
from pyatlan_v9.model.lineage import LineageListRequest
10497
from pyatlan_v9.model.response import AssetMutationResponse, MutatedEntities
10598
from pyatlan_v9.model.search import IndexSearchRequest, Query
10699
from pyatlan_v9.model.transform import from_atlas_format
107-
108-
from pyatlan.client.asset import CategoryHierarchy
100+
from pyatlan_v9.validate import validate_arguments
109101

110102
if TYPE_CHECKING:
111-
from pyatlan.client.common import AsyncApiCaller
103+
pass
112104

113105
LOGGER = logging.getLogger(__name__)
114106

@@ -129,7 +121,9 @@ def _custom_metadata_payload(custom_metadata_request):
129121
root_payload = getattr(custom_metadata_request, "__root__", None)
130122
if root_payload is not None:
131123
return root_payload
132-
if hasattr(custom_metadata_request, "dict") and callable(custom_metadata_request.dict):
124+
if hasattr(custom_metadata_request, "dict") and callable(
125+
custom_metadata_request.dict
126+
):
133127
return custom_metadata_request.dict(by_alias=True, exclude_none=True)
134128
return custom_metadata_request
135129

@@ -148,19 +142,13 @@ def _parse_mutation_response(raw_json: dict) -> AssetMutationResponse:
148142
if me_raw := raw_json.get("mutatedEntities"):
149143
mutated = MutatedEntities(
150144
CREATE=(
151-
_parse_entities_v9(me_raw["CREATE"])
152-
if me_raw.get("CREATE")
153-
else None
145+
_parse_entities_v9(me_raw["CREATE"]) if me_raw.get("CREATE") else None
154146
),
155147
UPDATE=(
156-
_parse_entities_v9(me_raw["UPDATE"])
157-
if me_raw.get("UPDATE")
158-
else None
148+
_parse_entities_v9(me_raw["UPDATE"]) if me_raw.get("UPDATE") else None
159149
),
160150
DELETE=(
161-
_parse_entities_v9(me_raw["DELETE"])
162-
if me_raw.get("DELETE")
163-
else None
151+
_parse_entities_v9(me_raw["DELETE"]) if me_raw.get("DELETE") else None
164152
),
165153
PARTIAL_UPDATE=(
166154
_parse_entities_v9(me_raw["PARTIAL_UPDATE"])
@@ -187,15 +175,58 @@ def _parse_aggregations_v9(raw: dict) -> Aggregations:
187175
AggregationMetricResult,
188176
)
189177

178+
def _parse_nested(bucket_dict: dict) -> "Aggregations | None":
179+
"""Parse nested aggregation results inside a bucket, recursively."""
180+
nested: dict = {}
181+
known_keys = {
182+
"key",
183+
"doc_count",
184+
"key_as_string",
185+
"max_matching_length",
186+
"to",
187+
"to_as_string",
188+
"from",
189+
"from_as_string",
190+
}
191+
for k, v in bucket_dict.items():
192+
if k in known_keys or not isinstance(v, dict):
193+
continue
194+
try:
195+
if "buckets" in v:
196+
result = msgspec.convert(v, AggregationBucketResult, strict=False)
197+
raw_inner = v.get("buckets", [])
198+
for i, inner in enumerate(result.buckets):
199+
if i < len(raw_inner):
200+
try:
201+
inner.nested_results = _parse_nested(raw_inner[i])
202+
except Exception:
203+
pass
204+
nested[k] = result
205+
elif "hits" in v:
206+
nested[k] = msgspec.convert(v, AggregationHitsResult, strict=False)
207+
elif "value" in v:
208+
nested[k] = msgspec.convert(
209+
v, AggregationMetricResult, strict=False
210+
)
211+
except Exception:
212+
pass
213+
return Aggregations(data=nested) if nested else None
214+
190215
parsed: dict = {}
191216
for key, value in raw.items():
192217
if not isinstance(value, dict):
193218
continue
194219
try:
195220
if "buckets" in value:
196-
parsed[key] = msgspec.convert(
197-
value, AggregationBucketResult, strict=False
198-
)
221+
result = msgspec.convert(value, AggregationBucketResult, strict=False)
222+
raw_buckets = value.get("buckets", [])
223+
for i, bucket in enumerate(result.buckets):
224+
if i < len(raw_buckets):
225+
try:
226+
bucket.nested_results = _parse_nested(raw_buckets[i])
227+
except Exception:
228+
pass
229+
parsed[key] = result
199230
elif "hits" in value:
200231
parsed[key] = msgspec.convert(
201232
value, AggregationHitsResult, strict=False
@@ -324,8 +355,12 @@ def _make_bulk_request_payload(entities: list, client) -> dict:
324355

325356
async def _make_bulk_request_payload_async(entities: list, client) -> dict:
326357
"""Async version: serialize entities into API-ready dict with tag retranslation."""
358+
from pyatlan_v9.client.asset import _normalize_meanings_for_mutation
359+
327360
bulk = BulkRequest(entities=entities)
328361
request_dict = bulk.to_dict()
362+
for entity in request_dict.get("entities", []):
363+
_normalize_meanings_for_mutation(entity)
329364
async_request = AsyncAtlanRequest(instance=request_dict, client=client)
330365
await async_request.retranslate()
331366
return async_request.translated
@@ -368,7 +403,9 @@ def __init__(self, client):
368403
# Search
369404
# ------------------------------------------------------------------
370405

371-
async def search(self, criteria: IndexSearchRequest, bulk=False) -> V9AsyncIndexSearchResults:
406+
async def search(
407+
self, criteria: IndexSearchRequest, bulk=False
408+
) -> V9AsyncIndexSearchResults:
372409
"""
373410
Search for assets using the provided criteria.
374411
@@ -627,9 +664,7 @@ async def get_by_guid(
627664
guid, min_ext_info, ignore_relationships
628665
)
629666
raw_json = await self._client._call_api(endpoint_path, query_params)
630-
return _process_get_response_v9(
631-
raw_json, guid, asset_type, by_guid=True
632-
)
667+
return _process_get_response_v9(raw_json, guid, asset_type, by_guid=True)
633668

634669
@validate_arguments
635670
async def retrieve_minimal(
@@ -715,10 +750,10 @@ async def save(
715750
asset.validate_required()
716751
await asset.flush_custom_metadata_async(client=self._client)
717752

718-
request_payload = await _make_bulk_request_payload_async(
719-
entities, self._client
753+
request_payload = await _make_bulk_request_payload_async(entities, self._client)
754+
raw_json = await self._client._call_api(
755+
BULK_UPDATE, query_params, request_payload
720756
)
721-
raw_json = await self._client._call_api(BULK_UPDATE, query_params, request_payload)
722757
response = _parse_mutation_response(raw_json)
723758

724759
if connections_created := response.assets_created(Connection):
@@ -845,10 +880,10 @@ async def save_replacing_cm(
845880
asset.validate_required()
846881
await asset.flush_custom_metadata_async(self._client)
847882

848-
request_payload = await _make_bulk_request_payload_async(
849-
entities, self._client
883+
request_payload = await _make_bulk_request_payload_async(entities, self._client)
884+
raw_json = await self._client._call_api(
885+
BULK_UPDATE, query_params, request_payload
850886
)
851-
raw_json = await self._client._call_api(BULK_UPDATE, query_params, request_payload)
852887
return _parse_mutation_response(raw_json)
853888

854889
@validate_arguments
@@ -991,10 +1026,10 @@ async def _restore_asset(self, asset: Asset) -> AssetMutationResponse:
9911026
for restored in entities:
9921027
await restored.flush_custom_metadata_async(self._client)
9931028

994-
request_payload = await _make_bulk_request_payload_async(
995-
entities, self._client
1029+
request_payload = await _make_bulk_request_payload_async(entities, self._client)
1030+
raw_json = await self._client._call_api(
1031+
BULK_UPDATE, query_params, request_payload
9961032
)
997-
raw_json = await self._client._call_api(BULK_UPDATE, query_params, request_payload)
9981033
return _parse_mutation_response(raw_json)
9991034

10001035
# ------------------------------------------------------------------
@@ -1573,7 +1608,21 @@ async def _manage_terms(
15731608
updated_asset = asset_type.updater(
15741609
qualified_name=first_result.qualified_name, name=first_result.name
15751610
)
1576-
processed_terms = ManageTerms.process_terms_with_semantic(terms, save_semantic)
1611+
processed_terms: list[AtlasGlossaryTerm] = []
1612+
for term in terms:
1613+
if getattr(term, "guid", None):
1614+
processed_terms.append(
1615+
AtlasGlossaryTerm.ref_by_guid(
1616+
guid=term.guid, semantic=save_semantic
1617+
)
1618+
)
1619+
elif getattr(term, "qualified_name", None):
1620+
processed_terms.append(
1621+
AtlasGlossaryTerm.ref_by_qualified_name(
1622+
qualified_name=term.qualified_name,
1623+
semantic=save_semantic,
1624+
)
1625+
)
15771626
updated_asset.assigned_terms = processed_terms
15781627
response = await self.save(entity=updated_asset)
15791628
return ManageTerms.process_save_response(response, asset_type, updated_asset)
@@ -1863,10 +1912,27 @@ async def get_hierarchy(
18631912
:param related_attributes: attributes to retrieve for each related asset in the hierarchy
18641913
:returns: a traversable category hierarchy
18651914
"""
1915+
from pyatlan.model.search import Term as SearchTerm
1916+
from pyatlan_v9.model.fluent_search import FluentSearch
1917+
18661918
GetHierarchy.validate_glossary(glossary)
1867-
request = GetHierarchy.prepare_search_request(
1868-
glossary, attributes, related_attributes
1869-
)
1919+
if attributes is None:
1920+
attributes = []
1921+
if related_attributes is None:
1922+
related_attributes = []
1923+
search = (
1924+
FluentSearch.select()
1925+
.where(AtlasGlossaryCategory.ANCHOR.eq(glossary.qualified_name))
1926+
.where(SearchTerm.with_type_name("AtlasGlossaryCategory"))
1927+
.include_on_results(AtlasGlossaryCategory.PARENT_CATEGORY)
1928+
.page_size(20)
1929+
.sort(AtlasGlossaryCategory.NAME.order(SortOrder.ASCENDING))
1930+
)
1931+
for field in attributes:
1932+
search = search.include_on_results(field)
1933+
for field in related_attributes:
1934+
search = search.include_on_relations(field)
1935+
request = search.to_request()
18701936
response = await self.search(request)
18711937
return await GetHierarchy.process_async_search_results(response, glossary)
18721938

0 commit comments

Comments
 (0)