Skip to content

Commit f68c768

Browse files
committed
Golden test tolerance: strict by default, HIGHJAX_TESTS_LOOSE=1 for CI
1 parent 62d0f32 commit f68c768

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

.github/workflows/tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ jobs:
2626
pip install -e ".[trainer,tests]"
2727
2828
- name: Run Python tests
29+
env:
30+
HIGHJAX_TESTS_LOOSE: '1'
2931
run: |
3032
pytest -n auto
3133

tests/test_highjax_trainer/test_golden_runs/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
'''Shared logic for golden run tests.'''
22
from __future__ import annotations
33

4+
import os
5+
46
import jax.numpy as jnp
57

8+
_LOOSE = os.environ.get('HIGHJAX_TESTS_LOOSE') == '1'
9+
_RTOL = 1e-02 if _LOOSE else 1e-04
10+
_ATOL = 1e-05 if _LOOSE else 1e-07
11+
612

713
def print_golden_data(epoch_metrics: list[dict]) -> None:
814
n = len(epoch_metrics)
@@ -38,7 +44,7 @@ def check_golden_run(epoch_metrics: list[dict], golden_data: dict) -> None:
3844
expected = expected_values[i]
3945
actual = float(metrics[key])
4046

41-
if not jnp.isclose(actual, expected, rtol=1e-02, atol=1e-05):
47+
if not jnp.isclose(actual, expected, rtol=_RTOL, atol=_ATOL):
4248
bad_fields.append((key, actual, expected))
4349

4450
if bad_fields:

0 commit comments

Comments
 (0)