File tree Expand file tree Collapse file tree
tests/test_highjax_trainer/test_golden_runs Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1+ name : Tests
2+
3+ on :
4+ push :
5+ branches : [master]
6+ pull_request :
7+ branches : [master]
8+
9+ jobs :
10+ python-tests :
11+ runs-on : ubuntu-latest
12+ strategy :
13+ matrix :
14+ python-version : ['3.14']
15+ steps :
16+ - uses : actions/checkout@v4
17+
18+ - name : Set up Python ${{ matrix.python-version }}
19+ uses : actions/setup-python@v5
20+ with :
21+ python-version : ${{ matrix.python-version }}
22+ allow-prereleases : true
23+
24+ - name : Install dependencies
25+ run : |
26+ pip install -e ".[trainer,tests]"
27+
28+ - name : Run Python tests
29+ env :
30+ HIGHJAX_TESTS_LOOSE : ' 1'
31+ run : |
32+ pytest -n auto
33+
34+ octane-tests :
35+ runs-on : ubuntu-latest
36+ steps :
37+ - uses : actions/checkout@v4
38+
39+ - name : Set up Rust
40+ uses : dtolnay/rust-toolchain@stable
41+
42+ - name : Cache Rust build
43+ uses : actions/cache@v4
44+ with :
45+ path : |
46+ octane/target
47+ ~/.cargo/registry
48+ ~/.cargo/git
49+ key : ${{ runner.os }}-cargo-${{ hashFiles('octane/Cargo.lock') }}
50+ restore-keys : |
51+ ${{ runner.os }}-cargo-
52+
53+ - name : Set up Python for test data
54+ uses : actions/setup-python@v5
55+ with :
56+ python-version : ' 3.14'
57+ allow-prereleases : true
58+
59+ - name : Install Python package (for test data generation)
60+ run : |
61+ pip install -e ".[trainer]"
62+
63+ - name : Run Octane tests
64+ run : |
65+ cd octane
66+ cargo test --release -- --skip airbus
Original file line number Diff line number Diff line change 11'''Shared logic for golden run tests.'''
22from __future__ import annotations
33
4+ import os
5+
46import jax .numpy as jnp
57
8+ _LOOSE = os .environ .get ('HIGHJAX_TESTS_LOOSE' ) == '1'
9+ _RTOL = 1e-02 if _LOOSE else 1e-04
10+ _ATOL = 1e-05 if _LOOSE else 1e-07
11+
612
713def print_golden_data (epoch_metrics : list [dict ]) -> None :
814 n = len (epoch_metrics )
@@ -38,7 +44,7 @@ def check_golden_run(epoch_metrics: list[dict], golden_data: dict) -> None:
3844 expected = expected_values [i ]
3945 actual = float (metrics [key ])
4046
41- if not jnp .isclose (actual , expected , rtol = 1e-04 , atol = 1e-07 ):
47+ if not jnp .isclose (actual , expected , rtol = _RTOL , atol = _ATOL ):
4248 bad_fields .append ((key , actual , expected ))
4349
4450 if bad_fields :
You can’t perform that action at this time.
0 commit comments