Skip to content

Commit 8e474e9

Browse files
committed
refactor: split client into deploy and download
1 parent ace93c5 commit 8e474e9

3 files changed

Lines changed: 339 additions & 349 deletions

File tree

Lines changed: 0 additions & 343 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,6 @@
33
import requests
44
import hashlib
55
import json
6-
from tqdm import tqdm
7-
from SPARQLWrapper import SPARQLWrapper, JSON
8-
from hashlib import sha256
9-
import os
10-
11-
from databusclient.api.utils import get_databus_id_parts_from_uri, get_json_ld_from_databus
126

137
__debug = False
148

@@ -491,340 +485,3 @@ def deploy_from_metadata(
491485
print(f"Deployed {len(metadata)} file(s):")
492486
for entry in metadata:
493487
print(f" - {entry['url']}")
494-
495-
496-
def __download_file__(url, filename, vault_token_file=None, databus_key=None, auth_url=None, client_id=None) -> None:
497-
"""
498-
Download a file from the internet with a progress bar using tqdm.
499-
500-
Parameters:
501-
- url: the URL of the file to download
502-
- filename: the local file path where the file should be saved
503-
- vault_token_file: Path to Vault refresh token file
504-
- auth_url: Keycloak token endpoint URL
505-
- client_id: Client ID for token exchange
506-
507-
Steps:
508-
1. Try direct GET without Authorization header.
509-
2. If server responds with WWW-Authenticate: Bearer, 401 Unauthorized) or url starts with "https://data.dbpedia.io/databus.dbpedia.org",
510-
then fetch Vault access token and retry with Authorization header.
511-
"""
512-
513-
print(f"Download file: {url}")
514-
dirpath = os.path.dirname(filename)
515-
if dirpath:
516-
os.makedirs(dirpath, exist_ok=True) # Create the necessary directories
517-
# --- 1. Get redirect URL by requesting HEAD ---
518-
response = requests.head(url, stream=True)
519-
# Check for redirect and update URL if necessary
520-
if response.headers.get("Location") and response.status_code in [301, 302, 303, 307, 308]:
521-
url = response.headers.get("Location")
522-
print("Redirects url: ", url)
523-
524-
# --- 2. Try direct GET ---
525-
response = requests.get(url, stream=True, allow_redirects=True, timeout=30)
526-
www = response.headers.get('WWW-Authenticate', '') # get WWW-Authenticate header if present to check for Bearer auth
527-
528-
# Vault token required if 401 Unauthorized with Bearer challenge
529-
if (response.status_code == 401 and "bearer" in www.lower()):
530-
print(f"Authentication required for {url}")
531-
if not (vault_token_file):
532-
raise ValueError("Vault token file not given for protected download")
533-
534-
# --- 3. Fetch Vault token ---
535-
vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id)
536-
headers = {"Authorization": f"Bearer {vault_token}"}
537-
538-
# --- 4. Retry with token ---
539-
response = requests.get(url, headers=headers, stream=True, timeout=30)
540-
541-
# Databus API key required if only 401 Unauthorized
542-
elif response.status_code == 401:
543-
print(f"API key required for {url}")
544-
if not databus_key:
545-
raise ValueError("Databus API key not given for protected download")
546-
547-
headers = {"X-API-KEY": databus_key}
548-
response = requests.get(url, headers=headers, stream=True, timeout=30)
549-
550-
try:
551-
response.raise_for_status() # Raise if still failing
552-
except requests.exceptions.HTTPError as e:
553-
if response.status_code == 404:
554-
print(f"WARNING: Skipping file {url} because it was not found (404).")
555-
return
556-
else:
557-
raise e
558-
559-
total_size_in_bytes = int(response.headers.get('content-length', 0))
560-
block_size = 1024 # 1 KiB
561-
562-
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
563-
with open(filename, 'wb') as file:
564-
for data in response.iter_content(block_size):
565-
progress_bar.update(len(data))
566-
file.write(data)
567-
progress_bar.close()
568-
569-
# TODO: could be a problem of github raw / openflaas
570-
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
571-
raise IOError("Downloaded size does not match Content-Length header")
572-
573-
574-
def __get_vault_access__(download_url: str,
575-
token_file: str,
576-
auth_url: str,
577-
client_id: str) -> str:
578-
"""
579-
Get Vault access token for a protected databus download.
580-
"""
581-
# 1. Load refresh token
582-
refresh_token = os.environ.get("REFRESH_TOKEN")
583-
if not refresh_token:
584-
if not os.path.exists(token_file):
585-
raise FileNotFoundError(f"Vault token file not found: {token_file}")
586-
with open(token_file, "r") as f:
587-
refresh_token = f.read().strip()
588-
if len(refresh_token) < 80:
589-
print(f"Warning: token from {token_file} is short (<80 chars)")
590-
591-
# 2. Refresh token -> access token
592-
resp = requests.post(auth_url, data={
593-
"client_id": client_id,
594-
"grant_type": "refresh_token",
595-
"refresh_token": refresh_token
596-
})
597-
resp.raise_for_status()
598-
access_token = resp.json()["access_token"]
599-
600-
# 3. Extract host as audience
601-
# Remove protocol prefix
602-
if download_url.startswith("https://"):
603-
host_part = download_url[len("https://"):]
604-
elif download_url.startswith("http://"):
605-
host_part = download_url[len("http://"):]
606-
else:
607-
host_part = download_url
608-
audience = host_part.split("/")[0] # host is before first "/"
609-
610-
# 4. Access token -> Vault token
611-
resp = requests.post(auth_url, data={
612-
"client_id": client_id,
613-
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
614-
"subject_token": access_token,
615-
"audience": audience
616-
})
617-
resp.raise_for_status()
618-
vault_token = resp.json()["access_token"]
619-
620-
print(f"Using Vault access token for {download_url}")
621-
return vault_token
622-
623-
624-
def __query_sparql__(endpoint_url, query, databus_key=None) -> dict:
625-
"""
626-
Query a SPARQL endpoint and return results in JSON format.
627-
628-
Parameters:
629-
- endpoint_url: the URL of the SPARQL endpoint
630-
- query: the SPARQL query string
631-
- databus_key: Optional API key for authentication
632-
633-
Returns:
634-
- Dictionary containing the query results
635-
"""
636-
sparql = SPARQLWrapper(endpoint_url)
637-
sparql.method = 'POST'
638-
sparql.setQuery(query)
639-
sparql.setReturnFormat(JSON)
640-
if databus_key is not None:
641-
sparql.setCustomHttpHeaders({"X-API-KEY": databus_key})
642-
results = sparql.query().convert()
643-
return results
644-
645-
646-
def __handle_databus_file_query__(endpoint_url, query, databus_key=None) -> List[str]:
647-
result_dict = __query_sparql__(endpoint_url, query, databus_key=databus_key)
648-
for binding in result_dict['results']['bindings']:
649-
if len(binding.keys()) > 1:
650-
print("Error multiple bindings in query response")
651-
break
652-
else:
653-
value = binding[next(iter(binding.keys()))]['value']
654-
yield value
655-
656-
657-
def __handle_databus_artifact_version__(json_str: str) -> List[str]:
658-
"""
659-
Parse the JSON-LD of a databus artifact version to extract download URLs.
660-
Don't get downloadURLs directly from the JSON-LD, but follow the "file" links to count access to databus accurately.
661-
662-
Returns a list of download URLs.
663-
"""
664-
665-
databusIdUrl = []
666-
json_dict = json.loads(json_str)
667-
graph = json_dict.get("@graph", [])
668-
for node in graph:
669-
if node.get("@type") == "Part":
670-
id = node.get("file")
671-
databusIdUrl.append(id)
672-
return databusIdUrl
673-
674-
675-
def __get_databus_latest_version_of_artifact__(json_str: str) -> str:
676-
"""
677-
Parse the JSON-LD of a databus artifact to extract URLs of the latest version.
678-
679-
Returns download URL of latest version of the artifact.
680-
"""
681-
json_dict = json.loads(json_str)
682-
versions = json_dict.get("databus:hasVersion")
683-
684-
# Single version case {}
685-
if isinstance(versions, dict):
686-
versions = [versions]
687-
# Multiple versions case [{}, {}]
688-
689-
version_urls = [v["@id"] for v in versions if "@id" in v]
690-
if not version_urls:
691-
raise ValueError("No versions found in artifact JSON-LD")
692-
693-
version_urls.sort(reverse=True) # Sort versions in descending order
694-
return version_urls[0] # Return the latest version URL
695-
696-
697-
def __get_databus_artifacts_of_group__(json_str: str) -> List[str]:
698-
"""
699-
Parse the JSON-LD of a databus group to extract URLs of all artifacts.
700-
701-
Returns a list of artifact URLs.
702-
"""
703-
json_dict = json.loads(json_str)
704-
artifacts = json_dict.get("databus:hasArtifact", [])
705-
706-
result = []
707-
for item in artifacts:
708-
uri = item.get("@id")
709-
if not uri:
710-
continue
711-
_, _, _, _, version, _ = get_databus_id_parts_from_uri(uri)
712-
if version is None:
713-
result.append(uri)
714-
return result
715-
716-
717-
def wsha256(raw: str):
718-
return sha256(raw.encode('utf-8')).hexdigest()
719-
720-
721-
def __handle_databus_collection__(uri: str, databus_key: str | None = None) -> str:
722-
headers = {"Accept": "text/sparql"}
723-
if databus_key is not None:
724-
headers["X-API-KEY"] = databus_key
725-
726-
return requests.get(uri, headers=headers, timeout=30).text
727-
728-
729-
def __download_list__(urls: List[str],
730-
localDir: str,
731-
vault_token_file: str = None,
732-
databus_key: str = None,
733-
auth_url: str = None,
734-
client_id: str = None) -> None:
735-
fileLocalDir = localDir
736-
for url in urls:
737-
if localDir is None:
738-
_host, account, group, artifact, version, file = get_databus_id_parts_from_uri(url)
739-
fileLocalDir = os.path.join(os.getcwd(), account, group, artifact, version if version is not None else "latest")
740-
print(f"Local directory not given, using {fileLocalDir}")
741-
742-
file = url.split("/")[-1]
743-
filename = os.path.join(fileLocalDir, file)
744-
print("\n")
745-
__download_file__(url=url, filename=filename, vault_token_file=vault_token_file, databus_key=databus_key, auth_url=auth_url, client_id=client_id)
746-
print("\n")
747-
748-
749-
def download(
750-
localDir: str,
751-
endpoint: str,
752-
databusURIs: List[str],
753-
token=None,
754-
databus_key=None,
755-
auth_url=None,
756-
client_id=None
757-
) -> None:
758-
"""
759-
Download datasets to local storage from databus registry. If download is on vault, vault token will be used for downloading protected files.
760-
------
761-
localDir: the local directory
762-
endpoint: the databus endpoint URL
763-
databusURIs: identifiers to access databus registered datasets
764-
token: Path to Vault refresh token file
765-
databus_key: Databus API key for protected downloads
766-
auth_url: Keycloak token endpoint URL
767-
client_id: Client ID for token exchange
768-
"""
769-
770-
# TODO: make pretty
771-
for databusURI in databusURIs:
772-
host, account, group, artifact, version, file = get_databus_id_parts_from_uri(databusURI)
773-
774-
# dataID or databus collection
775-
if databusURI.startswith("http://") or databusURI.startswith("https://"):
776-
# Auto-detect sparql endpoint from databusURI if not given -> no need to specify endpoint (--databus)
777-
if endpoint is None:
778-
endpoint = f"https://{host}/sparql"
779-
print(f"SPARQL endpoint {endpoint}")
780-
781-
# databus collection
782-
if group == "collections":
783-
query = __handle_databus_collection__(databusURI, databus_key=databus_key)
784-
res = __handle_databus_file_query__(endpoint, query)
785-
__download_list__(res, localDir, vault_token_file=token, databus_key=databus_key, auth_url=auth_url, client_id=client_id)
786-
# databus file
787-
elif file is not None:
788-
__download_list__([databusURI], localDir, vault_token_file=token, databus_key=databus_key, auth_url=auth_url, client_id=client_id)
789-
# databus artifact version
790-
elif version is not None:
791-
json_str = get_json_ld_from_databus(databusURI, databus_key=databus_key)
792-
res = __handle_databus_artifact_version__(json_str)
793-
__download_list__(res, localDir, vault_token_file=token, databus_key=databus_key, auth_url=auth_url, client_id=client_id)
794-
# databus artifact
795-
elif artifact is not None:
796-
json_str = get_json_ld_from_databus(databusURI, databus_key=databus_key)
797-
latest = __get_databus_latest_version_of_artifact__(json_str)
798-
print(f"No version given, using latest version: {latest}")
799-
json_str = get_json_ld_from_databus(latest, databus_key=databus_key)
800-
res = __handle_databus_artifact_version__(json_str)
801-
__download_list__(res, localDir, vault_token_file=token, databus_key=databus_key, auth_url=auth_url, client_id=client_id)
802-
803-
# databus group
804-
elif group is not None:
805-
json_str = get_json_ld_from_databus(databusURI, databus_key=databus_key)
806-
artifacts = __get_databus_artifacts_of_group__(json_str)
807-
for artifact_uri in artifacts:
808-
print(f"Processing artifact {artifact_uri}")
809-
json_str = get_json_ld_from_databus(artifact_uri, databus_key=databus_key)
810-
latest = __get_databus_latest_version_of_artifact__(json_str)
811-
print(f"No version given, using latest version: {latest}")
812-
json_str = get_json_ld_from_databus(latest, databus_key=databus_key)
813-
res = __handle_databus_artifact_version__(json_str)
814-
__download_list__(res, localDir, vault_token_file=token, databus_key=databus_key, auth_url=auth_url, client_id=client_id)
815-
816-
# databus account
817-
elif account is not None:
818-
print("accountId not supported yet") # TODO
819-
else:
820-
print("dataId not supported yet") # TODO add support for other DatabusIds
821-
# query in local file
822-
elif databusURI.startswith("file://"):
823-
print("query in file not supported yet")
824-
# query as argument
825-
else:
826-
print("QUERY {}", databusURI.replace("\n", " "))
827-
if endpoint is None: # endpoint is required for queries (--databus)
828-
raise ValueError("No endpoint given for query")
829-
res = __handle_databus_file_query__(endpoint, databusURI, databus_key=databus_key)
830-
__download_list__(res, localDir, vault_token_file=token, databus_key=databus_key, auth_url=auth_url, client_id=client_id)

0 commit comments

Comments
 (0)