diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md new file mode 100644 index 00000000..ebac26f3 --- /dev/null +++ b/.claude/CLAUDE.md @@ -0,0 +1 @@ +See @../AGENTS.md for documentation on how to configure and use Claude agents in this workspace. \ No newline at end of file diff --git a/.claude/commands/code-review.md b/.claude/commands/code-review.md new file mode 100644 index 00000000..119cc753 --- /dev/null +++ b/.claude/commands/code-review.md @@ -0,0 +1,181 @@ +--- +description: Smart code review for the Skyflow Python SDK. Pass "full review" to scan the entire skyflow/ tree, a file/directory path to review that target, or nothing to review files changed on the current branch. +constraints: + - "NEVER edit, create, or delete any file under skyflow/generated/. Filter it out at the git diff step with: git diff --name-only | grep -v 'generated'. If a finding relates to generated code, report it as an observation only." +--- + +You are a senior engineer reviewing the Skyflow Python SDK — a Python client library for Skyflow's data privacy vault. + +## Mode selection — pick exactly one + +Inspect `$ARGUMENTS` and choose the review mode: + +| Argument | Mode | +|---|---| +| `full review` (case-insensitive) | **Full** — scan all files under `skyflow/` recursively | +| A file or directory path (starts with `/`, `./`, `skyflow/`, `tests/`, `samples/`, etc.) | **Path** — review only that file or directory | +| Empty / anything else | **Branch** — review files changed on the current branch vs `main` | + +### Full mode +Scan all files under `skyflow/` recursively, grouped by layer: +``` +skyflow/vault/controller/ +skyflow/vault/data/ +skyflow/vault/tokens/ +skyflow/vault/detect/ +skyflow/vault/connection/ +skyflow/vault/client/ +skyflow/utils/validations/ +skyflow/utils/ +skyflow/service_account/ +skyflow/client/ +skyflow/error/ +skyflow/__init__.py +``` +Read each file fully before reporting findings. Work layer by layer — controllers first, then validators, then data models, then utilities. + +### Path mode +Restrict the scan to the path given in `$ARGUMENTS` (file or directory). Read every file under that path before reporting. + +### Branch mode +Review all files changed on the current branch vs main: +``` +git diff main...HEAD --name-only | grep -v 'generated' +git diff main...HEAD +git log main...HEAD --oneline +``` +Summarise what the branch does in 2–3 sentences. List files grouped by layer: data models / controllers / validation / service_account / tests / samples / exports / docs. + +> **IMPORTANT — Generated code boundary** +> `skyflow/generated/` contains Fern-generated REST client code. **Never modify any file inside `skyflow/generated/`**. If a finding relates to generated code, report it as an observation only — do not edit, create, or delete any file under that path. + +--- + +## What to review + +### Basic checks +- Identify issues and unhandled edge cases that can break the code at runtime. + +### 1. Request / Response pattern +- Every public operation must use dedicated classes: `XxxRequest`, `XxxResponse` +- Request classes must declare all fields with explicit types (use `Optional[T]` for optional fields, never bare `None` as the sole annotation) +- Response objects must be plain data containers — no business logic, no API calls inside them +- Flag any controller method that accepts or returns plain `dict` instead of a typed class +- Mutable default arguments in `__init__` are a Python footgun — `def __init__(self, items=[])` mutates the class-level list across instances; use `None` and assign in the body instead + +### 2. Validation completeness +- Every public controller method must call its `validate_xxx_request()` function from `skyflow/utils/validations/_validations.py` **before** any API call +- Validators must raise `SkyflowError` with an error code from `SkyflowMessages.ErrorCodes` +- Validators must call `log_error_log(SkyflowMessages.ErrorLogs.xxx.value)` before raising +- Check for missing edge cases: `None` inputs, empty strings, wrong types, empty lists, negative numbers +- No truthy guard `if not x:` for values where `0`, `""`, `False`, or `[]` could be intentionally valid — use `x is None` or `x is not None` instead +- Consistent null-guard style across validators in the changed files — no mixing of `if not x`, `if x is None`, `if x == None` + +### 3. Error handling +- All methods that call the REST API must be wrapped in `try/except Exception` that calls `handle_exception(e, logger)` or raises `SkyflowError` +- Never swallow errors silently with a bare `except: pass` +- `except` blocks that re-raise must add value (logging, wrapping) — a bare `except Exception as e: raise e` with no added context should be flagged +- `SkyflowError` must be raised with the correct error code for the failure type (INVALID_INPUT for validation, server codes for API failures) +- Bare `except:` (catching `BaseException`) must be replaced with `except Exception:` + +### 4. Concurrency and I/O patterns +- File reads must use context managers (`with open(...) as f:`) — never leave file handles open +- `open()` for binary content must use mode `"rb"` not `"r"` +- Blocking I/O (file reads, network calls) inside a loop must be flagged if the loop runs at high volume — suggest batching +- No module-level side effects that execute on import (network calls, file I/O, env reads at import time) + +### 5. Python quality +- No mutable default arguments: `def f(x=[])`, `def f(x={})` — use `None` and initialise in the body +- No `from module import *` in non-`__init__.py` files +- No bare `except:` — always catch a specific exception type +- No `global` or `nonlocal` in controller / validator code +- `isinstance()` checks must cover all relevant types — flag `isinstance(x, int)` where `bool` is a subclass of `int` and would be silently accepted when it should not be, or vice versa +- f-strings preferred over `%` formatting and `.format()` for new code; flag inconsistency within the same file +- `Optional[T]` must be imported from `typing` — do not use `T | None` union syntax unless the minimum Python version is ≥ 3.10 + +### 6. State and side effects +- Controller instances must be stateless per-call — no `self.xxx = ` inside method bodies; use local variables instead +- Cached/memoized state (e.g. `_cached_headers`) must not mix per-call and cross-call data +- Class-level variables shared across instances are a source of subtle bugs — flag any mutable class-level variable + +### 7. Exports and public API surface +- All public types and classes must be exported from the appropriate `__init__.py` +- Internal helpers (prefixed with `_`) must not appear in `__init__.py` exports +- Circular imports must not exist — flag any `from skyflow.x import y` inside `skyflow/y/__init__.py` that creates a cycle + +### 8. Logging +- Use `log_info(message, logger)` for informational messages and `log_error_log(message, logger)` for errors — never `print()` or `logging.xxx()` directly +- Sensitive values (tokens, credentials, private keys, PII) must never appear in log messages or error strings +- Every public method entry and success path must have a corresponding `log_info` call matching the `SkyflowMessages.Info` enum + +### 9. Naming conventions (Python) + +| Identifier type | Required style | Example | +|---|---|---| +| Variable / parameter / method | `snake_case` | `vault_id`, `token_uri`, `get_records` | +| Constant / module-level value | `UPPER_SNAKE_CASE` | `SKY_META_DATA_HEADER`, `CTX_KEY_REGEX` | +| Class / Exception / Enum | `PascalCase` | `InsertRequest`, `SkyflowError`, `RedactionType` | +| Private method / attribute | `_snake_case` | `_validate_ctx`, `_cached_headers` | +| Source file | `snake_case.py` or `_snake_case.py` for internals | `_file_upload_request.py`, `_validations.py` | + +**Acronym rule — all-lowercase in identifiers, not ALL-CAPS:** +- `id` not `ID` (e.g. `skyflow_id`, not `skyflow_ID`) +- `uri` not `URI` (e.g. `token_uri`, not `token_URI`) +- `url` not `URL` (e.g. `callback_url`, not `callback_URL`) +- `api` not `API` (e.g. `api_key`, not `API_key`) +- Exception: class names follow PascalCase title-casing: `SkyflowId`, `TokenUri`; standalone environment variable names follow OS convention (`SKYFLOW_ID`, `TOKEN_URI`) + +**What to flag:** +- Any public field, method, or parameter name that uses ALL-CAPS for an acronym in a snake_case context +- Any `camelCase` field or method name in the public API +- Any class, exception, or enum name that is not `PascalCase` +- Any constant that is not `UPPER_SNAKE_CASE` +- Mixed conventions within the same class or module + +### 10. Type annotations +- All public method signatures must have parameter type annotations and a return type annotation +- `Any` usage outside `skyflow/generated/` must be justified — flag unannotated parameters that default to implicit `Any` +- `dict` without subscript (`dict` not `dict[str, Any]`) in a public signature must be flagged +- `Optional[T]` parameters must default to `None` — `def f(x: Optional[str])` without `= None` is a bug + +### 11. Function size and complexity +- Flag any function exceeding 50 lines — include the actual line count +- Flag nesting deeper than 3 levels — suggest early returns or extracted helpers +- Flag long `if valid: … entire body …` blocks where an inverted guard + early return would flatten the code +- Flag validator functions that duplicate logic already present in another validator — suggest extracting a shared helper + +### 12. Cross-cutting consistency +- All `XxxRequest` classes must declare optional fields with `Optional[T] = None`, not bare `= None` +- All validators must call `log_error_log` before raising `SkyflowError` — flag any that raise directly without logging +- Controller methods must use a consistent call style — flag mixing of direct API calls and `.with_raw_response` within the same controller +- Credential field access must use `CredentialField` constants, not hardcoded strings like `"clientID"` or `"tokenURI"` + +--- + +## Output format + +Group findings by file. For each issue: + +``` +[SEVERITY] file:line — Description + Problem: + Fix: +``` + +Severity levels: **CRITICAL** | **BUG** | **EDGE CASE** | **QUALITY** + +End with: + +**Summary table** +| Severity | Count | +|---|---| +| CRITICAL | n | +| BUG | n | +| EDGE CASE | n | +| QUALITY | n | +| **Total** | n | + +**Top 5 highest-priority fixes** (by risk, not count) + +**Verdict**: `APPROVE` | `APPROVE WITH FIXES` | `REQUEST CHANGES` +One sentence explaining the verdict. diff --git a/.claude/commands/code-security.md b/.claude/commands/code-security.md new file mode 100644 index 00000000..09c74082 --- /dev/null +++ b/.claude/commands/code-security.md @@ -0,0 +1,114 @@ +--- +description: Security audit for the Skyflow Python SDK — credentials, injection, file I/O, deps, auth lifecycle, HTTP safety. +constraints: + - "NEVER edit, create, or delete any file under skyflow/generated/. Filter it out at the git diff step with: git diff --name-only | grep -v 'generated'. If a finding relates to generated code, report it as an observation only." +--- + +You are a security engineer auditing the Skyflow Python SDK — a Python library that handles sensitive PII and credentials. Perform a security review on the target: $ARGUMENTS + +If no argument is given, scan all files changed on the current branch vs main: +``` +git diff main...HEAD --name-only | grep -v 'generated' +``` + +> **IMPORTANT — Generated code boundary** +> `skyflow/generated/` contains Fern-generated REST client code. **Never modify any file inside `skyflow/generated/`**. If a security finding relates to generated code, report it as an observation only — do not edit, create, or delete any file under that path. + +Read each target file fully before reporting findings. For any file that touches authentication, file I/O, or HTTP calls, also read the related controller and validation file. + +--- + +## Security checks + +### 1. Credential and token exposure +- Bearer tokens, API keys, private keys, and service account credentials must **never** appear in: + - Log messages at any level (`log_info`, `log_error_log`, or any `print()`) + - `SkyflowError` message strings or detail fields + - Response objects returned to callers + - File names or paths written to disk +- Check that bearer tokens returned by `generate_bearer_token()` / `get_service_account_token()` are not stored as instance attributes on controllers or at module level — tokens must be used immediately or cached only with a validated expiry check +- Verify that credentials passed in `credentials` dicts or `CredentialField` values are not echoed back in any response, log entry, or raised exception +- Check that `except` blocks do not accidentally log the full exception object (`str(e)`, `repr(e)`) when the exception may contain auth headers or credential fields +- Private keys read from disk or passed as strings must be zero-reference after use — check they are not bound to a long-lived variable on `self` + +### 2. Input validation and injection +- All user-supplied strings passed to file system APIs (`open()`, `os.path.exists()`, `os.path.isfile()`, `os.makedirs()`, `shutil.*`) must be validated before use: + - Path traversal: reject strings containing `../`, `..\\`, or null bytes (`\x00`) + - Absolute paths escaping an expected base directory — verify with `os.path.realpath()` and confirm the result is still under the expected root +- Regex patterns supplied by users (e.g. `allow_regex_list`, `restrict_regex_list` in detect/deidentify requests) must be compiled inside a `try/except re.error` block before use — a malformed or catastrophic pattern causes `re.error` or ReDoS +- `base64.b64decode()` on data from API responses or user input must have a size check before decoding — an unbounded payload can cause OOM; also call with `validate=True` to reject non-base64 characters +- Ensure no user-controlled string reaches `eval()`, `exec()`, `subprocess.run/Popen` with `shell=True`, `os.system()`, or `__import__()` +- `pickle.loads()` / `pickle.load()` must never be called on data from external sources (API responses, user files, environment variables) — pickle is an arbitrary code execution vector +- `yaml.load()` must always specify `Loader=yaml.SafeLoader` — bare `yaml.load()` executes arbitrary Python +- SQL-like query strings passed to the vault API (`QueryRequest`) must be treated as opaque — verify no string interpolation of user data occurs before the SDK forwards the query to the API + +### 3. File system security +- User-supplied directory paths (e.g. `output_directory` in deidentify file requests) must be validated to be an existing, accessible directory using `os.path.isdir()` — the SDK must not call `os.makedirs()` on arbitrary user-supplied paths +- Output file paths must be constructed with `os.path.join(validated_base_dir, sanitized_file_name)` — never raw string concatenation or f-strings with unsanitized components +- File name components that come from API responses (extensions, type names, `processed_file_extension`) must be sanitized before use in a path — they could contain `../`, leading dots, or special characters if the response is tampered with +- File handles must always be closed via `with open(...) as f:` — bare `open()` without a context manager leaks handles on exception paths +- Temporary files must use `tempfile.NamedTemporaryFile` or `tempfile.mkstemp()`, never manually constructed paths like `"/tmp/" + user_value` +- Temporary or intermediate files written during processing must be cleaned up on both success and error paths (use `try/finally`) + +### 4. Dependency and supply chain +- Run `pip-audit` (or `safety check`) and report all HIGH and CRITICAL severity findings with CVE IDs: + ``` + pip install pip-audit && pip-audit -r requirements.txt + ``` +- Check `requirements.txt` for unpinned dependencies — floating `>=` without an upper bound on security-sensitive packages (`cryptography`, `PyJWT`, `requests`, `httpx`, `urllib3`) is risky; prefer `~=` (compatible release) or exact pins +- Flag any `importlib.import_module()` or `__import__()` call where the module name is derived from user input or environment variables +- Check that `cryptography` and `PyJWT` are recent enough to avoid known CVEs (check pinned versions against current release) + +### 5. Error message information leakage +- `SkyflowError` messages surfaced to callers must not include: internal file system paths, raw Python tracebacks (`traceback.format_exc()`), server-side query details, or internal service hostnames +- `except` blocks that call `log_error_log` must log only a controlled message string from `SkyflowMessages.ErrorLogs` — never `str(e)` or `repr(e)` from an exception that may contain credentials or server internals +- HTTP response bodies from API errors must be stripped of internal server details (stack traces, internal hostnames, database names) before being wrapped in `SkyflowError` +- Ensure `SkyflowError` fields (`message`, `http_code`, `grpc_code`) contain only information safe to surface to SDK consumers — no raw server error objects + +### 6. Authentication and token lifecycle +- Bearer tokens fetched via `generate_bearer_token()` must be checked with `is_expired()` immediately before the API call — the token could expire between retrieval and use (TOCTOU) +- `is_expired()` must decode without signature verification (`verify_signature: False`) only for expiry checking — it must not be used to bypass signature verification for actual authentication decisions +- JWT signing in `get_signed_jwt()` must use `RS256` or `ES256` — flag any path where `algorithm` could be set to `HS256` with a user-supplied or weak secret +- API key validation must check both the `sky-` prefix and the exact length (42 characters per `ApiKey.LENGTH`) — a check on only one criterion is insufficient +- Check that no credential type silently falls through to a less-secure fallback (e.g. ignoring an invalid token and falling back to no auth) without raising or logging a warning +- Service account credentials read from disk must validate that the file is not world-readable (`os.stat(path).st_mode`) — a credentials file with mode `0o644` is readable by all local users + +### 7. HTTP and network security +- All HTTP calls made through `httpx` or `requests` must use HTTPS — flag any hardcoded `http://` URL or any URL assembled from user-supplied components without a scheme check +- SSL verification must never be disabled: `verify=False` in `httpx.Client()` or `requests.get(..., verify=False)` bypasses certificate validation and enables MITM attacks +- Auth headers (`Authorization`, `X-Skyflow-Authorization`) must not be logged, echoed in error messages, or forwarded to redirect targets +- `x-request-id` and other response headers must be read only — verify they are never interpolated into subsequent outgoing requests (header injection) +- HTTP clients must be configured with connection and read timeouts — an absent timeout means a hung server stalls the calling thread indefinitely; check for `timeout=` parameters on all `httpx` / `requests` calls +- Check that the SDK does not silently retry on `401 Unauthorized` without first refreshing the token — silent retry with a stale token masks credential rotation failures + +### 8. Secrets in source and configuration +- Check all files for accidentally committed secrets: JWT private keys, API keys matching `sky-[A-Za-z0-9]{38}`, PEM blocks (`-----BEGIN RSA PRIVATE KEY-----`), or base64 blobs > 200 chars in non-test files +- `.env` files and `credentials.json` must be listed in `.gitignore` — verify they are not tracked +- Test fixture files (`tests/`) must use obviously fake credentials (e.g. dummy PEM blocks, placeholder strings) — real-looking keys in tests are a supply-chain risk if the repo becomes public +- Environment variable reads (`os.environ.get(...)`) for credential material must have a documented fallback policy — silently returning `None` and continuing is an authentication bypass + +### 9. Generated code (`skyflow/generated/`) +- Do not modify — but verify it uses HTTPS only, applies auth headers from the SDK layer, and does not hardcode environment-specific endpoints +- Flag if the generated client disables SSL verification or sets a permissive timeout of `None` +- Flag if the generated client silently retries on auth failure without surfacing the error to the caller + +--- + +## Output format + +For each finding: + +``` +[SEVERITY] file:line — Finding title + Risk: + Trigger: + Fix: + CWE: +``` + +Severity: **CRITICAL** | **HIGH** | **MEDIUM** | **LOW** | **INFO** + +End with: +1. Overall risk rating (Critical / High / Medium / Low) +2. Top 3 highest-priority fixes +3. Whether the code is safe to ship as-is diff --git a/.claude/commands/code-smell.md b/.claude/commands/code-smell.md new file mode 100644 index 00000000..acfebb12 --- /dev/null +++ b/.claude/commands/code-smell.md @@ -0,0 +1,147 @@ +--- +description: Code smell and anti-pattern detection for the Skyflow Python SDK — duplication, dead code, magic values, naming, complexity, Python-specific pitfalls. +constraints: + - "NEVER edit, create, or delete any file under skyflow/generated/. Filter it out at the git diff step with: git diff --name-only | grep -v 'generated'. If a smell is found in generated code, report it as an observation only." +--- + +You are a Python code quality engineer. Identify code smells, anti-patterns, and maintainability issues in the Skyflow Python SDK. Target: $ARGUMENTS + +If no argument is given, scan all files changed on the current branch vs main: +``` +git diff main...HEAD --name-only | grep -v 'generated' +``` + +> **IMPORTANT — Generated code boundary** +> `skyflow/generated/` contains Fern-generated REST client code. **Never modify any file inside `skyflow/generated/`**. If a smell is found in generated code, report it as an observation only — do not edit, create, or delete any file under that path. + +Read each target file fully before reporting findings. + +--- + +## Code smell categories + +### 1. Python anti-patterns +- **Mutable default arguments** — `def f(self, items=[], config={})` creates a single shared object across all calls; replace with `None` and initialise inside the body (`if items is None: items = []`) +- **Bare `except:`** without an exception type — catches `KeyboardInterrupt`, `SystemExit`, and `GeneratorExit`, which should never be silently swallowed; replace with `except Exception:` +- **`except Exception: pass`** or **`except Exception: continue`** — silently swallows all errors; at minimum log the exception before swallowing +- **`except Exception as e: raise e`** with no added context — no-op catch that obscures the real traceback; remove it or add logging before re-raising +- **Truthy guard on values where `0`, `""`, `False`, or `[]` are valid** — `if not x:` treats these as falsy when they may be intentional; use `if x is None:` or `if x is not None:` +- **`== None` / `!= None`** — use `is None` / `is not None`; `==` calls `__eq__` and can be overridden by any object +- **`from module import *`** outside `__init__.py` — pollutes the namespace and makes dependency tracking impossible; list explicit names +- **`global` or `nonlocal` inside controller / validator code** — hidden state mutation; use return values or class attributes instead +- **String concatenation in a loop** (`result += s`) — quadratic time; use `"".join(parts)` instead +- **`type(x) == SomeType`** instead of `isinstance(x, SomeType)` — misses subclasses; use `isinstance` unless an exact type check is intentional + +### 2. I/O and resource management +- **`open()` without a context manager** — file handle is left open on exception; always use `with open(...) as f:` +- **`open()` for binary content with mode `"r"`** instead of `"rb"` — silently corrupts binary data on Windows and may crash on non-UTF-8 content +- **Blocking I/O inside a hot loop** — file reads, network calls, or `time.sleep()` inside a tight loop without batching; flag any loop that makes one network call per item when a batch API exists +- **No `finally` / context manager for cleanup** — temporary files, open handles, or partially-written outputs not cleaned up on the error path +- **`time.sleep()` in non-test code without a comment** — busy-wait without a timeout ceiling; flag and ask whether exponential backoff or a bounded retry is more appropriate + +### 3. Duplication +- Identical or near-identical validation blocks (type check → empty check → element check) copy-pasted across multiple `validate_xxx` functions — extract a reusable `_validate_required_field()` helper +- The same `try: … except: raise SkyflowError(SkyflowMessages.Error.XXX.value, invalid_input_error_code)` wrapper repeated for identical failure modes — consolidate into a shared utility +- `_build_xxx_body()` / `_build_xxx_records()` methods that follow the same structure but differ only by request type — consider a shared builder parameterised by the differing fields +- Log calls (`log_info(SkyflowMessages.Info.XXX_TRIGGERED.value, logger)`) copy-pasted at the start of every controller method with no variation — consider a decorator or base-class hook for entry/exit logging +- Credential field extraction blocks (read `privateKey`, `clientID`, `keyID`, `tokenURI` with individual `try/except` per field) repeated across functions — extract a `_extract_credential_fields(credentials)` helper + +### 4. Dead / no-op code +- Variables assigned but never read after assignment +- Imports that are never referenced — run `flake8 --select=F401` or `ruff check --select=F401` to identify +- `if False:` / `if True:` branches — dead by construction +- `return None` at the end of a function that already returns `None` implicitly — redundant +- `# TODO`, `# FIXME`, or `# HACK` comments without a linked ticket — untracked TODOs rot permanently; flag all occurrences +- Code guarded by a Python version check that can never be `False` given the SDK's stated minimum Python version +- `pass` in an `except` block with a comment explaining the exception is expected — convert to `except SpecificError:` with a short inline explanation, or raise if the exception should not be swallowed + +### 5. Magic values +- Hardcoded status strings (`"IN_PROGRESS"`, `"SUCCESS"`, `"FAILED"`, `"UNKNOWN"`) in controller logic — use the corresponding enum values from `DetectStatus` or equivalent +- Hardcoded numeric literals for timeouts, retry counts, or max sizes with no named constant — define in `constants.py` with an `UPPER_SNAKE_CASE` name and a comment explaining the constraint +- Hardcoded credential key strings (`"clientID"`, `"tokenURI"`, `"privateKey"`) outside the `_normalize_credentials` mapping — all credential field access must go through `CredentialField` constants +- Hardcoded grant type string `"urn:ietf:params:oauth:grant-type:jwt-bearer"` repeated at multiple call sites — define once in the `JWT` constants class +- Hardcoded algorithm string `"RS256"` outside `JWT.ALGORITHM_RS256` — use the constant everywhere +- Hardcoded integer HTTP codes (200, 400, 500) in business logic — use `HttpStatusCode` constants + +### 6. Naming and clarity +- Vague variable names (`data`, `obj`, `temp`, `res`, `req`, `result`) in non-trivial logic — name after the domain concept (`insert_response`, `token_claims`, `credentials_dict`) +- Single-letter loop variables (`i`, `k`, `v`) outside trivial numeric ranges — use descriptive names (`token`, `record`, `vault_id`) +- Method names that imply a return value but only produce a side effect, or vice versa — align name with behaviour +- Typos in public class, method, or constant names — these are breaking changes once shipped; flag immediately + +**Python naming conventions — flag any violation:** + +| Identifier type | Required style | Example | +|---|---|---| +| Variable / parameter / method | `snake_case` | `vault_id`, `token_uri`, `get_records` | +| Constant / module-level value | `UPPER_SNAKE_CASE` | `SKY_META_DATA_HEADER`, `CTX_KEY_REGEX` | +| Class / Exception / Enum | `PascalCase` | `InsertRequest`, `SkyflowError`, `RedactionType` | +| Private method / attribute | `_snake_case` | `_validate_ctx`, `_cached_headers` | +| Source file | `snake_case.py` or `_snake_case.py` for internals | `_file_upload_request.py` | + +**Acronym rule — all-lowercase in snake_case identifiers:** +- `id` not `ID` (e.g. `skyflow_id`, not `skyflow_ID`) +- `uri` not `URI` (e.g. `token_uri`, not `token_URI`) +- `url` not `URL` (e.g. `callback_url`, not `callback_URL`) +- `api` not `API` (e.g. `api_key`, not `API_key`) +- Exception: class names use PascalCase title-casing (`SkyflowId`, `TokenUri`); standalone environment variable names follow OS convention (`SKYFLOW_ID`, `TOKEN_URI`) +- Mixed conventions within the same class or module — flag as inconsistency +- Any `camelCase` field, method, or local variable in SDK source (not in external credential dicts) — belongs in JavaScript/TypeScript, not Python + +### 7. Single responsibility +- Controller methods that mix HTTP transport with business logic (polling, file writing, response transformation) — each concern belongs in its own private method +- Validator functions that both validate and transform/normalise data — validation must only raise or pass; normalisation is a separate step (e.g. `_normalize_credentials`) +- Response parsers that also write to the file system — parsing and I/O must be separate functions +- `__init__` methods that perform network I/O or file reads — construction must be side-effect free; defer to an explicit `initialize()` or lazy-init method +- A single `if/elif` chain longer than 5 branches dispatching on a string type — extract a dispatch table (`dict` mapping type → handler) or split into separate methods + +### 8. Consistency across the SDK +- Some controller methods call `self.__initialize()` before the API call; others do not — this must be consistent for all public methods that require an initialised client +- Some validators call `log_error_log` before raising `SkyflowError`; others raise directly without logging — every validator must log before it raises +- Credential field access mixes `CredentialField` constants with hardcoded strings — use constants everywhere; mixing makes refactoring dangerous +- Some `XxxRequest` classes use `Optional[T] = None` for optional fields; others use bare `= None` without a type annotation — all optional fields must be annotated `Optional[T] = None` +- Required parameters placed after optional (`= None`) parameters in `__init__` — required parameters must always come first; optional at the end +- `log_info` call style varies (keyword vs positional `logger` argument) — pick one style per function signature and apply consistently +- HTTP response parsing uses `.with_raw_response` in some controllers and not in others — document the reason or standardise + +### 9. Validator completeness +- Validators that raise `SkyflowError` without calling `log_error_log(SkyflowMessages.ErrorLogs.xxx.value)` first — without the log there is no trace of the failure in production +- Validators that use `if not x:` instead of `if x is None:` — the truthy check silently rejects `0`, `""`, and `False` as invalid when they may be legitimate values +- Inconsistent null-guard style across validators (`if not x`, `if x is None`, `if x == None`) — pick one style and apply everywhere so edge cases are not missed in some paths but not others +- Validators that check a field's presence but not its type, or its type but not its emptiness — a complete field check is: present → correct type → non-empty / valid value +- Missing conditional validation for optional fields that carry constraints when provided (e.g. a string that must be a valid URL when non-`None`) — the validator must include an `if field is not None:` branch for such fields + +### 10. Function size and complexity +- Any function exceeding 50 lines — flag it with its actual line count; a function that long almost always violates single responsibility and is hard to test in isolation +- Nesting deeper than 3 levels (`if` inside `if` inside `for` inside `if`) — deep nesting hides edge cases; extract inner blocks into named helpers or use early returns +- **Missing guard / early return** — when a function opens with a long `if valid: … entire body …` block instead of `if not valid: raise/return`, the happy path is buried; invert the guard and return early +- Long `try` blocks spanning more than 10 lines — it is unclear which statement actually raised; split into one `try/except` per operation +- Chained ternary expressions (`a if cond1 else b if cond2 else c if cond3 else d`) — unreadable beyond two levels; use an explicit `if/elif/else` block + +### 11. Type annotation completeness +- Public method signatures missing parameter type annotations or a return type annotation — all public API surface must be fully annotated +- `dict` without subscript (`dict` instead of `dict[str, Any]`) in a public method signature — annotate key and value types explicitly +- `Optional[T]` parameters that do not default to `None` — `def f(x: Optional[str])` without `= None` is misleading and forces callers to pass `None` explicitly +- `Any` usage outside `skyflow/generated/` without a comment explaining why — every `Any` is a type-safety hole; flag all unannotated occurrences +- Inconsistent annotation style: some files use `Optional[T]`, others use `T | None` union syntax — pick one style for the minimum supported Python version and apply it throughout + +### 12. Error handling completeness +- `except Exception as e: raise SkyflowError(...)` that does not preserve the original error text — include `str(e)` or chain with `raise SkyflowError(...) from e` so the original context survives into production logs +- `except Exception` that catches a broad class but only handles one specific subtype, silently ignoring all others — add an `else: raise` or handle all expected subtypes explicitly +- Missing `finally` block when a resource (file, network connection, lock) is acquired inside a `try` — a naked `try/except` that does not release on the error path leaks the resource + +--- + +## Output format + +Group findings by category. For each smell: + +``` +[CATEGORY] file:line — Smell name + Why it's a problem: <1 sentence> + Suggested fix: +``` + +End with: +1. Top 5 highest-impact refactors (effort vs. benefit) +2. Estimated tech-debt score per category (High / Medium / Low) diff --git a/.claude/commands/sdk-sample.md b/.claude/commands/sdk-sample.md new file mode 100644 index 00000000..5e449697 --- /dev/null +++ b/.claude/commands/sdk-sample.md @@ -0,0 +1,145 @@ +--- +description: Generate a new Skyflow Python SDK sample file demonstrating a specific feature with proper patterns, error handling, and idiomatic style. +constraints: + - "NEVER edit, create, or delete any file under skyflow/generated/. Samples must import only from the public 'skyflow' package, never directly from skyflow/generated/." +--- + +Create a new Skyflow Python SDK sample. Feature to demonstrate: $ARGUMENTS + +> **IMPORTANT — Generated code boundary** +> `skyflow/generated/` contains Fern-generated REST client code. **Never modify any file inside `skyflow/generated/`**. Samples must import only from the public `skyflow` package, never directly from `skyflow/generated/`. + +Read the following before writing anything: +1. An existing sample in `samples/vault_api/` that is closest to the requested feature, to understand structure and style +2. `skyflow/__init__.py` to see what is exported from the top-level package +3. The relevant request class in `skyflow/vault/data/`, `skyflow/vault/tokens/`, `skyflow/vault/detect/`, or `skyflow/service_account/` to understand the constructor signature + +--- + +## Sample requirements + +### File location +- Vault CRUD operations → `samples/vault_api/.py` +- Token operations (tokenize, detokenize) → `samples/vault_api/_records.py` +- Connection operations → `samples/vault_api/invoke_connection.py` (extend existing) +- Service account / auth → `samples/service_account/_example.py` +- Detect / PII operations → `samples/detect_api/.py` + +### Required structure (in this order) +1. **Import block** — only from the public `skyflow` package; never from `skyflow.generated` +2. **Module-level docstring** — triple-quoted, listing what the sample demonstrates (numbered list) +3. **`cred` dict** — service account credential fields using their camelCase key names (`clientID`, `clientName`, `tokenURI`, `keyID`, `privateKey`) as placeholder strings +4. **`skyflow_credentials` dict** — wrap `cred` as `{'credentials_string': json.dumps(cred)}` +5. **`credentials` dict** — bearer token credential: `{'token': ''}` (shown as the per-vault override) +6. **`primary_vault_config` dict** — `vault_id`, `cluster_id`, `env: Env.PROD`, `credentials` +7. **`skyflow_client`** built with the fluent builder pattern: `.builder().add_vault_config(...).add_skyflow_credentials(...).set_log_level(LogLevel.ERROR).build()` +8. **Request object** constructed using the appropriate `XxxRequest` class with keyword arguments +9. **SDK call** through `skyflow_client.vault(vault_id).operation(request)` +10. **`print()` on success** — print the response object +11. **`except SkyflowError`** block with `http_code`, `message`, `details` +12. **`except Exception`** block for unexpected errors +13. **Top-level call** to the `perform_xxx()` function at the bottom of the file + +### Code style rules +- Function name: `perform_()` (snake_case, descriptive of the operation) +- Placeholder values: `''`, `''`, `''`, `''`, `''`, etc. +- Step comments: number each logical phase as `# Step 1: Configure Credentials`, `# Step 2: Configure Vault`, `# Step 3: Configure & Initialize Skyflow Client`, `# Step 4: Prepare Data`, `# Step 5: Perform ` +- Every non-obvious argument gets a short inline comment explaining WHY (not what): e.g. `return_tokens=True # Get tokens for inserted data` +- Keep the sample under 100 lines — if more is needed, split into focused examples +- No `print()` on success other than printing the response +- No hardcoded real credentials, private keys, or tokens — use `''` strings throughout + +### Template +```python +import json +from skyflow.error import SkyflowError +from skyflow import Env +from skyflow import Skyflow, LogLevel +from skyflow.vault.data import XxxRequest # replace with actual import(s) + +""" + * Skyflow Example + * + * This example demonstrates how to: + * 1. Configure Skyflow client credentials + * 2. Set up vault configuration + * 3. Create a request + * 4. Handle response and errors +""" + +def perform_(): + try: + # Step 1: Configure Credentials + cred = { + 'clientID': '', # Client identifier + 'clientName': '', # Client name + 'tokenURI': '', # Token URI + 'keyID': '', # Key identifier + 'privateKey': '', # Private key for authentication + } + + skyflow_credentials = { + 'credentials_string': json.dumps(cred) # Service account credentials + } + + credentials = { + 'token': '' # Bearer token for this vault + } + + # Step 2: Configure Vault + primary_vault_config = { + 'vault_id': '', # Primary vault + 'cluster_id': '', # Cluster ID from your vault URL + 'env': Env.PROD, # Deployment environment (PROD by default) + 'credentials': credentials # Per-vault authentication method + } + + # Step 3: Configure & Initialize Skyflow Client + skyflow_client = ( + Skyflow.builder() + .add_vault_config(primary_vault_config) + .add_skyflow_credentials(skyflow_credentials) # Fallback if no vault-level credentials + .set_log_level(LogLevel.ERROR) # Logging verbosity + .build() + ) + + # Step 4: Prepare Data + # ... build request object ... + request = XxxRequest( + # keyword arguments here + ) + + # Step 5: Perform + response = skyflow_client.vault(primary_vault_config.get('vault_id')).(request) + + # Handle Successful Response + print(' successful: ', response) + + except SkyflowError as error: + print('Skyflow Specific Error: ', { + 'code': error.http_code, + 'message': error.message, + 'details': error.details + }) + except Exception as error: + print('Unexpected Error:', error) + + +# Invoke the function +perform_() +``` + +### Service account samples (auth only) +For `samples/service_account/` samples that do not use the vault client, omit the `Skyflow` builder block and instead import directly from `skyflow.service_account`. Use a `global bearer_token = ''` pattern with an `is_expired()` guard before calling `generate_bearer_token()` or `generate_bearer_token_from_creds()`. Use `file_path = ''` as the placeholder. + +### Detect API samples +For `samples/detect_api/` samples, import request classes from `skyflow.vault.detect`. The client accessor is `skyflow_client.vault(vault_id).detect()` or the detect controller method directly. Follow the same five-step structure. + +--- + +After creating the file: +1. Run the sample with a dry-run import check to verify there are no import errors: + ``` + python -c "import ast; ast.parse(open('samples//.py').read()); print('Syntax OK')" + ``` +2. Report the file path created and any import or syntax issues to fix diff --git a/.claude/commands/test.md b/.claude/commands/test.md new file mode 100644 index 00000000..231143bf --- /dev/null +++ b/.claude/commands/test.md @@ -0,0 +1,129 @@ +--- +description: Full quality pipeline for the Skyflow Python SDK — lint, tests, coverage, and edge case analysis. +constraints: + - "NEVER edit, create, or delete any file under skyflow/generated/. Always filter generated files before passing any file list to lint: git diff --name-only | grep -E '\\.py$' | grep -v 'generated'. If analysis touches generated code, report it as an observation only." +--- + +Run the full Skyflow Python SDK quality pipeline and report results. Target (optional — specific test file, module path, or keyword pattern): $ARGUMENTS + +> **IMPORTANT — Generated code boundary** +> `skyflow/generated/` contains Fern-generated REST client code. **Never modify any file inside `skyflow/generated/`**. If analysis or test generation touches generated code, report it as an observation only — do not edit, create, or delete any file under that path. + +Execute the following steps in order. Stop and report clearly if any step fails. + +--- + +## Step 1 — Lint + +Lint only the files that have changed, not the entire repo. + +First, get the list of changed `.py` files, excluding generated code: +```bash +git diff --name-only HEAD | grep -E '\.py$' | grep -v 'generated' +``` +If that returns nothing (e.g. all changes are staged or no unstaged diff), fall back to staged files: +```bash +git diff --name-only --cached | grep -E '\.py$' | grep -v 'generated' +``` +If `$ARGUMENTS` is a file path, use that file directly instead. + +Then run ruff on those files: +```bash +.venv/bin/python -m ruff check +``` + +If ruff is not installed, fall back to flake8: +```bash +.venv/bin/python -m flake8 +``` + +Report any violations including file name and line number. A clean lint is required before proceeding. + +--- + +## Step 2 — Tests + +If `$ARGUMENTS` is provided and is a file or directory path, run only matching tests: +```bash +.venv/bin/python -m pytest "$ARGUMENTS" --import-mode=importlib -v +``` + +If `$ARGUMENTS` is a keyword pattern (no path separators), use `-k`: +```bash +.venv/bin/python -m pytest tests/ --import-mode=importlib -k "$ARGUMENTS" -v +``` + +Otherwise run the full suite: +```bash +.venv/bin/python -m pytest tests/ --import-mode=importlib -q +``` + +Report the count of passed, failed, and skipped tests. If any test fails, print the failure message in full and stop. + +--- + +## Step 3 — Coverage analysis + +Run the test suite with coverage, scoped to the `skyflow/` package and excluding generated code: +```bash +.venv/bin/python -m pytest tests/ --import-mode=importlib \ + --cov=skyflow \ + --cov-report=term-missing \ + --cov-omit="skyflow/generated/*" \ + -q +``` + +After tests complete, analyze the coverage output: +- Report overall line coverage % +- **Line coverage must be 100%** for all modules under `skyflow/` (excluding `skyflow/generated/`) — flag every file below 100% +- Flag any file with branch coverage below 80% if branch coverage is enabled +- For every flagged file, list the exact uncovered line numbers and what scenario they represent + +--- + +## Step 4 — Edge case analysis and test generation + +For every file under `skyflow/` that has line coverage below 100%, **or** for the target module if `$ARGUMENTS` is provided: + +### 4a — Identify uncovered edge cases + +Read the source file and its corresponding test file in `tests/`. For each uncovered branch or line, identify the missing scenario. Focus on: +- `None` inputs to public methods +- Empty strings, empty lists, zero or negative numbers +- Wrong types passed where a specific type is expected +- Error paths (`SkyflowError` raises) that are never triggered in tests +- Boundary conditions (exactly at limit vs. one over) +- Credential fields missing, malformed, or using snake_case vs camelCase variants +- Validator branches that check presence and type separately but only one branch is tested + +List each gap as: +``` +UNCOVERED: : +``` + +### 4b — Write the missing unit tests + +For each identified gap, write a concrete test case using the project's existing conventions: +- Test files live in `tests/vault/controller/`, `tests/service_account/`, `tests/utils/`, or `tests/vault/` — pick the nearest existing test file +- Follow the `class TestXxx(unittest.TestCase)` structure used throughout the project +- Mock external dependencies (API calls, `generate_bearer_token`, `AuthClient`) using `@patch` decorators, matching the pattern in existing tests +- Each test must: arrange inputs → call the method → assert the outcome (return value or raised `SkyflowError`) +- Use `self.assertRaises(SkyflowError)` with `.exception.message` assertions to match existing test style +- Do **not** remove or modify existing tests + +Output new tests as a ready-to-paste code block, clearly labelled with the target test file path. + +--- + +## Step 5 — Report + +Produce a summary table: + +| Step | Status | Notes | +|------|--------|-------| +| Lint | PASS / FAIL | violation count | +| Tests | PASS / FAIL | X passed, Y failed, Z skipped | +| Coverage | % | files below 100% threshold listed | +| Edge cases | n found | n tests written | + +End with: **READY TO MERGE** or **NEEDS FIXES**, followed by a bullet list of what must be addressed before merging. diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 00000000..73f068e0 --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,42 @@ +{ + "hooks": { + "PostToolUse": [ + { + "matcher": "Edit|Write", + "hooks": [ + { + "type": "command", + "command": "python3 -c \"\nimport sys, json, subprocess\nd = json.load(sys.stdin)\nf = d.get('tool_input', {}).get('file_path', d.get('file_path', ''))\nif f and f.endswith('.py') and 'generated' not in f:\n subprocess.run(['.venv/bin/python', '-m', 'ruff', 'check', '--fix', f], capture_output=True)\n subprocess.run(['.venv/bin/python', '-m', 'ruff', 'format', f], capture_output=True)\n\"" + } + ] + } + ], + "PreToolUse": [ + { + "matcher": "Edit|Write", + "hooks": [ + { + "type": "command", + "command": "python3 -c \"import sys,json; d=json.load(sys.stdin); p=d.get('tool_input',{}).get('file_path',d.get('file_path','')); banned='skyflow/generated'; (sys.stderr.write('BLOCKED: Fern-generated code — do not edit manually\\n'), sys.exit(2)) if banned in p else sys.exit(0)\"" + } + ] + } + ], + "Stop": [ + { + "hooks": [ + { + "type": "command", + "command": "notify-send 'Claude Code' 'Claude finished' 2>/dev/null || true" + } + ] + } + ] + }, + "permissions": { + "deny": [ + "Edit(skyflow/generated/**)", + "Write(skyflow/generated/**)" + ] + } +} diff --git a/.claude/skills/pr-review/SKILL.md b/.claude/skills/pr-review/SKILL.md new file mode 100644 index 00000000..50f53593 --- /dev/null +++ b/.claude/skills/pr-review/SKILL.md @@ -0,0 +1,222 @@ +--- +name: pr-review +description: Full pull-request review for the Skyflow Python SDK. Covers code quality, SDK patterns, tests, docs, security, breaking-change detection, naming conventions, edge cases, and coverage. +constraints: + - "NEVER edit, create, or delete any file under skyflow/generated/. Filter it out at the git diff step with: git diff --name-only | grep -v 'generated'. If a finding relates to generated code, report it as an observation only." +--- + +You are a senior engineer conducting a pull-request review for the Skyflow Python SDK — a Python client library for Skyflow's data privacy vault. + +Review the target PR or branch: $ARGUMENTS + +If no argument is given, diff the current branch against main: +``` +git diff main...HEAD +git diff main...HEAD --name-only +git log main...HEAD --oneline +``` + +> **IMPORTANT — Generated code boundary** +> `skyflow/generated/` contains Fern-generated REST client code. **Never modify any file inside `skyflow/generated/`**. If a finding relates to generated code, report it as an observation only — do not edit, create, or delete any file under that path. + +--- + +## Step 1 — Understand the change + +- Summarise what the PR does in 2–3 sentences. +- List every file changed, grouped by layer: data model / controller / validation / service_account / tests / samples / exports / docs. +- Identify the primary change type: **new feature** | **bug fix** | **refactor** | **docs** | **dependency update**. + +--- + +## Step 2 — Breaking-change detection + +- [ ] No public class or method removed from `skyflow/__init__.py` +- [ ] No public method signature changed (parameter added/removed/reordered, type narrowed, default removed) +- [ ] No `XxxRequest` / `XxxResponse` field renamed or removed +- [ ] No error code value changed (string identity in `SkyflowMessages.Error` or `SkyflowMessages.ErrorCodes`) +- [ ] No `LogLevel` or `Env` enum value changed +- [ ] No `CredentialField` or `RedactionType` enum value changed +- [ ] If a breaking change exists: is it intentional and documented in `CHANGELOG.md`? + +Flag any breaking change as **BREAKING** even if intentional. + +--- + +## Step 3 — SDK pattern compliance + +### Request / Response +- [ ] Every new public operation uses `XxxRequest` and `XxxResponse` dedicated classes +- [ ] Request classes declare all fields with explicit `Optional[T] = None` for optional fields — never bare `= None` without a type annotation +- [ ] Required parameters come before optional ones in `__init__` signatures +- [ ] Response objects are plain data containers — no business logic, no API calls inside them +- [ ] No controller method that accepts or returns a plain `dict` instead of a typed class + +### Validation +- [ ] Every public controller method calls its `validate_xxx_request()` from `skyflow/utils/validations/_validations.py` **before** any API call +- [ ] Validators raise `SkyflowError` with a code from `SkyflowMessages.ErrorCodes` +- [ ] `log_error_log(SkyflowMessages.ErrorLogs.xxx.value, logger)` called **before** raising — never after +- [ ] Edge cases covered: `None` inputs, empty strings, empty lists, wrong types, negative numbers +- [ ] No truthy guard `if not x:` for values where `0`, `""`, `False`, or `[]` could be valid — use `x is None` instead +- [ ] Consistent null-guard style across all validators in the changed files — no mixing of `if not x`, `if x is None`, `if x == None` + +### Error handling +- [ ] All methods calling the REST API are wrapped in `try/except Exception` that calls `handle_exception(e, logger)` or raises `SkyflowError` +- [ ] No silent error swallowing (`except: pass` or `except Exception: pass`) +- [ ] No no-op catch (`except Exception as e: raise e` with no added value — no logging, no wrapping) +- [ ] `SkyflowError` raised with the correct code for the failure type (`INVALID_INPUT` for validation, server codes for API failures) +- [ ] No bare `except:` — always `except Exception:` or a specific exception type + +### I/O and resource patterns +- [ ] `open()` always used as a context manager (`with open(...) as f:`) — never left open on exception +- [ ] Binary file reads use mode `"rb"` not `"r"` +- [ ] No module-level side effects that execute on import (network calls, file I/O, env reads at module load time) +- [ ] No blocking I/O or `time.sleep()` in a tight loop without a comment explaining the retry strategy + +### Python quality +- [ ] No mutable default arguments: `def f(x=[])`, `def f(x={})` — use `None` and initialise in the body +- [ ] No `from module import *` in non-`__init__.py` files +- [ ] No bare `except:` — always a specific exception type +- [ ] No `global` or `nonlocal` in controller or validator code +- [ ] `isinstance()` preferred over `type(x) == T` — covers subclasses +- [ ] `is None` / `is not None` preferred over `== None` / `!= None` +- [ ] f-strings preferred over `%` formatting and `.format()` for new code; flag inconsistency within the same file +- [ ] `Optional[T]` imported from `typing` — do not use `T | None` union syntax unless minimum Python version is ≥ 3.10 + +### Function size and complexity +- [ ] No function exceeds 50 lines — flag with actual line count +- [ ] No nesting deeper than 3 levels — suggest early returns or extracted helpers +- [ ] Long `if valid: ...entire body...` blocks replaced with inverted guard + early return +- [ ] No long `try` block spanning more than 10 lines — unclear which statement actually raised; split into one `try/except` per operation + +### State / side effects +- [ ] Controller instances are stateless per-call — no `self.xxx = ` inside method bodies; use local variables +- [ ] No mutable class-level variables shared across instances +- [ ] Cached state (e.g. `_cached_headers`) does not mix per-call and cross-call data + +--- + +## Step 4 — Exports and public API surface + +- [ ] New public types/classes exported from the appropriate `__init__.py` +- [ ] Internal helpers (prefixed `_`) not accidentally exported +- [ ] No circular imports introduced +- [ ] `skyflow/__init__.py` exports updated if new public surface was added + +--- + +## Step 5 — Tests + +- [ ] Test file exists at `tests/vault/controller/`, `tests/service_account/`, or `tests/utils/` +- [ ] Happy path covered +- [ ] One test per new error code / validation branch +- [ ] API error response handled +- [ ] Async / polling / retry logic tested if applicable +- [ ] No tests removed without explanation +- [ ] `.venv/bin/python -m pytest tests/ --import-mode=importlib -q` passes with no new failures +- [ ] **100% line coverage** on every new or changed file — flag any uncovered lines +- [ ] **Branch coverage ≥ 80%** on every new or changed file + +**Edge case coverage check — for every new/changed source file:** +Read the source and test file together. Flag any of these scenarios that exist in the code but have no test: +- `None` passed to public methods +- Empty string, empty list, zero, negative number inputs +- Wrong type passed where a specific type is expected +- Error paths inside `except` blocks that are never triggered +- Boundary conditions (exactly at a limit vs. one over) +- Credential fields missing, malformed, or using snake_case vs camelCase variants +- Same controller/client reused across two consecutive calls (state leakage) + +List each gap as: +``` +UNCOVERED EDGE CASE: : +``` + +--- + +## Step 6 — Sample code + +- [ ] Sample exists in `samples/vault_api/`, `samples/detect_api/`, or `samples/service_account/` +- [ ] Imports only from the public `skyflow` package — never from `skyflow.generated` +- [ ] Uses `XxxRequest` classes with keyword arguments +- [ ] Shows `try/except SkyflowError` handling with `http_code`, `message`, `details` +- [ ] Uses the fluent `Skyflow.builder()` pattern +- [ ] No hardcoded real credentials, tokens, or private keys — uses `''` strings +- [ ] Syntax-checks cleanly: `python -c "import ast; ast.parse(open('samples/...').read()); print('Syntax OK')"` + +--- + +## Step 7 — Documentation & logging + +- [ ] Docstring on new public-facing classes and methods where the behaviour is non-obvious +- [ ] No `print()` in SDK source — uses `log_info(message, logger)` for info and `log_error_log(message, logger)` for errors +- [ ] Sensitive values (tokens, credentials, private keys, PII) never appear in log messages or error strings +- [ ] Every public method entry and success path has a corresponding `log_info` call matching the `SkyflowMessages.Info` enum +- [ ] `CHANGELOG.md` updated if this is a user-visible change + +--- + +## Step 8 — Security spot-check + +- [ ] No credentials or tokens hardcoded or logged +- [ ] No new external HTTP calls outside the generated REST client +- [ ] No `eval()`, `exec()`, `pickle.loads()`, or `yaml.load()` without `Loader=yaml.SafeLoader` +- [ ] No `subprocess` calls with user-controlled input (shell injection risk) +- [ ] No `open()` with a user-supplied path that is not sanitised (path traversal risk) +- [ ] New dependencies in `requirements.txt` or `setup.py` are well-known, maintained packages — flag any unfamiliar ones +- [ ] No `verify=False` in `httpx`/`requests` calls (disables TLS verification) + +--- + +## Step 9 — Naming conventions (Python) + +| Identifier type | Required style | Example | +|---|---|---| +| Variable / parameter / method | `snake_case` | `vault_id`, `token_uri`, `get_records` | +| Constant / module-level value | `UPPER_SNAKE_CASE` | `SKY_META_DATA_HEADER`, `CTX_KEY_REGEX` | +| Class / Exception / Enum | `PascalCase` | `InsertRequest`, `SkyflowError`, `RedactionType` | +| Private method / attribute | `_snake_case` | `_validate_ctx`, `_cached_headers` | +| Source file | `snake_case.py` or `_snake_case.py` for internals | `_file_upload_request.py` | + +**Acronym rule — all-lowercase in snake_case identifiers:** +- `id` not `ID` (e.g. `skyflow_id`, not `skyflow_ID`) +- `uri` not `URI` (e.g. `token_uri`, not `token_URI`) +- `url` not `URL` (e.g. `callback_url`, not `callback_URL`) +- `api` not `API` (e.g. `api_key`, not `API_key`) +- Exception: class names use PascalCase title-casing (`SkyflowId`, `TokenUri`); standalone environment variable names follow OS convention (`SKYFLOW_ID`, `TOKEN_URI`) + +- [ ] No ALL-CAPS acronyms in public field, method, parameter, or variable names +- [ ] No `camelCase` fields or methods in the public API (external credential dicts are exempt) +- [ ] All classes, exceptions, and enums are `PascalCase` +- [ ] All module-level constants are `UPPER_SNAKE_CASE` +- [ ] No mixed conventions within the same class or module + +--- + +## Output format + +For each issue, output: + +``` +[SEVERITY] file:line — Short title + Problem: + Fix: +``` + +Severity levels: **BREAKING** | **CRITICAL** | **BUG** | **EDGE CASE** | **QUALITY** | **NITPICK** + +End with: + +**Summary table** +| Severity | Count | +|---|---| +| BREAKING | n | +| CRITICAL | n | +| BUG | n | +| EDGE CASE | n | +| QUALITY | n | +| NITPICK | n | +| Uncovered edge cases | n | + +**Verdict**: `APPROVE` | `APPROVE WITH FIXES` | `REQUEST CHANGES` +One sentence explaining the verdict. diff --git a/README.md b/README.md index b9d4ca86..6a980d3f 100644 --- a/README.md +++ b/README.md @@ -703,18 +703,65 @@ options = { Embed context values into a bearer token during generation so you can reference those values in your policies. This enables more flexible access controls, such as tracking end-user identity when making API calls using service accounts, and facilitates using signed data tokens during detokenization. -Generate bearer tokens containing context information using a service account with the context_id identifier. Context information is represented as a JWT claim in a Skyflow-generated bearer token. Tokens generated from such service accounts include a context_identifier claim, are valid for 60 minutes, and can be used to make API calls to the Data and Management APIs, depending on the service account's permissions. +Generate bearer tokens containing context information using a service account with the `context_id` identifier. Context information is represented as a JWT claim in a Skyflow-generated bearer token. Tokens generated from such service accounts include a `context_identifier` claim, are valid for 60 minutes, and can be used to make API calls to the Data and Management APIs, depending on the service account's permissions. + +The `ctx` parameter accepts either a **string** or a **dict**: + +**String context** — use when your policy references a single context value: + +```python +options = {'ctx': 'user_12345'} +token, _ = generate_bearer_token(filepath, options) +``` + +**Dict context** — use when your policy needs multiple context values for conditional data access. Each key in the dict maps to a Skyflow CEL policy variable under `request.context.*`: + +```python +options = { + 'ctx': { + 'role': 'admin', + 'department': 'finance', + 'user_id': 'user_12345', + } +} +token, _ = generate_bearer_token(filepath, options) +``` + +With the dict above, your Skyflow policies can reference `request.context.role`, `request.context.department`, and `request.context.user_id` to make conditional access decisions. + +Dict keys must contain only alphanumeric characters and underscores (`[a-zA-Z0-9_]`). Invalid keys will raise a `SkyflowError`. > [!TIP] -> See the full example in the samples directory: [token_generation_with_context_example.py](samples/service_account/token_generation_with_context_example.py) -> See [docs.skyflow.com](https://docs.skyflow.com) for more details on authentication, access control, and governance for Skyflow. +> See the full example in the samples directory: [token_generation_with_context_example.py](samples/service_account/token_generation_with_context_example.py) +> See Skyflow's [context-aware authorization](https://docs.skyflow.com) and [conditional data access](https://docs.skyflow.com) docs for policy variable syntax like `request.context.*`. #### Generate signed data tokens: `generate_signed_data_tokens(filepath, options)` Digitally sign data tokens with a service account's private key to add an extra layer of protection. Skyflow generates data tokens when sensitive data is inserted into the vault. Detokenize signed tokens only by providing the signed data token along with a bearer token generated from the service account's credentials. The service account must have the necessary permissions and context to successfully detokenize the signed data tokens. +The `ctx` parameter on signed data tokens also accepts either a **string** or a **dict**, using the same format as bearer tokens: + +```python +# String context +options = { + 'ctx': 'user_12345', + 'data_tokens': ['dataToken1', 'dataToken2'], + 'time_to_live': 90, +} + +# Dict context +options = { + 'ctx': { + 'role': 'analyst', + 'department': 'research', + }, + 'data_tokens': ['dataToken1', 'dataToken2'], + 'time_to_live': 90, +} +``` + > [!TIP] -> See the full example in the samples directory: [signed_token_generation_example.py](samples/service_account/signed_token_generation_example.py) +> See the full example in the samples directory: [signed_token_generation_example.py](samples/service_account/signed_token_generation_example.py) > See [docs.skyflow.com](https://docs.skyflow.com) for more details on authentication, access control, and governance for Skyflow. ## Logging diff --git a/samples/detect_api/deidentify_file_async.py b/samples/detect_api/deidentify_file_async.py new file mode 100644 index 00000000..579dab2e --- /dev/null +++ b/samples/detect_api/deidentify_file_async.py @@ -0,0 +1,124 @@ +from skyflow.error import SkyflowError +from skyflow import Env, Skyflow, LogLevel +from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions +from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, DateTransformation, Bleep, FileInput +from concurrent.futures import ThreadPoolExecutor + +""" + * Skyflow Deidentify File Example + * + * This sample demonstrates how to use all available options for deidentifying files + * using an asynchronous approach. + * Supported file types: images (jpg, png, etc.), pdf, audio (mp3, wav), documents, + * spreadsheets, presentations, structured text. +""" + +def perform_file_deidentification_async(): + try: + # Step 1: Configure Credentials + credentials = { + 'path': '/path/to/credentials.json' # Path to credentials file + } + + # Step 2: Configure Vault + vault_config = { + 'vault_id': '', # Replace with your vault ID + 'cluster_id': '', # Replace with your cluster ID + 'env': Env.PROD, # Deployment environment + 'credentials': credentials + } + + # Step 3: Configure & Initialize Skyflow Client + skyflow_client = ( + Skyflow.builder() + .add_vault_config(vault_config) + .set_log_level(LogLevel.INFO) # Use LogLevel.ERROR in production + .build() + ) + + # Step 4: Create File Object + file_path = '' # Replace with your file path + + deidentify_request = DeidentifyFileRequest( + file=FileInput(file_path=file_path), # File to de-identify + # entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect + allow_regex_list=[''], # Optional: Patterns to allow + restrict_regex_list=[''], # Optional: Patterns to restrict + + # Token format configuration + token_format=TokenFormat( + vault_token=[DetectEntities.SSN], # Use vault tokens for these entities + ), + + # Optional: Custom transformations + # transformations=Transformations( + # shift_dates=DateTransformation( + # max_days=30, + # min_days=10, + # entities=[DetectEntities.DOB] + # ) + # ), + + # Output configuration + output_directory='', # Where to save processed file + wait_time=15, # Max wait time in seconds (max 64) + + # Image-specific options + output_processed_image=True, # Include processed image in output + output_ocr_text=True, # Include OCR text in response + masking_method=MaskingMethod.BLACKBOX, # Masking method for images + + # PDF-specific options + pixel_density=15, # Pixel density for PDF processing + max_resolution=2000, # Max resolution for PDF + + # Audio-specific options + output_processed_audio=True, # Include processed audio + output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type + + # Audio bleep configuration + + # bleep=Bleep( + # gain=5, # Loudness in dB + # frequency=1000, # Pitch in Hz + # start_padding=0.1, # Padding at start (seconds) + # stop_padding=0.2 # Padding at end (seconds) + # ) + ) + + # Create a thread pool executor + executor = ThreadPoolExecutor(max_workers=1) + + future = executor.submit( + lambda: skyflow_client.detect().deidentify_file(deidentify_request) + ) + + def handle_response(future): + exception = future.exception() + if exception is not None: + if isinstance(exception, SkyflowError): + # Handle Skyflow-specific errors + print('\nSkyflow Error:', { + 'http_code': exception.http_code, + 'grpc_code': exception.grpc_code, + 'http_status': exception.http_status, + 'message': exception.message, + 'details': exception.details + }) + else: + # Handle unexpected errors + print('Unexpected Error:', exception) + return + + # Handle Successful Response + result = future.result() + print("\nDeidentify File Response:", result) + + future.add_done_callback(handle_response) + + executor.shutdown(wait=True) + + except Exception as error: + # Handle unexpected errors + print('Unexpected Error:', error) + diff --git a/samples/service_account/signed_token_generation_example.py b/samples/service_account/signed_token_generation_example.py index 32140ada..6ede1746 100644 --- a/samples/service_account/signed_token_generation_example.py +++ b/samples/service_account/signed_token_generation_example.py @@ -18,42 +18,54 @@ credentials_string = json.dumps(skyflow_credentials) -options = { - 'ctx': 'CONTEXT_ID', - 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], - 'time_to_live': 90, # in seconds -} +# Approach 1: Signed data tokens with string context +def get_signed_tokens_with_string_context(): + options = { + 'ctx': 'user_12345', + 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'time_to_live': 90, # in seconds + } + try: + data_token, signed_data_token = generate_signed_data_tokens(file_path, options) + return data_token, signed_data_token + except Exception as e: + print(f'Error: {str(e)}') -def get_signed_bearer_token_from_file_path(): - # Generate signed bearer token from credentials file path. - global bearer_token +# Approach 2: Signed data tokens with JSON object context (dict) +# Each key maps to a Skyflow CEL policy variable under request.context.* +# For example: request.context.role == "analyst" and request.context.department == "research" +def get_signed_tokens_with_object_context(): + options = { + 'ctx': { + 'role': 'analyst', + 'department': 'research', + 'user_id': 'user_67890', + }, + 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'time_to_live': 90, + } try: - if not is_expired(bearer_token): - return bearer_token - else: - data_token, signed_data_token = generate_signed_data_tokens(file_path, options) - return data_token, signed_data_token - + data_token, signed_data_token = generate_signed_data_tokens(file_path, options) + return data_token, signed_data_token except Exception as e: - print(f'Error generating token from file path: {str(e)}') + print(f'Error: {str(e)}') -def get_signed_bearer_token_from_credentials_string(): - # Generate signed bearer token from credentials string. - global bearer_token - +# Approach 3: Signed data tokens from credentials string +def get_signed_tokens_from_credentials_string(): + options = { + 'ctx': 'user_12345', + 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'time_to_live': 90, + } try: - if not is_expired(bearer_token): - return bearer_token - else: - data_token, signed_data_token = generate_signed_data_tokens_from_creds(credentials_string, options) - return data_token, signed_data_token - + data_token, signed_data_token = generate_signed_data_tokens_from_creds(credentials_string, options) + return data_token, signed_data_token except Exception as e: - print(f'Error generating token from credentials string: {str(e)}') - + print(f'Error: {str(e)}') -print(get_signed_bearer_token_from_file_path()) -print(get_signed_bearer_token_from_credentials_string()) +print("String context:", get_signed_tokens_with_string_context()) +print("Object context:", get_signed_tokens_with_object_context()) +print("Creds string:", get_signed_tokens_from_credentials_string()) diff --git a/samples/service_account/token_generation_with_context_example.py b/samples/service_account/token_generation_with_context_example.py index a43a072a..03aa9f06 100644 --- a/samples/service_account/token_generation_with_context_example.py +++ b/samples/service_account/token_generation_with_context_example.py @@ -18,11 +18,13 @@ } credentials_string = json.dumps(skyflow_credentials) -options = {'ctx': ''} -def get_bearer_token_with_context_from_file_path(): - # Generate bearer token with context from credentials file path. +# Approach 1: Bearer token with string context +# Use a simple string identifier when your policy references a single context value. +# In your Skyflow policy, reference this as: request.context +def get_bearer_token_with_string_context(): global bearer_token + options = {'ctx': 'user_12345'} try: if not is_expired(bearer_token): @@ -31,14 +33,40 @@ def get_bearer_token_with_context_from_file_path(): token, _ = generate_bearer_token(file_path, options) bearer_token = token return bearer_token + except Exception as e: + print(f'Error generating token: {str(e)}') + + +# Approach 2: Bearer token with JSON object context (dict) +# Use a dict when your policy needs multiple context values for conditional data access. +# Each key maps to a Skyflow CEL policy variable under request.context.* +# For example: request.context.role == "admin" and request.context.department == "finance" +def get_bearer_token_with_object_context(): + global bearer_token + options = { + 'ctx': { + 'role': 'admin', + 'department': 'finance', + 'user_id': 'user_12345', + } + } + try: + if not is_expired(bearer_token): + return bearer_token + else: + token, _ = generate_bearer_token(file_path, options) + bearer_token = token + return bearer_token except Exception as e: - print(f'Error generating token from file path: {str(e)}') + print(f'Error generating token: {str(e)}') +# Approach 3: Bearer token with string context from credentials string def get_bearer_token_with_context_from_credentials_string(): - # Generate bearer token with context from credentials string. global bearer_token + options = {'ctx': 'user_12345'} + try: if not is_expired(bearer_token): return bearer_token @@ -47,9 +75,9 @@ def get_bearer_token_with_context_from_credentials_string(): bearer_token = token return bearer_token except Exception as e: - print(f"Error generating token from credentials string: {str(e)}") - + print(f"Error generating token: {str(e)}") -print(get_bearer_token_with_context_from_file_path()) -print(get_bearer_token_with_context_from_credentials_string()) \ No newline at end of file +print("String context:", get_bearer_token_with_string_context()) +print("Object context:", get_bearer_token_with_object_context()) +print("Creds string:", get_bearer_token_with_context_from_credentials_string()) diff --git a/setup.py b/setup.py index aa38463d..8f76225e 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0.dev0+f7d26df' +current_version = '2.0.2' setup( name='skyflow', diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 0bfde34e..7b405f17 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -62,9 +62,6 @@ def set_log_level(self, log_level): def get_log_level(self): return self.__builder._Builder__log_level - def update_log_level(self, log_level): - self.__builder._Builder__set_log_level(log_level) - def vault(self, vault_id = None) -> Vault: vault_config = self.__builder.get_vault_config(vault_id) return vault_config.get(OptionField.VAULT_CONTROLLER) diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index fca43935..a63b5d68 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -1,5 +1,4 @@ from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_error class SkyflowError(Exception): def __init__(self, diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index f4c98faf..f75b4411 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -1,5 +1,6 @@ import json import datetime +import re import time import jwt from urllib.parse import urlparse @@ -10,11 +11,56 @@ from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError from skyflow.utils import is_valid_url +from skyflow.utils.constants import CTX_KEY_REGEX invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value +_CTX_KEY_PATTERN = re.compile(CTX_KEY_REGEX) + +_SNAKE_TO_CAMEL_CRED_MAP = { + 'private_key': CredentialField.PRIVATE_KEY, + 'client_id': CredentialField.CLIENT_ID, + 'key_id': CredentialField.KEY_ID, + 'token_uri': CredentialField.TOKEN_URI, + 'client_name': CredentialField.CLIENT_NAME, +} + + +def _normalize_credentials(credentials): + return {_SNAKE_TO_CAMEL_CRED_MAP.get(k, k): v for k, v in credentials.items()} + + +def _validate_and_resolve_ctx(ctx): + """Validate ctx value and return resolved value for JWT claims. + Returns None if ctx should be omitted, the value if valid, or raises SkyflowError if invalid. + """ + if ctx is None: + return None + if isinstance(ctx, str): + if ctx.strip() == '': + return None + return ctx + if isinstance(ctx, dict): + if len(ctx) == 0: + return None + for key in ctx: + if not isinstance(key, str) or not _CTX_KEY_PATTERN.match(key): + raise SkyflowError( + SkyflowMessages.Error.INVALID_CTX_MAP_KEY.value.format(key), + invalid_input_error_code + ) + return ctx + if isinstance(ctx, (bool, int, float)): + return ctx + raise SkyflowError( + SkyflowMessages.Error.INVALID_CTX_TYPE.value, + invalid_input_error_code + ) + def is_expired(token, logger = None): + if token is None: + return True if len(token) == 0: log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True @@ -34,20 +80,17 @@ def is_expired(token, logger = None): return True def generate_bearer_token(credentials_file_path, options = None, logger = None): + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) try: - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) - credentials_file =open(credentials_file_path, 'r') + credentials_file = open(credentials_file_path, 'r') except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger) - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) - - finally: - credentials_file.close() + with credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) result = get_service_account_token(credentials, options, logger) return result @@ -62,24 +105,25 @@ def generate_bearer_token_from_creds(credentials, options = None, logger = None) return result def get_service_account_token(credentials, options, logger): + credentials = _normalize_credentials(credentials) try: private_key = credentials[CredentialField.PRIVATE_KEY] - except: - log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger) + except KeyError: + log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code) try: client_id = credentials[CredentialField.CLIENT_ID] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code) try: key_id = credentials[CredentialField.KEY_ID] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code) try: token_uri = credentials[CredentialField.TOKEN_URI] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) @@ -87,9 +131,12 @@ def get_service_account_token(credentials, options, logger): log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) - if options and "token_uri" in options: - token_uri = options["token_uri"] - + if options and CredentialField.TOKEN_URI_OPTION in options: + token_uri = options[CredentialField.TOKEN_URI_OPTION] + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger) base_url = get_base_url(token_uri) auth_client = AuthClient(base_url) @@ -101,7 +148,7 @@ def get_service_account_token(credentials, options, logger): try: response = auth_api.authentication_service_get_auth_token(assertion = signed_token, - grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", + grant_type=JWT.GRANT_TYPE_JWT_BEARER, scope=formatted_scope) log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) except UnauthorizedError: @@ -120,8 +167,10 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): JwtField.SUB: client_id, JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if options and JwtField.CTX in options: - payload[JwtField.CTX] = options.get(JwtField.CTX) + if options and OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options.get(OptionField.CTX)) + if resolved_ctx is not None: + payload[JwtField.CTX] = resolved_ctx try: return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: @@ -130,18 +179,20 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): def get_signed_tokens(credentials_obj, options): + credentials_obj = _normalize_credentials(credentials_obj) expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60) prefix = JWT.SIGNED_TOKEN_PREFIX - token_uri = credentials_obj.get("tokenURI") + token_uri = credentials_obj.get(CredentialField.TOKEN_URI) if not isinstance(token_uri, str) or not is_valid_url(token_uri): log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) - - if options and "token_uri" in options: - token_uri = options["token_uri"] + resolved_ctx = None + if OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options[OptionField.CTX]) + results = [] if options and options.get(OptionField.DATA_TOKENS): for token in options[OptionField.DATA_TOKENS]: claims = { @@ -152,37 +203,30 @@ def get_signed_tokens(credentials_obj, options): JwtField.TOK: token, JwtField.IAT: int(time.time()), } - - if JwtField.CTX in options: - claims[JwtField.CTX] = options[JwtField.CTX] - + if resolved_ctx is not None: + claims[JwtField.CTX] = resolved_ctx private_key = credentials_obj.get(CredentialField.PRIVATE_KEY) - try: + try: signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) - - response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) + results.append(get_signed_data_token_response_object(prefix + signed_jwt, token)) log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) - return response_object + return results def generate_signed_data_tokens(credentials_file_path, options): log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value) try: - credentials_file =open(credentials_file_path, 'r') + credentials_file = open(credentials_file_path, 'r') except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), - invalid_input_error_code) - - finally: - credentials_file.close() - + with credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), + invalid_input_error_code) return get_signed_tokens(credentials, options) def generate_signed_data_tokens_from_creds(credentials, options): diff --git a/skyflow/utils/_helpers.py b/skyflow/utils/_helpers.py index 090f3a2b..12ff1257 100644 --- a/skyflow/utils/_helpers.py +++ b/skyflow/utils/_helpers.py @@ -13,6 +13,6 @@ def format_scope(scopes): def is_valid_url(url): try: result = urlparse(url) - return all([result.scheme in ("http", "https"), result.netloc]) + return all([result.scheme == "https", result.netloc]) except Exception: return False \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 989aa298..79f4c7ec 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -61,6 +61,8 @@ class Error(Enum): EMPTY_CONTEXT = f"{error_prefix} Initialization failed. Invalid context provided. Specify context as type Context." INVALID_CONTEXT_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid context for {{}} with id {{}}. Specify a valid context." INVALID_CONTEXT = f"{error_prefix} Initialization failed. Invalid context. Specify a valid context." + INVALID_CTX_TYPE = f"{error_prefix} Initialization failed. Invalid ctx type. Specify ctx as a string or a dict." + INVALID_CTX_MAP_KEY = f"{error_prefix} Initialization failed. Invalid key '{{}}' in ctx dict. Keys must contain only alphanumeric characters and underscores." INVALID_LOG_LEVEL = f"{error_prefix} Initialization failed. Invalid log level. Specify a valid log level." EMPTY_LOG_LEVEL = f"{error_prefix} Initialization failed. Specify a valid log level." @@ -88,7 +90,7 @@ class Error(Enum): INVALID_TABLE_NAME_IN_INSERT = f"{error_prefix} Validation error. Invalid table name in insert request. Specify a valid table name." INVALID_TYPE_OF_DATA_IN_INSERT = f"{error_prefix} Validation error. Invalid type of data in insert request. Specify data as a object array." EMPTY_DATA_IN_INSERT = f"{error_prefix} Validation error. Data array cannot be empty. Specify data in insert request." - INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. 'upsert' key cannot be empty in options. At least one object of table and column is required." + INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. Invalid 'upsert' value in options. Specify 'upsert' as a non-empty string containing the column name." INVALID_HOMOGENEOUS_TYPE = f"{error_prefix} Validation error. Invalid type of homogeneous. Specify homogeneous as a string." INVALID_TOKEN_MODE_TYPE = f"{error_prefix} Validation error. Invalid type of token mode. Specify token mode as a TokenMode enum." INVALID_RETURN_TOKENS_TYPE = f"{error_prefix} Validation error. Invalid type of return tokens. Specify return tokens as a boolean." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 5d83cbcc..31a2f24f 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -32,26 +32,18 @@ invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None): - dotenv.load_dotenv() + if config_level_creds is not None: + return config_level_creds + if common_skyflow_creds is not None: + return common_skyflow_creds dotenv_path = dotenv.find_dotenv(usecwd=True) if dotenv_path: load_dotenv(dotenv_path) env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") - if config_level_creds: - return config_level_creds - if common_skyflow_creds: - return common_skyflow_creds if env_skyflow_credentials: - env_skyflow_credentials.strip() - try: - env_creds = env_skyflow_credentials.replace('\n', '\\n') - return { - CredentialField.CREDENTIALS_STRING: env_creds - } - except json.JSONDecodeError: - raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + env_creds = env_skyflow_credentials.strip().replace('\n', '\\n') + return {'credentials_string': env_creds} + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: if len(api_key) != ApiKey.LENGTH: @@ -80,9 +72,9 @@ def parse_path_params(url, path_params): return result -def to_lowercase_keys(dict): +def to_lowercase_keys(data): result = {} - for key, value in dict.items(): + for key, value in data.items(): result[key.lower()] = value return result @@ -136,7 +128,7 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) except SkyflowError: raise - except Exception as e: + except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) if files and header and content_type == ContentType.FORMDATA.value: @@ -194,7 +186,6 @@ def get_data_from_content_type(data, content_type): if content_type == ContentType.URLENCODED.value: converted_data = http_build_query(data) elif content_type == ContentType.FORMDATA.value: - print("Hello") converted_data = None files = {} for key, value in data.items(): @@ -239,8 +230,11 @@ def build_xml(d, tag='item'): return ''.join(xml_parts) +_CACHED_METRICS: dict = {} + def get_metrics(): - sdk_name_version = SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION + if _CACHED_METRICS: + return _CACHED_METRICS try: sdk_client_device_model = platform.node() @@ -257,13 +251,13 @@ def get_metrics(): except Exception: sdk_runtime_details = "" - details_dic = { - SdkMetricsKey.SDK_NAME_VERSION: sdk_name_version, + _CACHED_METRICS.update({ + SdkMetricsKey.SDK_NAME_VERSION: SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION, SdkMetricsKey.SDK_CLIENT_DEVICE_MODEL: sdk_client_device_model, SdkMetricsKey.SDK_CLIENT_OS_DETAILS: sdk_client_os_details, SdkMetricsKey.SDK_RUNTIME_DETAILS: SdkPrefix.PYTHON_RUNTIME + sdk_runtime_details, - } - return details_dic + }) + return _CACHED_METRICS def parse_insert_response(api_response, continue_on_error): # Retrieve the headers and data from the API response @@ -492,7 +486,7 @@ def handle_json_error(err, data, request_id, logger): description = data.dict() else: description = json.loads(data) - status_code = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + status_code = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, HttpStatusCode.INTERNAL_SERVER_ERROR) http_status = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) grpc_code = description.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) details = description.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS, []) diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index bd8e63ec..bc50f210 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0.dev0+f7d26df' +SDK_VERSION = '2.0.2' diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 401bffe5..ec6c7bab 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -1,6 +1,7 @@ OPTIONAL_TOKEN='token' PROTOCOL='https' SKY_META_DATA_HEADER='sky-metadata' +CTX_KEY_REGEX=r'^[a-zA-Z0-9_]+$' class SKYFLOW: SKYFLOW_ID = 'skyflowId' @@ -123,6 +124,8 @@ class CredentialField: CLIENT_ID = 'clientID' KEY_ID = 'keyID' TOKEN_URI = 'tokenURI' + TOKEN_URI_OPTION = 'token_uri' + CLIENT_NAME = 'clientName' CREDENTIALS_STRING = 'credentials_string' API_KEY = 'api_key' TOKEN = 'token' diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 08d4905b..25968353 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -42,32 +42,32 @@ def validate_required_field(logger, config, field_name, expected_type, empty_err if field_name not in config or not isinstance(field_value, expected_type): if field_name == ConfigField.VAULT_ID: - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value, logger) if field_name == ConfigField.CLUSTER_ID: - logger.error(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value, logger) if field_name == OptionField.CONNECTION_ID: - logger.error(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value, logger) if field_name == OptionField.CONNECTION_URL: - logger.error(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value) + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value, logger) raise SkyflowError(invalid_error, invalid_input_error_code) if isinstance(field_value, str) and not field_value.strip(): if field_name == ConfigField.VAULT_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value, logger) if field_name == ConfigField.CLUSTER_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value, logger) if field_name == OptionField.CONNECTION_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value, logger) if field_name == OptionField.CONNECTION_URL: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value, logger) if field_name == CredentialField.PATH: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value, logger) if field_name == CredentialField.CREDENTIALS_STRING: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value, logger) if field_name == CredentialField.TOKEN: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value, logger) if field_name == CredentialField.API_KEY: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value, logger) raise SkyflowError(empty_error, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: @@ -90,6 +90,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) elif len(key_present) > 1: error_message = ( @@ -97,6 +98,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) if CredentialField.ROLES in credentials: @@ -142,6 +144,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value ) if is_expired(credentials.get(CredentialField.TOKEN), logger): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value, logger) raise SkyflowError( SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value, @@ -160,8 +163,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) - if "token_uri" in credentials: - token_uri = credentials.get("token_uri") + if CredentialField.TOKEN_URI_OPTION in credentials: + token_uri = credentials.get(CredentialField.TOKEN_URI_OPTION) if ( token_uri is None or not isinstance(token_uri, str) @@ -171,10 +174,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non def validate_log_level(logger, log_level): if not isinstance(log_level, LogLevel): - raise SkyflowError( SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) - - if log_level is None: - raise SkyflowError(SkyflowMessages.Error.EMPTY_LOG_LEVEL.value, invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) def validate_keys(logger, config, config_keys): for key in config.keys(): @@ -208,7 +208,7 @@ def validate_vault_config(logger, config): # Validate env (optional, should be one of LogLevel values) if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.ENV_IS_REQUIRED.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) return True @@ -255,8 +255,8 @@ def validate_connection_config(logger, config): SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" in config: - validate_credentials(logger, config.get("credentials"), "connection", connection_id) + if ConfigField.CREDENTIALS in config: + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id) return True @@ -429,7 +429,7 @@ def validate_insert_request(logger, request): log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) - if not len(request.values): + if not request.values: log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code) @@ -439,7 +439,7 @@ def validate_insert_request(logger, request): log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger) if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value(RequestOperation.INSERT), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) if request.homogeneous is not None and not isinstance(request.homogeneous, bool): @@ -471,7 +471,7 @@ def validate_insert_request(logger, request): logger=logger) if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value(RequestOperation.INSERT), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -503,21 +503,21 @@ def validate_delete_request(logger, request): raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code) def validate_query_request(logger, request): - if not request.query: - log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger = logger) - raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not isinstance(request.query, str): query_type = str(type(request.query)) raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code) + if not request.query: + log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger=logger) + raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) + if not request.query.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) if not request.query.upper().startswith(SqlCommand.SELECT): command = request.query - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) def validate_get_request(logger, request): redaction_type = request.redaction_type @@ -565,13 +565,13 @@ def validate_get_request(logger, request): invalid_input_error_code) if offset is not None and not isinstance(offset, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value(type(offset)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value.format(type(offset)), invalid_input_error_code) if limit is not None and not isinstance(limit, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value(type(limit)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value.format(type(limit)), invalid_input_error_code) if download_url is not None and not isinstance(download_url, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value(type(download_url)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value.format(type(download_url)), invalid_input_error_code) if column_name is not None and (not isinstance(column_name, str) or not column_name.strip()): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) @@ -603,33 +603,30 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code) def validate_update_request(logger, request): - skyflow_id = "" + if not isinstance(request.data, dict): + raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value.format(type(request.data)), invalid_input_error_code) + + if not len(request.data.items()): + raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} - try: - skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) - except Exception: + skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) + if skyflow_id is None: log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) - - if not skyflow_id.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger = logger) + elif not skyflow_id.strip(): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger=logger) if not isinstance(request.table, str): log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not isinstance(request.return_tokens, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code) - if not isinstance(request.data, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value(type(request.data)), invalid_input_error_code) - - if not len(request.data.items()): - raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) - if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value, invalid_input_error_code) @@ -667,9 +664,9 @@ def validate_detokenize_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code) if not isinstance(request.data, list): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.data)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - if not len(request.data): + if not request.data: log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format(RequestOperation.DETOKENIZE), logger = logger) log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.DETOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code) @@ -695,7 +692,7 @@ def validate_tokenize_request(logger, request): if not isinstance(parameters, list): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(parameters)), invalid_input_error_code) - if not len(parameters): + if not parameters: raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value, invalid_input_error_code) for i, param in enumerate(parameters): @@ -728,9 +725,7 @@ def validate_file_upload_request(logger, request): # Skyflow ID skyflow_id = getattr(request, FileUploadField.SKYFLOW_ID, None) - if skyflow_id is None: - raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code) - elif skyflow_id.strip() == "": + if skyflow_id is not None and skyflow_id.strip() == "": raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format(RequestOperation.FILE_UPLOAD), invalid_input_error_code) # Column Name @@ -797,7 +792,7 @@ def validate_invoke_connection_params(logger, query_params, path_params): except TypeError: raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code) -def validate_deidentify_text_request(self, request: DeidentifyTextRequest): +def validate_deidentify_text_request(logger, request: DeidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value, invalid_input_error_code) @@ -821,7 +816,7 @@ def validate_deidentify_text_request(self, request: DeidentifyTextRequest): if request.transformations is not None and not isinstance(request.transformations, Transformations): raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) -def validate_reidentify_text_request(self, request: ReidentifyTextRequest): +def validate_reidentify_text_request(logger, request: ReidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value, invalid_input_error_code) @@ -837,6 +832,6 @@ def validate_reidentify_text_request(self, request: ReidentifyTextRequest): if request.plain_text_entities is not None and not isinstance(request.plain_text_entities, list): raise SkyflowError(SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) -def validate_get_detect_run_request(self, request: GetDetectRunRequest): +def validate_get_detect_run_request(logger, request: GetDetectRunRequest): if not request.run_id or not isinstance(request.run_id, str) or not request.run_id.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_RUN_ID.value, invalid_input_error_code) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index c64e8c6a..8023646c 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -16,6 +16,9 @@ def __init__(self, config): self.__logger = None self.__is_config_updated = False self.__bearer_token = None + self.__credentials = None + self.__vault_url = None + self.__is_static_token = None def set_common_skyflow_credentials(self, credentials): self.__common_skyflow_credentials = credentials @@ -25,16 +28,27 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger = self.__logger) - token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), - self.__config.get(ConfigField.ENV), - self.__config.get(ConfigField.VAULT_ID), - logger = self.__logger) - self.initialize_api_client(vault_url, token) - - def initialize_api_client(self, vault_url, token): - self.__api_client = Skyflow(base_url=vault_url, token=token) + if self.__api_client is not None and not self.__is_config_updated: + if self.__is_static_token: + return + if self.__bearer_token is not None and not is_expired(self.__bearer_token): + return + + needs_reinit = self.__api_client is None or self.__is_config_updated + if needs_reinit: + self.__credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger=self.__logger) + self.__vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), + self.__config.get(ConfigField.ENV), + self.__config.get(ConfigField.VAULT_ID), + logger=self.__logger) + self.__is_static_token = CredentialField.TOKEN in self.__credentials or CredentialField.API_KEY in self.__credentials + bearer_token = self.get_bearer_token(self.__credentials) + if needs_reinit: + self.initialize_api_client(self.__vault_url, bearer_token) + + def initialize_api_client(self, vault_url, bearer_token): + token_provider = lambda: self.__bearer_token if self.__bearer_token is not None else bearer_token # noqa: E731 + self.__api_client = Skyflow(base_url=vault_url, token=token_provider) def get_records_api(self): return self.__api_client.records @@ -64,14 +78,13 @@ def get_bearer_token(self, credentials): OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES), OptionField.CTX: self.__config.get(OptionField.CTX) } - if "token_uri" in credentials and credentials.get("token_uri"): - options["token_uri"] = credentials.get("token_uri") + if CredentialField.TOKEN_URI_OPTION in credentials and credentials.get(CredentialField.TOKEN_URI_OPTION): + options[CredentialField.TOKEN_URI_OPTION] = credentials.get(CredentialField.TOKEN_URI_OPTION) - if self.__bearer_token is None or self.__is_config_updated: + if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token): if CredentialField.PATH in credentials: - path = credentials.get(CredentialField.PATH) self.__bearer_token, _ = generate_bearer_token( - path, + credentials.get(CredentialField.PATH), options, self.__logger ) @@ -87,10 +100,6 @@ def get_bearer_token(self, credentials): else: log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger) - if is_expired(self.__bearer_token): - self.__is_config_updated = True - raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - return self.__bearer_token def update_config(self, config): diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 76dbfaeb..2ce0c104 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,7 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest -from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader, OptionField, ConfigField from skyflow.utils import get_credentials @@ -16,11 +16,11 @@ def __init__(self, vault_client): def invoke(self, request: InvokeConnectionRequest): log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) config = self.__vault_client.get_config() - connection_url = config.get("connection_url") + connection_url = config.get(OptionField.CONNECTION_URL) invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) - credentials = get_credentials(config.get("credentials"), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) + credentials = get_credentials(config.get(ConfigField.CREDENTIALS), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) bearer_token = self.__vault_client.get_bearer_token(credentials) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index c6ef2fb1..f1abc20f 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -8,8 +8,8 @@ FileDataDeidentifyImage, Format, FileDataDeidentifyAudio, WordCharacterCount, DetectRunsResponse from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response -from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, - FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField) +from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, + FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField, Detect as DetectConstants) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -30,7 +30,7 @@ def __get_headers(self): } return headers - def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: + def __build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: deidentify_text_body = {} parsed_entity_types = request.entities @@ -43,7 +43,7 @@ def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[ return deidentify_text_body - def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: + def __build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: parsed_format = Format( redacted=request.redacted_entities, masked=request.masked_entities, @@ -57,13 +57,13 @@ def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[ def _get_file_extension(self, filename: str): return filename.split('.')[-1].lower() if '.' in filename else '' - def __poll_for_processed_file(self, run_id, max_wait_time=64): - max_wait_time = 64 if max_wait_time is None else max_wait_time + def __poll_for_processed_file(self, run_id, max_wait_time=None): + max_wait_time = DetectConstants.WAIT_TIME if max_wait_time is None else max_wait_time files_api = self.__vault_client.get_detect_file_api().with_raw_response current_wait_time = 1 # Start with 1 second try: while True: - response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data + response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()}).data status = response.status if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: @@ -80,7 +80,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: return response except Exception as e: - raise e + handle_exception(e, self.__vault_client.get_logger()) def __save_deidentify_file_response_output(self, response: DetectRunsResponse, output_directory: str, original_file_name: str, name_without_ext: str): if not response or not hasattr(response, DeidentifyField.OUTPUT) or not response.output or not output_directory: @@ -94,6 +94,7 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o base_original_filename = os.path.basename(original_file_name) base_name_without_ext = os.path.splitext(base_original_filename)[0] + real_output_dir = os.path.realpath(output_directory) for idx, output in enumerate(output_list): try: @@ -105,14 +106,25 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o continue decoded_data = base64.b64decode(processed_file) - + + # Sanitize extension from API response to prevent path traversal (CWE-22). + # Avoid os.path.basename here to keep basename mock-free in tests. + safe_ext = None + if processed_file_extension: + raw_ext = str(processed_file_extension).replace('\\', '/').split('/')[-1].lstrip('.') + safe_ext = ''.join(c for c in raw_ext if c.isalnum() or c in ('-', '_')) or 'bin' + if idx == 0 or processed_file_type == DeidentifyField.REDACTED_FILE: output_file_name = os.path.join(output_directory, deidentify_file_prefix + base_original_filename) - if processed_file_extension: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") + if safe_ext: + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext}") else: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") - + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext or 'bin'}") + + if not os.path.realpath(output_file_name).startswith(real_output_dir + os.sep): + log_error_log(SkyflowMessages.ErrorLogs.SAVING_DEIDENTIFY_FILE_FAILED.value, self.__vault_client.get_logger()) + continue + with open(output_file_name, 'wb') as f: f.write(decoded_data) except Exception as e: @@ -166,9 +178,9 @@ def output_to_dict_list(output): extension = first_output.get(DeidentifyField.EXTENSION, None) if base64_string is not None: - file_bytes = base64.b64decode(base64_string) - file_obj = io.BytesIO(file_bytes) - file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE + file_bytes = base64.b64decode(base64_string) + file_obj = io.BytesIO(file_bytes) + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE else: file_obj = None @@ -217,7 +229,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - deidentify_text_body = self.___build_deidentify_text_body(request) + deidentify_text_body = self.__build_deidentify_text_body(request) try: log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) @@ -229,7 +241,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo restrict_regex=deidentify_text_body[DeidentifyField.RESTRICT_REGEX], token_type=deidentify_text_body[DeidentifyField.TOKEN_TYPE], transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS], - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) deidentify_text_response = parse_deidentify_text_response(api_response) log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -245,7 +257,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - reidentify_text_body = self.___build_reidentify_text_body(request) + reidentify_text_body = self.__build_reidentify_text_body(request) try: log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) @@ -253,7 +265,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo vault_id=self.__vault_client.get_vault_id(), text=reidentify_text_body[DeidentifyField.TEXT], format=reidentify_text_body[DeidentifyField.FORMAT], - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) reidentify_text_response = parse_reidentify_text_response(api_response) log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -265,14 +277,16 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo def __get_file_from_request(self, request: DeidentifyFileRequest): file_input = request.file - - # Check for file + if hasattr(file_input, FileUploadField.FILE) and file_input.file is not None: return file_input.file - - # Check for file_path if file is not provided + if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None: - return open(file_input.file_path, 'rb') + with open(file_input.file_path, 'rb') as f: + content = f.read() + bio = io.BytesIO(content) + bio.name = file_input.file_path + return bio def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger()) @@ -297,12 +311,13 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio + bleep = request.bleep api_kwargs = { OptionField.VAULT_ID: self.__vault_client.get_vault_id(), DeidentifyField.FILE: req_file, @@ -313,11 +328,11 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION: getattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION, None), DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO, None), - DeidentifyField.BLEEP_GAIN: getattr(request, DeidentifyFileRequestField.BLEEP, None).gain if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_FREQUENCY: getattr(request, DeidentifyFileRequestField.BLEEP, None).frequency if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_START_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).start_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_STOP_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).stop_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.BLEEP_GAIN: bleep.gain if bleep is not None else None, + DeidentifyField.BLEEP_FREQUENCY: bleep.frequency if bleep is not None else None, + DeidentifyField.BLEEP_START_PADDING: bleep.start_padding if bleep is not None else None, + DeidentifyField.BLEEP_STOP_PADDING: bleep.stop_padding if bleep is not None else None, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension == FileExtension.PDF: @@ -332,7 +347,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyFileRequestField.MAX_RESOLUTION: getattr(request, DeidentifyFileRequestField.MAX_RESOLUTION, None), DeidentifyFileRequestField.PIXEL_DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: @@ -348,7 +363,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyFileRequestField.MASKING_METHOD: getattr(request, DeidentifyFileRequestField.MASKING_METHOD, None), DeidentifyFileRequestField.OUTPUT_OCR_TEXT: getattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT, None), DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE, None), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: @@ -361,7 +376,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: @@ -374,7 +389,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: @@ -387,7 +402,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.JSON, FileExtension.XML]: @@ -401,7 +416,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } else: @@ -415,7 +430,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) @@ -424,7 +439,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == DetectStatus.SUCCESS: + if request.output_directory and processed_response.status == DetectStatus.SUCCESS and file_name: name_without_ext, _ = os.path.splitext(file_name) self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) @@ -449,10 +464,10 @@ def get_detect_run(self, request: GetDetectRunRequest): response = files_api.get_run( run_id, vault_id=self.__vault_client.get_vault_id(), - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) if response.data.status == DetectStatus.IN_PROGRESS: - parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)) + parsed_response = DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 856a1961..7d51ee83 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -89,10 +89,7 @@ def __get_file_for_file_upload(self, request: FileUploadRequest) -> Optional[Fil return None def __get_headers(self): - headers = { - SKY_META_DATA_HEADER: json.dumps(get_metrics()) - } - return headers + return {SKY_META_DATA_HEADER: json.dumps(get_metrics())} def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.VALIDATE_INSERT_REQUEST.value, self.__vault_client.get_logger()) @@ -106,11 +103,11 @@ def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.INSERT_TRIGGERED.value, self.__vault_client.get_logger()) if request.continue_on_error: api_response = records_api.record_service_batch_operation(self.__vault_client.get_vault_id(), - records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options=self.__get_headers()) + records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) else: api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(), - request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers()) + request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) insert_response = parse_insert_response(api_response, request.continue_on_error) log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger()) @@ -138,7 +135,7 @@ def update(self, request: UpdateRequest): record=record, tokenization=request.return_tokens, byot=request.token_mode.value, - request_options = self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.UPDATE_SUCCESS.value, self.__vault_client.get_logger()) update_response = parse_update_record_response(api_response) @@ -159,7 +156,7 @@ def delete(self, request: DeleteRequest): self.__vault_client.get_vault_id(), request.table, skyflow_ids=request.ids, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DELETE_SUCCESS.value, self.__vault_client.get_logger()) delete_response = parse_delete_response(api_response) @@ -189,7 +186,7 @@ def get(self, request: GetRequest): download_url=request.download_url, column_name=request.column_name, column_values=request.column_values, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.GET_SUCCESS.value, self.__vault_client.get_logger()) get_response = parse_get_response(api_response) @@ -209,7 +206,7 @@ def query(self, request: QueryRequest): api_response = query_api.query_service_execute_query( self.__vault_client.get_vault_id(), query=request.query, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.QUERY_SUCCESS.value, self.__vault_client.get_logger()) query_response = parse_query_response(api_response) @@ -237,7 +234,7 @@ def detokenize(self, request: DetokenizeRequest): self.__vault_client.get_vault_id(), detokenization_parameters=tokens_list, continue_on_error = request.continue_on_error, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) detokenize_response = parse_detokenize_response(api_response) @@ -262,7 +259,7 @@ def tokenize(self, request: TokenizeRequest): api_response = tokens_api.record_service_tokenize( self.__vault_client.get_vault_id(), tokenization_parameters=records_list, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) tokenize_response = parse_tokenize_response(api_response) log_info(SkyflowMessages.Info.TOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) @@ -285,7 +282,7 @@ def upload_file(self, request: FileUploadRequest): file=self.__get_file_for_file_upload(request), skyflow_id=request.skyflow_id, return_file_metadata= False, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/data/_file_upload_request.py b/skyflow/vault/data/_file_upload_request.py index d1bd4a44..2f82c3e2 100644 --- a/skyflow/vault/data/_file_upload_request.py +++ b/skyflow/vault/data/_file_upload_request.py @@ -1,10 +1,10 @@ -from typing import BinaryIO +from typing import BinaryIO, Optional class FileUploadRequest: def __init__(self, table: str, - skyflow_id: str, column_name: str, + skyflow_id: Optional[str] = None, file_path: str= None, base64: str= None, file_object: BinaryIO= None, diff --git a/skyflow/vault/data/_get_response.py b/skyflow/vault/data/_get_response.py index cf1b0805..a1640254 100644 --- a/skyflow/vault/data/_get_response.py +++ b/skyflow/vault/data/_get_response.py @@ -1,6 +1,6 @@ class GetResponse: def __init__(self, data=None, errors = None): - self.data = data if data else [] + self.data = data if data is not None else [] self.errors = errors def __repr__(self): diff --git a/tests/client/test_skyflow.py b/tests/client/test_skyflow.py index 3e3681bb..c6589e87 100644 --- a/tests/client/test_skyflow.py +++ b/tests/client/test_skyflow.py @@ -319,9 +319,6 @@ def test_skyflow_add_and_update_log_level(self): self.assertEqual(LogLevel.INFO, builder._Builder__log_level) - skyflow_client.update_log_level(LogLevel.ERROR) - self.assertEqual(LogLevel.ERROR, builder._Builder__log_level) - @patch('skyflow.client.Skyflow.Builder.get_vault_config') def test_skyflow_vault_and_connection_method(self, mock_get_vault_config): diff --git a/tests/service_account/test__utils.py b/tests/service_account/test__utils.py index ca82527a..29c9e417 100644 --- a/tests/service_account/test__utils.py +++ b/tests/service_account/test__utils.py @@ -8,7 +8,11 @@ from skyflow.service_account import is_expired, generate_bearer_token, \ generate_bearer_token_from_creds from skyflow.utils import SkyflowMessages -from skyflow.service_account._utils import get_service_account_token, get_signed_jwt, generate_signed_data_tokens, get_signed_data_token_response_object, generate_signed_data_tokens_from_creds +from skyflow.service_account._utils import ( + get_service_account_token, get_signed_jwt, generate_signed_data_tokens, + get_signed_data_token_response_object, generate_signed_data_tokens_from_creds, + _validate_and_resolve_ctx, _normalize_credentials, get_signed_tokens, +) creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") with open(creds_path, 'r') as file: @@ -33,7 +37,29 @@ VALID_SERVICE_ACCOUNT_CREDS = credentials +# Snake-case version of the real credentials (keys remapped to snake_case) +SNAKE_CASE_CREDS = { + 'private_key': credentials['privateKey'], + 'client_id': credentials['clientID'], + 'key_id': credentials['keyID'], + 'token_uri': credentials['tokenURI'], +} + +SNAKE_CASE_CREDS_STRING = json.dumps({ + 'private_key': credentials['privateKey'], + 'client_id': credentials['clientID'], + 'key_id': credentials['keyID'], + 'token_uri': credentials['tokenURI'], +}) + + class TestServiceAccountUtils(unittest.TestCase): + + # ── is_expired ──────────────────────────────────────────────────────────── + + def test_is_expired_none_token(self): + self.assertTrue(is_expired(None)) + def test_is_expired_empty_token(self): self.assertTrue(is_expired("")) @@ -44,7 +70,7 @@ def test_is_expired_non_expired_token(self): def test_is_expired_expired_token(self): past_time = time.time() - 1000 - token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") + token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) @patch("skyflow.utils.logger._log_helpers.log_error_log") @@ -53,6 +79,8 @@ def test_is_expired_general_exception(self, mock_jwt_decode, mock_log_error): token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) + # ── generate_bearer_token ───────────────────────────────────────────────── + @patch("builtins.open", side_effect=FileNotFoundError) def test_generate_bearer_token_invalid_file_path(self, mock_open): with self.assertRaises(SkyflowError) as context: @@ -72,6 +100,8 @@ def test_generate_bearer_token_valid_file_path(self, mock_generate_bearer_token) generate_bearer_token(creds_path) mock_generate_bearer_token.assert_called_once() + # ── generate_bearer_token_from_creds ────────────────────────────────────── + @patch("skyflow.service_account._utils.get_service_account_token") def test_generate_bearer_token_from_creds_with_valid_json_string(self, mock_generate_bearer_token): generate_bearer_token_from_creds(VALID_CREDENTIALS_STRING) @@ -82,10 +112,11 @@ def test_generate_bearer_token_from_creds_invalid_json(self): generate_bearer_token_from_creds("invalid_json") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + # ── get_service_account_token ───────────────────────────────────────────── + def test_get_service_account_token_missing_private_key(self): - incomplete_credentials = {} with self.assertRaises(SkyflowError) as context: - get_service_account_token(incomplete_credentials, {}, None) + get_service_account_token({}, {}, None) self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) def test_get_service_account_token_missing_client_id_key(self): @@ -107,43 +138,42 @@ def test_get_service_account_token_with_valid_credentials(self): access_token, _ = get_service_account_token(VALID_SERVICE_ACCOUNT_CREDS, {}, None) self.assertTrue(access_token) + def test_get_service_account_token_with_snake_case_creds(self): + access_token, _ = get_service_account_token(SNAKE_CASE_CREDS, {}, None) + self.assertTrue(access_token) - @patch("jwt.encode", side_effect=Exception) - def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): + def test_get_service_account_token_missing_private_key_snake(self): + creds = { + 'client_id': 'id', + 'key_id': 'kid', + 'token_uri': 'https://example.com', + } with self.assertRaises(SkyflowError) as context: - get_signed_jwt({}, "client_id", "key_id", "token_uri", "private_key", None) - self.assertEqual(context.exception.message, SkyflowMessages.Error.JWT_INVALID_FORMAT.value) - - def test_get_signed_data_token_response_object(self): - token = "sample_token" - signed_token = "signed_sample_token" - response = get_signed_data_token_response_object(signed_token, token) - self.assertEqual(response[0], token) - self.assertEqual(response[1], signed_token) - - def test_generate_signed_data_tokens_from_file_path(self): - creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") - options = {"data_tokens": ["token1", "token2"], "ctx": 'ctx'} - result = generate_signed_data_tokens(creds_path, options) - self.assertEqual(len(result), 2) + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) - def test_generate_signed_data_tokens_from_invalid_file_path(self): - options = {"data_tokens": ["token1", "token2"]} + def test_get_service_account_token_invalid_token_uri(self): + creds = { + 'privateKey': 'key', + 'clientID': 'id', + 'keyID': 'kid', + 'tokenURI': 'not-a-url', + } with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens('credentials1.json', options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value) - - def test_generate_signed_data_tokens_from_creds(self): - options = {"data_tokens": ["token1", "token2"]} - result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) - self.assertEqual(len(result), 2) + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) - def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): - options = {"data_tokens": ["token1", "token2"]} - credentials_string = '{' + def test_get_service_account_token_invalid_token_uri_in_options(self): + creds = { + 'privateKey': 'key', + 'clientID': 'id', + 'keyID': 'kid', + 'tokenURI': 'https://valid-url.com', + } + options = {'token_uri': 'not-a-valid-url'} with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens_from_creds(credentials_string, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + get_service_account_token(creds, options, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) @patch("skyflow.service_account._utils.AuthClient") @patch("skyflow.service_account._utils.get_signed_jwt") @@ -157,8 +187,9 @@ def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_si options = {"role_ids": ["role1", "role2"]} mock_get_signed_jwt.return_value = "signed" mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value - mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {"access_token": "token", - "token_type": "bearer"}) + mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), { + "access_token": "token", "token_type": "bearer" + }) access_token, token_type = get_service_account_token(creds, options, None) self.assertEqual(access_token, "token") self.assertEqual(token_type, "bearer") @@ -200,6 +231,46 @@ def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, get_service_account_token(creds, {}, None) self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value) + # ── get_signed_jwt ──────────────────────────────────────────────────────── + + @patch("jwt.encode", side_effect=Exception) + def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): + with self.assertRaises(SkyflowError) as context: + get_signed_jwt({}, "client_id", "key_id", "token_uri", "private_key", None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.JWT_INVALID_FORMAT.value) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_valid_string_ctx(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": "valid_ctx"}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertEqual(payload["ctx"], "valid_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_valid_dict_ctx(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": {"role": "admin"}}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertEqual(payload["ctx"], {"role": "admin"}) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_empty_string_ctx_not_added(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": ""}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertNotIn("ctx", payload) + + # ── get_signed_data_token_response_object ───────────────────────────────── + + def test_get_signed_data_token_response_object(self): + token = "sample_token" + signed_token = "signed_sample_token" + response = get_signed_data_token_response_object(signed_token, token) + self.assertEqual(response[0], token) + self.assertEqual(response[1], signed_token) + + # ── get_signed_tokens ───────────────────────────────────────────────────── + @patch("jwt.encode", side_effect=Exception("jwt error")) def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode): creds = { @@ -210,6 +281,304 @@ def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode): } options = {"data_tokens": ["token1"]} with self.assertRaises(SkyflowError) as context: - from skyflow.service_account._utils import get_signed_tokens get_signed_tokens(creds, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) \ No newline at end of file + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) + + def test_get_signed_tokens_returns_list_one_per_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + def test_get_signed_tokens_returns_list_single_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + def test_get_signed_tokens_empty_data_tokens_returns_empty_list(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": []}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_string_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": "my_ctx"}) + call_args = mock_jwt_encode.call_args + claims = call_args[0][0] if call_args[0] else call_args.kwargs.get("args", [None])[0] + # jwt.encode(claims, key, algorithm=...) — first positional arg is claims + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], "my_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_dict_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + ctx_dict = {"role": "admin", "dept": "eng"} + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ctx_dict}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], ctx_dict) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_empty_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ""}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_none_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": None}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + def test_get_signed_tokens_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_missing_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_with_snake_case_creds(self): + result = get_signed_tokens(SNAKE_CASE_CREDS, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + # ── generate_signed_data_tokens (file path) ─────────────────────────────── + + def test_generate_signed_data_tokens_from_file_path(self): + options = {"data_tokens": ["token1", "token2"], "ctx": 'ctx'} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_invalid_file_path(self): + options = {"data_tokens": ["token1", "token2"]} + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens('credentials1.json', options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value) + + def test_generate_signed_data_tokens_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "department": "finance"}} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 1) + + # ── generate_signed_data_tokens_from_creds (string) ────────────────────── + + def test_generate_signed_data_tokens_from_creds(self): + options = {"data_tokens": ["token1", "token2"]} + result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): + options = {"data_tokens": ["token1", "token2"]} + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds('{', options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + + def test_generate_signed_data_tokens_from_creds_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "level": 3}} + result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) + self.assertEqual(len(result), 1) + + # ── snake_case end-to-end ───────────────────────────────────────────────── + + def test_generate_signed_data_tokens_with_snake_creds_file(self): + """generate_signed_data_tokens reads the file (camelCase) but the normalize fn is a no-op for camelCase.""" + options = {"data_tokens": ["token1", "token2"]} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_creds_snake(self): + result = generate_signed_data_tokens_from_creds(SNAKE_CASE_CREDS_STRING, options={"data_tokens": ["t1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + # ── _normalize_credentials ──────────────────────────────────────────────── + + def test_normalize_credentials_snake_case(self): + snake = { + 'private_key': 'pk', + 'client_id': 'cid', + 'key_id': 'kid', + 'token_uri': 'https://uri', + 'client_name': 'name', + } + result = _normalize_credentials(snake) + self.assertEqual(result['privateKey'], 'pk') + self.assertEqual(result['clientID'], 'cid') + self.assertEqual(result['keyID'], 'kid') + self.assertEqual(result['tokenURI'], 'https://uri') + self.assertEqual(result['clientName'], 'name') + self.assertNotIn('private_key', result) + self.assertNotIn('client_id', result) + self.assertNotIn('key_id', result) + self.assertNotIn('token_uri', result) + self.assertNotIn('client_name', result) + + def test_normalize_credentials_camel_case_unchanged(self): + camel = { + 'privateKey': 'pk', + 'clientID': 'cid', + 'keyID': 'kid', + 'tokenURI': 'https://uri', + } + result = _normalize_credentials(camel) + self.assertEqual(result, camel) + + def test_normalize_credentials_mixed_keys(self): + mixed = { + 'private_key': 'pk', + 'clientID': 'cid', + 'key_id': 'kid', + 'tokenURI': 'https://uri', + } + result = _normalize_credentials(mixed) + self.assertEqual(result['privateKey'], 'pk') + self.assertEqual(result['clientID'], 'cid') + self.assertEqual(result['keyID'], 'kid') + self.assertEqual(result['tokenURI'], 'https://uri') + self.assertNotIn('private_key', result) + self.assertNotIn('key_id', result) + + def test_normalize_credentials_unknown_key_passes_through(self): + creds = {'unknown_field': 'value', 'anotherField': 'val2'} + result = _normalize_credentials(creds) + self.assertEqual(result['unknown_field'], 'value') + self.assertEqual(result['anotherField'], 'val2') + + def test_normalize_credentials_empty_dict(self): + self.assertEqual(_normalize_credentials({}), {}) + + # ── _validate_and_resolve_ctx ───────────────────────────────────────────── + + def test_validate_and_resolve_ctx_none(self): + self.assertIsNone(_validate_and_resolve_ctx(None)) + + def test_validate_and_resolve_ctx_empty_string(self): + self.assertIsNone(_validate_and_resolve_ctx('')) + self.assertIsNone(_validate_and_resolve_ctx(' ')) + + def test_validate_and_resolve_ctx_valid_string(self): + self.assertEqual(_validate_and_resolve_ctx('user_12345'), 'user_12345') + + def test_validate_and_resolve_ctx_empty_dict(self): + self.assertIsNone(_validate_and_resolve_ctx({})) + + def test_validate_and_resolve_ctx_valid_dict(self): + ctx = {"role": "admin", "department": "finance"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_alphanumeric_keys(self): + ctx = {"role_1": "admin", "dept2": "finance", "ABC_123": "value"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_hyphen(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"valid_key": "value", "invalid-key": "value"}) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_space(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"invalid key": "value"}) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_dot(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"invalid.key": "value"}) + + def test_validate_and_resolve_ctx_valid_type_int(self): + self.assertEqual(_validate_and_resolve_ctx(42), 42) + + def test_validate_and_resolve_ctx_valid_type_float(self): + self.assertEqual(_validate_and_resolve_ctx(3.14), 3.14) + + def test_validate_and_resolve_ctx_valid_type_bool_true(self): + self.assertEqual(_validate_and_resolve_ctx(True), True) + + def test_validate_and_resolve_ctx_valid_type_bool_false(self): + self.assertEqual(_validate_and_resolve_ctx(False), False) + + def test_validate_and_resolve_ctx_invalid_type_list(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx(["a", "b"]) + + def test_validate_and_resolve_ctx_dict_with_mixed_value_types(self): + ctx = {"role": "admin", "level": 3, "active": True, "timestamp": "2025-12-25T10:30:00Z"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_nested_objects(self): + ctx = {"role": "admin", "metadata": {"level": 2, "tags": ["a", "b"]}} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + # ── additional coverage gaps ────────────────────────────────────────────── + + @patch("skyflow.service_account._utils.jwt.decode", side_effect=jwt.ExpiredSignatureError) + def test_is_expired_expired_signature_error(self, mock_decode): + token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") + self.assertTrue(is_expired(token)) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_token_uri_option_override(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + override_uri = "https://override-url.com" + options = {"token_uri": override_uri} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), { + "access_token": "token", "token_type": "bearer" + }) + get_service_account_token(creds, options, None) + mock_get_signed_jwt.assert_called_once() + call_args = mock_get_signed_jwt.call_args + self.assertEqual(call_args[0][3], override_uri) + + @patch("json.load", side_effect=json.JSONDecodeError("bad json", "", 0)) + def test_generate_signed_data_tokens_from_file_invalid_json(self, mock_load): + invalid_path = os.path.join(os.path.dirname(__file__), "invalid_creds.json") + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(invalid_path, {"data_tokens": ["t1"]}) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.FILE_INVALID_JSON.value.format(invalid_path), + ) diff --git a/tests/utils/test__helpers.py b/tests/utils/test__helpers.py index 6758b62e..6016c798 100644 --- a/tests/utils/test__helpers.py +++ b/tests/utils/test__helpers.py @@ -39,9 +39,10 @@ def test_format_scope_special_characters(self): def test_is_valid_url_valid(self): self.assertTrue(is_valid_url("https://example.com")) - self.assertTrue(is_valid_url("http://example.com/path")) + self.assertTrue(is_valid_url("https://example.com/path")) def test_is_valid_url_invalid(self): + self.assertFalse(is_valid_url("http://example.com")) self.assertFalse(is_valid_url("ftp://example.com")) self.assertFalse(is_valid_url("example.com")) self.assertFalse(is_valid_url("invalid-url")) diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index b0466498..a15128fe 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -1,7 +1,6 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock, PropertyMock import os -from unittest.mock import MagicMock from urllib.parse import quote import tempfile, json from requests import PreparedRequest @@ -15,7 +14,7 @@ parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ handle_exception, validate_api_key, encode_column_values, parse_deidentify_text_response, \ parse_reidentify_text_response, convert_detected_entity_to_entity_info -from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error +from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error, r_urlencode from skyflow.utils.enums import EnvUrls, Env, ContentType from skyflow.vault.connection import InvokeConnectionResponse from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse @@ -36,6 +35,13 @@ def test_get_credentials_env_variable(self): credentials_string = credentials.get('credentials_string') self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n')) + @patch("skyflow.utils._utils.dotenv.find_dotenv", return_value=None) + @patch.dict(os.environ, {}, clear=True) + def test_get_credentials_no_credentials_raises(self, mock_find_dotenv): + with self.assertRaises(SkyflowError) as context: + get_credentials(config_level_creds=None, common_skyflow_creds=None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) + def test_get_credentials_with_config_level_creds(self): test_creds = {"authToken": "test_token"} creds = get_credentials(config_level_creds=test_creds) @@ -140,6 +146,20 @@ def test_to_lowercase_keys(self): expected_output = {"key1": "value1", "key2": "value2"} self.assertEqual(to_lowercase_keys(input_dict), expected_output) + def test_r_urlencode_with_list_input(self): + pairs = {} + r_urlencode([], pairs, ["a", "b"]) + self.assertIn("[0]", pairs) + self.assertIn("[1]", pairs) + self.assertEqual(pairs["[0]"], "a") + self.assertEqual(pairs["[1]"], "b") + + def test_r_urlencode_with_tuple_input(self): + pairs = {} + r_urlencode([], pairs, ("x", "y")) + self.assertIn("[0]", pairs) + self.assertEqual(pairs["[0]"], "x") + def test_get_metrics(self): metrics = get_metrics() self.assertIn('sdk_name_version', metrics) @@ -147,6 +167,33 @@ def test_get_metrics(self): self.assertIn('sdk_client_os_details', metrics) self.assertIn('sdk_runtime_details', metrics) + def test_get_metrics_platform_node_exception(self): + import skyflow.utils._utils as utils_module + utils_module._CACHED_METRICS.clear() + with patch("skyflow.utils._utils.platform") as mock_platform: + mock_platform.node.side_effect = OSError("no node") + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_device_model"], "") + utils_module._CACHED_METRICS.clear() + + def test_get_metrics_sys_attribute_exception(self): + import skyflow.utils._utils as utils_module + utils_module._CACHED_METRICS.clear() + + class _RaisingSys: + @property + def platform(self): + raise RuntimeError("no platform") + @property + def version(self): + raise RuntimeError("no version") + + with patch("skyflow.utils._utils.sys", _RaisingSys()): + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_os_details"], "") + self.assertIn("sdk_runtime_details", metrics) + utils_module._CACHED_METRICS.clear() + def test_construct_invoke_connection_request_valid(self): mock_connection_request = Mock() @@ -244,6 +291,16 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): self.assertIsInstance(result, PreparedRequest) + def test_parse_insert_response_with_tokens_continue_on_error(self): + api_response = Mock() + api_response.headers = {"x-request-id": "req-1"} + api_response.data = Mock(responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1", "tokens": {"col1": "tok1"}}]}}, + ]) + result = parse_insert_response(api_response, continue_on_error=True) + self.assertEqual(result.inserted_fields[0]["col1"], "tok1") + self.assertEqual(result.inserted_fields[0]["skyflow_id"], "id1") + def test_parse_insert_response(self): api_response = Mock() api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} @@ -423,6 +480,31 @@ def test_parse_invoke_connection_response_http_error_with_json_error_message(sel self.assertEqual(context.exception.message, "Not Found") self.assertEqual(context.exception.request_id, "1234") + @patch("requests.Response") + def test_parse_invoke_connection_response_with_error_from_client_header(self, mock_response): + from requests.models import HTTPError + mock_response.status_code = 400 + mock_response.content = json.dumps({ + "error": { + "message": "Client error", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": 3, + "details": None, + } + }).encode("utf-8") + mock_response.headers = { + "x-request-id": "rid-1", + "error-from-client": "true", + } + mock_response.raise_for_status.side_effect = HTTPError("400") + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + err = context.exception + self.assertEqual(err.message, "Client error") + self.assertIsNotNone(err.details) + self.assertTrue(any(d.get("error_from_client") is True for d in err.details)) + @patch("requests.Response") def test_parse_invoke_connection_response_http_error_without_json_error_message(self, mock_response): mock_response.status_code = 500 @@ -1585,9 +1667,10 @@ def test_generate_signed_data_tokens_options_override_token_uri(self): with patch("jwt.encode") as mock_jwt_encode: mock_jwt_encode.return_value = "signed" result = generate_signed_data_tokens(tmp.name, options) - self.assertIsInstance(result, tuple) - self.assertEqual(result[0], "token1") - self.assertEqual(result[1], "signed_token_signed") + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") def test_generate_signed_data_tokens_from_creds_options_override_token_uri(self): creds = { @@ -1601,6 +1684,7 @@ def test_generate_signed_data_tokens_from_creds_options_override_token_uri(self) with patch("jwt.encode") as mock_jwt_encode: mock_jwt_encode.return_value = "signed" result = generate_signed_data_tokens_from_creds(creds_str, options) - self.assertIsInstance(result, tuple) - self.assertEqual(result[0], "token1") - self.assertEqual(result[1], "signed_token_signed") + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 6fa31e67..75826128 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -15,11 +15,19 @@ } CREDENTIALS_WITH_API_KEY = {"api_key": "dummy_api_key"} +CREDENTIALS_WITH_TOKEN = {"token": "dummy_static_token"} +CREDENTIALS_WITH_PATH = {"path": "/some/path/credentials.json"} +CREDENTIALS_WITH_STRING = {"credentials_string": '{"clientID": "x"}'} + class TestVaultClient(unittest.TestCase): def setUp(self): self.vault_client = VaultClient(CONFIG) + # ------------------------------------------------------------------ # + # Basic setters / getters # + # ------------------------------------------------------------------ # + def test_set_common_skyflow_credentials(self): credentials = {"api_key": "dummy_api_key"} self.vault_client.set_common_skyflow_credentials(credentials) @@ -31,173 +39,289 @@ def test_set_logger(self): self.assertEqual(self.vault_client.get_log_level(), "INFO") self.assertEqual(self.vault_client.get_logger(), mock_logger) + def test_get_vault_id(self): + self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + + def test_get_config(self): + self.assertEqual(self.vault_client.get_config(), CONFIG) + + def test_get_common_skyflow_credentials(self): + credentials = {"api_key": "dummy_api_key"} + self.vault_client.set_common_skyflow_credentials(credentials) + self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + + def test_get_log_level(self): + self.vault_client.set_logger("DEBUG", MagicMock()) + self.assertEqual(self.vault_client.get_log_level(), "DEBUG") + + def test_get_logger(self): + mock_logger = MagicMock() + self.vault_client.set_logger("INFO", mock_logger) + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + # ------------------------------------------------------------------ # + # initialize_client_configuration — first call (slow path) # + # ------------------------------------------------------------------ # + @patch("skyflow.vault.client.client.get_credentials") @patch("skyflow.vault.client.client.get_vault_url") @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") - def test_initialize_client_configuration(self, mock_init_api_client, mock_get_vault_url, mock_get_credentials): - mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY) + def test_initialize_client_configuration_first_call( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY mock_get_vault_url.return_value = "https://test-vault-url.com" self.vault_client.initialize_client_configuration() - mock_get_credentials.assert_called_once_with(CONFIG["credentials"], None, logger=None) - mock_get_vault_url.assert_called_once_with(CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None) + mock_get_credentials.assert_called_once_with( + CONFIG["credentials"], None, logger=None + ) + mock_get_vault_url.assert_called_once_with( + CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None + ) mock_init_api_client.assert_called_once() - @patch("skyflow.vault.client.client.Skyflow") - def test_initialize_api_client(self, mock_api_client): - self.vault_client.initialize_api_client("https://test-vault-url.com", "dummy_token") - mock_api_client.assert_called_once_with(base_url="https://test-vault-url.com", token="dummy_token") + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (static token) # + # ------------------------------------------------------------------ # - def test_get_records_api(self): - self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.records = MagicMock() - records_api = self.vault_client.get_records_api() - self.assertIsNotNone(records_api) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_api_key( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with api_key, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + # Side-effect simulates initialize_api_client actually setting __api_client + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) - def test_get_tokens_api(self): + self.vault_client.initialize_client_configuration() # first call — slow path + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() # second call — fast path + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_static_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with a static token, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_TOKEN + mock_get_vault_url.return_value = "https://test-vault-url.com" + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (service account) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_valid_sa_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, mock_is_expired + ): + """Service account with a still-valid token skips get_bearer_token entirely.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Seed the cached bearer token as if first call already ran self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.tokens = MagicMock() - tokens_api = self.vault_client.get_tokens_api() - self.assertIsNotNone(tokens_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "cached_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_query_api(self): + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — token expiry (no client reinit) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_sa_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_expired_token_no_reinit( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, + mock_is_expired, mock_generate_bearer_token + ): + """Expired service account token is regenerated in-place; httpx client is NOT recreated.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Client already initialized — simulate warm state with an expired token self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.query = MagicMock() - query_api = self.vault_client.get_query_api() - self.assertIsNotNone(query_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "expired_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_vault_id(self): - self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + self.vault_client.initialize_client_configuration() - @patch("skyflow.vault.client.client.generate_bearer_token") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token, - mock_generate_bearer_token_from_creds): - token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) - self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"]) - - def test_update_config(self): - new_config = {"credentials": "new_credentials"} - self.vault_client.update_config(new_config) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) - self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") + # Token was regenerated + mock_generate_bearer_token.assert_called_once() + self.assertEqual( + self.vault_client._VaultClient__bearer_token, "new_sa_token" + ) + # httpx client was NOT recreated + mock_init_api_client.assert_not_called() - def test_get_config(self): - self.assertEqual(self.vault_client.get_config(), CONFIG) + # ------------------------------------------------------------------ # + # initialize_client_configuration — config update forces reinit # + # ------------------------------------------------------------------ # - def test_get_common_skyflow_credentials(self): - credentials = {"api_key": "dummy_api_key"} - self.vault_client.set_common_skyflow_credentials(credentials) - self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_reinit_after_update_config( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """update_config() marks the client stale; next call must recreate it.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" - def test_get_log_level(self): - log_level = "DEBUG" - self.vault_client.set_logger(log_level, MagicMock()) - self.assertEqual(self.vault_client.get_log_level(), log_level) + # Simulate already-initialized client + self.vault_client._VaultClient__api_client = MagicMock() + self.vault_client._VaultClient__is_static_token = True - def test_get_logger(self): - mock_logger = MagicMock() - self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) + self.vault_client.update_config({"cluster_id": "new_cluster"}) + self.vault_client.initialize_client_configuration() - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_expired_token_raises_error(self, mock_generate_bearer_token, mock_is_expired): - """Test that expired token raises SkyflowError.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.return_value = ("expired_token", None) - mock_is_expired.return_value = True + mock_get_credentials.assert_called_once() + mock_get_vault_url.assert_called_once() + mock_init_api_client.assert_called_once() - with self.assertRaises(SkyflowError) as context: - self.vault_client.get_bearer_token(credentials) + # ------------------------------------------------------------------ # + # initialize_api_client — lambda token provider # + # ------------------------------------------------------------------ # - self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) - self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_passes_callable_token(self, mock_skyflow): + """initialize_api_client must pass a callable (lambda) as token, not a string.""" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - def test_get_bearer_token_expired_token_from_creds_string_raises_error(self, mock_generate_bearer_token_from_creds, mock_is_expired): - """Test that expired token from credentials string raises SkyflowError.""" - credentials = {"credentials_string": '{"key": "value"}'} - mock_generate_bearer_token_from_creds.return_value = ("expired_token", None) - mock_is_expired.return_value = True + args, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["base_url"], "https://test-vault-url.com") + self.assertTrue(callable(kwargs["token"]), "token must be a callable (lambda)") - with self.assertRaises(SkyflowError) as context: - self.vault_client.get_bearer_token(credentials) + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_returns_cached_bearer_token(self, mock_skyflow): + """Lambda returns __bearer_token when it is set (interceptor behaviour).""" + self.vault_client._VaultClient__bearer_token = "refreshed_token" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") - self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) - self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "refreshed_token") - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_reuses_valid_token(self, mock_generate_bearer_token, mock_is_expired): - """Test that valid bearer token is reused.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.return_value = ("valid_token", None) - mock_is_expired.return_value = False - - token1 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token1, "valid_token") - mock_generate_bearer_token.assert_called_once() + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_falls_back_to_initial_token(self, mock_skyflow): + """Lambda falls back to the initial token when __bearer_token is None.""" + self.vault_client._VaultClient__bearer_token = None + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "initial_token") + + # ------------------------------------------------------------------ # + # get_bearer_token # + # ------------------------------------------------------------------ # + + def test_get_bearer_token_with_api_key(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) + self.assertEqual(result, "dummy_api_key") + + def test_get_bearer_token_with_static_token(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_TOKEN) + self.assertEqual(result, "dummy_static_token") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("sa_token", None)) + def test_get_bearer_token_generates_from_path_on_first_call(self, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token") + self.assertEqual(self.vault_client._VaultClient__bearer_token, "sa_token") + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds", return_value=("sa_token_str", None)) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_generates_from_credentials_string(self, mock_log, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_STRING) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token_str") - token2 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token2, "valid_token") - mock_generate_bearer_token.assert_called_once() + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_regenerates_on_expiry(self, mock_log, mock_is_expired, mock_generate): + """Expired token is regenerated silently — no exception raised.""" + self.vault_client._VaultClient__bearer_token = "expired_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "new_token") - @patch("skyflow.vault.client.client.is_expired") @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_regenerates_after_config_update(self, mock_generate_bearer_token, mock_is_expired): - """Test that bearer token is regenerated after config update.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.side_effect = [("first_token", None), ("second_token", None)] - mock_is_expired.return_value = False + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_reuses_valid_cached_token(self, mock_log, mock_is_expired, mock_generate): + """Valid cached token is reused without calling generate_bearer_token.""" + self.vault_client._VaultClient__bearer_token = "valid_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_not_called() + self.assertEqual(result, "valid_token") + + # ------------------------------------------------------------------ # + # update_config # + # ------------------------------------------------------------------ # + + def test_update_config_sets_flag(self): + self.vault_client.update_config({"credentials": "new_credentials"}) + self.assertTrue(self.vault_client._VaultClient__is_config_updated) + self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") + + # ------------------------------------------------------------------ # + # API accessor stubs # + # ------------------------------------------------------------------ # + + def test_get_records_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_records_api()) - token1 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token1, "first_token") + def test_get_tokens_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_tokens_api()) - self.vault_client.update_config({"new_key": "new_value"}) + def test_get_query_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_query_api()) - token2 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token2, "second_token") - self.assertEqual(mock_generate_bearer_token.call_count, 2) - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_credentials_string(self, mock_log_info, mock_generate_bearer_token_from_creds, mock_is_expired): - """Test get_bearer_token with credentials_string.""" - credentials = {"credentials_string": '{"clientID": "test", "clientName": "test"}'} - mock_generate_bearer_token_from_creds.return_value = ("token_from_creds", None) - mock_is_expired.return_value = False - - token = self.vault_client.get_bearer_token(credentials) - - self.assertEqual(token, "token_from_creds") - mock_generate_bearer_token_from_creds.assert_called_once() - mock_log_info.assert_called_with( - SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, - None - ) - def test_get_bearer_token_with_token(self): - credentials = {"token": "dummy_token"} - token = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token, "dummy_token") - - def test_get_bearer_token_with_token_uri_in_credentials(self): - credentials = { - "path": "dummy_path", - "token_uri": "https://valid-url.com" - } - with patch("skyflow.vault.client.client.generate_bearer_token") as mock_generate_bearer_token, \ - patch("skyflow.vault.client.client.is_expired", return_value=False): - mock_generate_bearer_token.return_value = ("bearer_token", "bearer") - token = self.vault_client.get_bearer_token(credentials) - mock_generate_bearer_token.assert_called_once() - args, kwargs = mock_generate_bearer_token.call_args - self.assertIn("token_uri", args[1]) - self.assertEqual(args[1]["token_uri"], "https://valid-url.com") - self.assertEqual(token, "bearer_token") +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index c2f9a861..b86087f5 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch, MagicMock import base64 import os +import tempfile from skyflow.error import SkyflowError from skyflow.generated.rest import WordCharacterCount from skyflow.utils import SkyflowMessages @@ -513,16 +514,12 @@ def test_get_detect_run_in_progress_status(self, mock_validate): self.vault_client.get_detect_file_api.return_value = files_api - # Execute - with patch.object(self.detect, "_Detect__parse_deidentify_file_response") as mock_parse: - result = self.detect.get_detect_run(req) + # Execute — IN_PROGRESS is returned directly without going through the parser + result = self.detect.get_detect_run(req) - # Verify IN_PROGRESS handling - mock_parse.assert_called_once() - args = mock_parse.call_args[0][0] - self.assertIsInstance(args, DeidentifyFileResponse) - self.assertEqual(args.status, 'IN_PROGRESS') - self.assertEqual(args.run_id, run_id) + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, 'IN_PROGRESS') + self.assertEqual(result.run_id, run_id) def test_get_transformations_with_shift_dates(self): @@ -711,3 +708,98 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) + + def test_poll_for_processed_file_exception(self): + files_api = Mock() + files_api.with_raw_response = files_api + files_api.get_run.side_effect = Exception("poll error") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect._Detect__poll_for_processed_file("runid", max_wait_time=5) + + def test_save_output_directory_not_exists(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=False): + self.detect._Detect__save_deidentify_file_response_output( + response, "/nonexistent_dir", "file.txt", "file" + ) + + def test_save_output_second_non_redacted_item(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output1 = Mock() + output1.processedFile = base64.b64encode(b"data1").decode() + output1.processedFileType = "redacted_file" + output1.processedFileExtension = "txt" + output2 = Mock() + output2.processedFile = base64.b64encode(b"data2").decode() + output2.processedFileType = "entities" + output2.processedFileExtension = "json" + response = Mock() + response.output = [output1, output2] + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "original.txt", "original" + ) + + def test_save_output_path_traversal_blocked(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + call_count = [0] + + def fake_realpath(p): + call_count[0] += 1 + if call_count[0] == 1: + return "/safe_dir" + return "/outside/path" + + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=True), \ + patch("skyflow.vault.controller._detect.os.path.realpath", side_effect=fake_realpath): + self.detect._Detect__save_deidentify_file_response_output( + response, "/safe_dir", "file.txt", "file" + ) + + def test_save_output_write_exception(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.base64.b64decode", + side_effect=Exception("decode error")), \ + self.assertRaises(Exception): + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "file.txt", "file" + ) + + @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") + @patch("skyflow.vault.controller._detect.base64") + def test_deidentify_file_api_error_inside_try(self, mock_base64, mock_validate): + file_content = b"test content" + file_obj = Mock() + file_obj.read.return_value = file_content + file_obj.name = "test.txt" + mock_base64.b64encode.return_value.decode.return_value = "encoded" + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) + req.entities = [] + req.token_format = None + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = None + req.wait_time = None + files_api = Mock() + files_api.with_raw_response = files_api + files_api.deidentify_text.side_effect = Exception("API error inside try") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect.deidentify_file(req) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 4e1a0dda..993cd72a 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -722,6 +722,26 @@ def test_upload_file_with_missing_file_source(self, mock_validate): self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_without_skyflow_id_successful(self, mock_validate): + """Test upload_file succeeds when skyflow_id is None (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/test.txt", + ) + mocked_open = mock_open_func(read_data=b"test file content") + mock_api_response = Mock() + mock_api_response.data = Mock(skyflow_id="generated-id-123") + records_api = self.vault_client.get_records_api.return_value + records_api.with_raw_response.upload_file_v_2.return_value = mock_api_response + with patch('builtins.open', mocked_open): + result = self.vault.upload_file(request) + mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + self.assertIsNone(request.skyflow_id) + self.assertEqual(result.skyflow_id, "generated-id-123") + self.assertIsNone(result.errors) + class TestFileUploadValidation(unittest.TestCase): def setUp(self): self.logger = Mock() @@ -874,3 +894,38 @@ def test_validate_missing_file_source(self): with self.assertRaises(SkyflowError) as error: validate_file_upload_request(self.logger, request) self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + def test_validate_none_skyflow_id_is_allowed(self): + """Test that skyflow_id=None passes validation (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + base64="dGVzdCBmaWxlIGNvbnRlbnQ=", + file_name="test.txt" + ) + self.assertIsNone(request.skyflow_id) + validate_file_upload_request(self.logger, request) + + @patch('os.path.exists') + @patch('os.path.isfile') + def test_validate_file_path_without_skyflow_id(self, mock_isfile, mock_exists): + """Test validation succeeds with file_path and no skyflow_id.""" + mock_exists.return_value = True + mock_isfile.return_value = True + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/file.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_object_without_skyflow_id(self): + """Test validation succeeds with file_object and no skyflow_id.""" + mock_file = Mock() + mock_file.seek = Mock() + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_object=mock_file + ) + validate_file_upload_request(self.logger, request)