Skip to content

Commit ba1ff23

Browse files
committed
Add GitHub CI for Python and Octane tests
Python tests run on 3.14 with pytest -n auto. Octane tests run cargo test --release, skipping airbus tests (need PTY). Golden test tolerance: strict by default (rtol=1e-4), CI uses HIGHJAX_TESTS_LOOSE=1 for loose tolerance (rtol=1e-2) to handle CPU/GPU numerical differences.
1 parent 832e674 commit ba1ff23

2 files changed

Lines changed: 73 additions & 1 deletion

File tree

.github/workflows/tests.yml

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [master]
6+
pull_request:
7+
branches: [master]
8+
9+
jobs:
10+
python-tests:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ['3.14']
15+
steps:
16+
- uses: actions/checkout@v4
17+
18+
- name: Set up Python ${{ matrix.python-version }}
19+
uses: actions/setup-python@v5
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
allow-prereleases: true
23+
24+
- name: Install dependencies
25+
run: |
26+
pip install -e ".[trainer,tests]"
27+
28+
- name: Run Python tests
29+
env:
30+
HIGHJAX_TESTS_LOOSE: '1'
31+
run: |
32+
pytest -n auto
33+
34+
octane-tests:
35+
runs-on: ubuntu-latest
36+
steps:
37+
- uses: actions/checkout@v4
38+
39+
- name: Set up Rust
40+
uses: dtolnay/rust-toolchain@stable
41+
42+
- name: Cache Rust build
43+
uses: actions/cache@v4
44+
with:
45+
path: |
46+
octane/target
47+
~/.cargo/registry
48+
~/.cargo/git
49+
key: ${{ runner.os }}-cargo-${{ hashFiles('octane/Cargo.lock') }}
50+
restore-keys: |
51+
${{ runner.os }}-cargo-
52+
53+
- name: Set up Python for test data
54+
uses: actions/setup-python@v5
55+
with:
56+
python-version: '3.14'
57+
allow-prereleases: true
58+
59+
- name: Install Python package (for test data generation)
60+
run: |
61+
pip install -e ".[trainer]"
62+
63+
- name: Run Octane tests
64+
run: |
65+
cd octane
66+
cargo test --release -- --skip airbus

tests/test_highjax_trainer/test_golden_runs/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
'''Shared logic for golden run tests.'''
22
from __future__ import annotations
33

4+
import os
5+
46
import jax.numpy as jnp
57

8+
_LOOSE = os.environ.get('HIGHJAX_TESTS_LOOSE') == '1'
9+
_RTOL = 1e-02 if _LOOSE else 1e-04
10+
_ATOL = 1e-05 if _LOOSE else 1e-07
11+
612

713
def print_golden_data(epoch_metrics: list[dict]) -> None:
814
n = len(epoch_metrics)
@@ -38,7 +44,7 @@ def check_golden_run(epoch_metrics: list[dict], golden_data: dict) -> None:
3844
expected = expected_values[i]
3945
actual = float(metrics[key])
4046

41-
if not jnp.isclose(actual, expected, rtol=1e-04, atol=1e-07):
47+
if not jnp.isclose(actual, expected, rtol=_RTOL, atol=_ATOL):
4248
bad_fields.append((key, actual, expected))
4349

4450
if bad_fields:

0 commit comments

Comments
 (0)