Skip to content

Commit a6ad357

Browse files
authored
Allow running regression tests in parallel (#145)
test.py can be run with --nproc 0 to utilize all cores. If running in parallel, the output from each test case is collected and is printed all at once when that case finishes, to minimize interleaving.
1 parent 92d9366 commit a6ad357

2 files changed

Lines changed: 82 additions & 25 deletions

File tree

.github/workflows/regtest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ jobs:
4444
run: pip install .
4545

4646
- name: Run tests via test.py
47-
run: ./pyro/test.py
47+
run: ./pyro/test.py --nproc 0
4848

pyro/test.py

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33

44
import argparse
5+
import contextlib
56
import datetime
7+
import io
68
import os
79
import sys
10+
from multiprocessing import Pool
11+
from pathlib import Path
812

913
import pyro.pyro_sim as pyro
1014
from pyro.multigrid.examples import (mg_test_general_inhomogeneous,
@@ -23,9 +27,64 @@ def __str__(self):
2327
return f"{self.solver}-{self.problem}"
2428

2529

30+
@contextlib.contextmanager
31+
def avoid_interleaved_output(nproc):
32+
"""Collect all the printed output and print it all at once to avoid interleaving."""
33+
if nproc == 1:
34+
# not running in parallel, so we don't have to worry about interleaving
35+
yield
36+
else:
37+
output_buffer = io.StringIO()
38+
try:
39+
with contextlib.redirect_stdout(output_buffer), \
40+
contextlib.redirect_stderr(output_buffer):
41+
yield
42+
finally:
43+
# a single print call probably won't get interleaved
44+
print(output_buffer.getvalue(), end="", flush=True)
45+
46+
47+
def run_test(t, reset_fails, store_all_benchmarks, rtol, nproc):
48+
orig_cwd = Path.cwd()
49+
# run each test in its own directory, since some of the output file names
50+
# overlap between tests, and h5py needs exclusive access when writing
51+
test_dir = orig_cwd / f"test_outputs/{t}"
52+
test_dir.mkdir(parents=True, exist_ok=True)
53+
try:
54+
os.chdir(test_dir)
55+
with avoid_interleaved_output(nproc):
56+
p = pyro.PyroBenchmark(t.solver, comp_bench=True,
57+
reset_bench_on_fail=reset_fails,
58+
make_bench=store_all_benchmarks)
59+
p.initialize_problem(t.problem, t.inputs, t.options)
60+
start_n = p.sim.n
61+
err = p.run_sim(rtol)
62+
finally:
63+
os.chdir(orig_cwd)
64+
if err == 0:
65+
# the test passed; clean up the output files for developer use
66+
basename = p.rp.get_param("io.basename")
67+
(test_dir / f"{basename}{start_n:04d}.h5").unlink()
68+
(test_dir / f"{basename}{p.sim.n:04d}.h5").unlink()
69+
(test_dir / "inputs.auto").unlink()
70+
test_dir.rmdir()
71+
# try removing the top-level output directory
72+
try:
73+
test_dir.parent.rmdir()
74+
except OSError:
75+
pass
76+
77+
return str(t), err
78+
79+
80+
def run_test_star(args):
81+
"""multiprocessing doesn't like lambdas, so this needs to be a full function"""
82+
return run_test(*args)
83+
84+
2685
def do_tests(out_file,
2786
reset_fails=False, store_all_benchmarks=False,
28-
single=None, solver=None, rtol=1e-12):
87+
single=None, solver=None, rtol=1e-12, nproc=1):
2988

3089
opts = {"driver.verbose": 0, "vis.dovis": 0, "io.do_io": 0}
3190

@@ -59,13 +118,16 @@ def do_tests(out_file,
59118
else:
60119
tests_to_run = tests
61120

62-
for t in tests_to_run:
63-
p = pyro.PyroBenchmark(t.solver, comp_bench=True,
64-
reset_bench_on_fail=reset_fails, make_bench=store_all_benchmarks)
65-
p.initialize_problem(t.problem, t.inputs, t.options)
66-
err = p.run_sim(rtol)
67-
68-
results[str(t)] = err
121+
if nproc == 0:
122+
nproc = os.cpu_count()
123+
# don't create more processes than needed
124+
nproc = min(nproc, len(tests_to_run))
125+
with Pool(processes=nproc) as pool:
126+
tasks = ((t, reset_fails, store_all_benchmarks, rtol, nproc) for t in tests_to_run)
127+
imap_it = pool.imap_unordered(run_test_star, tasks)
128+
# collect run results
129+
for name, err in imap_it:
130+
results[name] = err
69131

70132
# standalone tests
71133
if single is None:
@@ -120,9 +182,9 @@ def do_tests(out_file,
120182

121183
p = argparse.ArgumentParser()
122184

123-
p.add_argument("-o",
124-
help="name of file to output the report to (otherwise output to the screen",
125-
type=str, nargs=1)
185+
p.add_argument("--outfile", "-o",
186+
help="name of file to output the report to (in addition to the screen)",
187+
type=str, default=None)
126188

127189
p.add_argument("--single",
128190
help="name of a single test (solver-problem) to run",
@@ -142,23 +204,18 @@ def do_tests(out_file,
142204

143205
p.add_argument("--rtol",
144206
help="relative tolerance to use when comparing data to benchmarks",
145-
type=float, nargs=1)
207+
type=float, default=1.e-12)
146208

147-
args = p.parse_args()
148-
149-
try:
150-
outfile = args.o[0]
151-
except TypeError:
152-
outfile = None
209+
p.add_argument("--nproc", "-n",
210+
help="maximum number of parallel processes to run, or 0 to use all cores",
211+
type=int, default=1)
153212

154-
try:
155-
rtol = args.rtol[0]
156-
except TypeError:
157-
rtol = 1.e-12
213+
args = p.parse_args()
158214

159-
failed = do_tests(outfile,
215+
failed = do_tests(args.outfile,
160216
reset_fails=args.reset_failures,
161217
store_all_benchmarks=args.store_all_benchmarks,
162-
single=args.single, solver=args.solver, rtol=rtol)
218+
single=args.single, solver=args.solver, rtol=args.rtol,
219+
nproc=args.nproc)
163220

164221
sys.exit(failed)

0 commit comments

Comments
 (0)