diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 614e60a..00c53b8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,19 +2,52 @@ name: Code CI on: push: - branches: [ main, dev ] + # Fire on every branch so the `gate` job below can decide whether to run. + # GitHub cannot filter on commit message at the event level, so we trigger + # broadly and gate execution: main/dev always run, other branches run only + # when the head commit message contains "[force ci]". + branches: [ '**' ] tags: [ 'v*' ] pull_request: types: [ opened, reopened, ready_for_review ] workflow_dispatch: jobs: - # Required gate: a lint violation here FAILS the CI run. This job has no - # `needs` and nothing `needs` it, so it runs in parallel with install/test/ - # build — its failure turns the run red but does NOT stop the other jobs from - # running (GitHub Actions does not cancel sibling jobs on failure). + # Single source of truth for "should the core CI jobs run?". Reproduces the + # previous trigger rules (PR / manual dispatch / main / dev run; tags skip) + # and adds a manual escape hatch for custom branches via "[force ci]". + gate: + runs-on: ubuntu-latest + outputs: + run_ci: ${{ steps.decide.outputs.run_ci }} + steps: + - name: Decide whether to run CI + id: decide + env: + # Evaluated by GitHub (not the shell) so the raw commit message is + # never interpolated into bash — avoids quoting/injection issues. + FORCE_CI: ${{ contains(github.event.head_commit.message, '[force ci]') }} + run: | + if [ "${{ github.event_name }}" != "push" ]; then + run_ci=true # PRs and manual dispatch always run + elif [[ "${{ github.ref }}" == refs/tags/* ]]; then + run_ci=false # tag pushes are handled by the release workflow + elif [ "${{ github.ref }}" = "refs/heads/main" ] || [ "${{ github.ref }}" = "refs/heads/dev" ]; then + run_ci=true # default branches always run + elif [ "${FORCE_CI}" = "true" ]; then + run_ci=true # custom branch opted in via [force ci] + else + run_ci=false # custom branch without [force ci] → skip + fi + echo "run_ci=${run_ci}" >> "$GITHUB_OUTPUT" + echo "Decision: run_ci=${run_ci} (event=${{ github.event_name }}, ref=${{ github.ref }}, force_ci=${FORCE_CI})" + # Required gate: a lint violation here FAILS the CI run. Nothing `needs` this + # job, so it runs in parallel with install/test/build — its failure turns the + # run red but does NOT stop the other jobs from running (GitHub Actions does + # not cancel sibling jobs on failure). code-quality: - if: ${{ github.event_name != 'push' || !startsWith(github.ref, 'refs/tags/') }} + needs: gate + if: ${{ needs.gate.outputs.run_ci == 'true' }} runs-on: ubuntu-latest steps: - name: Checkout repository @@ -99,7 +132,8 @@ jobs: # Single-version install on 3.11 — always runs so `test` has a gate to # depend on regardless of branch. install: - if: ${{ github.event_name != 'push' || !startsWith(github.ref, 'refs/tags/') }} + needs: gate + if: ${{ needs.gate.outputs.run_ci == 'true' }} runs-on: ubuntu-latest name: install (Python 3.11) steps: @@ -160,10 +194,11 @@ jobs: python -c "import weightslab; print(f'weightslab imported successfully on Python ${{ matrix.python-version }}')" test: - if: ${{ github.event_name != 'push' || !startsWith(github.ref, 'refs/tags/') }} + if: ${{ needs.gate.outputs.run_ci == 'true' }} runs-on: ubuntu-latest - # Depends on the fast 3.11 gate only — the matrix runs independently in parallel. - needs: install + # Depends on the fast 3.11 install gate; `gate` is also a direct need so this + # job can read run_ci (the matrix and main-only jobs run independently). + needs: [ gate, install ] steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/AGENTS.md b/AGENTS.md index 92be371..50773f7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,209 +1,268 @@ -# AGENTS.md — coding-agent & contributor onboarding +# WeightsLab — agent context for users & debugging -This file is the shared context for **AI coding agents** (Claude Code, etc.) and -**human contributors** working on WeightsLab. It captures the cross-repo -architecture and the conventions that aren't obvious from any single file, so a -newcomer (or an agent with no prior memory of the project) can orient quickly -and debug confidently. +This file is a **portable context for AI coding agents** (Claude Code, etc.) and +the humans driving them. Its job is to let you — or an agent helping you — +**install, configure, run, and debug WeightsLab and Weights Studio** without +having to reverse-engineer the system first. -> File/line references drift as the code evolves — treat them as starting -> points and verify against the current source before relying on them. +It deliberately covers only the two shipped repositories: ---- +- **weightslab** — the Python backend / core (training instrumentation, data + ledger, gRPC service, the shared proto). +- **weights_studio** — the browser frontend (the studio UI that inspects and + edits a *running* experiment). -## 1. Working with AI coding agents here - -WeightsLab is developed with coding agents in the loop. A few things make that -productive and safe: - -- **This repo is one of three that ship together** (see §2). Many changes are - *cross-repo* — a proto edit touches the Python backend *and* the TypeScript - frontend. An agent that edits only one side leaves the build broken. Always - ask "does this change cross the proto boundary?" -- **Keep changes additive and verify.** The default expectation is that - existing usecases keep working. Run the relevant test suites (§7) before - declaring done; prefer adding a new branch/flag over changing a shared path. -- **Agent memory ≠ this file.** Agents may keep private, point-in-time "memory" - notes outside the repo. This file is the *committed, reviewed* distillation of - that knowledge — the part we want every contributor and every future agent to - start from. When you learn something load-bearing and non-obvious, add it - here (in a PR) rather than leaving it only in private memory. -- **Onboarding flow:** read §2–§6 top to bottom, then skim §7–§8 before your - first change. For a feature, find the closest existing example in - `weightslab/weightslab/examples/` and mirror its structure. +> File/line references drift as the code evolves — treat them as starting points +> and verify against the current source before relying on them. Environment +> variable names and defaults are the most stable thing here; when in doubt the +> authoritative reference is `weightslab/docs/configuration.rst`. --- -## 2. The workspace: three repos, side by side +## 0. How to load this guide into Claude Code -These must be checked out as **sibling directories** — codegen and proxy -configs reach across by relative path. +So an agent actually *has* this context when you ask it for help: -| Repo | Role | Stack | -|---|---|---| -| **weightslab** (this repo) | Backend / core: training instrumentation, data ledger, gRPC service, the shared proto | Python | -| **weights_studio** | Frontend: the studio UI that inspects/edits a running experiment | TypeScript + Vite | -| **weightslab_kitchen** | Private examples / reference material | mixed | +- **Working inside a checkout of the repo** (`git clone`): this guide is + committed as `AGENTS.md`; the repo keeps a gitignored `CLAUDE.md` copy of it at + the root so Claude Code auto-loads it every session. Nothing to do. (Claude + Code also loads `~/.claude/CLAUDE.md` global memory and any parent-dir + `CLAUDE.md`.) +- **You only ran `pip install weightslab`** (no checkout — the package lives in + `site-packages`): absolute `@import` paths are fragile because the path + changes per venv/OS. The robust pattern is a small **skill** that locates the + installed file at runtime. Create `~/.claude/skills/weightslab/SKILL.md`: -Layout assumption: `…/Codes/weightslab`, `…/Codes/weights_studio`, -`…/Codes/weightslab_kitchen`. + ```yaml + --- + name: weightslab + description: Load the WeightsLab debugging & configuration guide when helping with weightslab or weights_studio problems (connection, TLS, env vars, training hangs, rendering). + --- + !`python -c "import weightslab, os; print(open(os.path.join(os.path.dirname(weightslab.__file__), 'AGENTS.md')).read())"` ---- + Use the guide above to diagnose the user's weightslab / weights_studio issue. + ``` -## 3. Runtime integration (how backend ↔ frontend connect) - -One proto is the single source of truth: -`weightslab/weightslab/proto/experiment_service.proto` (`service -ExperimentService`, ~20 RPCs). - -- **Backend** implements it in `weightslab/trainer/services/experiment_service.py` - (servicer), delegating to `data_service.py`, `model_service.py`, - `agent_service.py`. The servicer and the in-process training loop run in the - **same process, different threads**, coordinated by the locks in - `weightslab/components/global_monitoring.py`. -- **Frontend** consumes a generated TS client - (`weights_studio/src/experiment_service.client.ts`) produced by - `npm run generate-proto:data`, which runs `protoc` against the sibling - weightslab proto. -- **Wire path:** browser → `:8080` Envoy (grpc-web ↔ grpc transcoding) → Python - gRPC servicer → training loop. The browser can't speak raw gRPC, so Envoy - translates. - -**Editing the proto is a three-step, cross-repo operation — do all three or the -build breaks:** -1. Edit `experiment_service.proto`. -2. Regenerate Python stubs **from the repo root** (import style matters): - `python -m grpc_tools.protoc --proto_path=. --python_out=. --grpc_python_out=. weightslab/proto/experiment_service.proto` -3. Regenerate the TS client in weights_studio: `npm run generate-proto:data`. + Then run `/weightslab` (or let Claude auto-invoke it). This requires the guide + to be **shipped as package data** inside the installed package (see §7); the + copy at the repo root is for contributors working in a checkout. +- **Quick-and-dirty:** copy this file to `~/.claude/WEIGHTSLAB.md` and add + `@~/.claude/WEIGHTSLAB.md` to your `~/.claude/CLAUDE.md`. --- -## 4. Backend module map (`weightslab/weightslab/`) - -Public API is re-exported from `__init__.py` (← `src.py`); users do -`import weightslab as wl`. - -- **`src.py`** — the public verbs: `watch_or_edit`, `serve`, `keep_serving`, - `save_signals`, `tag_samples`, `query_*`, decorators (`eval_fn`, - `pointcloud_thumbnail`, `pointcloud_boxes`), etc. -- **`trainer/`** — orchestration. `services/experiment_service.py` is the gRPC - servicer; `services/{model,data,agent}_service.py` are per-domain delegates; - `services/data_image_utils.py` handles preview/mask encoding. -- **`components/`** — cross-cutting runtime: `global_monitoring.py` - (`guard_training_context` / `guard_testing_context`, pause controller, global - rlock), `checkpoint_manager.py`, `evaluation_controller.py`, `tracking.py`. -- **`models/`** — `model_with_ops.py` (the watched/op-able model wrapper). -- **`data/`** — the dataframe + storage backbone: `dataframe_manager.py`, - `data_samples_with_ops.py`, `sample_stats.py` (`SampleStatsEx`), `data_utils.py`, - `point_cloud_utils.py`; storage in `h5_dataframe_store.py`, `h5_array_store.py`, - `array_proxy.py`. -- **`backend/`** — primitives: `ledgers.py` (`GLOBAL_LEDGER`, the watch/edit - substrate + hyperparameter registry), the watched-object interfaces, `logger.py`, - `audit_logger.py`, `cli.py`. -- **`proto/`**, **`security/`** (`CertAuthManager`), **`ui/`** (bundled - Docker/Envoy assets), **`examples/`**. - -**Fan-in:** `ledgers.GLOBAL_LEDGER` is the hub — `watch_or_edit` registers -objects there; the servicer reads/mutates through it. +## 1. What it is and how the pieces connect + +A user wraps their own PyTorch training script with WeightsLab so a running +experiment becomes inspectable/editable; Weights Studio is the UI for that. + +**Wire path (the thing that breaks most often):** + +``` +Browser (Vite app) → Envoy :8080 (grpc-web ↔ grpc) → Python gRPC servicer → training loop +``` + +- The browser cannot speak raw gRPC, so **Envoy** transcodes grpc-web ↔ gRPC. + If Envoy is down or misconfigured, the UI loads but no data appears. +- The gRPC servicer and the training loop run in the **same process, different + threads**, coordinated by locks in + `weightslab/weightslab/components/global_monitoring.py`. +- One proto is the single source of truth: + `weightslab/weightslab/proto/experiment_service.proto`. --- -## 5. Frontend module map (`weights_studio/src/`) +## 2. Install & run (the happy path) + +```bash +pip install weightslab +``` -Vite + TypeScript; entry `index.html` → `src/main.ts`. +In your training script: -- **`main.ts`** — bootstrap: infers host/port (default `:8080`), builds the - grpc-web transport, wires panels and the sample modal. -- **`experiment_service.client.ts` / `experiment_service.ts`** — generated - client/types (do **not** hand-edit; regenerate via `npm run generate-proto:data`). -- **`grid_data/`** — sample grid + modal rendering (`GridCell.ts`, - `DataImageService.ts`, `BboxRenderer.ts`, `SegmentationRenderer.ts`, - `PointCloudViewer.ts` / `PointCloudService.ts`). -- **`left_panel/`, `plots/`, `agent/`, `helpers.ts`, `ui/`, `utils/`** — controls, - Chart.js plots, the LLM agent panel, shared helpers/reconnection. -- Tests in `tests/` (vitest unit + Playwright E2E). +```python +import weightslab as wl +# wrap your objects so the studio can see/edit them (see §3), then: +wl.serve(serving_grpc=True, serving_cli=True) # background threads, same process +# ... your training loop ... +wl.keep_serving() # keep the process alive for the UI +``` -Build/proto scripts read the **sibling weightslab repo** at codegen time. +Then start the studio stack (Envoy + frontend) and open it in a browser. +Working starting points live in +`weightslab/weightslab/examples/{PyTorch,Lightning,Usecases}//` +(each is a `main.py` + `config.yaml`) — find the closest example and mirror it. + +Studio deployment details (Docker compose, Envoy, ports, certs) are in +`weights_studio/docker/` and documented in `weightslab/docs/weights_studio.rst`. --- -## 6. User-integration API (`import weightslab as wl`) +## 3. The integration API (`import weightslab as wl`) -How a user's own PyTorch script plugs in so the studio can inspect/edit it live. -Examples: `weightslab/weightslab/examples/{PyTorch,Lightning,Usecases}//` -(each a `main.py` + `config.yaml`). +How a user's script plugs in. Wrap each training object with +`wl.watch_or_edit(obj, flag=...)`; the returned tracked proxy is registered in +the global ledger (`weightslab/weightslab/backend/ledgers.py`, +`GLOBAL_LEDGER` — the hub everything reads/mutates through). -Wrap each training object with `wl.watch_or_edit(obj, flag=...)`; the returned -tracked proxy is registered in the global ledger: - `flag="hyperparameters"` (dict), `flag="model"` (nn.Module, `device=…`), `flag="optimizer"`, `flag="data"` (Dataset → tracked DataLoader: `loader_name`, `batch_size`, `is_training`, `collate_fn`, …), `flag="loss"` (a - `reduction="none"` criterion; called with `(preds_raw, targets, batch_ids=ids, + `reduction="none"` criterion, called with `(preds_raw, targets, batch_ids=ids, preds=preds)`), `flag="metric"`. -Conventions: +Conventions that matter for correctness: + - Wrap the train step in `with guard_training_context:` and eval in - `with guard_testing_context:` (from `weightslab.components.global_monitoring`) - — this is how pause/resume and train/test separation work. + `with guard_testing_context:` (from + `weightslab.components.global_monitoring`). This is how pause/resume and + train/test separation work — **skip it and pause/resume or stats will misbehave.** - Use `model.get_age()` (steps actually trained; survives checkpoint reloads), not the raw loop counter. -- `wl.serve(serving_grpc=…, serving_cli=…)` starts background serving threads in - the same process; end the script with `wl.keep_serving()`. - `task_type` on the dataset/model selects rendering: `classification`, `segmentation`, `detection`, `detection_pointcloud`. +- **Hyperparameter handle access:** the registered hyperparameters proxy + supports both `hp.get("lr")` and `hp["lr"]` (subscript == `.get`), and stays + live — reads reflect in-place updates and re-registration. --- -## 7. Tests & verification +## 4. Configuration (environment variables) + +WeightsLab and Weights Studio are configured almost entirely through env vars. +**Authoritative reference: `weightslab/docs/configuration.rst`.** The high-signal +ones when debugging: + +**Backend (Python):** -- **Backend (Python):** `python -m pytest weightslab/tests/...` — domains under - `tests/{data,trainer,gRPC,...}`. -- **Frontend unit (vitest):** `npm run test` in weights_studio - (`tests/utests/**`). -- **E2E / user-simulation (Playwright): lives in weights_studio, not here** — - the tests drive the real UI against a backend the harness spins up - (`test:realtime:*`, `test:e2e:*`). +| Variable | Default | Why you touch it | +|---|---|---| +| `WEIGHTSLAB_LOG_LEVEL` | `INFO` | Set `DEBUG` to see what's happening. (`WATCHDOG` level sits between WARNING/ERROR.) | +| `GRPC_BACKEND_HOST` / `GRPC_BACKEND_PORT` | `0.0.0.0` / `50051` | Backend must listen where Envoy expects it. | +| `GRPC_TLS_ENABLED` | `1` | TLS on the gRPC socket. Set `0` **only** for isolated local debugging. | +| `GRPC_TLS_REQUIRE_CLIENT_AUTH` | `1` | mTLS. Must match what Envoy presents. | +| `GRPC_TLS_CERT_DIR` | `~/certs` | Where default cert files are looked up. | +| `GRPC_AUTH_TOKEN` | *(unset)* | Optional metadata-token auth on top of mTLS. | +| `GRPC_MAX_MESSAGE_BYTES` | `268435456` (256 MB) | Raise it if large tensors/image batches fail. | +| `WEIGHTSLAB_DISABLE_WATCHDOGS` | `0` | Set `1` when debugging with breakpoints (see §5). | +| `GRPC_WATCHDOG_STUCK_SECONDS` | `60` | Lock/RPC stuck threshold + lock-acquire timeout. | + +**Frontend (Weights Studio) — runtime-injected `window.*` globals:** + +| Variable | Default | Why you touch it | +|---|---|---| +| `WS_SERVER_HOST` / `WS_SERVER_PORT` / `WS_SERVER_PROTOCOL` | `localhost` / `8080` / `https` | How the browser reaches the backend (via Envoy). The #1 connection-issue knob. | +| `WS_HISTOGRAM_MAX_BINS` | `512` | Cap on metadata histogram bars. | +| `BB_THUMB_RENDER` | `10` | Max bounding boxes drawn per **thumbnail**, per overlay (GT and PRED capped independently). | +| `BB_MODAL_RENDER` | `100` | Max bounding boxes drawn per **modal** image, per overlay. A `?` button in the modal shows the active limit. | -Before declaring a cross-repo change done: regenerate protos (both sides), -build the frontend (`npm run build`), and run the affected unit suites. +> **VITE_ vs WS_/BB_:** `VITE_*` variables are baked at **build time** (changing +> them needs a rebuild). `WS_*` / `BB_*` are injected at **container start** into +> `config.js` and read as `window.*` globals — changing them needs only a +> container restart + browser reload (see the caching note in §5). --- -## 8. Point-cloud detection pipeline (`detection_pointcloud`) - -A worked example of a non-trivial cross-repo feature; see -`weightslab/weightslab/examples/Usecases/ws-3d-lidar-detection/` and -`weightslab/weightslab/data/point_cloud_utils.py`. - -- **Task type** `detection_pointcloud` (covers 2D and 3D; box-row column count - decides dimensionality; legacy alias `detection_3d`). -- **Dataset yields** `(cloud [M,F], uid, boxes, metadata)`. Cloud columns: - `x,y,z,intensity` (model reads the first 4) + optional viz-only channels - (`nx,ny,nz` normals, `r,g,b` colour) named via `point_feature_names`. Boxes: - 3D `[cx,cy,cz,dx,dy,dz,yaw,cls?,conf?]` or 2D `[cx,cy,dx,dy,cls?,conf?]`; - predictions use the **same** schema. -- **Previews** (thumbnail/grid/modal image) are **server-rendered 2D** (BEV or - range projection) with boxes projected on; the raw cloud streams to the - browser only for the interactive **three.js 3D viewer** (modal "3D" toggle), - via the `GetPointCloud` streaming RPC. Colour modes are data-driven (height / - distance / intensity / camera-RGB / normal shading). -- **Override hooks:** dataset methods `render_thumbnail_2d` / `project_boxes_2d`, - or global decorators `@wl.pointcloud_thumbnail` / `@wl.pointcloud_boxes`. -- **Real data:** `source: kitti_raw` downloads a KITTI raw drive (sync + - tracklets for GT boxes + calibration for RGB) to a temp dir; falls back to - synthetic scenes offline. +## 5. Troubleshooting — symptom → cause → fix + +This is the core of the guide. Each entry is a real failure mode (several are +distilled from issues hit in development). + +**UI loads but the sample grid is empty / "failed to fetch" / gRPC errors.** +The wire path (§1) is broken somewhere. Check in order: (1) backend actually +serving on `0.0.0.0:50051`; (2) Envoy running and reachable on `:8080`; +(3) frontend `WS_SERVER_HOST/PORT/PROTOCOL` point at Envoy, not the raw backend; +(4) **TLS mismatch** — `WS_SERVER_PROTOCOL=https` vs `http`, Envoy server certs, +and Envoy→backend mTLS certs all consistent. For local debugging you can drop +TLS end-to-end (`GRPC_TLS_ENABLED=0` + `VITE_SERVER_PROTOCOL=http`). + +**Changed an env var, restarted, but the UI still uses the old value.** +- `VITE_*` is build-time → you must **rebuild** the frontend, not just restart. +- `WS_*` / `BB_*` are read once per page load → you must **reload the tab**. +- Historically `config.js` was served `Cache-Control: immutable` so even a + restart needed a **hard refresh**; current builds serve `/config.js` with + `no-store`, so a container restart + normal reload is enough. On an older + deployment, hard-refresh (Ctrl+Shift+R) or clear cache. + +**Sample grid flashes empty cells when auto-refresh fires.** +An auto-refresh (timer or manual) that lands while a `GetDataSamples` grid fetch +is still in flight used to clear the cache mid-render. The fix in +`weights_studio/src/grid_data/gridDataManager.ts` is `isFetchInProgress()`: +refreshes are skipped while a grid fetch is ongoing. If you see this, confirm +you're on a build that has that guard. + +**Detection overlays are slow or unreadably cluttered.** +Dense detection samples can carry hundreds of boxes. Cap rendering with +`BB_THUMB_RENDER` (thumbnails) and `BB_MODAL_RENDER` (modal); each is applied +separately to GT and to predictions. Render-only — no sample data is dropped. + +**Training appears hung; RPCs return `RESOURCE_EXHAUSTED`; server "restarts".** +A watchdog monitors the global rlock and in-flight RPCs. If a lock/RPC is held +longer than `GRPC_WATCHDOG_STUCK_SECONDS` (60s) it's flagged; locks get +interrupted, and after `GRPC_WATCHDOG_RESTART_THRESHOLD` unhealthy polls the +gRPC server restarts. When **debugging with breakpoints** that intentionally +pause longer than that, set `WEIGHTSLAB_DISABLE_WATCHDOGS=1`. If RPCs fail with +`RESOURCE_EXHAUSTED`, a handler couldn't acquire the lock within the window — +something else is holding it; check for a long/blocking train or eval step. + +**Pause/resume doesn't work, or train vs test stats are mixed up.** +The train step isn't wrapped in `guard_training_context` (or eval in +`guard_testing_context`). See §3 — these context managers are how the system +gates and separates phases. + +**Large weights/images fail to transfer.** Raise `GRPC_MAX_MESSAGE_BYTES`. + +**The agent bar says it's unconfigured.** The LLM agent needs a provider: a +local **Ollama** server (`provider: ollama`, available immediately) or **cloud +OpenRouter** initialized from the UI via `/init` (then `/model` to switch, +`/reset` to clear). See `weightslab/docs/weights_studio.rst`. --- -## 9. Contributor checklist & gotchas +## 6. Where things live (for deeper digging) + +**Backend (`weightslab/weightslab/`):** +- `src.py` — the public verbs (`watch_or_edit`, `serve`, `keep_serving`, + `tag_samples`, `query_*`, decorators) re-exported from `__init__.py`. +- `trainer/services/` — `experiment_service.py` (gRPC servicer) delegating to + `{model,data,agent}_service.py`; `data_image_utils.py` (preview/mask encoding). +- `components/` — `global_monitoring.py` (locks, `guard_*` contexts, pause), + `checkpoint_manager.py`, `evaluation_controller.py`. +- `data/` — `dataframe_manager.py`, `data_samples_with_ops.py`, `sample_stats.py`, + H5 storage (`h5_dataframe_store.py`, `h5_array_store.py`, `array_proxy.py`). +- `backend/` — `ledgers.py` (`GLOBAL_LEDGER`), `logger.py`, `audit_logger.py`, `cli.py`. +- `security/` (`CertAuthManager`), `proto/`, `examples/`, `docs/`. + +**Frontend (`weights_studio/src/`):** +- `main.ts` — bootstrap; builds the grpc-web transport from `WS_SERVER_*`. +- `experiment_service.client.ts` / `experiment_service.ts` — generated client + (regenerate with `npm run generate-proto:data`; do not hand-edit). +- `grid_data/` — grid + modal rendering (`GridCell.ts`, `DataImageService.ts`, + `gridDataManager.ts`, `BboxRenderer.ts`, `SegmentationRenderer.ts`, + `PointCloudViewer.ts`). +- `docker/` — compose, `nginx-entrypoint.sh` (injects `config.js`), Envoy assets. + +**Docs:** `weightslab/docs/` (Sphinx) — `configuration.rst` (all env vars), +`weights_studio.rst` (studio deploy + agent), `quickstart.rst`, `grpc/`. + +--- -- Proto change → regen Python (from repo root) **and** TS, or runtime breaks. -- The three repos must sit side by side; codegen uses relative paths. -- Training + serving share a process — respect the `global_monitoring` locks and - the `guard_*` context managers. -- TLS/auth in the bundled UI is decided by the presence of certs under +## 7. For contributors (working in a checkout) + +- **The two repos must sit side by side** (`…/weightslab`, `…/weights_studio`); + codegen and Envoy configs reach across by relative path. +- **Editing the proto is cross-repo** — do all of: edit + `experiment_service.proto`; regenerate Python stubs from the repo root; run + `npm run generate-proto:data` in weights_studio. Editing one side only leaves + the build broken. +- **Tests:** backend `python -m pytest weightslab/tests/...`; frontend unit + `npm run test` (vitest); E2E/user-simulation Playwright lives in + **weights_studio** (`test:realtime:*`, `test:e2e:*`), not here. +- **CI on a custom branch:** pushes to non-`main`/`dev` branches only run CI when + the commit message contains `[force ci]` (both repos). +- **TLS/auth in the bundled UI** is decided by cert presence under `WEIGHTSLAB_CERTS_DIR` (single source of truth) — don't hardcode secure/insecure. -- Per-instance data (detection/segmentation) uses a MultiIndex - `(sample_id, annotation_id)` through the dataframe and H5 store. -- Playwright/E2E tests belong in weights_studio, not here. -- Don't commit large datasets; real-data downloads go to a temp dir. +- **To make this guide available to pip users**, ship it as package data inside + the installed package (e.g. as `weightslab/weightslab/AGENTS.md`) so the §0 + skill can locate it; keep the root `AGENTS.md` (mirrored as the gitignored + `CLAUDE.md`) as the contributor-facing source. diff --git a/weightslab/backend/ledgers.py b/weightslab/backend/ledgers.py index 2931b1a..eb188dd 100644 --- a/weightslab/backend/ledgers.py +++ b/weightslab/backend/ledgers.py @@ -29,6 +29,15 @@ DEFAULT_NAME = "main" +def _debug_logging_enabled() -> bool: + """True only when WEIGHTSLAB_LOG_LEVEL is 'debug' (case-insensitive). + + Used to gate verbose diagnostics (traceback dumps) so they stay quiet in + normal runs and only surface when the user explicitly opts into debug. + """ + return os.environ.get("WEIGHTSLAB_LOG_LEVEL", "").strip().lower() == "debug" + + def _rebuild_proxy(obj: Any) -> "Proxy": proxy = Proxy() proxy._obj = obj @@ -39,6 +48,36 @@ def _rebuild_value_proxy(parent: "Proxy", key: Any, default: Any = None) -> "Pro return Proxy._ValueProxy(parent, key, default) +# Sentinel: the resolved value should stay a live proxy (no plain-value coercion). +_KEEP_AS_PROXY = object() + + +def _is_torch_device(value: Any) -> bool: + """True if ``value`` is a ``torch.device``, matched by duck typing. + + Checked via the type's module/name so this backend module never imports + torch (it must stay usable without a torch install). + """ + t = type(value) + return t.__name__ == "device" and t.__module__ == "torch" + + +def _plain_get_value(value: Any) -> Any: + """Plain value a ``get`` should hand back for types that must NOT be returned + as a live proxy, else the ``_KEEP_AS_PROXY`` sentinel. + + - ``str``: immutable; as a live proxy it breaks os.PathLike consumers + (os.stat / os.path.isdir see ``__index__`` and treat it as a file descriptor). + - ``torch.device``: torch's C-level ``torch.device()`` / ``Tensor.to()`` reject + a proxy, so we return its string form (which torch accepts everywhere). + """ + if isinstance(value, str): + return value + if _is_torch_device(value): + return str(value) + return _KEEP_AS_PROXY + + class Proxy: """A small forwarding proxy that holds a mutable reference to an object. @@ -115,7 +154,9 @@ def get(self, *args, **kwargs) -> Any: """ v = self._resolve() if not args and not kwargs: - return v + # Hand back the plain form for str / torch.device; raw value otherwise. + plain = _plain_get_value(v) + return v if plain is _KEEP_AS_PROXY else plain if hasattr(v, 'get'): return v.get(*args, **kwargs) # Fallback for the proxy's own resolve-with-default logic @@ -357,17 +398,50 @@ def __len__(self) -> int: return 0 return len(v) + def __iter__(self): + """Iterate over the resolved value (e.g. a list/tuple hyperparameter). + + Without this, ``for x in proxy`` / ``tuple(proxy)`` / unpacking raise + ``'_ValueProxy' object is not iterable`` even though the wrapped value + is iterable. A None target iterates as empty. + """ + v = self._resolve() + if v is None: + return iter(()) + return iter(v) + + def __getitem__(self, item: Any) -> Any: + """Index, slice, or key into the resolved value. + + Works for lists (``proxy[0]``, ``proxy[1:3]``), dicts (``proxy[key]``) + and strings. Nested mappings are re-wrapped in a live proxy to mirror + __getattr__, so chained subscripting keeps tracking edits. + """ + v = self._resolve() + if v is None: + raise TypeError("ValueProxy target not set (resolved to None)") + value = v[item] + if isinstance(value, dict): + return Proxy._ValueProxy(Proxy(v), item) + return value + def get(self, ref=None, default=None, proxy: bool = True) -> Any: """Get wrapped object or a key from the wrapped mapping. Args: ref: Mapping key when provided. default: Fallback value when key/target is missing. - proxy: When True and ref is provided, return a live key proxy. + proxy: When True and ref is provided, return a live key proxy — + except for ``str`` values, which are returned raw (see below). """ if ref is not None: if proxy: - return Proxy._ValueProxy(self, ref, default) + vp = Proxy._ValueProxy(self, ref, default) + # str and torch.device are handed back as plain values (see + # _plain_get_value); other types (dict/list/int/float/...) stay + # live proxies so studio edits keep tracking. + plain = _plain_get_value(vp._resolve()) + return vp if plain is _KEEP_AS_PROXY else plain return self._obj.get(ref, default) return self._obj if self._obj is not None else default @@ -457,11 +531,14 @@ def __next__(self): self._it.is_a_loop = False # Loop ends here raise except KeyError: - traceback.print_exc() - logger.error( - "KeyError during Proxy iteration. This may indicate the underlying object was modified during iteration. Returning StopIteration to end iteration gracefully." + - "\nOtherwise there is a missmatch between data metadata returned, e.g., some metadata has augmentation parameters and other not. Please initialize all metadata with the same keys and types to avoid this error." - ) + # Quiet by default; only surface this diagnostic when the user + # opted into debug logging (WEIGHTSLAB_LOG_LEVEL=debug). + if _debug_logging_enabled(): + traceback.print_exc() + logger.error( + "KeyError during Proxy iteration. This may indicate the underlying object was modified during iteration. Returning StopIteration to end iteration gracefully." + + "\nOtherwise there is a missmatch between data metadata returned, e.g., some metadata has augmentation parameters and other not. Please initialize all metadata with the same keys and types to avoid this error." + ) raise StopIteration return _ProxyIterator(underlying_iter) @@ -590,13 +667,15 @@ def __next__(self): # Let StopIteration propagate naturally to signal end of iteration raise except Exception: - traceback.print_exc() + if _debug_logging_enabled(): + traceback.print_exc() # clear cached iterator so future next(proxy) restarts try: if '_iterator' in self.__dict__: object.__delattr__(self, '_iterator') except Exception: - traceback.print_exc() + if _debug_logging_enabled(): + traceback.print_exc() raise StopIteration # Context manager support so `with proxy as x:` works when the proxy diff --git a/weightslab/components/experiment_hash.py b/weightslab/components/experiment_hash.py index 8d27534..5f7ac62 100644 --- a/weightslab/components/experiment_hash.py +++ b/weightslab/components/experiment_hash.py @@ -237,6 +237,8 @@ def _hash_config(self, config: Dict[str, Any]) -> str: config_cp.pop('is_training', None) config_cp.pop('pause_at_step', None) # config_cp.pop('auditor_mode', None) # Audit should be another state + if 'auditor_mode' not in config_cp: + config_cp['auditor_mode'] = False try: # Sort keys for deterministic hashing diff --git a/weightslab/data/dataframe_manager.py b/weightslab/data/dataframe_manager.py index 0764419..fc6c8fd 100644 --- a/weightslab/data/dataframe_manager.py +++ b/weightslab/data/dataframe_manager.py @@ -4,6 +4,7 @@ import threading import logging import traceback +import warnings import numpy as np import pandas as pd import torch @@ -60,6 +61,16 @@ def _safe_update(target: pd.DataFrame, source: pd.DataFrame) -> None: mask = src.notna() if not mask.any(): continue + # Widen the target to object before assigning object (arrays/masks) or + # bool values into a numeric (e.g. NaN -> float64) column. Otherwise + # pandas >= 2.1 emits an incompatible-dtype FutureWarning (a future hard + # error). Compatible numeric assignments keep their dtype (fast path). + try: + if src.dtype.kind in ("O", "b") and target[col].dtype != object \ + and target[col].dtype.kind != src.dtype.kind: + target[col] = target[col].astype(object) + except Exception: + pass try: target.loc[common_idx[mask.values], col] = src[common_idx].values except Exception: @@ -765,7 +776,7 @@ def _is_bbox_array(value: Any) -> bool: arr = np.asanyarray(value) except Exception: return False - return arr.ndim == 2 and arr.shape[0] >= 0 and arr.shape[-1] in (4, 5, 6) + return arr.ndim == 2 and arr.shape[0] >= 0 and arr.shape[-1] in range(3, 9+1) @staticmethod def _is_empty(value: Any) -> bool: @@ -1624,17 +1635,31 @@ def _apply_updates_frame_locked(self, df_updates: pd.DataFrame, broadcast: bool) if self._df.index.has_duplicates: self._df = self._df[~self._df.index.duplicated(keep='last')] - # Widen categorical target columns so a new value doesn't raise - # "Cannot setitem on a Categorical with a new category". + # Categorical target columns must be widened up front: assigning a value + # outside the category list raises an UNCATCHABLE AssertionError, so it + # cannot be handled by the try/except below. for col in df_updates.columns: if col in self._df.columns and isinstance(self._df[col].dtype, pd.CategoricalDtype): self._df[col] = self._df[col].astype(object) # Masked update: only overwrite where df_updates is non-NaN. mask = df_updates.notna() - self._df.loc[df_updates.index, df_updates.columns] = ( - self._df.loc[df_updates.index, df_updates.columns].where(~mask, df_updates) - ) + combined = self._df.loc[df_updates.index, df_updates.columns].where(~mask, df_updates) + try: + with warnings.catch_warnings(): + # pandas >= 2.1 only *warns* when a partial assignment changes a + # column's dtype (object arrays/masks or bool into a numeric column), + # but will raise in a future version. Promote it so we widen the + # affected columns to object and retry. Compatible assignments take + # this fast path unchanged. + warnings.filterwarnings( + "error", message=".*incompatible dtype.*", category=FutureWarning) + self._df.loc[df_updates.index, df_updates.columns] = combined + except Exception: + for col in df_updates.columns: + if col in self._df.columns and self._df[col].dtype != object: + self._df[col] = self._df[col].astype(object) + self._df.loc[df_updates.index, df_updates.columns] = combined return df_updates.index def _merge_buffer_frames_into(self, df: pd.DataFrame, sample_df: pd.DataFrame, instance_df: pd.DataFrame) -> pd.DataFrame: @@ -1748,7 +1773,8 @@ def _apply_buffer_records_nonblocking(self, records: List[Dict[str, Any]]): applied_index = written_s.append(written_i) if len(written_i) else written_s update_cols = sample_df.columns.union(instance_df.columns) # Keep newly-added signal columns float32 and empty object cells as None. - self._df = self._optimize_dataframe_memory(self._df) + _df = self._optimize_dataframe_memory(self._df) + self._df = _df finally: self._lock.release() diff --git a/weightslab/data/h5_dataframe_store.py b/weightslab/data/h5_dataframe_store.py index e227199..72c13a8 100644 --- a/weightslab/data/h5_dataframe_store.py +++ b/weightslab/data/h5_dataframe_store.py @@ -29,6 +29,27 @@ pass +def _align_col_dtype_for_assign(existing: pd.DataFrame, source: pd.DataFrame, col: str) -> None: + """Upcast ``existing[col]`` to object before a partial ``.loc`` assignment when + the source column holds object data (numpy arrays, stringified masks) or bool + flags that don't fit the target's dtype. + + pandas >= 2.1 raises a FutureWarning (a future hard error) when a partial + ``.loc`` assignment would change a column's dtype — e.g. assigning object or + bool values into a column that was just initialized as float64 via ``np.nan``. + Upcasting the target column to object first makes the assignment dtype-stable. + Compatible numeric assignments (e.g. int into float) are left untouched. + Best-effort: dtype alignment must never break a merge. + """ + try: + src_kind = source[col].dtype.kind # 'O' object, 'b' bool, 'i'/'u'/'f' numeric + tgt_dtype = existing[col].dtype + if src_kind in ("O", "b") and tgt_dtype != object and tgt_dtype.kind != src_kind: + existing[col] = existing[col].astype(object) + except Exception: + pass + + class _InterProcessFileLock: """Lightweight cross-platform file lock. @@ -694,6 +715,7 @@ def upsert(self, origin: str, df: pd.DataFrame) -> int: if col not in existing.columns: is_categorical = col.startswith("tag") or col.startswith("TAG") or col == "discarded" existing[col] = False if is_categorical else np.nan + _align_col_dtype_for_assign(existing, df_norm, col) existing.loc[matching_rows.index, col] = df_norm.loc[idx, col] elif isinstance(matching_rows, pd.Series): # Single row matched @@ -701,6 +723,7 @@ def upsert(self, origin: str, df: pd.DataFrame) -> int: if col not in existing.columns: is_categorical = col.startswith("tag") or col.startswith("TAG") or col == "discarded" existing[col] = False if is_categorical else np.nan + _align_col_dtype_for_assign(existing, df_norm, col) existing.loc[matching_rows.name, col] = df_norm.loc[idx, col] else: # Normal case: same index structure @@ -717,6 +740,7 @@ def upsert(self, origin: str, df: pd.DataFrame) -> int: is_categorical = col.startswith("tag") or col.startswith("TAG") or col == "discarded" existing[col] = False if is_categorical else np.nan # Update values for common rows + _align_col_dtype_for_assign(existing, df_norm, col) existing.loc[common_idx, col] = df_norm.loc[common_idx, col] # 2. Append strictly new rows diff --git a/weightslab/examples/PyTorch/ws-classification/main.py b/weightslab/examples/PyTorch/ws-classification/main.py index d999033..8f2ec0d 100644 --- a/weightslab/examples/PyTorch/ws-classification/main.py +++ b/weightslab/examples/PyTorch/ws-classification/main.py @@ -236,6 +236,14 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): # Experiment name exp_name = parameters["experiment_name"] + # Hyperparameters (must use 'hyperparameters' flag for trainer services / UI) + wl.watch_or_edit( + parameters, + flag="hyperparameters", + defaults=parameters, + poll_interval=1.0, + ) + # Device selection if parameters.get("device", "auto") == "auto": parameters["device"] = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -256,15 +264,6 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): enable_h5_persistence = parameters.get("enable_h5_persistence", True) training_steps_to_do = parameters.get("training_steps_to_do", 1000) - # Hyperparameters (must use 'hyperparameters' flag for trainer services / UI) - hp = ledgers.get_hyperparams() - wl.watch_or_edit( - parameters, - flag="hyperparameters", - defaults=parameters, - poll_interval=1.0, - ) - # Model _model = CNN().to(device) model = wl.watch_or_edit(_model, flag="model", device=device) diff --git a/weightslab/examples/PyTorch/ws-clustering/main.py b/weightslab/examples/PyTorch/ws-clustering/main.py index 9d93a0e..d7c0361 100644 --- a/weightslab/examples/PyTorch/ws-clustering/main.py +++ b/weightslab/examples/PyTorch/ws-clustering/main.py @@ -208,6 +208,14 @@ def train( parameters.setdefault("eval_full_to_train_steps_ratio", 50) parameters.setdefault("log_every", 10) + # ---- Hyperparameter tracking ---- + wl.watch_or_edit( + parameters, + flag="hyperparameters", + defaults=parameters, + poll_interval=1.0, + ) + # ---- Device ---- if parameters.get("device", "auto") == "auto": parameters["device"] = "cuda" if torch.cuda.is_available() else "cpu" @@ -224,14 +232,6 @@ def train( log_every = int(parameters.get("log_every", 10)) enable_h5 = parameters.get("enable_h5_persistence", True) - # ---- Hyperparameter tracking ---- - wl.watch_or_edit( - parameters, - flag="hyperparameters", - defaults=parameters, - poll_interval=1.0, - ) - # ---- Datasets ---- data_cfg = parameters.get("data", {}) dataset_type = data_cfg.get("dataset_type", "olivetti") diff --git a/weightslab/examples/PyTorch/ws-detection/main.py b/weightslab/examples/PyTorch/ws-detection/main.py index 7da1d95..63710bd 100644 --- a/weightslab/examples/PyTorch/ws-detection/main.py +++ b/weightslab/examples/PyTorch/ws-detection/main.py @@ -119,40 +119,40 @@ def test(loader, model, sig, device, grid_size, conf_thresh, test_loader_len): parameters.setdefault("freeze_backbone", True) parameters.setdefault("compute_natural_sort", True) + # --- 2) Register hyperparameters --- + wl.watch_or_edit( + parameters, + flag="hyperparameters", + name=exp_name, + defaults=parameters, + poll_interval=1.0, + ) + exp_name = parameters["experiment_name"] num_classes = int(parameters["num_classes"]) image_size = int(parameters["image_size"]) grid_size = int(parameters["grid_size"]) conf_thresh = float(parameters["conf_thresh"]) - # --- 2) Device selection --- + # --- 3) Device selection --- if parameters.get("device", "auto") == "auto": parameters["device"] = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) device = parameters["device"] - # --- 3) Logging directory --- + # --- 4) Logging directory --- if not parameters.get("root_log_dir"): tmp_dir = tempfile.mkdtemp() parameters["root_log_dir"] = tmp_dir print(f"No root_log_dir specified, using temporary directory: {parameters['root_log_dir']}") os.makedirs(parameters["root_log_dir"], exist_ok=True) - log_dir = parameters["root_log_dir"] max_steps = parameters["training_steps_to_do"] eval_full_to_train_steps_ratio = parameters["eval_full_to_train_steps_ratio"] verbose = parameters.get("verbose", True) tqdm_display = parameters.get("tqdm_display", True) - # --- 4) Register hyperparameters --- - wl.watch_or_edit( - parameters, - flag="hyperparameters", - name=exp_name, - defaults=parameters, - poll_interval=1.0, - ) # --- 5) Data (Penn-Fudan pedestrians, downloaded on first run) --- default_data_root = os.path.abspath( diff --git a/weightslab/examples/PyTorch/ws-generation/main.py b/weightslab/examples/PyTorch/ws-generation/main.py index 1a16718..b1890d5 100644 --- a/weightslab/examples/PyTorch/ws-generation/main.py +++ b/weightslab/examples/PyTorch/ws-generation/main.py @@ -407,6 +407,9 @@ def evaluate_all(loader, model, cls_criterion, contrastive_criterion, metric, de parameters.setdefault("is_training", False) + # WeightsLab initialization + wl.watch_or_edit(parameters, flag="hyperparameters", defaults=parameters) + if parameters["device"] == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: @@ -414,9 +417,6 @@ def evaluate_all(loader, model, cls_criterion, contrastive_criterion, metric, de os.makedirs(parameters["root_log_dir"], exist_ok=True) - # WeightsLab initialization - wl.watch_or_edit(parameters, flag="hyperparameters", defaults=parameters) - # 1.5. Serving logger.info("Starting WeightsLab servers...") wl.serve(serving_grpc=True, serving_cli=True) diff --git a/weightslab/examples/PyTorch/ws-segmentation/main.py b/weightslab/examples/PyTorch/ws-segmentation/main.py index f3a94d4..c3a4eb4 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/main.py +++ b/weightslab/examples/PyTorch/ws-segmentation/main.py @@ -164,7 +164,15 @@ def test(loader, model, sig, device, test_loader_len): parameters.setdefault("image_size", 256) parameters.setdefault("compute_natural_sort", True) + # --- 4) Register hyperparameters --- exp_name = parameters["experiment_name"] + wl.watch_or_edit( + parameters, + flag="hyperparameters", + name=exp_name, + defaults=parameters, + poll_interval=1.0, + ) num_classes = int(parameters["num_classes"]) ignore_index = int(parameters["ignore_index"]) image_size = int(parameters["image_size"]) @@ -189,15 +197,6 @@ def test(loader, model, sig, device, test_loader_len): verbose = parameters.get("verbose", True) tqdm_display = parameters.get("tqdm_display", True) - # --- 4) Register hyperparameters --- - wl.watch_or_edit( - parameters, - flag="hyperparameters", - name=exp_name, - defaults=parameters, - poll_interval=1.0, - ) - # --- 5) Data (BDD100k reduced) --- default_data_root = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "..", "data", "BDD100k_reduced") diff --git a/weightslab/examples/Ultralytics/ws-detection/main.py b/weightslab/examples/Ultralytics/ws-detection/main.py index 130ca10..d2ba1a6 100644 --- a/weightslab/examples/Ultralytics/ws-detection/main.py +++ b/weightslab/examples/Ultralytics/ws-detection/main.py @@ -39,6 +39,7 @@ def main(): if not cfg.get("root_log_dir"): cfg["root_log_dir"] = tempfile.mkdtemp() os.makedirs(cfg["root_log_dir"], exist_ok=True) + wl.watch_or_edit(cfg, flag="hyperparameters", defaults=cfg, poll_interval=1.0) # Read raw config values BEFORE wrapping so YOLO.train kwargs are plain # Python (avoids ProxyValue.__gt__ during max()/comparisons). @@ -53,7 +54,6 @@ def main(): name = cfg["experiment_name"] signals_cfg = cfg.get('signals_cfg', {}) - wl.watch_or_edit(cfg, flag="hyperparameters", defaults=cfg, poll_interval=1.0) wl.serve(serving_grpc=serving_grpc, serving_cli=serving_cli) # ================ diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py index 548f118..7b48ec5 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py @@ -83,6 +83,9 @@ def test(loader, model, sig, device, grid_size, pc_range, conf_thresh, test_load parameters.setdefault("compute_natural_sort", True) exp_name = parameters["experiment_name"] + wl.watch_or_edit(parameters, flag="hyperparameters", name=exp_name, + defaults=parameters, poll_interval=1.0) + num_classes = int(parameters["num_classes"]) pc_range = tuple(float(v) for v in parameters["point_cloud_range"]) voxel_size = float(parameters["voxel_size"]) @@ -102,9 +105,6 @@ def test(loader, model, sig, device, grid_size, pc_range, conf_thresh, test_load verbose = parameters.get("verbose", True) tqdm_display = parameters.get("tqdm_display", True) - wl.watch_or_edit(parameters, flag="hyperparameters", name=exp_name, - defaults=parameters, poll_interval=1.0) - data_cfg = parameters.get("data", {}) train_cfg = data_cfg.get("train_loader", {}) test_cfg = data_cfg.get("test_loader", {}) diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py index ce0065b..d0d41eb 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py @@ -177,6 +177,15 @@ def test(loader, model, sig, device, grid_size, pc_range, conf_thresh, test_load parameters.setdefault("compute_natural_sort", True) exp_name = parameters["experiment_name"] + # --- Register hyperparameters --- + wl.watch_or_edit( + parameters, + flag="hyperparameters", + name=exp_name, + defaults=parameters, + poll_interval=1.0, + ) + num_classes = int(parameters["num_classes"]) pc_range = tuple(float(v) for v in parameters["point_cloud_range"]) voxel_size = float(parameters["voxel_size"]) @@ -203,15 +212,6 @@ def test(loader, model, sig, device, grid_size, pc_range, conf_thresh, test_load verbose = parameters.get("verbose", True) tqdm_display = parameters.get("tqdm_display", True) - # --- 4) Register hyperparameters --- - wl.watch_or_edit( - parameters, - flag="hyperparameters", - name=exp_name, - defaults=parameters, - poll_interval=1.0, - ) - # --- 5) Data (KITTI if present under data_root, else synthetic scenes) --- default_data_root = os.path.abspath( os.path.join(os.path.dirname(__file__), "data") diff --git a/weightslab/src.py b/weightslab/src.py index f9da636..d3600fc 100644 --- a/weightslab/src.py +++ b/weightslab/src.py @@ -6,11 +6,12 @@ import gc import os import sys -import ctypes import time import types +import ctypes import logging import inspect +import tempfile import functools import threading import traceback @@ -931,7 +932,7 @@ def new_forward(*a, **kw): # If obj is a string, treat as a file path and start watcher try: # Initialize CheckpointManager if we have a root dir (fallback to default root) - root_log_dir = obj.get('root_log_dir') or os.path.join('.', 'root_log_dir') + root_log_dir = obj.get('root_log_dir') or tempfile.mkdtemp() try: # Check if a checkpoint manager is already registered in ledger try: diff --git a/weightslab/tests/backend/test_ledgers.py b/weightslab/tests/backend/test_ledgers.py index f6b2f73..07f679d 100644 --- a/weightslab/tests/backend/test_ledgers.py +++ b/weightslab/tests/backend/test_ledgers.py @@ -181,17 +181,21 @@ def test_hyperparams_default_name_proxy_pattern(self): self.assertEqual(GLOBAL_LEDGER.list_hyperparams(), [DEFAULT_NAME]) def test_proxy_get_key_default_mode_returns_live_proxy(self): - """Default key get returns a live ValueProxy for dict-backed targets.""" + """Non-str keys return a live ValueProxy; str keys return the raw string.""" hp_handle = GLOBAL_LEDGER.get_hyperparams() - GLOBAL_LEDGER.register_hyperparams(params={"data_root": "C:/data/v1"}) + GLOBAL_LEDGER.register_hyperparams(params={"data_root": "C:/data/v1", "lr": 0.01}) + # Strings come back raw (immutable; avoids os.PathLike/__index__ breakage). data_root = hp_handle.get("data_root") + self.assertIsInstance(data_root, str) self.assertEqual(data_root, "C:/data/v1") - self.assertTrue(hasattr(data_root, "set")) + self.assertFalse(hasattr(data_root, "set")) - # Updating the underlying mapping is reflected through the live proxy. - hp_handle["data_root"] = "C:/data/v2" - self.assertEqual(data_root.get(), "C:/data/v2") + # Non-string values stay live proxies and track underlying updates. + lr = hp_handle.get("lr") + self.assertTrue(hasattr(lr, "set")) + hp_handle["lr"] = 0.02 + self.assertEqual(lr.get(), 0.02) def test_proxy_pickles_and_restores(self): proxy = Proxy({"flag": True, "count": 3}) @@ -279,48 +283,106 @@ def test_value_proxy_numeric_comparisons(self): s = {lr, 0.01} self.assertEqual(len(s), 1) + def test_value_proxy_iteration_and_subscript(self): + """ValueProxy is iterable and subscriptable (lists, slices, dict keys, str).""" + hp_handle = GLOBAL_LEDGER.get_hyperparams() + GLOBAL_LEDGER.register_hyperparams(params={ + "point_cloud_range": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "data": {"train_loader": {"batch_size": 8}}, + "name": "lidar", + }) + + pcr = hp_handle.get("point_cloud_range") + + # Iteration — the reported bug was "'_ValueProxy' object is not iterable". + self.assertEqual(tuple(float(v) for v in pcr), (1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + self.assertEqual(list(pcr), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + first, second, *_ = pcr + self.assertEqual((first, second), (1.0, 2.0)) + + # Subscript: list index, slice, and string char. + self.assertEqual(pcr[0], 1.0) + self.assertEqual(pcr[1:3], [2.0, 3.0]) + self.assertEqual(hp_handle.get("name")[0], "l") + + # Dict subscript; nested mappings stay live proxies that resolve through. + data = hp_handle.get("data") + self.assertEqual(data["train_loader"]["batch_size"], 8) + + # A None target iterates as empty and reports length 0. + missing = hp_handle.get("missing_key") + self.assertEqual(list(missing), []) + self.assertEqual(len(missing), 0) + + def test_proxy_get_unwraps_torch_device_to_str(self): + """torch.device values are handed back as plain strings, not live proxies.""" + try: + import torch + except ImportError: + self.skipTest("torch not installed") + + hp_handle = GLOBAL_LEDGER.get_hyperparams() + GLOBAL_LEDGER.register_hyperparams(params={"device": torch.device("cpu")}) + + # Proxy.get returns a plain string (not a _ValueProxy), so torch's + # C-level torch.device()/.to() accept it instead of rejecting the proxy. + device = hp_handle.get("device") + self.assertIsInstance(device, str) + self.assertEqual(device, "cpu") + self.assertEqual(torch.device(device), torch.device("cpu")) + + # _ValueProxy.get() applies the same coercion when called directly. + vp = Proxy._ValueProxy(Proxy({"d": torch.device("cpu")}), "d") + self.assertIsInstance(vp.get(), str) + self.assertEqual(vp.get(), "cpu") + def test_proxy_get_key_proxy_mode_live_read_and_write(self): - """proxy=True returns a live key proxy that tracks and updates parent mapping.""" + """proxy=True returns a live key proxy (non-str) that tracks and updates parent mapping.""" hp_handle = GLOBAL_LEDGER.get_hyperparams() - GLOBAL_LEDGER.register_hyperparams(params={"data_root": "C:/data/v1"}) + GLOBAL_LEDGER.register_hyperparams(params={"lr": 0.01, "data_root": "C:/data/v1"}) - data_root_proxy = hp_handle.get("data_root", proxy=True) - self.assertEqual(data_root_proxy.get(), "C:/data/v1") + lr_proxy = hp_handle.get("lr", proxy=True) + self.assertEqual(lr_proxy.get(), 0.01) # External mapping update is visible from the key proxy. - hp_handle["data_root"] = "C:/data/v2" - self.assertEqual(data_root_proxy.get(), "C:/data/v2") + hp_handle["lr"] = 0.02 + self.assertEqual(lr_proxy.get(), 0.02) # Key proxy update writes back into parent mapping. - data_root_proxy.set("C:/data/v3") - self.assertEqual(hp_handle.get("data_root"), "C:/data/v3") + lr_proxy.set(0.03) + self.assertEqual(hp_handle.get("lr"), 0.03) + + # Strings are handed back raw even with proxy=True (not a live proxy). + self.assertIsInstance(hp_handle.get("data_root", proxy=True), str) def test_proxy_get_key_proxy_mode_survives_parent_replacement(self): """Live key proxy resolves against current parent object after re-registration.""" hp_handle = GLOBAL_LEDGER.get_hyperparams() - GLOBAL_LEDGER.register_hyperparams(params={"data_root": "C:/data/v1"}) + GLOBAL_LEDGER.register_hyperparams(params={"lr": 0.01}) - data_root_proxy = hp_handle.get("data_root", proxy=True) - self.assertEqual(data_root_proxy.get(), "C:/data/v1") + lr_proxy = hp_handle.get("lr", proxy=True) + self.assertEqual(lr_proxy.get(), 0.01) # Re-register hyperparams under same name; existing parent proxy is updated. - GLOBAL_LEDGER.register_hyperparams(params={"data_root": "C:/data/v4", "lr": 0.01}) - self.assertEqual(data_root_proxy.get(), "C:/data/v4") + GLOBAL_LEDGER.register_hyperparams(params={"lr": 0.04, "batch_size": 32}) + self.assertEqual(lr_proxy.get(), 0.04) # Updates through key proxy should target the latest registered mapping. - data_root_proxy.set("C:/data/v5") - self.assertEqual(hp_handle.get("data_root"), "C:/data/v5") + lr_proxy.set(0.05) + self.assertEqual(hp_handle.get("lr"), 0.05) def test_proxy_get_key_proxy_mode_default_and_late_set(self): """Missing-key live proxy returns default and can later populate the key.""" hp_handle = GLOBAL_LEDGER.get_hyperparams() GLOBAL_LEDGER.register_hyperparams(params={}) - data_root_proxy = hp_handle.get("data_root", default="C:/fallback", proxy=True) - self.assertEqual(data_root_proxy.get(), "C:/fallback") + # Non-str default + missing key -> stays a live proxy (str defaults would + # resolve to a raw string instead, by design). + lr_proxy = hp_handle.get("lr", default=0.01, proxy=True) + self.assertEqual(lr_proxy.get(), 0.01) - data_root_proxy.set("C:/data/live") - self.assertEqual(hp_handle.get("data_root"), "C:/data/live") + lr_proxy.set(0.02) + self.assertEqual(hp_handle.get("lr"), 0.02) def test_watch_or_edit_hyperparams_rebinds_caller_variable(self): """wl.watch_or_edit(parameters, flag='hyperparameters') rebinds the caller's