Skip to content

Commit b4fe40e

Browse files
committed
Started on a pyro class
1 parent 9de6e24 commit b4fe40e

1 file changed

Lines changed: 209 additions & 144 deletions

File tree

pyro.py

Lines changed: 209 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -12,173 +12,238 @@
1212
from util import msg, profile, runparams, io
1313

1414

15-
def doit(solver_name, problem_name, param_file,
16-
other_commands=None,
17-
comp_bench=False, reset_bench_on_fail=False, make_bench=False):
18-
"""The main driver to run pyro"""
15+
class Pyro(object):
16+
"""
17+
The main driver to run pyro.
18+
19+
Notes / TODO:
20+
* Further decouple some of the benchmarking stuff?
21+
* Should it be possible to pass in a problem function and initialise a Pyro/
22+
Simulation object using that? That would also require some modifications
23+
of the Simulation class, as we currently pass the problem name to its
24+
constructor. I think this would make sense though if we were moving
25+
towards something that worked better as a Jupyter Notebook.
26+
"""
27+
28+
def __init__(self, solver_name, comp_bench=False,
29+
reset_bench_on_fail=False, make_bench=False):
30+
"""
31+
Constructor
32+
33+
Parameters
34+
----------
35+
solver_name : str
36+
Name of solver to use
37+
comp_bench : bool
38+
Are we comparing to a benchmark?
39+
reset_bench_on_fail : bool
40+
Do we reset the benchmark on fail?
41+
make_bench : bool
42+
Are we storing a benchmark?
43+
"""
44+
msg.bold('pyro ...')
45+
46+
# import desired solver under "solver" namespace
47+
self.solver = importlib.import_module(solver_name)
48+
self.solver_name = solver_name
49+
self.comp_bench = comp_bench
50+
self.reset_bench_on_fail = reset_bench_on_fail
51+
self.make_bench = make_bench
52+
53+
#-------------------------------------------------------------------------
54+
# runtime parameters
55+
#-------------------------------------------------------------------------
56+
57+
# parameter defaults
58+
self.rp = runparams.RuntimeParameters()
59+
self.rp.load_params("_defaults")
60+
self.rp.load_params(solver_name + "/_defaults")
61+
62+
self.tc = profile.TimerCollection()
63+
64+
def initialize_problem(self, problem_name, param_file=None, param_dict=None,
65+
other_commands=None):
66+
"""
67+
Initialize the specific problem
68+
69+
Parameters
70+
----------
71+
problem_name : str
72+
Name of the problem
73+
param_file : str
74+
Filename containing problem's runtime parameters
75+
param_dict : dict
76+
Dictionary containing extra runtime parameters
77+
other_commands : str
78+
Other command line parameter options
79+
"""
80+
81+
# problem-specific runtime parameters
82+
self.rp.load_params(self.solver_name + "/problems/_" + problem_name + ".defaults")
83+
84+
# now read in the inputs file
85+
if param_dict is not None:
86+
for k, v in param_dict.items():
87+
self.rp.params[k] = v
88+
if param_file is not None:
89+
if not os.path.isfile(param_file):
90+
# check if the param file lives in the solver's problems directory
91+
param_file = self.solver_name + "/problems/" + param_file
92+
if not os.path.isfile(param_file):
93+
msg.fail("ERROR: inputs file does not exist")
94+
95+
self.rp.load_params(param_file, no_new=1)
96+
97+
# and any commandline overrides
98+
if other_commands is not None:
99+
self.rp.command_line_params(other_commands)
100+
101+
# write out the inputs.auto
102+
self.rp.print_paramfile()
103+
104+
self.verbose = self.rp.get_param("driver.verbose")
105+
self.dovis = self.rp.get_param("vis.dovis")
106+
107+
#-------------------------------------------------------------------------
108+
# initialization
109+
#-------------------------------------------------------------------------
110+
111+
# initialize the Simulation object -- this will hold the grid and
112+
# data and know about the runtime parameters and which problem we
113+
# are running
114+
self.sim = self.solver.Simulation(self.solver_name, problem_name, self.rp, timers=self.tc)
115+
116+
self.sim.initialize()
117+
self.sim.preevolve()
118+
119+
def run_sim(self):
120+
"""
121+
Evolve entire simulation
122+
"""
123+
124+
tm_main = self.tc.timer("main")
125+
tm_main.begin()
126+
127+
plt.ion()
128+
129+
self.sim.cc_data.t = 0.0
130+
131+
# output the 0th data
132+
basename = self.rp.get_param("io.basename")
133+
self.sim.write("{}{:04d}".format(basename, self.sim.n))
134+
135+
if self.dovis:
136+
plt.figure(num=1, figsize=(8, 6), dpi=100, facecolor='w')
137+
self.sim.dovis()
138+
139+
while not self.sim.finished():
140+
141+
self.single_step()
142+
143+
# final output
144+
if self.verbose > 0:
145+
msg.warning("outputting...")
146+
basename = self.rp.get_param("io.basename")
147+
self.sim.write("{}{:04d}".format(basename, self.sim.n))
148+
149+
tm_main.end()
19150

20-
msg.bold('pyro ...')
151+
result = self.compare_to_benchmark()
152+
self.make_bench(result)
153+
154+
#-------------------------------------------------------------------------
155+
# final reports
156+
#-------------------------------------------------------------------------
157+
if self.verbose > 0:
158+
self.rp.print_unused_params()
159+
self.tc.report()
21160

22-
tc = profile.TimerCollection()
161+
self.sim.finalize()
23162

24-
tm_main = tc.timer("main")
25-
tm_main.begin()
26-
27-
# import desired solver under "solver" namespace
28-
solver = importlib.import_module(solver_name)
29-
30-
#-------------------------------------------------------------------------
31-
# runtime parameters
32-
#-------------------------------------------------------------------------
33-
34-
# parameter defaults
35-
rp = runparams.RuntimeParameters()
36-
rp.load_params("_defaults")
37-
rp.load_params(solver_name + "/_defaults")
38-
39-
# problem-specific runtime parameters
40-
rp.load_params(solver_name + "/problems/_" + problem_name + ".defaults")
41-
42-
# now read in the inputs file
43-
if not os.path.isfile(param_file):
44-
# check if the param file lives in the solver's problems directory
45-
param_file = solver_name + "/problems/" + param_file
46-
if not os.path.isfile(param_file):
47-
msg.fail("ERROR: inputs file does not exist")
48-
49-
rp.load_params(param_file, no_new=1)
50-
51-
# and any commandline overrides
52-
if other_commands is not None:
53-
rp.command_line_params(other_commands)
54-
55-
# write out the inputs.auto
56-
rp.print_paramfile()
57-
58-
#-------------------------------------------------------------------------
59-
# initialization
60-
#-------------------------------------------------------------------------
61-
62-
# initialize the Simulation object -- this will hold the grid and
63-
# data and know about the runtime parameters and which problem we
64-
# are running
65-
sim = solver.Simulation(solver_name, problem_name, rp, timers=tc)
66-
67-
sim.initialize()
68-
sim.preevolve()
69-
70-
#-------------------------------------------------------------------------
71-
# evolve
72-
#-------------------------------------------------------------------------
73-
verbose = rp.get_param("driver.verbose")
74-
75-
plt.ion()
76-
77-
sim.cc_data.t = 0.0
78-
79-
# output the 0th data
80-
basename = rp.get_param("io.basename")
81-
sim.write("{}{:04d}".format(basename, sim.n))
82-
83-
dovis = rp.get_param("vis.dovis")
84-
if dovis:
85-
plt.figure(num=1, figsize=(8, 6), dpi=100, facecolor='w')
86-
sim.dovis()
87-
88-
while not sim.finished():
163+
if self.comp_bench:
164+
return result
165+
else:
166+
return self.sim
89167

168+
def single_step(self):
169+
"""
170+
Do a single step
171+
"""
90172
# fill boundary conditions
91-
sim.cc_data.fill_BC_all()
173+
self.sim.cc_data.fill_BC_all()
92174

93175
# get the timestep
94-
sim.compute_timestep()
176+
self.sim.compute_timestep()
95177

96178
# evolve for a single timestep
97-
sim.evolve()
179+
self.sim.evolve()
98180

99-
if verbose > 0:
100-
print("%5d %10.5f %10.5f" % (sim.n, sim.cc_data.t, sim.dt))
181+
if self.verbose > 0:
182+
print("%5d %10.5f %10.5f" % (self.sim.n, self.sim.cc_data.t, self.sim.dt))
101183

102184
# output
103-
if sim.do_output():
104-
if verbose > 0:
185+
if self.sim.do_output():
186+
if self.verbose > 0:
105187
msg.warning("outputting...")
106-
basename = rp.get_param("io.basename")
107-
sim.write("{}{:04d}".format(basename, sim.n))
188+
basename = self.rp.get_param("io.basename")
189+
self.sim.write("{}{:04d}".format(basename, self.sim.n))
108190

109191
# visualization
110-
if dovis:
111-
tm_vis = tc.timer("vis")
192+
if self.dovis:
193+
tm_vis = self.tc.timer("vis")
112194
tm_vis.begin()
113195

114-
sim.dovis()
115-
store = rp.get_param("vis.store_images")
196+
self.sim.dovis()
197+
store = self.rp.get_param("vis.store_images")
116198

117199
if store == 1:
118-
basename = rp.get_param("io.basename")
119-
plt.savefig("{}{:04d}.png".format(basename, sim.n))
200+
basename = self.rp.get_param("io.basename")
201+
plt.savefig("{}{:04d}.png".format(basename, self.sim.n))
120202

121203
tm_vis.end()
122204

123-
# final output
124-
if verbose > 0:
125-
msg.warning("outputting...")
126-
basename = rp.get_param("io.basename")
127-
sim.write("{}{:04d}".format(basename, sim.n))
128-
129-
tm_main.end()
130-
131-
#-------------------------------------------------------------------------
132-
# benchmarks (for regression testing)
133-
#-------------------------------------------------------------------------
134-
result = 0
135-
# are we comparing to a benchmark?
136-
if comp_bench:
137-
compare_file = "{}/tests/{}{:04d}".format(
138-
solver_name, basename, sim.n)
139-
msg.warning("comparing to: {} ".format(compare_file))
140-
try:
141-
sim_bench = io.read(compare_file)
142-
except IOError:
143-
msg.warning("ERROR openning compare file")
144-
return "ERROR openning compare file"
145-
146-
result = compare.compare(sim.cc_data, sim_bench.cc_data)
147-
148-
if result == 0:
149-
msg.success("results match benchmark\n")
150-
else:
151-
msg.warning("ERROR: " + compare.errors[result] + "\n")
205+
def compare_to_benchmark(self):
206+
""" Are we comparing to a benchmark? """
152207

153-
# are we storing a benchmark?
154-
if make_bench or (result != 0 and reset_bench_on_fail):
155-
if not os.path.isdir(solver_name + "/tests/"):
156-
try:
157-
os.mkdir(solver_name + "/tests/")
158-
except (FileNotFoundError, PermissionError):
159-
msg.fail("ERROR: unable to create the solver's tests/ directory")
208+
result = 0
160209

161-
bench_file = solver_name + "/tests/" + basename + "%4.4d" % (sim.n)
162-
msg.warning("storing new benchmark: {}\n".format(bench_file))
163-
sim.write(bench_file)
210+
if self.comp_bench:
211+
basename = self.rp.get_param("io.basename")
212+
compare_file = "{}/tests/{}{:04d}".format(
213+
self.solver_name, basename, self.sim.n)
214+
msg.warning("comparing to: {} ".format(compare_file))
215+
try:
216+
sim_bench = io.read(compare_file)
217+
except IOError:
218+
msg.warning("ERROR openning compare file")
219+
return "ERROR openning compare file"
164220

165-
#-------------------------------------------------------------------------
166-
# final reports
167-
#-------------------------------------------------------------------------
168-
if verbose > 0:
169-
rp.print_unused_params()
170-
tc.report()
221+
result = compare.compare(self.sim.cc_data, sim_bench.cc_data)
171222

172-
sim.finalize()
223+
if result == 0:
224+
msg.success("results match benchmark\n")
225+
else:
226+
msg.warning("ERROR: " + compare.errors[result] + "\n")
173227

174-
if comp_bench:
175228
return result
176-
else:
177-
return sim
178229

230+
def store_as_benchmark(self, result):
231+
""" Are we storing a benchmark? """
232+
if self.make_bench or (result != 0 and self.reset_bench_on_fail):
233+
if not os.path.isdir(self.solver_name + "/tests/"):
234+
try:
235+
os.mkdir(self.solver_name + "/tests/")
236+
except (FileNotFoundError, PermissionError):
237+
msg.fail("ERROR: unable to create the solver's tests/ directory")
179238

180-
def parse_and_run():
181-
"""Parse the runtime parameters and run a pyro instance"""
239+
basename = self.rp.get_param("io.basename")
240+
bench_file = self.solver_name + "/tests/" + basename + "%4.4d" % (self.sim.n)
241+
msg.warning("storing new benchmark: {}\n".format(bench_file))
242+
self.sim.write(bench_file)
243+
244+
245+
def parse_args():
246+
"""Parse the runtime parameters"""
182247

183248
valid_solvers = ["advection",
184249
"advection_rk",
@@ -214,13 +279,13 @@ def parse_and_run():
214279
help="additional runtime parameters that override the inputs file "
215280
"in the format section.option=value")
216281

217-
args = p.parse_args()
218-
219-
doit(args.solver[0], args.problem[0], args.param[0],
220-
other_commands=args.other,
221-
comp_bench=args.compare_benchmark,
222-
make_bench=args.make_benchmark)
282+
return p.parse_args()
223283

224284

225285
if __name__ == "__main__":
226-
parse_and_run()
286+
args = parse_args()
287+
pyro = Pyro(args.solver[0], comp_bench=args.compare_benchmark,
288+
make_bench=args.make_benchmark)
289+
pyro.initialize_problem(problem_name=args.problem[0], param_file=args.param[0],
290+
other_commands=args.other)
291+
pyro.run_sim()

0 commit comments

Comments
 (0)