22
33
44import argparse
5+ import contextlib
56import datetime
7+ import io
68import os
79import sys
10+ from multiprocessing import Pool
11+ from pathlib import Path
812
913import pyro .pyro_sim as pyro
1014from pyro .multigrid .examples import (mg_test_general_inhomogeneous ,
@@ -23,9 +27,64 @@ def __str__(self):
2327 return f"{ self .solver } -{ self .problem } "
2428
2529
30+ @contextlib .contextmanager
31+ def avoid_interleaved_output (nproc ):
32+ """Collect all the printed output and print it all at once to avoid interleaving."""
33+ if nproc == 1 :
34+ # not running in parallel, so we don't have to worry about interleaving
35+ yield
36+ else :
37+ output_buffer = io .StringIO ()
38+ try :
39+ with contextlib .redirect_stdout (output_buffer ), \
40+ contextlib .redirect_stderr (output_buffer ):
41+ yield
42+ finally :
43+ # a single print call probably won't get interleaved
44+ print (output_buffer .getvalue (), end = "" , flush = True )
45+
46+
47+ def run_test (t , reset_fails , store_all_benchmarks , rtol , nproc ):
48+ orig_cwd = Path .cwd ()
49+ # run each test in its own directory, since some of the output file names
50+ # overlap between tests, and h5py needs exclusive access when writing
51+ test_dir = orig_cwd / f"test_outputs/{ t } "
52+ test_dir .mkdir (parents = True , exist_ok = True )
53+ try :
54+ os .chdir (test_dir )
55+ with avoid_interleaved_output (nproc ):
56+ p = pyro .PyroBenchmark (t .solver , comp_bench = True ,
57+ reset_bench_on_fail = reset_fails ,
58+ make_bench = store_all_benchmarks )
59+ p .initialize_problem (t .problem , t .inputs , t .options )
60+ start_n = p .sim .n
61+ err = p .run_sim (rtol )
62+ finally :
63+ os .chdir (orig_cwd )
64+ if err == 0 :
65+ # the test passed; clean up the output files for developer use
66+ basename = p .rp .get_param ("io.basename" )
67+ (test_dir / f"{ basename } { start_n :04d} .h5" ).unlink ()
68+ (test_dir / f"{ basename } { p .sim .n :04d} .h5" ).unlink ()
69+ (test_dir / "inputs.auto" ).unlink ()
70+ test_dir .rmdir ()
71+ # try removing the top-level output directory
72+ try :
73+ test_dir .parent .rmdir ()
74+ except OSError :
75+ pass
76+
77+ return str (t ), err
78+
79+
80+ def run_test_star (args ):
81+ """multiprocessing doesn't like lambdas, so this needs to be a full function"""
82+ return run_test (* args )
83+
84+
2685def do_tests (out_file ,
2786 reset_fails = False , store_all_benchmarks = False ,
28- single = None , solver = None , rtol = 1e-12 ):
87+ single = None , solver = None , rtol = 1e-12 , nproc = 1 ):
2988
3089 opts = {"driver.verbose" : 0 , "vis.dovis" : 0 , "io.do_io" : 0 }
3190
@@ -59,13 +118,16 @@ def do_tests(out_file,
59118 else :
60119 tests_to_run = tests
61120
62- for t in tests_to_run :
63- p = pyro .PyroBenchmark (t .solver , comp_bench = True ,
64- reset_bench_on_fail = reset_fails , make_bench = store_all_benchmarks )
65- p .initialize_problem (t .problem , t .inputs , t .options )
66- err = p .run_sim (rtol )
67-
68- results [str (t )] = err
121+ if nproc == 0 :
122+ nproc = os .cpu_count ()
123+ # don't create more processes than needed
124+ nproc = min (nproc , len (tests_to_run ))
125+ with Pool (processes = nproc ) as pool :
126+ tasks = ((t , reset_fails , store_all_benchmarks , rtol , nproc ) for t in tests_to_run )
127+ imap_it = pool .imap_unordered (run_test_star , tasks )
128+ # collect run results
129+ for name , err in imap_it :
130+ results [name ] = err
69131
70132 # standalone tests
71133 if single is None :
@@ -120,9 +182,9 @@ def do_tests(out_file,
120182
121183 p = argparse .ArgumentParser ()
122184
123- p .add_argument ("-o" ,
124- help = "name of file to output the report to (otherwise output to the screen" ,
125- type = str , nargs = 1 )
185+ p .add_argument ("--outfile" , "- o" ,
186+ help = "name of file to output the report to (in addition to the screen) " ,
187+ type = str , default = None )
126188
127189 p .add_argument ("--single" ,
128190 help = "name of a single test (solver-problem) to run" ,
@@ -142,23 +204,18 @@ def do_tests(out_file,
142204
143205 p .add_argument ("--rtol" ,
144206 help = "relative tolerance to use when comparing data to benchmarks" ,
145- type = float , nargs = 1 )
207+ type = float , default = 1.e-12 )
146208
147- args = p .parse_args ()
148-
149- try :
150- outfile = args .o [0 ]
151- except TypeError :
152- outfile = None
209+ p .add_argument ("--nproc" , "-n" ,
210+ help = "maximum number of parallel processes to run, or 0 to use all cores" ,
211+ type = int , default = 1 )
153212
154- try :
155- rtol = args .rtol [0 ]
156- except TypeError :
157- rtol = 1.e-12
213+ args = p .parse_args ()
158214
159- failed = do_tests (outfile ,
215+ failed = do_tests (args . outfile ,
160216 reset_fails = args .reset_failures ,
161217 store_all_benchmarks = args .store_all_benchmarks ,
162- single = args .single , solver = args .solver , rtol = rtol )
218+ single = args .single , solver = args .solver , rtol = args .rtol ,
219+ nproc = args .nproc )
163220
164221 sys .exit (failed )
0 commit comments